lnbits_rs/api/websocket.rs
1//! Websocket
2
3use futures_util::{SinkExt, StreamExt};
4use serde::Deserialize;
5use tokio_tungstenite::connect_async;
6use tokio_tungstenite::tungstenite::protocol::Message;
7
8use crate::LNBitsClient;
9
10#[derive(Debug, Deserialize)]
11struct WebSocketPayment {
12 payment_hash: String,
13 amount: i64,
14}
15
16#[derive(Debug, Deserialize)]
17struct WebSocketMessage {
18 payment: Option<WebSocketPayment>,
19}
20
21impl LNBitsClient {
22 /// Subscribe to websocket updates
23 pub async fn subscribe_to_websocket(&self) -> anyhow::Result<()> {
24 // Create a new channel for this connection
25 // This ensures old receivers will get None and new receivers will work
26 let (new_sender, new_receiver) = tokio::sync::mpsc::channel(8);
27
28 // Replace the sender and receiver with the new ones
29 *self.sender.lock().await = new_sender;
30 *self.receiver.lock().await = new_receiver;
31
32 let base_url = self
33 .lnbits_url
34 .to_string()
35 .trim_end_matches('/')
36 .replace("http", "ws");
37 let ws_url = format!("{}/api/v1/ws/{}", base_url, self.invoice_read_key);
38
39 let (ws_stream, _) = connect_async(ws_url).await?;
40 let (mut write, mut read) = ws_stream.split();
41
42 let sender = self.sender.clone();
43
44 // Handle incoming messages with timeout detection
45 tokio::spawn(async move {
46 let mut last_message_time = std::time::Instant::now();
47 let timeout_duration = std::time::Duration::from_secs(60); // 60 second timeout
48
49 loop {
50 // Use timeout to detect dead connections
51 let message_result =
52 tokio::time::timeout(std::time::Duration::from_secs(30), read.next()).await;
53
54 match message_result {
55 Ok(Some(message)) => {
56 last_message_time = std::time::Instant::now();
57 match message {
58 Ok(msg) => {
59 match msg {
60 Message::Text(text) => {
61 tracing::trace!("Received websocket message: {}", text);
62
63 // Parse the message
64 if let Ok(message) =
65 serde_json::from_str::<WebSocketMessage>(&text)
66 {
67 if let Some(payment) = message.payment {
68 if payment.amount > 0 {
69 tracing::info!(
70 "Payment received: {}",
71 payment.payment_hash
72 );
73 let sender = sender.lock().await;
74 if let Err(err) =
75 sender.send(payment.payment_hash).await
76 {
77 log::error!(
78 "Failed to send payment hash: {}",
79 err
80 );
81 }
82 }
83 }
84 }
85 }
86 Message::Ping(_) | Message::Pong(_) => {
87 // Keepalive messages
88 tracing::trace!("Received ping/pong");
89 }
90 Message::Close(_) => {
91 tracing::warn!("WebSocket closed by server");
92 break;
93 }
94 _ => {}
95 }
96 }
97 Err(e) => {
98 tracing::error!("Error receiving websocket message: {}", e);
99 break;
100 }
101 }
102 }
103 Ok(None) => {
104 // Stream ended
105 tracing::warn!("WebSocket stream ended");
106 break;
107 }
108 Err(_) => {
109 // Timeout - check if we've exceeded the overall timeout
110 if last_message_time.elapsed() > timeout_duration {
111 tracing::warn!(
112 "WebSocket timeout - no messages received for {:?}",
113 timeout_duration
114 );
115 break;
116 }
117 // Send a ping to keep connection alive and detect dead connections
118 if let Err(e) = write.send(Message::Ping(vec![].into())).await {
119 tracing::error!("Failed to send ping: {}", e);
120 break;
121 }
122 tracing::trace!("Sent ping to keep connection alive");
123 }
124 }
125 }
126
127 tracing::info!("WebSocket task ending, sender will be dropped");
128 // Task ends, sender gets dropped, receiver will get None
129 });
130
131 Ok(())
132 }
133}