Skip to main content

rustapi_ws/
upgrade.rs

1//! WebSocket upgrade response
2
3use crate::{WebSocketError, WebSocketStream, WsHeartbeatConfig};
4use http::{header, Response, StatusCode};
5use hyper::upgrade::OnUpgrade;
6use hyper_util::rt::TokioIo;
7use rustapi_core::{IntoResponse, ResponseBody};
8use rustapi_openapi::{Operation, ResponseModifier, ResponseSpec};
9use std::collections::BTreeMap;
10use std::future::Future;
11use std::pin::Pin;
12use tokio_tungstenite::tungstenite::protocol::Role;
13
14/// Type alias for WebSocket upgrade callback
15type UpgradeCallback =
16    Box<dyn FnOnce(WebSocketStream) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
17
18/// WebSocket upgrade response
19///
20/// This type is returned from WebSocket handlers to initiate the upgrade
21/// handshake and establish a WebSocket connection.
22use crate::compression::WsCompressionConfig;
23
24/// WebSocket upgrade response
25///
26/// This type is returned from WebSocket handlers to initiate the upgrade
27/// handshake and establish a WebSocket connection.
28pub struct WebSocketUpgrade {
29    /// The upgrade response
30    response: Response<ResponseBody>,
31    /// Callback to handle the WebSocket connection
32    on_upgrade: Option<UpgradeCallback>,
33    /// SEC-WebSocket-Key from request
34    #[allow(dead_code)]
35    sec_key: String,
36    /// Client requested extensions
37    client_extensions: Option<String>,
38    /// Configured compression
39    compression: Option<WsCompressionConfig>,
40    /// Configured heartbeat
41    heartbeat: Option<WsHeartbeatConfig>,
42    /// OnUpgrade future from hyper
43    on_upgrade_fut: Option<OnUpgrade>,
44}
45
46impl WebSocketUpgrade {
47    /// Create a new WebSocket upgrade from request headers
48    pub(crate) fn new(
49        sec_key: String,
50        client_extensions: Option<String>,
51        on_upgrade_fut: Option<OnUpgrade>,
52    ) -> Self {
53        // Generate accept key
54        let accept_key = generate_accept_key(&sec_key);
55
56        // Build upgrade response
57        let response = Response::builder()
58            .status(StatusCode::SWITCHING_PROTOCOLS)
59            .header(header::UPGRADE, "websocket")
60            .header(header::CONNECTION, "Upgrade")
61            .header("Sec-WebSocket-Accept", accept_key)
62            .body(ResponseBody::empty())
63            .unwrap();
64
65        Self {
66            response,
67            on_upgrade: None,
68            sec_key,
69            client_extensions,
70            compression: None,
71            heartbeat: None,
72            on_upgrade_fut,
73        }
74    }
75
76    /// Enable WebSocket heartbeat
77    pub fn heartbeat(mut self, config: WsHeartbeatConfig) -> Self {
78        self.heartbeat = Some(config);
79        self
80    }
81
82    /// Enable WebSocket compression
83    pub fn compress(mut self, config: WsCompressionConfig) -> Self {
84        self.compression = Some(config);
85
86        if let Some(exts) = &self.client_extensions {
87            if let Some(header_val) = negotiate_permessage_deflate(exts, config) {
88                if let Ok(val) = header::HeaderValue::from_str(&header_val) {
89                    self.response
90                        .headers_mut()
91                        .insert("Sec-WebSocket-Extensions", val);
92                }
93            }
94        }
95        self
96    }
97
98    /// Set the callback to handle the upgraded WebSocket connection
99    ///
100    /// # Example
101    ///
102    /// ```rust,ignore
103    /// ws.on_upgrade(|socket| async move {
104    ///     let (mut sender, mut receiver) = socket.split();
105    ///     while let Some(msg) = receiver.next().await {
106    ///         // Handle messages...
107    ///     }
108    /// })
109    /// ```
110    pub fn on_upgrade<F, Fut>(mut self, callback: F) -> Self
111    where
112        F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
113        Fut: Future<Output = ()> + Send + 'static,
114    {
115        self.on_upgrade = Some(Box::new(move |stream| Box::pin(callback(stream))));
116        self
117    }
118
119    /// Add a protocol to the response
120    pub fn protocol(mut self, protocol: &str) -> Self {
121        // Rebuild response to keep headers clean (or just insert)
122        // More efficient to just insert
123        self.response.headers_mut().insert(
124            "Sec-WebSocket-Protocol",
125            header::HeaderValue::from_str(protocol).unwrap(),
126        );
127        self
128    }
129
130    /// Get the underlying response (for implementing IntoResponse)
131    #[allow(dead_code)]
132    pub(crate) fn into_response_inner(self) -> Response<ResponseBody> {
133        self.response
134    }
135
136    /// Get the on_upgrade callback
137    #[allow(dead_code)]
138    pub(crate) fn take_callback(&mut self) -> Option<UpgradeCallback> {
139        self.on_upgrade.take()
140    }
141}
142
143impl IntoResponse for WebSocketUpgrade {
144    fn into_response(mut self) -> rustapi_core::Response {
145        // If we have the upgrade future and a callback, spawn the upgrade task
146        if let (Some(on_upgrade), Some(callback)) =
147            (self.on_upgrade_fut.take(), self.on_upgrade.take())
148        {
149            let heartbeat = self.heartbeat;
150
151            // TODO: Apply compression config to WebSocketConfig if/when supported by from_raw_socket
152            // Currently tungstenite negotiation logic in handshake is separate from stream config
153
154            tokio::spawn(async move {
155                match on_upgrade.await {
156                    Ok(upgraded) => {
157                        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
158                            TokioIo::new(upgraded),
159                            Role::Server,
160                            None,
161                        )
162                        .await;
163
164                        let socket = if let Some(hb_config) = heartbeat {
165                            WebSocketStream::new_managed(ws_stream, hb_config)
166                        } else {
167                            WebSocketStream::new(ws_stream)
168                        };
169
170                        callback(socket).await;
171                    }
172                    Err(e) => {
173                        tracing::error!("WebSocket upgrade failed: {:?}", e);
174                        // Also try to print the source if available
175                        if let Some(source) = std::error::Error::source(&e) {
176                            tracing::error!("Cause: {:?}", source);
177                        }
178                    }
179                }
180            });
181        }
182
183        self.response
184    }
185}
186
187impl ResponseModifier for WebSocketUpgrade {
188    fn update_response(op: &mut Operation) {
189        op.responses.insert(
190            "101".to_string(),
191            ResponseSpec {
192                description: "WebSocket upgrade successful".to_string(),
193                content: BTreeMap::new(),
194                headers: BTreeMap::new(),
195            },
196        );
197    }
198}
199
200#[derive(Debug)]
201struct ParsedExtension {
202    name: String,
203    params: Vec<(String, Option<String>)>,
204}
205
206#[derive(Debug, Default)]
207struct PerMessageDeflateOffer {
208    server_no_context_takeover: bool,
209    client_no_context_takeover: bool,
210    server_max_window_bits: Option<Option<u8>>,
211    client_max_window_bits: Option<Option<u8>>,
212}
213
214fn negotiate_permessage_deflate(
215    client_extensions: &str,
216    config: WsCompressionConfig,
217) -> Option<String> {
218    for ext in parse_extension_offers(client_extensions) {
219        if ext.name != "permessage-deflate" {
220            continue;
221        }
222
223        let Some(offer) = parse_permessage_deflate_offer(&ext) else {
224            continue;
225        };
226        let mut negotiated = vec!["permessage-deflate".to_string()];
227
228        if offer.server_no_context_takeover {
229            negotiated.push("server_no_context_takeover".to_string());
230        }
231        if offer.client_no_context_takeover {
232            negotiated.push("client_no_context_takeover".to_string());
233        }
234
235        if let Some(requested) = offer.server_max_window_bits {
236            let bits = requested
237                .map(|max| config.window_bits.min(max))
238                .unwrap_or(config.window_bits);
239            negotiated.push(format!("server_max_window_bits={}", bits));
240        }
241
242        if let Some(requested) = offer.client_max_window_bits {
243            let bits = requested
244                .map(|max| config.client_window_bits.min(max))
245                .unwrap_or(config.client_window_bits);
246            negotiated.push(format!("client_max_window_bits={}", bits));
247        }
248
249        return Some(negotiated.join("; "));
250    }
251
252    None
253}
254
255fn parse_extension_offers(header_value: &str) -> Vec<ParsedExtension> {
256    let mut offers = Vec::new();
257
258    for raw_extension in header_value.split(',') {
259        let mut parts = raw_extension
260            .split(';')
261            .map(|part| part.trim())
262            .filter(|part| !part.is_empty());
263
264        let Some(name) = parts.next() else {
265            continue;
266        };
267
268        let mut params = Vec::new();
269        for raw_param in parts {
270            let (key, value) = parse_extension_param(raw_param);
271            params.push((key, value));
272        }
273
274        offers.push(ParsedExtension {
275            name: name.to_ascii_lowercase(),
276            params,
277        });
278    }
279
280    offers
281}
282
283fn parse_extension_param(raw_param: &str) -> (String, Option<String>) {
284    if let Some((key, value)) = raw_param.split_once('=') {
285        let value = value.trim().trim_matches('"').to_string();
286        (key.trim().to_ascii_lowercase(), Some(value))
287    } else {
288        (raw_param.trim().to_ascii_lowercase(), None)
289    }
290}
291
292fn parse_permessage_deflate_offer(ext: &ParsedExtension) -> Option<PerMessageDeflateOffer> {
293    let mut offer = PerMessageDeflateOffer::default();
294
295    for (key, value) in &ext.params {
296        match key.as_str() {
297            "server_no_context_takeover" => {
298                if value.is_some() || offer.server_no_context_takeover {
299                    return None;
300                }
301                offer.server_no_context_takeover = true;
302            }
303            "client_no_context_takeover" => {
304                if value.is_some() || offer.client_no_context_takeover {
305                    return None;
306                }
307                offer.client_no_context_takeover = true;
308            }
309            "server_max_window_bits" => {
310                if offer.server_max_window_bits.is_some() {
311                    return None;
312                }
313                let parsed = match value {
314                    Some(v) => Some(parse_window_bits(v)?),
315                    None => None,
316                };
317                offer.server_max_window_bits = Some(parsed);
318            }
319            "client_max_window_bits" => {
320                if offer.client_max_window_bits.is_some() {
321                    return None;
322                }
323                let parsed = match value {
324                    Some(v) => Some(parse_window_bits(v)?),
325                    None => None,
326                };
327                offer.client_max_window_bits = Some(parsed);
328            }
329            _ => {
330                // Ignore unknown permessage-deflate params for compatibility.
331            }
332        }
333    }
334
335    Some(offer)
336}
337
338fn parse_window_bits(value: &str) -> Option<u8> {
339    let parsed = value.parse::<u8>().ok()?;
340    if (9..=15).contains(&parsed) {
341        Some(parsed)
342    } else {
343        None
344    }
345}
346
347/// Generate the Sec-WebSocket-Accept key from the client's Sec-WebSocket-Key
348fn generate_accept_key(key: &str) -> String {
349    use base64::Engine;
350    use sha1::{Digest, Sha1};
351
352    const GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
353
354    let mut hasher = Sha1::new();
355    hasher.update(key.as_bytes());
356    hasher.update(GUID.as_bytes());
357    let hash = hasher.finalize();
358
359    base64::engine::general_purpose::STANDARD.encode(hash)
360}
361
362/// Validate that a request is a valid WebSocket upgrade request
363pub(crate) fn validate_upgrade_request(
364    method: &http::Method,
365    headers: &http::HeaderMap,
366) -> Result<String, WebSocketError> {
367    // Must be GET
368    if method != http::Method::GET {
369        return Err(WebSocketError::invalid_upgrade("Method must be GET"));
370    }
371
372    // Must have Upgrade: websocket header
373    let upgrade = headers
374        .get(header::UPGRADE)
375        .and_then(|v| v.to_str().ok())
376        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Upgrade header"))?;
377
378    if !upgrade.eq_ignore_ascii_case("websocket") {
379        return Err(WebSocketError::invalid_upgrade(
380            "Upgrade header must be 'websocket'",
381        ));
382    }
383
384    // Must have Connection: Upgrade header
385    let connection = headers
386        .get(header::CONNECTION)
387        .and_then(|v| v.to_str().ok())
388        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Connection header"))?;
389
390    let has_upgrade = connection
391        .split(',')
392        .any(|s| s.trim().eq_ignore_ascii_case("upgrade"));
393
394    if !has_upgrade {
395        return Err(WebSocketError::invalid_upgrade(
396            "Connection header must contain 'Upgrade'",
397        ));
398    }
399
400    // Must have Sec-WebSocket-Key header
401    let sec_key = headers
402        .get("Sec-WebSocket-Key")
403        .and_then(|v| v.to_str().ok())
404        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Key header"))?;
405
406    // Must have Sec-WebSocket-Version: 13
407    let version = headers
408        .get("Sec-WebSocket-Version")
409        .and_then(|v| v.to_str().ok())
410        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Version header"))?;
411
412    if version != "13" {
413        return Err(WebSocketError::invalid_upgrade(
414            "Sec-WebSocket-Version must be 13",
415        ));
416    }
417
418    Ok(sec_key.to_string())
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::WsCompressionConfig;
425
426    #[test]
427    fn test_accept_key_generation() {
428        // Example from RFC 6455
429        let key = "dGhlIHNhbXBsZSBub25jZQ==";
430        let accept = generate_accept_key(key);
431        assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
432    }
433
434    #[test]
435    fn test_permessage_deflate_negotiates_context_takeover_and_window_bits() {
436        let config = WsCompressionConfig::new()
437            .window_bits(13)
438            .client_window_bits(10);
439
440        let negotiated = negotiate_permessage_deflate(
441            "permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=12; client_max_window_bits",
442            config,
443        )
444        .expect("expected successful negotiation");
445
446        assert!(negotiated.contains("permessage-deflate"));
447        assert!(negotiated.contains("server_no_context_takeover"));
448        assert!(negotiated.contains("client_no_context_takeover"));
449        assert!(negotiated.contains("server_max_window_bits=12"));
450        assert!(negotiated.contains("client_max_window_bits=10"));
451    }
452
453    #[test]
454    fn test_permessage_deflate_skips_invalid_offer_and_uses_next_offer() {
455        let config = WsCompressionConfig::new()
456            .window_bits(11)
457            .client_window_bits(11);
458
459        let negotiated = negotiate_permessage_deflate(
460            "permessage-deflate; server_max_window_bits=7, permessage-deflate; client_max_window_bits",
461            config,
462        )
463        .expect("expected fallback to second valid offer");
464
465        assert!(negotiated.contains("permessage-deflate"));
466        assert!(negotiated.contains("client_max_window_bits=11"));
467        assert!(!negotiated.contains("server_max_window_bits=7"));
468    }
469
470    #[test]
471    fn test_permessage_deflate_returns_none_when_not_offered() {
472        let config = WsCompressionConfig::default();
473        let negotiated = negotiate_permessage_deflate("x-webkit-deflate-frame", config);
474        assert!(negotiated.is_none());
475    }
476}