lnbits_rs/api/
websocket.rs

1//! Websocket
2
3use futures_util::{SinkExt, StreamExt};
4use serde::Deserialize;
5use tokio_tungstenite::connect_async;
6use tokio_tungstenite::tungstenite::protocol::Message;
7
8use crate::LNBitsClient;
9
10#[derive(Debug, Deserialize)]
11struct WebSocketPayment {
12    payment_hash: String,
13    amount: i64,
14}
15
16#[derive(Debug, Deserialize)]
17struct WebSocketMessage {
18    payment: Option<WebSocketPayment>,
19}
20
21impl LNBitsClient {
22    /// Subscribe to websocket updates
23    pub async fn subscribe_to_websocket(&self) -> anyhow::Result<()> {
24        // Create a new channel for this connection
25        // This ensures old receivers will get None and new receivers will work
26        let (new_sender, new_receiver) = tokio::sync::mpsc::channel(8);
27
28        // Replace the receiver with the new one
29        *self.receiver.lock().await = new_receiver;
30
31        let base_url = self
32            .lnbits_url
33            .to_string()
34            .trim_end_matches('/')
35            .replace("http", "ws");
36        let ws_url = format!("{}/api/v1/ws/{}", base_url, self.invoice_read_key);
37
38        let (ws_stream, _) = connect_async(ws_url).await?;
39        let (mut write, mut read) = ws_stream.split();
40
41        // Move the sender into the task (don't store it in self.sender)
42        // This ensures when the task ends, the sender is dropped and receiver gets None
43        let sender = new_sender;
44
45        // Handle incoming messages with timeout detection
46        tokio::spawn(async move {
47            let mut last_message_time = std::time::Instant::now();
48            let timeout_duration = std::time::Duration::from_secs(60); // 60 second timeout
49
50            loop {
51                // Use timeout to detect dead connections
52                let message_result =
53                    tokio::time::timeout(std::time::Duration::from_secs(30), read.next()).await;
54
55                match message_result {
56                    Ok(Some(message)) => {
57                        last_message_time = std::time::Instant::now();
58                        match message {
59                            Ok(msg) => {
60                                match msg {
61                                    Message::Text(text) => {
62                                        tracing::trace!("Received websocket message: {}", text);
63
64                                        // Parse the message
65                                        if let Ok(message) =
66                                            serde_json::from_str::<WebSocketMessage>(&text)
67                                        {
68                                            if let Some(payment) = message.payment {
69                                                if payment.amount > 0 {
70                                                    tracing::info!(
71                                                        "Payment received: {}",
72                                                        payment.payment_hash
73                                                    );
74                                                    if let Err(err) =
75                                                        sender.send(payment.payment_hash).await
76                                                    {
77                                                        log::error!(
78                                                            "Failed to send payment hash: {}",
79                                                            err
80                                                        );
81                                                    }
82                                                }
83                                            }
84                                        }
85                                    }
86                                    Message::Ping(payload) => {
87                                        // Server sent us a ping, must respond with pong
88                                        tracing::trace!("Received ping, sending pong");
89                                        if let Err(e) = write.send(Message::Pong(payload)).await {
90                                            tracing::error!("Failed to send pong response: {}", e);
91                                            break;
92                                        }
93                                    }
94                                    Message::Pong(_) => {
95                                        // Response to our ping, just log it
96                                        tracing::trace!("Received pong");
97                                    }
98                                    Message::Close(_) => {
99                                        tracing::warn!("WebSocket closed by server");
100                                        break;
101                                    }
102                                    _ => {}
103                                }
104                            }
105                            Err(e) => {
106                                // Log with both Display and Debug to get full error details
107                                tracing::error!(
108                                    "Error receiving websocket message: {} (Debug: {:?})",
109                                    e,
110                                    e
111                                );
112
113                                // Log specific protocol error details if available
114                                use tokio_tungstenite::tungstenite::Error;
115                                if let Error::Protocol(ref proto_err) = e {
116                                    tracing::error!(
117                                        "WebSocket protocol error details: {:?}",
118                                        proto_err
119                                    );
120                                }
121
122                                break;
123                            }
124                        }
125                    }
126                    Ok(None) => {
127                        // Stream ended
128                        tracing::warn!("WebSocket stream ended");
129                        break;
130                    }
131                    Err(_) => {
132                        // Timeout - check if we've exceeded the overall timeout
133                        if last_message_time.elapsed() > timeout_duration {
134                            tracing::warn!(
135                                "WebSocket timeout - no messages received for {:?}",
136                                timeout_duration
137                            );
138                            break;
139                        }
140                        // Send a ping to keep connection alive and detect dead connections
141                        if let Err(e) = write.send(Message::Ping(vec![].into())).await {
142                            tracing::error!("Failed to send ping: {}", e);
143                            break;
144                        }
145                        tracing::trace!("Sent ping to keep connection alive");
146                    }
147                }
148            }
149
150            tracing::info!("WebSocket task ending, sender will be dropped");
151            // Task ends, sender gets dropped, receiver will get None
152        });
153
154        Ok(())
155    }
156}