1use crate::{WebSocketError, WebSocketStream, WsHeartbeatConfig};
4use http::{header, Response, StatusCode};
5use hyper::upgrade::OnUpgrade;
6use hyper_util::rt::TokioIo;
7use rustapi_core::{IntoResponse, ResponseBody};
8use rustapi_openapi::{Operation, ResponseModifier, ResponseSpec};
9use std::collections::BTreeMap;
10use std::future::Future;
11use std::pin::Pin;
12use tokio_tungstenite::tungstenite::protocol::Role;
13
14type UpgradeCallback =
16 Box<dyn FnOnce(WebSocketStream) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
17
18use crate::compression::WsCompressionConfig;
23
24pub struct WebSocketUpgrade {
29 response: Response<ResponseBody>,
31 on_upgrade: Option<UpgradeCallback>,
33 #[allow(dead_code)]
35 sec_key: String,
36 client_extensions: Option<String>,
38 compression: Option<WsCompressionConfig>,
40 heartbeat: Option<WsHeartbeatConfig>,
42 on_upgrade_fut: Option<OnUpgrade>,
44}
45
46impl WebSocketUpgrade {
47 pub(crate) fn new(
49 sec_key: String,
50 client_extensions: Option<String>,
51 on_upgrade_fut: Option<OnUpgrade>,
52 ) -> Self {
53 let accept_key = generate_accept_key(&sec_key);
55
56 let response = Response::builder()
58 .status(StatusCode::SWITCHING_PROTOCOLS)
59 .header(header::UPGRADE, "websocket")
60 .header(header::CONNECTION, "Upgrade")
61 .header("Sec-WebSocket-Accept", accept_key)
62 .body(ResponseBody::empty())
63 .unwrap();
64
65 Self {
66 response,
67 on_upgrade: None,
68 sec_key,
69 client_extensions,
70 compression: None,
71 heartbeat: None,
72 on_upgrade_fut,
73 }
74 }
75
76 pub fn heartbeat(mut self, config: WsHeartbeatConfig) -> Self {
78 self.heartbeat = Some(config);
79 self
80 }
81
82 pub fn compress(mut self, config: WsCompressionConfig) -> Self {
84 self.compression = Some(config);
85
86 if let Some(exts) = &self.client_extensions {
87 if let Some(header_val) = negotiate_permessage_deflate(exts, config) {
88 if let Ok(val) = header::HeaderValue::from_str(&header_val) {
89 self.response
90 .headers_mut()
91 .insert("Sec-WebSocket-Extensions", val);
92 }
93 }
94 }
95 self
96 }
97
98 pub fn on_upgrade<F, Fut>(mut self, callback: F) -> Self
111 where
112 F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
113 Fut: Future<Output = ()> + Send + 'static,
114 {
115 self.on_upgrade = Some(Box::new(move |stream| Box::pin(callback(stream))));
116 self
117 }
118
119 pub fn protocol(mut self, protocol: &str) -> Self {
121 self.response.headers_mut().insert(
124 "Sec-WebSocket-Protocol",
125 header::HeaderValue::from_str(protocol).unwrap(),
126 );
127 self
128 }
129
130 #[allow(dead_code)]
132 pub(crate) fn into_response_inner(self) -> Response<ResponseBody> {
133 self.response
134 }
135
136 #[allow(dead_code)]
138 pub(crate) fn take_callback(&mut self) -> Option<UpgradeCallback> {
139 self.on_upgrade.take()
140 }
141}
142
143impl IntoResponse for WebSocketUpgrade {
144 fn into_response(mut self) -> rustapi_core::Response {
145 if let (Some(on_upgrade), Some(callback)) =
147 (self.on_upgrade_fut.take(), self.on_upgrade.take())
148 {
149 let heartbeat = self.heartbeat;
150
151 tokio::spawn(async move {
155 match on_upgrade.await {
156 Ok(upgraded) => {
157 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
158 TokioIo::new(upgraded),
159 Role::Server,
160 None,
161 )
162 .await;
163
164 let socket = if let Some(hb_config) = heartbeat {
165 WebSocketStream::new_managed(ws_stream, hb_config)
166 } else {
167 WebSocketStream::new(ws_stream)
168 };
169
170 callback(socket).await;
171 }
172 Err(e) => {
173 tracing::error!("WebSocket upgrade failed: {:?}", e);
174 if let Some(source) = std::error::Error::source(&e) {
176 tracing::error!("Cause: {:?}", source);
177 }
178 }
179 }
180 });
181 }
182
183 self.response
184 }
185}
186
187impl ResponseModifier for WebSocketUpgrade {
188 fn update_response(op: &mut Operation) {
189 op.responses.insert(
190 "101".to_string(),
191 ResponseSpec {
192 description: "WebSocket upgrade successful".to_string(),
193 content: BTreeMap::new(),
194 headers: BTreeMap::new(),
195 },
196 );
197 }
198}
199
200#[derive(Debug)]
201struct ParsedExtension {
202 name: String,
203 params: Vec<(String, Option<String>)>,
204}
205
206#[derive(Debug, Default)]
207struct PerMessageDeflateOffer {
208 server_no_context_takeover: bool,
209 client_no_context_takeover: bool,
210 server_max_window_bits: Option<Option<u8>>,
211 client_max_window_bits: Option<Option<u8>>,
212}
213
214fn negotiate_permessage_deflate(
215 client_extensions: &str,
216 config: WsCompressionConfig,
217) -> Option<String> {
218 for ext in parse_extension_offers(client_extensions) {
219 if ext.name != "permessage-deflate" {
220 continue;
221 }
222
223 let Some(offer) = parse_permessage_deflate_offer(&ext) else {
224 continue;
225 };
226 let mut negotiated = vec!["permessage-deflate".to_string()];
227
228 if offer.server_no_context_takeover {
229 negotiated.push("server_no_context_takeover".to_string());
230 }
231 if offer.client_no_context_takeover {
232 negotiated.push("client_no_context_takeover".to_string());
233 }
234
235 if let Some(requested) = offer.server_max_window_bits {
236 let bits = requested
237 .map(|max| config.window_bits.min(max))
238 .unwrap_or(config.window_bits);
239 negotiated.push(format!("server_max_window_bits={}", bits));
240 }
241
242 if let Some(requested) = offer.client_max_window_bits {
243 let bits = requested
244 .map(|max| config.client_window_bits.min(max))
245 .unwrap_or(config.client_window_bits);
246 negotiated.push(format!("client_max_window_bits={}", bits));
247 }
248
249 return Some(negotiated.join("; "));
250 }
251
252 None
253}
254
255fn parse_extension_offers(header_value: &str) -> Vec<ParsedExtension> {
256 let mut offers = Vec::new();
257
258 for raw_extension in header_value.split(',') {
259 let mut parts = raw_extension
260 .split(';')
261 .map(|part| part.trim())
262 .filter(|part| !part.is_empty());
263
264 let Some(name) = parts.next() else {
265 continue;
266 };
267
268 let mut params = Vec::new();
269 for raw_param in parts {
270 let (key, value) = parse_extension_param(raw_param);
271 params.push((key, value));
272 }
273
274 offers.push(ParsedExtension {
275 name: name.to_ascii_lowercase(),
276 params,
277 });
278 }
279
280 offers
281}
282
283fn parse_extension_param(raw_param: &str) -> (String, Option<String>) {
284 if let Some((key, value)) = raw_param.split_once('=') {
285 let value = value.trim().trim_matches('"').to_string();
286 (key.trim().to_ascii_lowercase(), Some(value))
287 } else {
288 (raw_param.trim().to_ascii_lowercase(), None)
289 }
290}
291
292fn parse_permessage_deflate_offer(ext: &ParsedExtension) -> Option<PerMessageDeflateOffer> {
293 let mut offer = PerMessageDeflateOffer::default();
294
295 for (key, value) in &ext.params {
296 match key.as_str() {
297 "server_no_context_takeover" => {
298 if value.is_some() || offer.server_no_context_takeover {
299 return None;
300 }
301 offer.server_no_context_takeover = true;
302 }
303 "client_no_context_takeover" => {
304 if value.is_some() || offer.client_no_context_takeover {
305 return None;
306 }
307 offer.client_no_context_takeover = true;
308 }
309 "server_max_window_bits" => {
310 if offer.server_max_window_bits.is_some() {
311 return None;
312 }
313 let parsed = match value {
314 Some(v) => Some(parse_window_bits(v)?),
315 None => None,
316 };
317 offer.server_max_window_bits = Some(parsed);
318 }
319 "client_max_window_bits" => {
320 if offer.client_max_window_bits.is_some() {
321 return None;
322 }
323 let parsed = match value {
324 Some(v) => Some(parse_window_bits(v)?),
325 None => None,
326 };
327 offer.client_max_window_bits = Some(parsed);
328 }
329 _ => {
330 }
332 }
333 }
334
335 Some(offer)
336}
337
338fn parse_window_bits(value: &str) -> Option<u8> {
339 let parsed = value.parse::<u8>().ok()?;
340 if (9..=15).contains(&parsed) {
341 Some(parsed)
342 } else {
343 None
344 }
345}
346
347fn generate_accept_key(key: &str) -> String {
349 use base64::Engine;
350 use sha1::{Digest, Sha1};
351
352 const GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
353
354 let mut hasher = Sha1::new();
355 hasher.update(key.as_bytes());
356 hasher.update(GUID.as_bytes());
357 let hash = hasher.finalize();
358
359 base64::engine::general_purpose::STANDARD.encode(hash)
360}
361
362pub(crate) fn validate_upgrade_request(
364 method: &http::Method,
365 headers: &http::HeaderMap,
366) -> Result<String, WebSocketError> {
367 if method != http::Method::GET {
369 return Err(WebSocketError::invalid_upgrade("Method must be GET"));
370 }
371
372 let upgrade = headers
374 .get(header::UPGRADE)
375 .and_then(|v| v.to_str().ok())
376 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Upgrade header"))?;
377
378 if !upgrade.eq_ignore_ascii_case("websocket") {
379 return Err(WebSocketError::invalid_upgrade(
380 "Upgrade header must be 'websocket'",
381 ));
382 }
383
384 let connection = headers
386 .get(header::CONNECTION)
387 .and_then(|v| v.to_str().ok())
388 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Connection header"))?;
389
390 let has_upgrade = connection
391 .split(',')
392 .any(|s| s.trim().eq_ignore_ascii_case("upgrade"));
393
394 if !has_upgrade {
395 return Err(WebSocketError::invalid_upgrade(
396 "Connection header must contain 'Upgrade'",
397 ));
398 }
399
400 let sec_key = headers
402 .get("Sec-WebSocket-Key")
403 .and_then(|v| v.to_str().ok())
404 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Key header"))?;
405
406 let version = headers
408 .get("Sec-WebSocket-Version")
409 .and_then(|v| v.to_str().ok())
410 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Version header"))?;
411
412 if version != "13" {
413 return Err(WebSocketError::invalid_upgrade(
414 "Sec-WebSocket-Version must be 13",
415 ));
416 }
417
418 Ok(sec_key.to_string())
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::WsCompressionConfig;
425
426 #[test]
427 fn test_accept_key_generation() {
428 let key = "dGhlIHNhbXBsZSBub25jZQ==";
430 let accept = generate_accept_key(key);
431 assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
432 }
433
434 #[test]
435 fn test_permessage_deflate_negotiates_context_takeover_and_window_bits() {
436 let config = WsCompressionConfig::new()
437 .window_bits(13)
438 .client_window_bits(10);
439
440 let negotiated = negotiate_permessage_deflate(
441 "permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=12; client_max_window_bits",
442 config,
443 )
444 .expect("expected successful negotiation");
445
446 assert!(negotiated.contains("permessage-deflate"));
447 assert!(negotiated.contains("server_no_context_takeover"));
448 assert!(negotiated.contains("client_no_context_takeover"));
449 assert!(negotiated.contains("server_max_window_bits=12"));
450 assert!(negotiated.contains("client_max_window_bits=10"));
451 }
452
453 #[test]
454 fn test_permessage_deflate_skips_invalid_offer_and_uses_next_offer() {
455 let config = WsCompressionConfig::new()
456 .window_bits(11)
457 .client_window_bits(11);
458
459 let negotiated = negotiate_permessage_deflate(
460 "permessage-deflate; server_max_window_bits=7, permessage-deflate; client_max_window_bits",
461 config,
462 )
463 .expect("expected fallback to second valid offer");
464
465 assert!(negotiated.contains("permessage-deflate"));
466 assert!(negotiated.contains("client_max_window_bits=11"));
467 assert!(!negotiated.contains("server_max_window_bits=7"));
468 }
469
470 #[test]
471 fn test_permessage_deflate_returns_none_when_not_offered() {
472 let config = WsCompressionConfig::default();
473 let negotiated = negotiate_permessage_deflate("x-webkit-deflate-frame", config);
474 assert!(negotiated.is_none());
475 }
476}