Skip to main content

rustapi_ws/
upgrade.rs

1//! WebSocket upgrade response
2
3use crate::{WebSocketError, WebSocketStream, WsHeartbeatConfig};
4use http::{header, Response, StatusCode};
5use hyper::upgrade::OnUpgrade;
6use hyper_util::rt::TokioIo;
7use rustapi_core::{IntoResponse, ResponseBody};
8use rustapi_openapi::{Operation, ResponseModifier, ResponseSpec};
9use std::collections::BTreeMap;
10use std::future::Future;
11use std::pin::Pin;
12use tokio_tungstenite::tungstenite::protocol::Role;
13
14/// Type alias for WebSocket upgrade callback
15type UpgradeCallback =
16    Box<dyn FnOnce(WebSocketStream) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
17
18/// WebSocket upgrade response
19///
20/// This type is returned from WebSocket handlers to initiate the upgrade
21/// handshake and establish a WebSocket connection.
22use crate::compression::WsCompressionConfig;
23
24/// WebSocket upgrade response
25///
26/// This type is returned from WebSocket handlers to initiate the upgrade
27/// handshake and establish a WebSocket connection.
28pub struct WebSocketUpgrade {
29    /// The upgrade response
30    response: Response<ResponseBody>,
31    /// Callback to handle the WebSocket connection
32    on_upgrade: Option<UpgradeCallback>,
33    /// SEC-WebSocket-Key from request
34    #[allow(dead_code)]
35    sec_key: String,
36    /// Client requested extensions
37    client_extensions: Option<String>,
38    /// Configured compression
39    compression: Option<WsCompressionConfig>,
40    /// Configured heartbeat
41    heartbeat: Option<WsHeartbeatConfig>,
42    /// OnUpgrade future from hyper
43    on_upgrade_fut: Option<OnUpgrade>,
44}
45
46impl WebSocketUpgrade {
47    /// Create a new WebSocket upgrade from request headers
48    pub(crate) fn new(
49        sec_key: String,
50        client_extensions: Option<String>,
51        on_upgrade_fut: Option<OnUpgrade>,
52    ) -> Self {
53        // Generate accept key
54        let accept_key = generate_accept_key(&sec_key);
55
56        // Build upgrade response
57        let response = Response::builder()
58            .status(StatusCode::SWITCHING_PROTOCOLS)
59            .header(header::UPGRADE, "websocket")
60            .header(header::CONNECTION, "Upgrade")
61            .header("Sec-WebSocket-Accept", accept_key)
62            .body(ResponseBody::empty())
63            .unwrap();
64
65        Self {
66            response,
67            on_upgrade: None,
68            sec_key,
69            client_extensions,
70            compression: None,
71            heartbeat: None,
72            on_upgrade_fut,
73        }
74    }
75
76    /// Enable WebSocket heartbeat
77    pub fn heartbeat(mut self, config: WsHeartbeatConfig) -> Self {
78        self.heartbeat = Some(config);
79        self
80    }
81
82    /// Enable WebSocket compression
83    pub fn compress(mut self, config: WsCompressionConfig) -> Self {
84        self.compression = Some(config);
85
86        // Simple negotiation: if client supports it, we enable it
87        if let Some(exts) = &self.client_extensions {
88            if exts.contains("permessage-deflate") {
89                // We currently use a simple negotiation strategy
90                // TODO: Parse parameters and negotiate window bits
91                let mut header_val = String::from("permessage-deflate");
92
93                // Add server/client_no_context_takeover to reduce memory usage at cost of compression ratio
94                // This is a common default for many servers
95                header_val.push_str("; server_no_context_takeover");
96                header_val.push_str("; client_no_context_takeover");
97
98                if config.window_bits < 15 {
99                    header_val
100                        .push_str(&format!("; server_max_window_bits={}", config.window_bits));
101                }
102                if config.client_window_bits < 15 {
103                    header_val.push_str(&format!(
104                        "; client_max_window_bits={}",
105                        config.client_window_bits
106                    ));
107                }
108
109                if let Ok(val) = header::HeaderValue::from_str(&header_val) {
110                    self.response
111                        .headers_mut()
112                        .insert("Sec-WebSocket-Extensions", val);
113                }
114            }
115        }
116        self
117    }
118
119    /// Set the callback to handle the upgraded WebSocket connection
120    ///
121    /// # Example
122    ///
123    /// ```rust,ignore
124    /// ws.on_upgrade(|socket| async move {
125    ///     let (mut sender, mut receiver) = socket.split();
126    ///     while let Some(msg) = receiver.next().await {
127    ///         // Handle messages...
128    ///     }
129    /// })
130    /// ```
131    pub fn on_upgrade<F, Fut>(mut self, callback: F) -> Self
132    where
133        F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
134        Fut: Future<Output = ()> + Send + 'static,
135    {
136        self.on_upgrade = Some(Box::new(move |stream| Box::pin(callback(stream))));
137        self
138    }
139
140    /// Add a protocol to the response
141    pub fn protocol(mut self, protocol: &str) -> Self {
142        // Rebuild response to keep headers clean (or just insert)
143        // More efficient to just insert
144        self.response.headers_mut().insert(
145            "Sec-WebSocket-Protocol",
146            header::HeaderValue::from_str(protocol).unwrap(),
147        );
148        self
149    }
150
151    /// Get the underlying response (for implementing IntoResponse)
152    #[allow(dead_code)]
153    pub(crate) fn into_response_inner(self) -> Response<ResponseBody> {
154        self.response
155    }
156
157    /// Get the on_upgrade callback
158    #[allow(dead_code)]
159    pub(crate) fn take_callback(&mut self) -> Option<UpgradeCallback> {
160        self.on_upgrade.take()
161    }
162}
163
164impl IntoResponse for WebSocketUpgrade {
165    fn into_response(mut self) -> rustapi_core::Response {
166        // If we have the upgrade future and a callback, spawn the upgrade task
167        if let (Some(on_upgrade), Some(callback)) =
168            (self.on_upgrade_fut.take(), self.on_upgrade.take())
169        {
170            let heartbeat = self.heartbeat;
171
172            // TODO: Apply compression config to WebSocketConfig if/when supported by from_raw_socket
173            // Currently tungstenite negotiation logic in handshake is separate from stream config
174
175            tokio::spawn(async move {
176                match on_upgrade.await {
177                    Ok(upgraded) => {
178                        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
179                            TokioIo::new(upgraded),
180                            Role::Server,
181                            None,
182                        )
183                        .await;
184
185                        let socket = if let Some(hb_config) = heartbeat {
186                            WebSocketStream::new_managed(ws_stream, hb_config)
187                        } else {
188                            WebSocketStream::new(ws_stream)
189                        };
190
191                        callback(socket).await;
192                    }
193                    Err(e) => {
194                        tracing::error!("WebSocket upgrade failed: {:?}", e);
195                        // Also try to print the source if available
196                        if let Some(source) = std::error::Error::source(&e) {
197                            tracing::error!("Cause: {:?}", source);
198                        }
199                    }
200                }
201            });
202        }
203
204        self.response
205    }
206}
207
208impl ResponseModifier for WebSocketUpgrade {
209    fn update_response(op: &mut Operation) {
210        op.responses.insert(
211            "101".to_string(),
212            ResponseSpec {
213                description: "WebSocket upgrade successful".to_string(),
214                content: BTreeMap::new(),
215                headers: BTreeMap::new(),
216            },
217        );
218    }
219}
220
221/// Generate the Sec-WebSocket-Accept key from the client's Sec-WebSocket-Key
222fn generate_accept_key(key: &str) -> String {
223    use base64::Engine;
224    use sha1::{Digest, Sha1};
225
226    const GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
227
228    let mut hasher = Sha1::new();
229    hasher.update(key.as_bytes());
230    hasher.update(GUID.as_bytes());
231    let hash = hasher.finalize();
232
233    base64::engine::general_purpose::STANDARD.encode(hash)
234}
235
236/// Validate that a request is a valid WebSocket upgrade request
237pub(crate) fn validate_upgrade_request(
238    method: &http::Method,
239    headers: &http::HeaderMap,
240) -> Result<String, WebSocketError> {
241    // Must be GET
242    if method != http::Method::GET {
243        return Err(WebSocketError::invalid_upgrade("Method must be GET"));
244    }
245
246    // Must have Upgrade: websocket header
247    let upgrade = headers
248        .get(header::UPGRADE)
249        .and_then(|v| v.to_str().ok())
250        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Upgrade header"))?;
251
252    if !upgrade.eq_ignore_ascii_case("websocket") {
253        return Err(WebSocketError::invalid_upgrade(
254            "Upgrade header must be 'websocket'",
255        ));
256    }
257
258    // Must have Connection: Upgrade header
259    let connection = headers
260        .get(header::CONNECTION)
261        .and_then(|v| v.to_str().ok())
262        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Connection header"))?;
263
264    let has_upgrade = connection
265        .split(',')
266        .any(|s| s.trim().eq_ignore_ascii_case("upgrade"));
267
268    if !has_upgrade {
269        return Err(WebSocketError::invalid_upgrade(
270            "Connection header must contain 'Upgrade'",
271        ));
272    }
273
274    // Must have Sec-WebSocket-Key header
275    let sec_key = headers
276        .get("Sec-WebSocket-Key")
277        .and_then(|v| v.to_str().ok())
278        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Key header"))?;
279
280    // Must have Sec-WebSocket-Version: 13
281    let version = headers
282        .get("Sec-WebSocket-Version")
283        .and_then(|v| v.to_str().ok())
284        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Version header"))?;
285
286    if version != "13" {
287        return Err(WebSocketError::invalid_upgrade(
288            "Sec-WebSocket-Version must be 13",
289        ));
290    }
291
292    Ok(sec_key.to_string())
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_accept_key_generation() {
301        // Example from RFC 6455
302        let key = "dGhlIHNhbXBsZSBub25jZQ==";
303        let accept = generate_accept_key(key);
304        assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
305    }
306}