bonsaidb_server/server/
websockets.rs1use bonsaidb_core::networking::{Payload, CURRENT_PROTOCOL_VERSION};
2use futures::{SinkExt, StreamExt};
3use tokio::io::{AsyncRead, AsyncWrite};
4use tokio_tungstenite::tungstenite::Message;
5
6use crate::server::connected_client::OwnedClient;
7use crate::server::shutdown::{ShutdownState, ShutdownStateWatcher};
8use crate::{Backend, CustomServer, Error, Transport};
9
10impl<B: Backend> CustomServer<B> {
11 pub async fn listen_for_websockets_on<T: tokio::net::ToSocketAddrs + Send + Sync>(
13 &self,
14 addr: T,
15 with_tls: bool,
16 ) -> Result<(), Error> {
17 if with_tls {
18 self.listen_for_secure_tcp_on(addr, ()).await
19 } else {
20 self.listen_for_tcp_on(addr, ()).await
21 }
22 }
23
24 pub(crate) async fn handle_raw_websocket_connection<
25 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
26 >(
27 &self,
28 connection: S,
29 peer_address: std::net::SocketAddr,
30 ) -> Result<(), Error> {
31 let stream = tokio_tungstenite::accept_hdr_async(connection, VersionChecker).await?;
32 self.handle_websocket(stream, peer_address).await;
33 Ok(())
34 }
35
36 #[cfg(feature = "hyper")]
39 pub fn upgrade_websocket(
40 &self,
41 peer_address: std::net::SocketAddr,
42 mut request: hyper::Request<hyper::Body>,
43 ) -> hyper::Response<hyper::Body> {
44 use hyper::header::{
45 HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, UPGRADE,
46 };
47 use hyper::StatusCode;
48 use tokio_tungstenite::tungstenite::protocol::Role;
49 use tokio_tungstenite::WebSocketStream;
50
51 let mut response = hyper::Response::new(hyper::Body::empty());
52 if !request.headers().contains_key(UPGRADE) {
55 *response.status_mut() = StatusCode::BAD_REQUEST;
56 return response;
57 }
58
59 let Some(sec_websocket_key) = request.headers_mut().remove(SEC_WEBSOCKET_KEY) else {
60 *response.status_mut() = StatusCode::BAD_REQUEST;
61 return response;
62 };
63
64 let task_self = self.clone();
65 tokio::spawn(async move {
66 match hyper::upgrade::on(&mut request).await {
67 Ok(upgraded) => {
68 let ws = WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
69 task_self.handle_websocket(ws, peer_address).await;
70 }
71 Err(err) => {
72 log::error!("Error upgrading websocket: {:?}", err);
73 }
74 }
75 });
76
77 *response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
78 response
79 .headers_mut()
80 .insert(UPGRADE, HeaderValue::from_static("websocket"));
81 response
82 .headers_mut()
83 .insert(CONNECTION, HeaderValue::from_static("upgrade"));
84 response.headers_mut().insert(
85 SEC_WEBSOCKET_ACCEPT,
86 compute_websocket_accept_header(sec_websocket_key.as_bytes()),
87 );
88
89 response
90 }
91
92 pub async fn handle_websocket<
94 S: futures::Stream<Item = Result<tokio_tungstenite::tungstenite::Message, E>>
95 + futures::Sink<tokio_tungstenite::tungstenite::Message>
96 + Send
97 + 'static,
98 E: std::fmt::Debug + Send,
99 >(
100 &self,
101 connection: S,
102 peer_address: std::net::SocketAddr,
103 ) {
104 let mut shutdown = self
105 .data
106 .shutdown
107 .watcher()
108 .await
109 .expect("watcher shut down");
110
111 let (mut sender, mut receiver) = connection.split();
112 let (response_sender, response_receiver) = flume::unbounded();
113 let (message_sender, message_receiver) = flume::unbounded();
114
115 let (api_response_sender, api_response_receiver) = flume::unbounded();
116 let Some(client) = self
117 .initialize_client(Transport::WebSocket, peer_address, api_response_sender)
118 .await
119 else {
120 return;
121 };
122 let task_sender = response_sender.clone();
123 tokio::spawn(async move {
124 while let Ok((session_id, name, value)) = api_response_receiver.recv_async().await {
125 if task_sender
126 .send(Payload {
127 id: None,
128 session_id,
129 name,
130 value: Ok(value),
131 })
132 .is_err()
133 {
134 break;
135 }
136 }
137 });
138
139 tokio::spawn(async move {
140 while let Ok(response) = message_receiver.recv_async().await {
141 if sender.send(response).await.is_err() {
142 break;
143 }
144 }
145
146 Result::<(), Error>::Ok(())
147 });
148
149 let task_sender = message_sender.clone();
150 tokio::spawn(async move {
151 while let Ok(response) = response_receiver.recv_async().await {
152 if task_sender
153 .send(Message::Binary(bincode::serialize(&response)?))
154 .is_err()
155 {
156 break;
157 }
158 }
159
160 Result::<(), Error>::Ok(())
161 });
162
163 let (request_sender, request_receiver) =
164 flume::bounded::<Payload>(self.data.client_simultaneous_request_limit);
165
166 self.spawn_client_request_handler(client, request_receiver, response_sender, &shutdown);
167
168 loop {
169 tokio::select! {
170 payload = receiver.next() => {
171 if let Some(payload) = payload {
172 match payload {
173 Ok(Message::Binary(binary)) => match bincode::deserialize::<Payload>(&binary) {
174 Ok(payload) => drop(request_sender.send_async(payload).await),
175 Err(err) => {
176 log::error!("[server] error decoding message: {:?}", err);
177 break;
178 }
179 },
180 Ok(Message::Close(_)) => break,
181 Ok(Message::Ping(payload)) => {
182 drop(message_sender.send(Message::Pong(payload)));
183 }
184 other => {
185 log::error!("[server] unexpected message: {:?}", other);
186 break;
187 }
188 }
189 } else {
190 return;
191 }
192 },
193 shutdown = shutdown.wait_for_shutdown() => {
194 if matches!(shutdown, ShutdownState::Shutdown) {
195 return;
196 }
197 }
198 }
199 }
200 }
201
202 fn spawn_client_request_handler(
203 &self,
204 client: OwnedClient<B>,
205 request_receiver: flume::Receiver<Payload>,
206 response_sender: flume::Sender<Payload>,
207 shutdown: &ShutdownStateWatcher,
208 ) {
209 tokio::spawn({
210 let task_self = self.clone();
211 let shutdown = shutdown.clone();
212 async move {
213 task_self
214 .handle_client_requests(
215 client.clone(),
216 request_receiver,
217 response_sender,
218 shutdown,
219 )
220 .await;
221 }
222 });
223 }
224}
225
226#[cfg(feature = "hyper")]
227fn compute_websocket_accept_header(key: &[u8]) -> hyper::header::HeaderValue {
228 use base64::engine::general_purpose::STANDARD as BASE64;
229 use base64::Engine;
230 use sha1::{Digest, Sha1};
231
232 let mut digest = Sha1::default();
233 digest.update(key);
234 digest.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
235 let encoded = BASE64.encode(digest.finalize());
236 hyper::header::HeaderValue::from_str(&encoded).expect("base64 is a valid value")
237}
238
239struct VersionChecker;
240
241impl tokio_tungstenite::tungstenite::handshake::server::Callback for VersionChecker {
242 fn on_request(
243 self,
244 request: &tokio_tungstenite::tungstenite::handshake::server::Request,
245 mut response: tokio_tungstenite::tungstenite::handshake::server::Response,
246 ) -> Result<
247 tokio_tungstenite::tungstenite::handshake::server::Response,
248 tokio_tungstenite::tungstenite::handshake::server::ErrorResponse,
249 > {
250 if let Some(protocols) = request.headers().get("Sec-WebSocket-Protocol") {
251 if let Ok(protocols) = protocols.to_str() {
252 for protocol in protocols.split(',').map(str::trim) {
253 if protocol == CURRENT_PROTOCOL_VERSION {
254 response.headers_mut().insert(
255 "Sec-WebSocket-Protocol",
256 CURRENT_PROTOCOL_VERSION.try_into().unwrap(),
257 );
258 return Ok(response);
259 }
260 }
261 }
262 }
263
264 let mut err = tokio_tungstenite::tungstenite::handshake::server::ErrorResponse::new(None);
265 *err.status_mut() = 406_u16.try_into().unwrap();
266 Err(err)
267 }
268}