ankurah_websocket_server/
server.rs1use ankurah_core::storage::StorageEngine;
2use ankurah_proto as proto;
3use anyhow::Result;
4use axum::{
5 extract::{
6 ws::{WebSocket, WebSocketUpgrade},
7 State,
8 },
9 response::{IntoResponse, Response},
10 routing::get,
11 Router,
12};
13use axum_extra::{headers, TypedHeader};
14use bincode::deserialize;
15use futures_util::StreamExt;
16use std::net::IpAddr;
17use std::{future::Future, net::SocketAddr, ops::ControlFlow, pin::Pin};
18use tower::ServiceBuilder;
19use tower_http::trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer};
20#[cfg(feature = "instrument")]
21use tracing::instrument;
22use tracing::{debug, error, info, warn, Level};
23
24use ankurah_core::{node::Node, policy::PolicyAgent};
25
26use crate::{client_ip::SmartClientIp, OptionalUserAgent};
27
28use super::state::Connection;
29
30pub struct WebsocketServer<SE, PA>
31where
32 SE: StorageEngine + Send + Sync + 'static,
33 PA: PolicyAgent + Send + Sync + 'static,
34{
35 node: Option<Node<SE, PA>>,
36}
37
38impl<SE, PA> WebsocketServer<SE, PA>
39where
40 SE: StorageEngine + Send + Sync + 'static,
41 PA: PolicyAgent + Send + Sync + 'static,
42{
43 pub fn new(node: Node<SE, PA>) -> Self { Self { node: Some(node) } }
44
45 pub fn route_handler(
46 &self,
47 ) -> impl Clone + Send + 'static + Fn(WebSocketUpgrade, SmartClientIp, OptionalUserAgent) -> Pin<Box<dyn Future<Output = Response> + Send>>
48 {
49 let node = self.node.as_ref().expect("websocket server cannot produce a route after being run").clone();
50
51 move |ws: WebSocketUpgrade, SmartClientIp(client_ip): SmartClientIp, OptionalUserAgent(user_agent)| {
52 let node = node.clone();
53 Box::pin(async move { upgrade_connection(ws, client_ip, user_agent, node) })
54 }
55 }
56
57 pub async fn run(&mut self, bind_address: &str) -> Result<()> {
58 let Some(node) = self.node.take() else {
59 return Err(anyhow::anyhow!("Already been run"));
60 };
61 let app = Router::new().route("/ws", get(ws_handler)).with_state(node).layer(
62 ServiceBuilder::new()
63 .layer(
64 TraceLayer::new_for_http()
65 .make_span_with(DefaultMakeSpan::new().level(Level::INFO))
66 .on_request(DefaultOnRequest::new().level(Level::INFO))
67 .on_response(DefaultOnResponse::new().level(Level::INFO)),
68 )
69 .into_inner(),
70 );
71
72 let listener = tokio::net::TcpListener::bind(bind_address).await?;
73 info!("Websocket server listening on {}", listener.local_addr()?);
74
75 axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
76
77 Ok(())
78 }
79}
80
81async fn ws_handler<SE, PA>(
82 ws: WebSocketUpgrade,
83 SmartClientIp(client_ip): SmartClientIp,
84 user_agent: Option<TypedHeader<headers::UserAgent>>,
85 State(node): State<Node<SE, PA>>,
86) -> impl IntoResponse
87where
88 SE: StorageEngine + Send + Sync + 'static,
89 PA: PolicyAgent + Send + Sync + 'static,
90{
91 let user_agent = user_agent.map(|TypedHeader(user_agent)| user_agent.to_string());
92 upgrade_connection(ws, client_ip, user_agent, node)
93}
94
95fn upgrade_connection<SE, PA>(ws: WebSocketUpgrade, client_ip: IpAddr, user_agent: Option<String>, node: Node<SE, PA>) -> Response
96where
97 SE: StorageEngine + Send + Sync + 'static,
98 PA: PolicyAgent + Send + Sync + 'static,
99{
100 debug!("Websocket server upgrading connection");
101 let user_agent = format_user_agent(user_agent);
102 debug!("`{user_agent}` at {client_ip} connected.");
103 ws.on_upgrade(move |socket| handle_websocket(socket, client_ip, node))
104}
105
106fn format_user_agent(user_agent: Option<String>) -> String { user_agent.unwrap_or_else(|| String::from("Unknown browser")) }
107
108#[cfg_attr(feature = "instrument", instrument(level = "debug", skip_all, fields(client_ip = %client_ip)))]
109async fn handle_websocket<SE, PA>(socket: WebSocket, client_ip: IpAddr, node: Node<SE, PA>)
110where
111 SE: StorageEngine + Send + Sync + 'static,
112 PA: PolicyAgent + Send + Sync + 'static,
113{
114 info!("Websocket server connected to {}", client_ip);
115
116 let (sender, mut receiver) = socket.split();
117 let mut conn = Connection::Initial(Some(sender));
118
119 if let Err(e) = conn
121 .send(proto::Message::Presence(proto::Presence { node_id: node.id, durable: node.durable, system_root: node.system.root() }))
122 .await
123 {
124 debug!("Error sending presence to {client_ip}: {:?}", e);
125 return;
126 }
127
128 while let Some(msg) = receiver.next().await {
129 if let Ok(msg) = msg {
130 if process_message(msg, client_ip, &mut conn, node.clone()).await.is_break() {
131 break;
132 }
133 } else {
134 debug!("client {client_ip} abruptly disconnected");
135 break;
136 }
137 }
138
139 if let Connection::Established(peer_sender) = conn {
141 use ankurah_core::connector::PeerSender;
142 node.deregister_peer(peer_sender.recipient_node_id());
143 }
144
145 debug!("Websocket context {client_ip} destroyed");
146}
147
148#[cfg_attr(feature = "instrument", instrument(level = "debug", skip_all, fields(client_ip = %client_ip)))]
149async fn process_message<SE, PA>(
150 msg: axum::extract::ws::Message,
151 client_ip: IpAddr,
152 state: &mut Connection,
153 node: Node<SE, PA>,
154) -> ControlFlow<(), ()>
155where
156 SE: StorageEngine + Send + Sync + 'static,
157 PA: PolicyAgent + Send + Sync + 'static,
158{
159 match msg {
160 axum::extract::ws::Message::Binary(d) => {
161 debug!(">>> {} sent {} bytes", client_ip, d.len());
162
163 if let Ok(message) = deserialize::<proto::Message>(&d) {
164 match message {
165 proto::Message::Presence(presence) => {
166 match state {
167 Connection::Initial(sender) => {
168 if let Some(sender) = sender.take() {
169 debug!("Received client presence from {}", client_ip);
170
171 use super::sender::WebSocketClientSender;
172 let sender = WebSocketClientSender::new(presence.node_id, sender);
174
175 node.register_peer(presence, Box::new(sender.clone()));
176 *state = Connection::Established(sender);
177 }
178 }
179 _ => warn!("Received presence from {} but already have a peer sender - ignoring", client_ip),
180 }
181 }
182 proto::Message::PeerMessage(msg) => {
183 if let Connection::Established(_) = state {
184 tokio::spawn(async move {
185 if let Err(e) = node.handle_message(msg).await {
186 error!("Error handling message from {}: {:?}", client_ip, e);
187 }
188 });
189 } else {
190 warn!("Received peer message from {} but not connected as a peer", client_ip);
191 }
192 }
193 }
194 } else {
195 error!("Failed to deserialize message from {}", client_ip);
196 }
197 }
198 axum::extract::ws::Message::Text(t) => {
199 debug!(">>> {client_ip} sent str: {t:?}");
200 }
201 axum::extract::ws::Message::Close(c) => {
202 if let Some(cf) = c {
203 debug!(">>> {} sent close with code {} and reason `{}`", client_ip, cf.code, cf.reason);
204 } else {
205 debug!(">>> {client_ip} somehow sent close message without CloseFrame");
206 }
207 return ControlFlow::Break(());
208 }
209 axum::extract::ws::Message::Pong(v) => {
210 debug!(">>> {client_ip} sent pong with {v:?}");
211 }
212 axum::extract::ws::Message::Ping(v) => {
213 debug!(">>> {client_ip} sent ping with {v:?}");
214 }
215 }
216 ControlFlow::Continue(())
217}