viz_core/types/
websocket.rs

1//! `WebSocket` Extractor
2
3use std::{borrow::Cow, future::Future, str};
4
5use hyper::upgrade::{OnUpgrade, Upgraded};
6use tokio_tungstenite::tungstenite::protocol::Role;
7
8use crate::{
9    Body, FromRequest, IntoResponse, Io, Request, RequestExt, Response, Result, StatusCode,
10    header::{SEC_WEBSOCKET_PROTOCOL, UPGRADE},
11    headers::{
12        Connection, HeaderMapExt, HeaderValue, SecWebsocketAccept, SecWebsocketKey,
13        SecWebsocketVersion, Upgrade,
14    },
15};
16
17mod error;
18
19pub use error::WebSocketError;
20pub use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
21
22/// A wrapper around an underlying raw stream which implements the `WebSocket` protocol.
23pub type WebSocketStream<T = Io<Upgraded>> = tokio_tungstenite::WebSocketStream<T>;
24
25/// Then `WebSocket` provides the API for creating and managing a [`WebSocket`][mdn] connection,
26/// as well as for sending and receiving data on the connection.
27///
28/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/WebSocket>
29#[derive(Debug)]
30pub struct WebSocket {
31    key: SecWebsocketKey,
32    on_upgrade: Option<OnUpgrade>,
33    protocols: Option<Box<[Cow<'static, str>]>>,
34    sec_websocket_protocol: Option<HeaderValue>,
35}
36
37impl WebSocket {
38    const NAME: &'static [u8] = b"websocket";
39
40    /// The specifies one or more protocols that you wish to use.
41    ///
42    /// In order of preference. The first one that is supported by the server will be
43    /// selected and responsed.
44    #[must_use]
45    pub fn protocols<I>(mut self, protocols: I) -> Self
46    where
47        I: IntoIterator,
48        I::Item: Into<Cow<'static, str>>,
49    {
50        self.protocols = Some(
51            protocols
52                .into_iter()
53                .map(Into::into)
54                .collect::<Vec<_>>()
55                .into(),
56        );
57        self
58    }
59
60    /// Finish the upgrade, passing a function and a [`WebSocketConfig`] to handle the `WebSocket`.
61    ///
62    /// # Panics
63    ///
64    /// When missing `OnUpgrade`
65    #[must_use]
66    pub fn on_upgrade_with_config<F, Fut>(
67        mut self,
68        callback: F,
69        config: Option<WebSocketConfig>,
70    ) -> Response
71    where
72        F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
73        Fut: Future<Output = ()> + Send + 'static,
74    {
75        let on_upgrade = self.on_upgrade.take().expect("missing OnUpgrade");
76
77        tokio::task::spawn(async move {
78            let Ok(upgraded) = on_upgrade.await else {
79                return;
80            };
81
82            let socket =
83                WebSocketStream::from_raw_socket(Io::new(upgraded), Role::Server, config).await;
84
85            (callback)(socket).await;
86        });
87
88        self.into_response()
89    }
90
91    /// Finish the upgrade, passing a function to handle the `WebSocket`.
92    pub fn on_upgrade<F, Fut>(self, callback: F) -> Response
93    where
94        F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
95        Fut: Future<Output = ()> + Send + 'static,
96    {
97        self.on_upgrade_with_config(callback, None)
98    }
99}
100
101impl FromRequest for WebSocket {
102    type Error = WebSocketError;
103
104    async fn extract(req: &mut Request) -> Result<Self, Self::Error> {
105        // check connection header
106        req.header_typed::<Connection>()
107            .ok_or(WebSocketError::MissingConnectUpgrade)
108            .and_then(|h| {
109                if h.contains(UPGRADE) {
110                    Ok(())
111                } else {
112                    Err(WebSocketError::InvalidConnectUpgrade)
113                }
114            })?;
115
116        // check upgrade header
117        req.headers()
118            .get(UPGRADE)
119            .ok_or(WebSocketError::MissingUpgrade)
120            .and_then(|h| {
121                if h.as_bytes().eq_ignore_ascii_case(Self::NAME) {
122                    Ok(())
123                } else {
124                    Err(WebSocketError::InvalidUpgrade)
125                }
126            })?;
127
128        // check sec-websocket-version header
129        req.header_typed::<SecWebsocketVersion>()
130            .ok_or(WebSocketError::MissingWebSocketVersion)
131            .and_then(|h| {
132                if h == SecWebsocketVersion::V13 {
133                    Ok(())
134                } else {
135                    Err(WebSocketError::InvalidWebSocketVersion)
136                }
137            })?;
138
139        let key = req
140            .header_typed::<SecWebsocketKey>()
141            .ok_or(WebSocketError::MissingWebSocketKey)?;
142
143        let on_upgrade = req.extensions_mut().remove::<OnUpgrade>();
144
145        if on_upgrade.is_none() {
146            Err(WebSocketError::ConnectionNotUpgradable)?;
147        }
148
149        let sec_websocket_protocol = req.headers().get(SEC_WEBSOCKET_PROTOCOL).cloned();
150
151        Ok(Self {
152            key,
153            on_upgrade,
154            protocols: None,
155            sec_websocket_protocol,
156        })
157    }
158}
159
160impl IntoResponse for WebSocket {
161    fn into_response(self) -> Response {
162        let protocol = self
163            .sec_websocket_protocol
164            .as_ref()
165            .and_then(|req_protocols| {
166                let req_protocols = req_protocols.to_str().ok()?;
167                let protocols = self.protocols.as_ref()?;
168                req_protocols
169                    .split(',')
170                    .map(str::trim)
171                    .find(|req_p| protocols.iter().any(|p| p == req_p))
172                    .and_then(|v| HeaderValue::from_str(v).ok())
173            });
174
175        let mut res = Response::new(Body::Empty);
176
177        *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
178        res.headers_mut().typed_insert(Connection::upgrade());
179        res.headers_mut().typed_insert(Upgrade::websocket());
180        res.headers_mut()
181            .typed_insert(SecWebsocketAccept::from(self.key));
182
183        if let Some(protocol) = protocol {
184            res.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, protocol);
185        }
186
187        res
188    }
189}