viz_core/types/
websocket.rs1use std::{borrow::Cow, future::Future, str};
4
5use hyper::upgrade::{OnUpgrade, Upgraded};
6use tokio_tungstenite::tungstenite::protocol::Role;
7
8use crate::{
9 Body, FromRequest, IntoResponse, Io, Request, RequestExt, Response, Result, StatusCode,
10 header::{SEC_WEBSOCKET_PROTOCOL, UPGRADE},
11 headers::{
12 Connection, HeaderMapExt, HeaderValue, SecWebsocketAccept, SecWebsocketKey,
13 SecWebsocketVersion, Upgrade,
14 },
15};
16
17mod error;
18
19pub use error::WebSocketError;
20pub use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
21
22pub type WebSocketStream<T = Io<Upgraded>> = tokio_tungstenite::WebSocketStream<T>;
24
25#[derive(Debug)]
30pub struct WebSocket {
31 key: SecWebsocketKey,
32 on_upgrade: Option<OnUpgrade>,
33 protocols: Option<Box<[Cow<'static, str>]>>,
34 sec_websocket_protocol: Option<HeaderValue>,
35}
36
37impl WebSocket {
38 const NAME: &'static [u8] = b"websocket";
39
40 #[must_use]
45 pub fn protocols<I>(mut self, protocols: I) -> Self
46 where
47 I: IntoIterator,
48 I::Item: Into<Cow<'static, str>>,
49 {
50 self.protocols = Some(
51 protocols
52 .into_iter()
53 .map(Into::into)
54 .collect::<Vec<_>>()
55 .into(),
56 );
57 self
58 }
59
60 #[must_use]
66 pub fn on_upgrade_with_config<F, Fut>(
67 mut self,
68 callback: F,
69 config: Option<WebSocketConfig>,
70 ) -> Response
71 where
72 F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
73 Fut: Future<Output = ()> + Send + 'static,
74 {
75 let on_upgrade = self.on_upgrade.take().expect("missing OnUpgrade");
76
77 tokio::task::spawn(async move {
78 let Ok(upgraded) = on_upgrade.await else {
79 return;
80 };
81
82 let socket =
83 WebSocketStream::from_raw_socket(Io::new(upgraded), Role::Server, config).await;
84
85 (callback)(socket).await;
86 });
87
88 self.into_response()
89 }
90
91 pub fn on_upgrade<F, Fut>(self, callback: F) -> Response
93 where
94 F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
95 Fut: Future<Output = ()> + Send + 'static,
96 {
97 self.on_upgrade_with_config(callback, None)
98 }
99}
100
101impl FromRequest for WebSocket {
102 type Error = WebSocketError;
103
104 async fn extract(req: &mut Request) -> Result<Self, Self::Error> {
105 req.header_typed::<Connection>()
107 .ok_or(WebSocketError::MissingConnectUpgrade)
108 .and_then(|h| {
109 if h.contains(UPGRADE) {
110 Ok(())
111 } else {
112 Err(WebSocketError::InvalidConnectUpgrade)
113 }
114 })?;
115
116 req.headers()
118 .get(UPGRADE)
119 .ok_or(WebSocketError::MissingUpgrade)
120 .and_then(|h| {
121 if h.as_bytes().eq_ignore_ascii_case(Self::NAME) {
122 Ok(())
123 } else {
124 Err(WebSocketError::InvalidUpgrade)
125 }
126 })?;
127
128 req.header_typed::<SecWebsocketVersion>()
130 .ok_or(WebSocketError::MissingWebSocketVersion)
131 .and_then(|h| {
132 if h == SecWebsocketVersion::V13 {
133 Ok(())
134 } else {
135 Err(WebSocketError::InvalidWebSocketVersion)
136 }
137 })?;
138
139 let key = req
140 .header_typed::<SecWebsocketKey>()
141 .ok_or(WebSocketError::MissingWebSocketKey)?;
142
143 let on_upgrade = req.extensions_mut().remove::<OnUpgrade>();
144
145 if on_upgrade.is_none() {
146 Err(WebSocketError::ConnectionNotUpgradable)?;
147 }
148
149 let sec_websocket_protocol = req.headers().get(SEC_WEBSOCKET_PROTOCOL).cloned();
150
151 Ok(Self {
152 key,
153 on_upgrade,
154 protocols: None,
155 sec_websocket_protocol,
156 })
157 }
158}
159
160impl IntoResponse for WebSocket {
161 fn into_response(self) -> Response {
162 let protocol = self
163 .sec_websocket_protocol
164 .as_ref()
165 .and_then(|req_protocols| {
166 let req_protocols = req_protocols.to_str().ok()?;
167 let protocols = self.protocols.as_ref()?;
168 req_protocols
169 .split(',')
170 .map(str::trim)
171 .find(|req_p| protocols.iter().any(|p| p == req_p))
172 .and_then(|v| HeaderValue::from_str(v).ok())
173 });
174
175 let mut res = Response::new(Body::Empty);
176
177 *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
178 res.headers_mut().typed_insert(Connection::upgrade());
179 res.headers_mut().typed_insert(Upgrade::websocket());
180 res.headers_mut()
181 .typed_insert(SecWebsocketAccept::from(self.key));
182
183 if let Some(protocol) = protocol {
184 res.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, protocol);
185 }
186
187 res
188 }
189}