bonsaidb_server/server/
websockets.rs

1use bonsaidb_core::networking::{Payload, CURRENT_PROTOCOL_VERSION};
2use futures::{SinkExt, StreamExt};
3use tokio::io::{AsyncRead, AsyncWrite};
4use tokio_tungstenite::tungstenite::Message;
5
6use crate::server::connected_client::OwnedClient;
7use crate::server::shutdown::{ShutdownState, ShutdownStateWatcher};
8use crate::{Backend, CustomServer, Error, Transport};
9
10impl<B: Backend> CustomServer<B> {
11    /// Listens for websocket connections on `addr`.
12    pub async fn listen_for_websockets_on<T: tokio::net::ToSocketAddrs + Send + Sync>(
13        &self,
14        addr: T,
15        with_tls: bool,
16    ) -> Result<(), Error> {
17        if with_tls {
18            self.listen_for_secure_tcp_on(addr, ()).await
19        } else {
20            self.listen_for_tcp_on(addr, ()).await
21        }
22    }
23
24    pub(crate) async fn handle_raw_websocket_connection<
25        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
26    >(
27        &self,
28        connection: S,
29        peer_address: std::net::SocketAddr,
30    ) -> Result<(), Error> {
31        let stream = tokio_tungstenite::accept_hdr_async(connection, VersionChecker).await?;
32        self.handle_websocket(stream, peer_address).await;
33        Ok(())
34    }
35
36    /// Handles upgrading an HTTP connection to the `WebSocket` protocol based
37    /// on the upgrade `request`. Requires feature `hyper` to be enabled.
38    #[cfg(feature = "hyper")]
39    pub fn upgrade_websocket(
40        &self,
41        peer_address: std::net::SocketAddr,
42        mut request: hyper::Request<hyper::Body>,
43    ) -> hyper::Response<hyper::Body> {
44        use hyper::header::{
45            HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, UPGRADE,
46        };
47        use hyper::StatusCode;
48        use tokio_tungstenite::tungstenite::protocol::Role;
49        use tokio_tungstenite::WebSocketStream;
50
51        let mut response = hyper::Response::new(hyper::Body::empty());
52        // Send a 400 to any request that doesn't have
53        // an `Upgrade` header.
54        if !request.headers().contains_key(UPGRADE) {
55            *response.status_mut() = StatusCode::BAD_REQUEST;
56            return response;
57        }
58
59        let Some(sec_websocket_key) = request.headers_mut().remove(SEC_WEBSOCKET_KEY) else {
60            *response.status_mut() = StatusCode::BAD_REQUEST;
61            return response;
62        };
63
64        let task_self = self.clone();
65        tokio::spawn(async move {
66            match hyper::upgrade::on(&mut request).await {
67                Ok(upgraded) => {
68                    let ws = WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
69                    task_self.handle_websocket(ws, peer_address).await;
70                }
71                Err(err) => {
72                    log::error!("Error upgrading websocket: {:?}", err);
73                }
74            }
75        });
76
77        *response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
78        response
79            .headers_mut()
80            .insert(UPGRADE, HeaderValue::from_static("websocket"));
81        response
82            .headers_mut()
83            .insert(CONNECTION, HeaderValue::from_static("upgrade"));
84        response.headers_mut().insert(
85            SEC_WEBSOCKET_ACCEPT,
86            compute_websocket_accept_header(sec_websocket_key.as_bytes()),
87        );
88
89        response
90    }
91
92    /// Handles an established `tokio-tungstenite` `WebSocket` stream.
93    pub async fn handle_websocket<
94        S: futures::Stream<Item = Result<tokio_tungstenite::tungstenite::Message, E>>
95            + futures::Sink<tokio_tungstenite::tungstenite::Message>
96            + Send
97            + 'static,
98        E: std::fmt::Debug + Send,
99    >(
100        &self,
101        connection: S,
102        peer_address: std::net::SocketAddr,
103    ) {
104        let mut shutdown = self
105            .data
106            .shutdown
107            .watcher()
108            .await
109            .expect("watcher shut down");
110
111        let (mut sender, mut receiver) = connection.split();
112        let (response_sender, response_receiver) = flume::unbounded();
113        let (message_sender, message_receiver) = flume::unbounded();
114
115        let (api_response_sender, api_response_receiver) = flume::unbounded();
116        let Some(client) = self
117            .initialize_client(Transport::WebSocket, peer_address, api_response_sender)
118            .await
119        else {
120            return;
121        };
122        let task_sender = response_sender.clone();
123        tokio::spawn(async move {
124            while let Ok((session_id, name, value)) = api_response_receiver.recv_async().await {
125                if task_sender
126                    .send(Payload {
127                        id: None,
128                        session_id,
129                        name,
130                        value: Ok(value),
131                    })
132                    .is_err()
133                {
134                    break;
135                }
136            }
137        });
138
139        tokio::spawn(async move {
140            while let Ok(response) = message_receiver.recv_async().await {
141                if sender.send(response).await.is_err() {
142                    break;
143                }
144            }
145
146            Result::<(), Error>::Ok(())
147        });
148
149        let task_sender = message_sender.clone();
150        tokio::spawn(async move {
151            while let Ok(response) = response_receiver.recv_async().await {
152                if task_sender
153                    .send(Message::Binary(bincode::serialize(&response)?))
154                    .is_err()
155                {
156                    break;
157                }
158            }
159
160            Result::<(), Error>::Ok(())
161        });
162
163        let (request_sender, request_receiver) =
164            flume::bounded::<Payload>(self.data.client_simultaneous_request_limit);
165
166        self.spawn_client_request_handler(client, request_receiver, response_sender, &shutdown);
167
168        loop {
169            tokio::select! {
170                payload = receiver.next() => {
171                    if let Some(payload) = payload {
172                        match payload {
173                            Ok(Message::Binary(binary)) => match bincode::deserialize::<Payload>(&binary) {
174                                Ok(payload) => drop(request_sender.send_async(payload).await),
175                                Err(err) => {
176                                    log::error!("[server] error decoding message: {:?}", err);
177                                    break;
178                                }
179                            },
180                            Ok(Message::Close(_)) => break,
181                            Ok(Message::Ping(payload)) => {
182                                drop(message_sender.send(Message::Pong(payload)));
183                            }
184                            other => {
185                                log::error!("[server] unexpected message: {:?}", other);
186                                break;
187                            }
188                        }
189                    } else {
190                        return;
191                    }
192                },
193                shutdown = shutdown.wait_for_shutdown() => {
194                    if matches!(shutdown, ShutdownState::Shutdown) {
195                        return;
196                    }
197                }
198            }
199        }
200    }
201
202    fn spawn_client_request_handler(
203        &self,
204        client: OwnedClient<B>,
205        request_receiver: flume::Receiver<Payload>,
206        response_sender: flume::Sender<Payload>,
207        shutdown: &ShutdownStateWatcher,
208    ) {
209        tokio::spawn({
210            let task_self = self.clone();
211            let shutdown = shutdown.clone();
212            async move {
213                task_self
214                    .handle_client_requests(
215                        client.clone(),
216                        request_receiver,
217                        response_sender,
218                        shutdown,
219                    )
220                    .await;
221            }
222        });
223    }
224}
225
226#[cfg(feature = "hyper")]
227fn compute_websocket_accept_header(key: &[u8]) -> hyper::header::HeaderValue {
228    use base64::engine::general_purpose::STANDARD as BASE64;
229    use base64::Engine;
230    use sha1::{Digest, Sha1};
231
232    let mut digest = Sha1::default();
233    digest.update(key);
234    digest.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
235    let encoded = BASE64.encode(digest.finalize());
236    hyper::header::HeaderValue::from_str(&encoded).expect("base64 is a valid value")
237}
238
239struct VersionChecker;
240
241impl tokio_tungstenite::tungstenite::handshake::server::Callback for VersionChecker {
242    fn on_request(
243        self,
244        request: &tokio_tungstenite::tungstenite::handshake::server::Request,
245        mut response: tokio_tungstenite::tungstenite::handshake::server::Response,
246    ) -> Result<
247        tokio_tungstenite::tungstenite::handshake::server::Response,
248        tokio_tungstenite::tungstenite::handshake::server::ErrorResponse,
249    > {
250        if let Some(protocols) = request.headers().get("Sec-WebSocket-Protocol") {
251            if let Ok(protocols) = protocols.to_str() {
252                for protocol in protocols.split(',').map(str::trim) {
253                    if protocol == CURRENT_PROTOCOL_VERSION {
254                        response.headers_mut().insert(
255                            "Sec-WebSocket-Protocol",
256                            CURRENT_PROTOCOL_VERSION.try_into().unwrap(),
257                        );
258                        return Ok(response);
259                    }
260                }
261            }
262        }
263
264        let mut err = tokio_tungstenite::tungstenite::handshake::server::ErrorResponse::new(None);
265        *err.status_mut() = 406_u16.try_into().unwrap();
266        Err(err)
267    }
268}