ankurah_websocket_client/
client.rs1use crate::sender::WebsocketPeerSender;
2use ankurah_core::{connector::PeerSender, policy::PolicyAgent, storage::StorageEngine, Node};
3use ankurah_proto as proto;
4use ankurah_signals::{Mut, Read, Wait};
5use anyhow::Result;
6use futures_util::{SinkExt, StreamExt};
7use std::{
8 sync::{
9 atomic::{AtomicBool, Ordering},
10 Arc,
11 },
12 time::Duration,
13};
14use strum::Display;
15use thiserror::Error;
16use tokio::{select, sync::Notify, task::JoinHandle, time::sleep};
17use tokio_tungstenite::{connect_async, tungstenite::Message};
18use tracing::{debug, error, info, warn};
19
20#[derive(Debug, Clone, PartialEq, Display)]
22pub enum ConnectionState {
23 Disconnected,
24 #[strum(serialize = "Connecting")]
25 Connecting {
26 url: String,
27 },
28 #[strum(serialize = "Connected")]
29 Connected {
30 url: String,
31 server_presence: proto::Presence,
32 },
33 #[strum(serialize = "Error")]
34 Error(ConnectionError),
35}
36
37#[derive(Debug, Clone, PartialEq, Error)]
38pub enum ConnectionError {
39 #[error("General connection error: {0}")]
40 General(String),
41}
42
43const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
44const MAX_BACKOFF: Duration = Duration::from_secs(30);
45
46struct Inner<SE, PA>
47where
48 SE: StorageEngine + Send + Sync + 'static,
49 PA: PolicyAgent + Send + Sync + 'static,
50{
51 node: Node<SE, PA>,
52 server_url: String,
53 connection_state: Mut<ConnectionState>,
54 connected: AtomicBool,
55 shutdown: Notify,
56 shutdown_requested: AtomicBool,
57}
58
59pub struct WebsocketClient<SE, PA>
61where
62 SE: StorageEngine + Send + Sync + 'static,
63 PA: PolicyAgent + Send + Sync + 'static,
64{
65 inner: Arc<Inner<SE, PA>>,
66 task: std::sync::Mutex<Option<JoinHandle<()>>>,
67}
68
69impl<SE, PA> WebsocketClient<SE, PA>
70where
71 SE: StorageEngine + Send + Sync + 'static,
72 PA: PolicyAgent + Send + Sync + 'static,
73{
74 pub async fn new(node: Node<SE, PA>, server_url: &str) -> anyhow::Result<Self> {
76 let ws_url = Self::normalize_url(server_url);
77 info!("Creating WebSocket client for {}", ws_url);
78
79 let inner = Arc::new(Inner {
80 node,
81 server_url: ws_url,
82 connection_state: Mut::new(ConnectionState::Disconnected),
83 connected: AtomicBool::new(false),
84 shutdown: Notify::new(),
85 shutdown_requested: AtomicBool::new(false),
86 });
87
88 let task = tokio::spawn(Self::run_connection_loop(inner.clone()));
89 Ok(Self { inner, task: std::sync::Mutex::new(Some(task)) })
90 }
91
92 fn normalize_url(url: &str) -> String {
93 match url {
94 u if u.starts_with("ws://") || u.starts_with("wss://") => format!("{}/ws", u),
95 u if u.starts_with("http://") => format!("ws://{}/ws", &u[7..]),
96 u if u.starts_with("https://") => format!("wss://{}/ws", &u[8..]),
97 u => format!("wss://{}/ws", u),
98 }
99 }
100
101 pub fn state(&self) -> Read<ConnectionState> { self.inner.connection_state.read() }
103
104 pub fn is_connected(&self) -> bool { self.inner.connected.load(Ordering::Acquire) }
106
107 pub async fn shutdown(self) -> anyhow::Result<()> {
109 info!("Shutting down WebSocket client");
110
111 if let Some(task) = self.task.lock().unwrap().take() {
112 self.inner.shutdown_requested.store(true, Ordering::Release);
113 self.inner.shutdown.notify_waiters();
114
115 match task.await {
116 Ok(()) => info!("WebSocket client shutdown completed"),
117 Err(e) => warn!("Connection task join error during shutdown: {}", e),
118 }
119 } else {
120 info!("WebSocket client already shut down");
121 }
122 Ok(())
123 }
124
125 pub async fn wait_connected(&self) -> Result<(), ConnectionError> {
127 self.state()
129 .wait_for(|state| match state {
130 ConnectionState::Connected { .. } => Some(Ok(())),
131 ConnectionState::Error(e) => Some(Err(e.clone())),
132 _ => None, })
134 .await
135 }
136
137 pub fn server_node_id(&self) -> Option<proto::EntityId> {
139 use ankurah_signals::Get;
140 match self.state().get() {
141 ConnectionState::Connected { server_presence, .. } => Some(server_presence.node_id),
142 _ => None,
143 }
144 }
145
146 async fn run_connection_loop(inner: Arc<Inner<SE, PA>>) {
148 let mut backoff = INITIAL_BACKOFF;
149 info!("Starting websocket connection loop to {}", inner.server_url);
150
151 loop {
152 select! {
153 _ = inner.shutdown.notified() => {
154 info!("Websocket connection shutting down");
155 break;
156 }
157 result = Self::connect_once(&inner) => {
158 match result {
159 Ok(()) => {
160 info!("Connection to {} completed normally", inner.server_url);
161 backoff = INITIAL_BACKOFF;
162 if inner.shutdown_requested.load(Ordering::Acquire) {
163 info!("Shutdown requested, stopping reconnection attempts");
164 break;
165 }
166 }
167 Err(e) => {
168 error!("Connection to {} failed: {}", inner.server_url, e);
169 inner.connection_state.set(ConnectionState::Error(ConnectionError::General(e.to_string())));
170 inner.connected.store(false, Ordering::Release);
171
172 info!("Retrying connection in {:?}", backoff);
173 select! {
174 _ = inner.shutdown.notified() => break,
175 _ = sleep(backoff) => {}
176 }
177 backoff = (backoff * 2).min(MAX_BACKOFF);
178 }
179 }
180 }
181 }
182 }
183
184 inner.connection_state.set(ConnectionState::Disconnected);
185 inner.connected.store(false, Ordering::Release);
186 }
187
188 async fn connect_once(inner: &Arc<Inner<SE, PA>>) -> Result<()> {
190 info!("Attempting to connect to {}", inner.server_url);
191 inner.connection_state.set(ConnectionState::Connecting { url: inner.server_url.clone() });
192
193 let (ws_stream, _) = connect_async(inner.server_url.as_str()).await?;
194 info!("WebSocket handshake completed with {}", inner.server_url);
195
196 let (mut sink, mut stream) = ws_stream.split();
197 debug!("Starting connection handling");
198
199 let presence = proto::Message::Presence(proto::Presence {
201 node_id: inner.node.id,
202 durable: inner.node.durable,
203 system_root: inner.node.system.root(),
204 });
205
206 sink.send(Message::Binary(bincode::serialize(&presence)?.into())).await?;
207 debug!("Sent client presence");
208
209 let mut peer_sender: Option<WebsocketPeerSender> = None;
210 let mut outgoing_rx: Option<tokio::sync::mpsc::UnboundedReceiver<proto::NodeMessage>> = None;
211
212 loop {
213 select! {
214 _ = inner.shutdown.notified() => {
215 debug!("Connection received shutdown signal");
216 break;
217 }
218 msg = async {
219 match &mut outgoing_rx {
220 Some(rx) => rx.recv().await,
221 None => std::future::pending().await,
222 }
223 } => {
224 if Self::handle_outgoing_message(&mut sink, msg).await.is_err() {
225 break;
226 }
227 }
228 msg = stream.next() => {
229 match Self::handle_incoming_message(inner, msg, &mut peer_sender, &mut outgoing_rx, &mut sink).await? {
230 MessageResult::Continue => continue,
231 MessageResult::Break => break,
232 }
233 }
234 }
235 }
236
237 inner.connected.store(false, Ordering::Release);
239 if let Some(sender) = peer_sender {
240 inner.node.deregister_peer(sender.recipient_node_id());
241 debug!("Deregistered peer {}", sender.recipient_node_id());
242 }
243 Ok(())
244 }
245
246 async fn handle_outgoing_message(
247 sink: &mut futures_util::stream::SplitSink<
248 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
249 Message,
250 >,
251 msg: Option<proto::NodeMessage>,
252 ) -> Result<()> {
253 if let Some(node_message) = msg {
254 let proto_message = proto::Message::PeerMessage(node_message);
255 match bincode::serialize(&proto_message) {
256 Ok(data) => {
257 sink.send(Message::Binary(data.into())).await?;
258 }
259 Err(e) => error!("Failed to serialize outgoing message: {}", e),
260 }
261 }
262 Ok(())
263 }
264
265 async fn handle_incoming_message(
266 inner: &Arc<Inner<SE, PA>>,
267 msg: Option<Result<Message, tokio_tungstenite::tungstenite::Error>>,
268 peer_sender: &mut Option<WebsocketPeerSender>,
269 outgoing_rx: &mut Option<tokio::sync::mpsc::UnboundedReceiver<proto::NodeMessage>>,
270 sink: &mut futures_util::stream::SplitSink<
271 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
272 Message,
273 >,
274 ) -> Result<MessageResult> {
275 match msg {
276 Some(Ok(Message::Binary(data))) => match bincode::deserialize(&data) {
277 Ok(proto::Message::Presence(server_presence)) => {
278 Self::handle_server_presence(inner, server_presence, peer_sender, outgoing_rx).await;
279 Ok(MessageResult::Continue)
280 }
281 Ok(proto::Message::PeerMessage(node_msg)) => {
282 Self::handle_peer_message(inner, node_msg).await;
283 Ok(MessageResult::Continue)
284 }
285 Err(e) => {
286 warn!("Failed to deserialize message: {}", e);
287 Ok(MessageResult::Continue)
288 }
289 },
290 Some(Ok(Message::Close(_))) => {
291 info!("WebSocket connection closed by server");
292 Ok(MessageResult::Break)
293 }
294 Some(Ok(Message::Ping(data))) => {
295 debug!("Received ping, sending pong");
296 if let Err(e) = sink.send(Message::Pong(data)).await {
297 warn!("Failed to send pong: {}", e);
298 return Err(e.into());
299 }
300 Ok(MessageResult::Continue)
301 }
302 Some(Ok(Message::Pong(_))) => {
303 debug!("Received pong");
304 Ok(MessageResult::Continue)
305 }
306 Some(Ok(Message::Text(text))) => {
307 debug!("Received unexpected text message: {}", text);
308 Ok(MessageResult::Continue)
309 }
310 Some(Ok(_)) => {
311 debug!("Received other message type");
312 Ok(MessageResult::Continue)
313 }
314 Some(Err(e)) => {
315 error!("WebSocket error: {}", e);
316 Err(e.into())
317 }
318 None => {
319 info!("WebSocket stream closed");
320 Ok(MessageResult::Break)
321 }
322 }
323 }
324
325 async fn handle_server_presence(
326 inner: &Arc<Inner<SE, PA>>,
327 server_presence: proto::Presence,
328 peer_sender: &mut Option<WebsocketPeerSender>,
329 outgoing_rx: &mut Option<tokio::sync::mpsc::UnboundedReceiver<proto::NodeMessage>>,
330 ) {
331 info!("Received server presence: {}", server_presence.node_id);
332
333 let (sender, rx) = WebsocketPeerSender::new(server_presence.node_id);
334 inner.node.register_peer(server_presence.clone(), Box::new(sender.clone()));
335
336 *outgoing_rx = Some(rx);
337 *peer_sender = Some(sender);
338
339 inner.connection_state.set(ConnectionState::Connected { url: inner.server_url.to_string(), server_presence });
340 inner.connected.store(true, Ordering::Release);
341 info!("Successfully connected to server {}", inner.server_url);
342 }
343
344 async fn handle_peer_message(inner: &Arc<Inner<SE, PA>>, node_msg: proto::NodeMessage) {
345 debug!("Received peer message");
346 let node = inner.node.clone();
347 tokio::spawn(async move {
348 if let Err(e) = node.handle_message(node_msg).await {
349 warn!("Error handling peer message: {}", e);
350 }
351 });
352 }
353}
354
355#[derive(Debug)]
356enum MessageResult {
357 Continue,
358 Break,
359}
360
361impl<SE, PA> Drop for WebsocketClient<SE, PA>
362where
363 SE: StorageEngine + Send + Sync + 'static,
364 PA: PolicyAgent + Send + Sync + 'static,
365{
366 fn drop(&mut self) {
367 if let Some(task) = self.task.lock().unwrap().take() {
368 debug!("WebSocket client dropped, requesting shutdown");
369 self.inner.shutdown_requested.store(true, Ordering::Release);
370 self.inner.shutdown.notify_waiters();
371 task.abort();
372 }
373 }
374}