1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4use axum::{
5 async_trait,
6 body::Bytes,
7 extract::FromRequestParts,
8 http::{header, request::Parts, HeaderMap, HeaderName, HeaderValue, Method, StatusCode},
9};
10use hyper_util::rt::TokioIo;
11use std::future::Future;
12
13pub use web_socket_io::*;
14
15pub struct SocketIoUpgrade {
17 sec_websocket_key: HeaderValue,
18 on_upgrade: hyper::upgrade::OnUpgrade,
19}
20
21impl SocketIoUpgrade {
22 pub fn on_upgrade<C, Fut>(self, buffer: usize, callback: C) -> axum::response::Response
29 where
30 C: FnOnce(SocketIo) -> Fut + Send + 'static,
31 Fut: Future<Output = ()> + Send + 'static,
32 {
33 tokio::spawn(async move {
34 if let Ok(upgraded) = self.on_upgrade.await {
35 let (reader, writer) = tokio::io::split(TokioIo::new(upgraded));
36 callback(SocketIo::new(reader, writer, buffer)).await;
37 }
38 });
39
40 static H_UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
41 static H_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
42 static H_WS_PROTOCOL: HeaderValue = HeaderValue::from_static("websocket.io-rpc-v0.1");
43
44 axum::response::Response::builder()
45 .status(StatusCode::SWITCHING_PROTOCOLS)
46 .header(header::CONNECTION, H_UPGRADE.clone())
47 .header(header::UPGRADE, H_WEBSOCKET.clone())
48 .header(header::SEC_WEBSOCKET_PROTOCOL, H_WS_PROTOCOL.clone())
49 .header(
50 header::SEC_WEBSOCKET_ACCEPT,
51 sign(self.sec_websocket_key.as_bytes()),
52 )
53 .body(axum::body::Body::empty())
54 .unwrap()
55 }
56}
57
58#[async_trait]
59impl<S> FromRequestParts<S> for SocketIoUpgrade
60where
61 S: Send + Sync,
62{
63 type Rejection = ();
64
65 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
66 if parts.method != Method::GET {
67 return Err(());
68 }
69 if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
70 return Err(());
71 }
72 if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
73 return Err(());
74 }
75 if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
76 return Err(());
77 }
78 if !header_eq(
79 &parts.headers,
80 header::SEC_WEBSOCKET_PROTOCOL,
81 "websocket.io-rpc-v0.1",
82 ) {
83 return Err(());
84 }
85 Ok(Self {
86 sec_websocket_key: parts
87 .headers
88 .get(header::SEC_WEBSOCKET_KEY)
89 .ok_or(())?
90 .clone(),
91
92 on_upgrade: parts
93 .extensions
94 .remove::<hyper::upgrade::OnUpgrade>()
95 .ok_or(())?,
96 })
97 }
98}
99
100fn sign(key: &[u8]) -> HeaderValue {
101 use base64::engine::Engine as _;
102 use sha1::{Digest, Sha1};
103
104 let mut sha1 = Sha1::default();
105 sha1.update(key);
106 sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
107 let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
108 HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
109}
110
111fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
112 if let Some(header) = headers.get(&key) {
113 header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
114 } else {
115 false
116 }
117}
118
119fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
120 let header = if let Some(header) = headers.get(&key) {
121 header
122 } else {
123 return false;
124 };
125 if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
126 header.to_ascii_lowercase().contains(value)
127 } else {
128 false
129 }
130}