ankurah_websocket_client/
client.rs

1use crate::sender::WebsocketPeerSender;
2use ankurah_core::{connector::PeerSender, policy::PolicyAgent, storage::StorageEngine, Node};
3use ankurah_proto as proto;
4use ankurah_signals::{Mut, Read, Wait};
5use anyhow::Result;
6use futures_util::{SinkExt, StreamExt};
7use std::{
8    sync::{
9        atomic::{AtomicBool, Ordering},
10        Arc,
11    },
12    time::Duration,
13};
14use strum::Display;
15use thiserror::Error;
16use tokio::{select, sync::Notify, task::JoinHandle, time::sleep};
17use tokio_tungstenite::{connect_async, tungstenite::Message};
18use tracing::{debug, error, info, warn};
19
20/// Connection state for the websocket client
21#[derive(Debug, Clone, PartialEq, Display)]
22pub enum ConnectionState {
23    Disconnected,
24    #[strum(serialize = "Connecting")]
25    Connecting {
26        url: String,
27    },
28    #[strum(serialize = "Connected")]
29    Connected {
30        url: String,
31        server_presence: proto::Presence,
32    },
33    #[strum(serialize = "Error")]
34    Error(ConnectionError),
35}
36
37#[derive(Debug, Clone, PartialEq, Error)]
38pub enum ConnectionError {
39    #[error("General connection error: {0}")]
40    General(String),
41}
42
43const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
44const MAX_BACKOFF: Duration = Duration::from_secs(30);
45
46struct Inner<SE, PA>
47where
48    SE: StorageEngine + Send + Sync + 'static,
49    PA: PolicyAgent + Send + Sync + 'static,
50{
51    node: Node<SE, PA>,
52    server_url: String,
53    connection_state: Mut<ConnectionState>,
54    connected: AtomicBool,
55    shutdown: Notify,
56    shutdown_requested: AtomicBool,
57}
58
59/// A WebSocket client for connecting Ankurah nodes
60pub struct WebsocketClient<SE, PA>
61where
62    SE: StorageEngine + Send + Sync + 'static,
63    PA: PolicyAgent + Send + Sync + 'static,
64{
65    inner: Arc<Inner<SE, PA>>,
66    task: std::sync::Mutex<Option<JoinHandle<()>>>,
67}
68
69impl<SE, PA> WebsocketClient<SE, PA>
70where
71    SE: StorageEngine + Send + Sync + 'static,
72    PA: PolicyAgent + Send + Sync + 'static,
73{
74    /// Create a new WebSocket client and start connecting to the server
75    pub async fn new(node: Node<SE, PA>, server_url: &str) -> anyhow::Result<Self> {
76        let ws_url = Self::normalize_url(server_url);
77        info!("Creating WebSocket client for {}", ws_url);
78
79        let inner = Arc::new(Inner {
80            node,
81            server_url: ws_url,
82            connection_state: Mut::new(ConnectionState::Disconnected),
83            connected: AtomicBool::new(false),
84            shutdown: Notify::new(),
85            shutdown_requested: AtomicBool::new(false),
86        });
87
88        let task = tokio::spawn(Self::run_connection_loop(inner.clone()));
89        Ok(Self { inner, task: std::sync::Mutex::new(Some(task)) })
90    }
91
92    fn normalize_url(url: &str) -> String {
93        match url {
94            u if u.starts_with("ws://") || u.starts_with("wss://") => format!("{}/ws", u),
95            u if u.starts_with("http://") => format!("ws://{}/ws", &u[7..]),
96            u if u.starts_with("https://") => format!("wss://{}/ws", &u[8..]),
97            u => format!("wss://{}/ws", u),
98        }
99    }
100
101    /// Get the connection state as a reactive signal
102    pub fn state(&self) -> Read<ConnectionState> { self.inner.connection_state.read() }
103
104    /// Check if currently connected to the server
105    pub fn is_connected(&self) -> bool { self.inner.connected.load(Ordering::Acquire) }
106
107    /// Gracefully shutdown the WebSocket connection
108    pub async fn shutdown(self) -> anyhow::Result<()> {
109        info!("Shutting down WebSocket client");
110
111        if let Some(task) = self.task.lock().unwrap().take() {
112            self.inner.shutdown_requested.store(true, Ordering::Release);
113            self.inner.shutdown.notify_waiters();
114
115            match task.await {
116                Ok(()) => info!("WebSocket client shutdown completed"),
117                Err(e) => warn!("Connection task join error during shutdown: {}", e),
118            }
119        } else {
120            info!("WebSocket client already shut down");
121        }
122        Ok(())
123    }
124
125    /// Wait for the client to establish a connection to the server (signal-based)
126    pub async fn wait_connected(&self) -> Result<(), ConnectionError> {
127        // Wait for either Connected or Error state, returning appropriate Result
128        self.state()
129            .wait_for(|state| match state {
130                ConnectionState::Connected { .. } => Some(Ok(())),
131                ConnectionState::Error(e) => Some(Err(e.clone())),
132                _ => None, // Continue waiting for Connecting/Disconnected states
133            })
134            .await
135    }
136
137    /// Get the node ID of the connected server (if connected)
138    pub fn server_node_id(&self) -> Option<proto::EntityId> {
139        use ankurah_signals::Get;
140        match self.state().get() {
141            ConnectionState::Connected { server_presence, .. } => Some(server_presence.node_id),
142            _ => None,
143        }
144    }
145
146    /// Main connection loop with automatic reconnection
147    async fn run_connection_loop(inner: Arc<Inner<SE, PA>>) {
148        let mut backoff = INITIAL_BACKOFF;
149        info!("Starting websocket connection loop to {}", inner.server_url);
150
151        loop {
152            select! {
153                _ = inner.shutdown.notified() => {
154                    info!("Websocket connection shutting down");
155                    break;
156                }
157                result = Self::connect_once(&inner) => {
158                    match result {
159                        Ok(()) => {
160                            info!("Connection to {} completed normally", inner.server_url);
161                            backoff = INITIAL_BACKOFF;
162                            if inner.shutdown_requested.load(Ordering::Acquire) {
163                                info!("Shutdown requested, stopping reconnection attempts");
164                                break;
165                            }
166                        }
167                        Err(e) => {
168                            error!("Connection to {} failed: {}", inner.server_url, e);
169                            inner.connection_state.set(ConnectionState::Error(ConnectionError::General(e.to_string())));
170                            inner.connected.store(false, Ordering::Release);
171
172                            info!("Retrying connection in {:?}", backoff);
173                            select! {
174                                _ = inner.shutdown.notified() => break,
175                                _ = sleep(backoff) => {}
176                            }
177                            backoff = (backoff * 2).min(MAX_BACKOFF);
178                        }
179                    }
180                }
181            }
182        }
183
184        inner.connection_state.set(ConnectionState::Disconnected);
185        inner.connected.store(false, Ordering::Release);
186    }
187
188    /// Attempt a single connection
189    async fn connect_once(inner: &Arc<Inner<SE, PA>>) -> Result<()> {
190        info!("Attempting to connect to {}", inner.server_url);
191        inner.connection_state.set(ConnectionState::Connecting { url: inner.server_url.clone() });
192
193        let (ws_stream, _) = connect_async(inner.server_url.as_str()).await?;
194        info!("WebSocket handshake completed with {}", inner.server_url);
195
196        let (mut sink, mut stream) = ws_stream.split();
197        debug!("Starting connection handling");
198
199        // Send our presence immediately
200        let presence = proto::Message::Presence(proto::Presence {
201            node_id: inner.node.id,
202            durable: inner.node.durable,
203            system_root: inner.node.system.root(),
204        });
205
206        sink.send(Message::Binary(bincode::serialize(&presence)?.into())).await?;
207        debug!("Sent client presence");
208
209        let mut peer_sender: Option<WebsocketPeerSender> = None;
210        let mut outgoing_rx: Option<tokio::sync::mpsc::UnboundedReceiver<proto::NodeMessage>> = None;
211
212        loop {
213            select! {
214                _ = inner.shutdown.notified() => {
215                    debug!("Connection received shutdown signal");
216                    break;
217                }
218                msg = async {
219                    match &mut outgoing_rx {
220                        Some(rx) => rx.recv().await,
221                        None => std::future::pending().await,
222                    }
223                } => {
224                    if Self::handle_outgoing_message(&mut sink, msg).await.is_err() {
225                        break;
226                    }
227                }
228                msg = stream.next() => {
229                    match Self::handle_incoming_message(inner, msg, &mut peer_sender, &mut outgoing_rx, &mut sink).await? {
230                        MessageResult::Continue => continue,
231                        MessageResult::Break => break,
232                    }
233                }
234            }
235        }
236
237        // Cleanup
238        inner.connected.store(false, Ordering::Release);
239        if let Some(sender) = peer_sender {
240            inner.node.deregister_peer(sender.recipient_node_id());
241            debug!("Deregistered peer {}", sender.recipient_node_id());
242        }
243        Ok(())
244    }
245
246    async fn handle_outgoing_message(
247        sink: &mut futures_util::stream::SplitSink<
248            tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
249            Message,
250        >,
251        msg: Option<proto::NodeMessage>,
252    ) -> Result<()> {
253        if let Some(node_message) = msg {
254            let proto_message = proto::Message::PeerMessage(node_message);
255            match bincode::serialize(&proto_message) {
256                Ok(data) => {
257                    sink.send(Message::Binary(data.into())).await?;
258                }
259                Err(e) => error!("Failed to serialize outgoing message: {}", e),
260            }
261        }
262        Ok(())
263    }
264
265    async fn handle_incoming_message(
266        inner: &Arc<Inner<SE, PA>>,
267        msg: Option<Result<Message, tokio_tungstenite::tungstenite::Error>>,
268        peer_sender: &mut Option<WebsocketPeerSender>,
269        outgoing_rx: &mut Option<tokio::sync::mpsc::UnboundedReceiver<proto::NodeMessage>>,
270        sink: &mut futures_util::stream::SplitSink<
271            tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
272            Message,
273        >,
274    ) -> Result<MessageResult> {
275        match msg {
276            Some(Ok(Message::Binary(data))) => match bincode::deserialize(&data) {
277                Ok(proto::Message::Presence(server_presence)) => {
278                    Self::handle_server_presence(inner, server_presence, peer_sender, outgoing_rx).await;
279                    Ok(MessageResult::Continue)
280                }
281                Ok(proto::Message::PeerMessage(node_msg)) => {
282                    Self::handle_peer_message(inner, node_msg).await;
283                    Ok(MessageResult::Continue)
284                }
285                Err(e) => {
286                    warn!("Failed to deserialize message: {}", e);
287                    Ok(MessageResult::Continue)
288                }
289            },
290            Some(Ok(Message::Close(_))) => {
291                info!("WebSocket connection closed by server");
292                Ok(MessageResult::Break)
293            }
294            Some(Ok(Message::Ping(data))) => {
295                debug!("Received ping, sending pong");
296                if let Err(e) = sink.send(Message::Pong(data)).await {
297                    warn!("Failed to send pong: {}", e);
298                    return Err(e.into());
299                }
300                Ok(MessageResult::Continue)
301            }
302            Some(Ok(Message::Pong(_))) => {
303                debug!("Received pong");
304                Ok(MessageResult::Continue)
305            }
306            Some(Ok(Message::Text(text))) => {
307                debug!("Received unexpected text message: {}", text);
308                Ok(MessageResult::Continue)
309            }
310            Some(Ok(_)) => {
311                debug!("Received other message type");
312                Ok(MessageResult::Continue)
313            }
314            Some(Err(e)) => {
315                error!("WebSocket error: {}", e);
316                Err(e.into())
317            }
318            None => {
319                info!("WebSocket stream closed");
320                Ok(MessageResult::Break)
321            }
322        }
323    }
324
325    async fn handle_server_presence(
326        inner: &Arc<Inner<SE, PA>>,
327        server_presence: proto::Presence,
328        peer_sender: &mut Option<WebsocketPeerSender>,
329        outgoing_rx: &mut Option<tokio::sync::mpsc::UnboundedReceiver<proto::NodeMessage>>,
330    ) {
331        info!("Received server presence: {}", server_presence.node_id);
332
333        let (sender, rx) = WebsocketPeerSender::new(server_presence.node_id);
334        inner.node.register_peer(server_presence.clone(), Box::new(sender.clone()));
335
336        *outgoing_rx = Some(rx);
337        *peer_sender = Some(sender);
338
339        inner.connection_state.set(ConnectionState::Connected { url: inner.server_url.to_string(), server_presence });
340        inner.connected.store(true, Ordering::Release);
341        info!("Successfully connected to server {}", inner.server_url);
342    }
343
344    async fn handle_peer_message(inner: &Arc<Inner<SE, PA>>, node_msg: proto::NodeMessage) {
345        debug!("Received peer message");
346        let node = inner.node.clone();
347        tokio::spawn(async move {
348            if let Err(e) = node.handle_message(node_msg).await {
349                warn!("Error handling peer message: {}", e);
350            }
351        });
352    }
353}
354
355#[derive(Debug)]
356enum MessageResult {
357    Continue,
358    Break,
359}
360
361impl<SE, PA> Drop for WebsocketClient<SE, PA>
362where
363    SE: StorageEngine + Send + Sync + 'static,
364    PA: PolicyAgent + Send + Sync + 'static,
365{
366    fn drop(&mut self) {
367        if let Some(task) = self.task.lock().unwrap().take() {
368            debug!("WebSocket client dropped, requesting shutdown");
369            self.inner.shutdown_requested.store(true, Ordering::Release);
370            self.inner.shutdown.notify_waiters();
371            task.abort();
372        }
373    }
374}