rustapi_ws/
upgrade.rs

1//! WebSocket upgrade response
2
3use crate::{WebSocketError, WebSocketStream, WsHeartbeatConfig};
4use bytes::Bytes;
5use http::{header, Response, StatusCode};
6use http_body_util::Full;
7use hyper::upgrade::OnUpgrade;
8use hyper_util::rt::TokioIo;
9use rustapi_core::IntoResponse;
10use rustapi_openapi::{Operation, ResponseModifier, ResponseSpec};
11use std::future::Future;
12use std::pin::Pin;
13use tokio_tungstenite::tungstenite::protocol::Role;
14
15/// Type alias for WebSocket upgrade callback
16type UpgradeCallback =
17    Box<dyn FnOnce(WebSocketStream) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
18
19/// WebSocket upgrade response
20///
21/// This type is returned from WebSocket handlers to initiate the upgrade
22/// handshake and establish a WebSocket connection.
23use crate::compression::WsCompressionConfig;
24
25/// WebSocket upgrade response
26///
27/// This type is returned from WebSocket handlers to initiate the upgrade
28/// handshake and establish a WebSocket connection.
29pub struct WebSocketUpgrade {
30    /// The upgrade response
31    response: Response<Full<Bytes>>,
32    /// Callback to handle the WebSocket connection
33    on_upgrade: Option<UpgradeCallback>,
34    /// SEC-WebSocket-Key from request
35    #[allow(dead_code)]
36    sec_key: String,
37    /// Client requested extensions
38    client_extensions: Option<String>,
39    /// Configured compression
40    compression: Option<WsCompressionConfig>,
41    /// Configured heartbeat
42    heartbeat: Option<WsHeartbeatConfig>,
43    /// OnUpgrade future from hyper
44    on_upgrade_fut: Option<OnUpgrade>,
45}
46
47impl WebSocketUpgrade {
48    /// Create a new WebSocket upgrade from request headers
49    pub(crate) fn new(
50        sec_key: String,
51        client_extensions: Option<String>,
52        on_upgrade_fut: Option<OnUpgrade>,
53    ) -> Self {
54        // Generate accept key
55        let accept_key = generate_accept_key(&sec_key);
56
57        // Build upgrade response
58        let response = Response::builder()
59            .status(StatusCode::SWITCHING_PROTOCOLS)
60            .header(header::UPGRADE, "websocket")
61            .header(header::CONNECTION, "Upgrade")
62            .header("Sec-WebSocket-Accept", accept_key)
63            .body(Full::new(Bytes::new()))
64            .unwrap();
65
66        Self {
67            response,
68            on_upgrade: None,
69            sec_key,
70            client_extensions,
71            compression: None,
72            heartbeat: None,
73            on_upgrade_fut,
74        }
75    }
76
77    /// Enable WebSocket heartbeat
78    pub fn heartbeat(mut self, config: WsHeartbeatConfig) -> Self {
79        self.heartbeat = Some(config);
80        self
81    }
82
83    /// Enable WebSocket compression
84    pub fn compress(mut self, config: WsCompressionConfig) -> Self {
85        self.compression = Some(config);
86
87        // Simple negotiation: if client supports it, we enable it
88        if let Some(exts) = &self.client_extensions {
89            if exts.contains("permessage-deflate") {
90                // We currently use a simple negotiation strategy
91                // TODO: Parse parameters and negotiate window bits
92                let mut header_val = String::from("permessage-deflate");
93
94                // Add server/client_no_context_takeover to reduce memory usage at cost of compression ratio
95                // This is a common default for many servers
96                header_val.push_str("; server_no_context_takeover");
97                header_val.push_str("; client_no_context_takeover");
98
99                if config.window_bits < 15 {
100                    header_val
101                        .push_str(&format!("; server_max_window_bits={}", config.window_bits));
102                }
103                if config.client_window_bits < 15 {
104                    header_val.push_str(&format!(
105                        "; client_max_window_bits={}",
106                        config.client_window_bits
107                    ));
108                }
109
110                if let Ok(val) = header::HeaderValue::from_str(&header_val) {
111                    self.response
112                        .headers_mut()
113                        .insert("Sec-WebSocket-Extensions", val);
114                }
115            }
116        }
117        self
118    }
119
120    /// Set the callback to handle the upgraded WebSocket connection
121    ///
122    /// # Example
123    ///
124    /// ```rust,ignore
125    /// ws.on_upgrade(|socket| async move {
126    ///     let (mut sender, mut receiver) = socket.split();
127    ///     while let Some(msg) = receiver.next().await {
128    ///         // Handle messages...
129    ///     }
130    /// })
131    /// ```
132    pub fn on_upgrade<F, Fut>(mut self, callback: F) -> Self
133    where
134        F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
135        Fut: Future<Output = ()> + Send + 'static,
136    {
137        self.on_upgrade = Some(Box::new(move |stream| Box::pin(callback(stream))));
138        self
139    }
140
141    /// Add a protocol to the response
142    pub fn protocol(mut self, protocol: &str) -> Self {
143        // Rebuild response to keep headers clean (or just insert)
144        // More efficient to just insert
145        self.response.headers_mut().insert(
146            "Sec-WebSocket-Protocol",
147            header::HeaderValue::from_str(protocol).unwrap(),
148        );
149        self
150    }
151
152    /// Get the underlying response (for implementing IntoResponse)
153    #[allow(dead_code)]
154    pub(crate) fn into_response_inner(self) -> Response<Full<Bytes>> {
155        self.response
156    }
157
158    /// Get the on_upgrade callback
159    #[allow(dead_code)]
160    pub(crate) fn take_callback(&mut self) -> Option<UpgradeCallback> {
161        self.on_upgrade.take()
162    }
163}
164
165impl IntoResponse for WebSocketUpgrade {
166    fn into_response(mut self) -> http::Response<Full<Bytes>> {
167        // If we have the upgrade future and a callback, spawn the upgrade task
168        if let (Some(on_upgrade), Some(callback)) =
169            (self.on_upgrade_fut.take(), self.on_upgrade.take())
170        {
171            let heartbeat = self.heartbeat;
172
173            // TODO: Apply compression config to WebSocketConfig if/when supported by from_raw_socket
174            // Currently tungstenite negotiation logic in handshake is separate from stream config
175
176            tokio::spawn(async move {
177                match on_upgrade.await {
178                    Ok(upgraded) => {
179                        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
180                            TokioIo::new(upgraded),
181                            Role::Server,
182                            None,
183                        )
184                        .await;
185
186                        let socket = if let Some(hb_config) = heartbeat {
187                            WebSocketStream::new_managed(ws_stream, hb_config)
188                        } else {
189                            WebSocketStream::new(ws_stream)
190                        };
191
192                        callback(socket).await;
193                    }
194                    Err(e) => {
195                        tracing::error!("WebSocket upgrade failed: {:?}", e);
196                        // Also try to print the source if available
197                        if let Some(source) = std::error::Error::source(&e) {
198                            tracing::error!("Cause: {:?}", source);
199                        }
200                    }
201                }
202            });
203        }
204
205        self.response
206    }
207}
208
209impl ResponseModifier for WebSocketUpgrade {
210    fn update_response(op: &mut Operation) {
211        op.responses.insert(
212            "101".to_string(),
213            ResponseSpec {
214                description: "WebSocket upgrade successful".to_string(),
215                content: None,
216            },
217        );
218    }
219}
220
221/// Generate the Sec-WebSocket-Accept key from the client's Sec-WebSocket-Key
222fn generate_accept_key(key: &str) -> String {
223    use base64::Engine;
224    use sha1::{Digest, Sha1};
225
226    const GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
227
228    let mut hasher = Sha1::new();
229    hasher.update(key.as_bytes());
230    hasher.update(GUID.as_bytes());
231    let hash = hasher.finalize();
232
233    base64::engine::general_purpose::STANDARD.encode(hash)
234}
235
236/// Validate that a request is a valid WebSocket upgrade request
237pub(crate) fn validate_upgrade_request(
238    method: &http::Method,
239    headers: &http::HeaderMap,
240) -> Result<String, WebSocketError> {
241    // Must be GET
242    if method != http::Method::GET {
243        return Err(WebSocketError::invalid_upgrade("Method must be GET"));
244    }
245
246    // Must have Upgrade: websocket header
247    let upgrade = headers
248        .get(header::UPGRADE)
249        .and_then(|v| v.to_str().ok())
250        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Upgrade header"))?;
251
252    if !upgrade.eq_ignore_ascii_case("websocket") {
253        return Err(WebSocketError::invalid_upgrade(
254            "Upgrade header must be 'websocket'",
255        ));
256    }
257
258    // Must have Connection: Upgrade header
259    let connection = headers
260        .get(header::CONNECTION)
261        .and_then(|v| v.to_str().ok())
262        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Connection header"))?;
263
264    let has_upgrade = connection
265        .split(',')
266        .any(|s| s.trim().eq_ignore_ascii_case("upgrade"));
267
268    if !has_upgrade {
269        return Err(WebSocketError::invalid_upgrade(
270            "Connection header must contain 'Upgrade'",
271        ));
272    }
273
274    // Must have Sec-WebSocket-Key header
275    let sec_key = headers
276        .get("Sec-WebSocket-Key")
277        .and_then(|v| v.to_str().ok())
278        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Key header"))?;
279
280    // Must have Sec-WebSocket-Version: 13
281    let version = headers
282        .get("Sec-WebSocket-Version")
283        .and_then(|v| v.to_str().ok())
284        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Version header"))?;
285
286    if version != "13" {
287        return Err(WebSocketError::invalid_upgrade(
288            "Sec-WebSocket-Version must be 13",
289        ));
290    }
291
292    Ok(sec_key.to_string())
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_accept_key_generation() {
301        // Example from RFC 6455
302        let key = "dGhlIHNhbXBsZSBub25jZQ==";
303        let accept = generate_accept_key(key);
304        assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
305    }
306}