ankurah_websocket_server/
server.rs

1use ankurah_core::storage::StorageEngine;
2use ankurah_proto as proto;
3use anyhow::Result;
4use axum::{
5    extract::{
6        ws::{WebSocket, WebSocketUpgrade},
7        State,
8    },
9    response::{IntoResponse, Response},
10    routing::get,
11    Router,
12};
13use axum_extra::{headers, TypedHeader};
14use bincode::deserialize;
15use futures_util::StreamExt;
16use std::net::IpAddr;
17use std::{future::Future, net::SocketAddr, ops::ControlFlow, pin::Pin};
18use tower::ServiceBuilder;
19use tower_http::trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer};
20#[cfg(feature = "instrument")]
21use tracing::instrument;
22use tracing::{debug, error, info, warn, Level};
23
24use ankurah_core::{node::Node, policy::PolicyAgent};
25
26use crate::{client_ip::SmartClientIp, OptionalUserAgent};
27
28use super::state::Connection;
29
30pub struct WebsocketServer<SE, PA>
31where
32    SE: StorageEngine + Send + Sync + 'static,
33    PA: PolicyAgent + Send + Sync + 'static,
34{
35    node: Option<Node<SE, PA>>,
36}
37
38impl<SE, PA> WebsocketServer<SE, PA>
39where
40    SE: StorageEngine + Send + Sync + 'static,
41    PA: PolicyAgent + Send + Sync + 'static,
42{
43    pub fn new(node: Node<SE, PA>) -> Self { Self { node: Some(node) } }
44
45    pub fn route_handler(
46        &self,
47    ) -> impl Clone + Send + 'static + Fn(WebSocketUpgrade, SmartClientIp, OptionalUserAgent) -> Pin<Box<dyn Future<Output = Response> + Send>>
48    {
49        let node = self.node.as_ref().expect("websocket server cannot produce a route after being run").clone();
50
51        move |ws: WebSocketUpgrade, SmartClientIp(client_ip): SmartClientIp, OptionalUserAgent(user_agent)| {
52            let node = node.clone();
53            Box::pin(async move { upgrade_connection(ws, client_ip, user_agent, node) })
54        }
55    }
56
57    pub async fn run(&mut self, bind_address: &str) -> Result<()> {
58        let Some(node) = self.node.take() else {
59            return Err(anyhow::anyhow!("Already been run"));
60        };
61        let app = Router::new().route("/ws", get(ws_handler)).with_state(node).layer(
62            ServiceBuilder::new()
63                .layer(
64                    TraceLayer::new_for_http()
65                        .make_span_with(DefaultMakeSpan::new().level(Level::INFO))
66                        .on_request(DefaultOnRequest::new().level(Level::INFO))
67                        .on_response(DefaultOnResponse::new().level(Level::INFO)),
68                )
69                .into_inner(),
70        );
71
72        let listener = tokio::net::TcpListener::bind(bind_address).await?;
73        info!("Websocket server listening on {}", listener.local_addr()?);
74
75        axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
76
77        Ok(())
78    }
79}
80
81async fn ws_handler<SE, PA>(
82    ws: WebSocketUpgrade,
83    SmartClientIp(client_ip): SmartClientIp,
84    user_agent: Option<TypedHeader<headers::UserAgent>>,
85    State(node): State<Node<SE, PA>>,
86) -> impl IntoResponse
87where
88    SE: StorageEngine + Send + Sync + 'static,
89    PA: PolicyAgent + Send + Sync + 'static,
90{
91    let user_agent = user_agent.map(|TypedHeader(user_agent)| user_agent.to_string());
92    upgrade_connection(ws, client_ip, user_agent, node)
93}
94
95fn upgrade_connection<SE, PA>(ws: WebSocketUpgrade, client_ip: IpAddr, user_agent: Option<String>, node: Node<SE, PA>) -> Response
96where
97    SE: StorageEngine + Send + Sync + 'static,
98    PA: PolicyAgent + Send + Sync + 'static,
99{
100    debug!("Websocket server upgrading connection");
101    let user_agent = format_user_agent(user_agent);
102    debug!("`{user_agent}` at {client_ip} connected.");
103    ws.on_upgrade(move |socket| handle_websocket(socket, client_ip, node))
104}
105
106fn format_user_agent(user_agent: Option<String>) -> String { user_agent.unwrap_or_else(|| String::from("Unknown browser")) }
107
108#[cfg_attr(feature = "instrument", instrument(level = "debug", skip_all, fields(client_ip = %client_ip)))]
109async fn handle_websocket<SE, PA>(socket: WebSocket, client_ip: IpAddr, node: Node<SE, PA>)
110where
111    SE: StorageEngine + Send + Sync + 'static,
112    PA: PolicyAgent + Send + Sync + 'static,
113{
114    info!("Websocket server connected to {}", client_ip);
115
116    let (sender, mut receiver) = socket.split();
117    let mut conn = Connection::Initial(Some(sender));
118
119    // Immediately send server presence after connection
120    if let Err(e) = conn
121        .send(proto::Message::Presence(proto::Presence { node_id: node.id, durable: node.durable, system_root: node.system.root() }))
122        .await
123    {
124        debug!("Error sending presence to {client_ip}: {:?}", e);
125        return;
126    }
127
128    while let Some(msg) = receiver.next().await {
129        if let Ok(msg) = msg {
130            if process_message(msg, client_ip, &mut conn, node.clone()).await.is_break() {
131                break;
132            }
133        } else {
134            debug!("client {client_ip} abruptly disconnected");
135            break;
136        }
137    }
138
139    // Clean up peer registration if we had registered one
140    if let Connection::Established(peer_sender) = conn {
141        use ankurah_core::connector::PeerSender;
142        node.deregister_peer(peer_sender.recipient_node_id());
143    }
144
145    debug!("Websocket context {client_ip} destroyed");
146}
147
148#[cfg_attr(feature = "instrument", instrument(level = "debug", skip_all, fields(client_ip = %client_ip)))]
149async fn process_message<SE, PA>(
150    msg: axum::extract::ws::Message,
151    client_ip: IpAddr,
152    state: &mut Connection,
153    node: Node<SE, PA>,
154) -> ControlFlow<(), ()>
155where
156    SE: StorageEngine + Send + Sync + 'static,
157    PA: PolicyAgent + Send + Sync + 'static,
158{
159    match msg {
160        axum::extract::ws::Message::Binary(d) => {
161            debug!(">>> {} sent {} bytes", client_ip, d.len());
162
163            if let Ok(message) = deserialize::<proto::Message>(&d) {
164                match message {
165                    proto::Message::Presence(presence) => {
166                        match state {
167                            Connection::Initial(sender) => {
168                                if let Some(sender) = sender.take() {
169                                    debug!("Received client presence from {}", client_ip);
170
171                                    use super::sender::WebSocketClientSender;
172                                    // Register peer sender for this client
173                                    let sender = WebSocketClientSender::new(presence.node_id, sender);
174
175                                    node.register_peer(presence, Box::new(sender.clone()));
176                                    *state = Connection::Established(sender);
177                                }
178                            }
179                            _ => warn!("Received presence from {} but already have a peer sender - ignoring", client_ip),
180                        }
181                    }
182                    proto::Message::PeerMessage(msg) => {
183                        if let Connection::Established(_) = state {
184                            tokio::spawn(async move {
185                                if let Err(e) = node.handle_message(msg).await {
186                                    error!("Error handling message from {}: {:?}", client_ip, e);
187                                }
188                            });
189                        } else {
190                            warn!("Received peer message from {} but not connected as a peer", client_ip);
191                        }
192                    }
193                }
194            } else {
195                error!("Failed to deserialize message from {}", client_ip);
196            }
197        }
198        axum::extract::ws::Message::Text(t) => {
199            debug!(">>> {client_ip} sent str: {t:?}");
200        }
201        axum::extract::ws::Message::Close(c) => {
202            if let Some(cf) = c {
203                debug!(">>> {} sent close with code {} and reason `{}`", client_ip, cf.code, cf.reason);
204            } else {
205                debug!(">>> {client_ip} somehow sent close message without CloseFrame");
206            }
207            return ControlFlow::Break(());
208        }
209        axum::extract::ws::Message::Pong(v) => {
210            debug!(">>> {client_ip} sent pong with {v:?}");
211        }
212        axum::extract::ws::Message::Ping(v) => {
213            debug!(">>> {client_ip} sent ping with {v:?}");
214        }
215    }
216    ControlFlow::Continue(())
217}