axum_socket_io/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4use axum::{
5    async_trait,
6    body::Bytes,
7    extract::FromRequestParts,
8    http::{header, request::Parts, HeaderMap, HeaderName, HeaderValue, Method, StatusCode},
9};
10use hyper_util::rt::TokioIo;
11use std::future::Future;
12
13pub use web_socket_io::*;
14
15/// Extractor for establishing `SocketIo` connections.
16pub struct SocketIoUpgrade {
17    sec_websocket_key: HeaderValue,
18    on_upgrade: hyper::upgrade::OnUpgrade,
19}
20
21impl SocketIoUpgrade {
22    /// Finalize upgrading the connection and call the provided callback with `SocketIo` instance.
23    ///
24    /// ## Arguments
25    ///
26    /// * `buffer` - The size of the buffer to be used in the `SocketIo` instance.
27    /// * `callback` - A function that will be called with the upgraded `SocketIo` instance.
28    pub fn on_upgrade<C, Fut>(self, buffer: usize, callback: C) -> axum::response::Response
29    where
30        C: FnOnce(SocketIo) -> Fut + Send + 'static,
31        Fut: Future<Output = ()> + Send + 'static,
32    {
33        tokio::spawn(async move {
34            if let Ok(upgraded) = self.on_upgrade.await {
35                let (reader, writer) = tokio::io::split(TokioIo::new(upgraded));
36                callback(SocketIo::new(reader, writer, buffer)).await;
37            }
38        });
39
40        static H_UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
41        static H_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
42        static H_WS_PROTOCOL: HeaderValue = HeaderValue::from_static("websocket.io-rpc-v0.1");
43
44        axum::response::Response::builder()
45            .status(StatusCode::SWITCHING_PROTOCOLS)
46            .header(header::CONNECTION, H_UPGRADE.clone())
47            .header(header::UPGRADE, H_WEBSOCKET.clone())
48            .header(header::SEC_WEBSOCKET_PROTOCOL, H_WS_PROTOCOL.clone())
49            .header(
50                header::SEC_WEBSOCKET_ACCEPT,
51                sign(self.sec_websocket_key.as_bytes()),
52            )
53            .body(axum::body::Body::empty())
54            .unwrap()
55    }
56}
57
58#[async_trait]
59impl<S> FromRequestParts<S> for SocketIoUpgrade
60where
61    S: Send + Sync,
62{
63    type Rejection = ();
64
65    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
66        if parts.method != Method::GET {
67            return Err(());
68        }
69        if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
70            return Err(());
71        }
72        if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
73            return Err(());
74        }
75        if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
76            return Err(());
77        }
78        if !header_eq(
79            &parts.headers,
80            header::SEC_WEBSOCKET_PROTOCOL,
81            "websocket.io-rpc-v0.1",
82        ) {
83            return Err(());
84        }
85        Ok(Self {
86            sec_websocket_key: parts
87                .headers
88                .get(header::SEC_WEBSOCKET_KEY)
89                .ok_or(())?
90                .clone(),
91
92            on_upgrade: parts
93                .extensions
94                .remove::<hyper::upgrade::OnUpgrade>()
95                .ok_or(())?,
96        })
97    }
98}
99
100fn sign(key: &[u8]) -> HeaderValue {
101    use base64::engine::Engine as _;
102    use sha1::{Digest, Sha1};
103
104    let mut sha1 = Sha1::default();
105    sha1.update(key);
106    sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
107    let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
108    HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
109}
110
111fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
112    if let Some(header) = headers.get(&key) {
113        header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
114    } else {
115        false
116    }
117}
118
119fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
120    let header = if let Some(header) = headers.get(&key) {
121        header
122    } else {
123        return false;
124    };
125    if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
126        header.to_ascii_lowercase().contains(value)
127    } else {
128        false
129    }
130}