1use crate::{WebSocketError, WebSocketStream, WsHeartbeatConfig};
4use bytes::Bytes;
5use http::{header, Response, StatusCode};
6use http_body_util::Full;
7use hyper::upgrade::OnUpgrade;
8use hyper_util::rt::TokioIo;
9use rustapi_core::IntoResponse;
10use rustapi_openapi::{Operation, ResponseModifier, ResponseSpec};
11use std::future::Future;
12use std::pin::Pin;
13use tokio_tungstenite::tungstenite::protocol::Role;
14
15type UpgradeCallback =
17 Box<dyn FnOnce(WebSocketStream) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
18
19use crate::compression::WsCompressionConfig;
24
25pub struct WebSocketUpgrade {
30 response: Response<Full<Bytes>>,
32 on_upgrade: Option<UpgradeCallback>,
34 #[allow(dead_code)]
36 sec_key: String,
37 client_extensions: Option<String>,
39 compression: Option<WsCompressionConfig>,
41 heartbeat: Option<WsHeartbeatConfig>,
43 on_upgrade_fut: Option<OnUpgrade>,
45}
46
47impl WebSocketUpgrade {
48 pub(crate) fn new(
50 sec_key: String,
51 client_extensions: Option<String>,
52 on_upgrade_fut: Option<OnUpgrade>,
53 ) -> Self {
54 let accept_key = generate_accept_key(&sec_key);
56
57 let response = Response::builder()
59 .status(StatusCode::SWITCHING_PROTOCOLS)
60 .header(header::UPGRADE, "websocket")
61 .header(header::CONNECTION, "Upgrade")
62 .header("Sec-WebSocket-Accept", accept_key)
63 .body(Full::new(Bytes::new()))
64 .unwrap();
65
66 Self {
67 response,
68 on_upgrade: None,
69 sec_key,
70 client_extensions,
71 compression: None,
72 heartbeat: None,
73 on_upgrade_fut,
74 }
75 }
76
77 pub fn heartbeat(mut self, config: WsHeartbeatConfig) -> Self {
79 self.heartbeat = Some(config);
80 self
81 }
82
83 pub fn compress(mut self, config: WsCompressionConfig) -> Self {
85 self.compression = Some(config);
86
87 if let Some(exts) = &self.client_extensions {
89 if exts.contains("permessage-deflate") {
90 let mut header_val = String::from("permessage-deflate");
93
94 header_val.push_str("; server_no_context_takeover");
97 header_val.push_str("; client_no_context_takeover");
98
99 if config.window_bits < 15 {
100 header_val
101 .push_str(&format!("; server_max_window_bits={}", config.window_bits));
102 }
103 if config.client_window_bits < 15 {
104 header_val.push_str(&format!(
105 "; client_max_window_bits={}",
106 config.client_window_bits
107 ));
108 }
109
110 if let Ok(val) = header::HeaderValue::from_str(&header_val) {
111 self.response
112 .headers_mut()
113 .insert("Sec-WebSocket-Extensions", val);
114 }
115 }
116 }
117 self
118 }
119
120 pub fn on_upgrade<F, Fut>(mut self, callback: F) -> Self
133 where
134 F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
135 Fut: Future<Output = ()> + Send + 'static,
136 {
137 self.on_upgrade = Some(Box::new(move |stream| Box::pin(callback(stream))));
138 self
139 }
140
141 pub fn protocol(mut self, protocol: &str) -> Self {
143 self.response.headers_mut().insert(
146 "Sec-WebSocket-Protocol",
147 header::HeaderValue::from_str(protocol).unwrap(),
148 );
149 self
150 }
151
152 #[allow(dead_code)]
154 pub(crate) fn into_response_inner(self) -> Response<Full<Bytes>> {
155 self.response
156 }
157
158 #[allow(dead_code)]
160 pub(crate) fn take_callback(&mut self) -> Option<UpgradeCallback> {
161 self.on_upgrade.take()
162 }
163}
164
165impl IntoResponse for WebSocketUpgrade {
166 fn into_response(mut self) -> http::Response<Full<Bytes>> {
167 if let (Some(on_upgrade), Some(callback)) =
169 (self.on_upgrade_fut.take(), self.on_upgrade.take())
170 {
171 let heartbeat = self.heartbeat;
172
173 tokio::spawn(async move {
177 match on_upgrade.await {
178 Ok(upgraded) => {
179 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
180 TokioIo::new(upgraded),
181 Role::Server,
182 None,
183 )
184 .await;
185
186 let socket = if let Some(hb_config) = heartbeat {
187 WebSocketStream::new_managed(ws_stream, hb_config)
188 } else {
189 WebSocketStream::new(ws_stream)
190 };
191
192 callback(socket).await;
193 }
194 Err(e) => {
195 tracing::error!("WebSocket upgrade failed: {:?}", e);
196 if let Some(source) = std::error::Error::source(&e) {
198 tracing::error!("Cause: {:?}", source);
199 }
200 }
201 }
202 });
203 }
204
205 self.response
206 }
207}
208
209impl ResponseModifier for WebSocketUpgrade {
210 fn update_response(op: &mut Operation) {
211 op.responses.insert(
212 "101".to_string(),
213 ResponseSpec {
214 description: "WebSocket upgrade successful".to_string(),
215 content: None,
216 },
217 );
218 }
219}
220
221fn generate_accept_key(key: &str) -> String {
223 use base64::Engine;
224 use sha1::{Digest, Sha1};
225
226 const GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
227
228 let mut hasher = Sha1::new();
229 hasher.update(key.as_bytes());
230 hasher.update(GUID.as_bytes());
231 let hash = hasher.finalize();
232
233 base64::engine::general_purpose::STANDARD.encode(hash)
234}
235
236pub(crate) fn validate_upgrade_request(
238 method: &http::Method,
239 headers: &http::HeaderMap,
240) -> Result<String, WebSocketError> {
241 if method != http::Method::GET {
243 return Err(WebSocketError::invalid_upgrade("Method must be GET"));
244 }
245
246 let upgrade = headers
248 .get(header::UPGRADE)
249 .and_then(|v| v.to_str().ok())
250 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Upgrade header"))?;
251
252 if !upgrade.eq_ignore_ascii_case("websocket") {
253 return Err(WebSocketError::invalid_upgrade(
254 "Upgrade header must be 'websocket'",
255 ));
256 }
257
258 let connection = headers
260 .get(header::CONNECTION)
261 .and_then(|v| v.to_str().ok())
262 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Connection header"))?;
263
264 let has_upgrade = connection
265 .split(',')
266 .any(|s| s.trim().eq_ignore_ascii_case("upgrade"));
267
268 if !has_upgrade {
269 return Err(WebSocketError::invalid_upgrade(
270 "Connection header must contain 'Upgrade'",
271 ));
272 }
273
274 let sec_key = headers
276 .get("Sec-WebSocket-Key")
277 .and_then(|v| v.to_str().ok())
278 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Key header"))?;
279
280 let version = headers
282 .get("Sec-WebSocket-Version")
283 .and_then(|v| v.to_str().ok())
284 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Version header"))?;
285
286 if version != "13" {
287 return Err(WebSocketError::invalid_upgrade(
288 "Sec-WebSocket-Version must be 13",
289 ));
290 }
291
292 Ok(sec_key.to_string())
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_accept_key_generation() {
301 let key = "dGhlIHNhbXBsZSBub25jZQ==";
303 let accept = generate_accept_key(key);
304 assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
305 }
306}