rustapi_ws/
upgrade.rs

1//! WebSocket upgrade response
2
3use crate::{WebSocketError, WebSocketStream, WsHeartbeatConfig};
4use http::{header, Response, StatusCode};
5use hyper::upgrade::OnUpgrade;
6use hyper_util::rt::TokioIo;
7use rustapi_core::{IntoResponse, ResponseBody};
8use rustapi_openapi::{Operation, ResponseModifier, ResponseSpec};
9use std::future::Future;
10use std::pin::Pin;
11use tokio_tungstenite::tungstenite::protocol::Role;
12
13/// Type alias for WebSocket upgrade callback
14type UpgradeCallback =
15    Box<dyn FnOnce(WebSocketStream) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
16
17/// WebSocket upgrade response
18///
19/// This type is returned from WebSocket handlers to initiate the upgrade
20/// handshake and establish a WebSocket connection.
21use crate::compression::WsCompressionConfig;
22
23/// WebSocket upgrade response
24///
25/// This type is returned from WebSocket handlers to initiate the upgrade
26/// handshake and establish a WebSocket connection.
27pub struct WebSocketUpgrade {
28    /// The upgrade response
29    response: Response<ResponseBody>,
30    /// Callback to handle the WebSocket connection
31    on_upgrade: Option<UpgradeCallback>,
32    /// SEC-WebSocket-Key from request
33    #[allow(dead_code)]
34    sec_key: String,
35    /// Client requested extensions
36    client_extensions: Option<String>,
37    /// Configured compression
38    compression: Option<WsCompressionConfig>,
39    /// Configured heartbeat
40    heartbeat: Option<WsHeartbeatConfig>,
41    /// OnUpgrade future from hyper
42    on_upgrade_fut: Option<OnUpgrade>,
43}
44
45impl WebSocketUpgrade {
46    /// Create a new WebSocket upgrade from request headers
47    pub(crate) fn new(
48        sec_key: String,
49        client_extensions: Option<String>,
50        on_upgrade_fut: Option<OnUpgrade>,
51    ) -> Self {
52        // Generate accept key
53        let accept_key = generate_accept_key(&sec_key);
54
55        // Build upgrade response
56        let response = Response::builder()
57            .status(StatusCode::SWITCHING_PROTOCOLS)
58            .header(header::UPGRADE, "websocket")
59            .header(header::CONNECTION, "Upgrade")
60            .header("Sec-WebSocket-Accept", accept_key)
61            .body(ResponseBody::empty())
62            .unwrap();
63
64        Self {
65            response,
66            on_upgrade: None,
67            sec_key,
68            client_extensions,
69            compression: None,
70            heartbeat: None,
71            on_upgrade_fut,
72        }
73    }
74
75    /// Enable WebSocket heartbeat
76    pub fn heartbeat(mut self, config: WsHeartbeatConfig) -> Self {
77        self.heartbeat = Some(config);
78        self
79    }
80
81    /// Enable WebSocket compression
82    pub fn compress(mut self, config: WsCompressionConfig) -> Self {
83        self.compression = Some(config);
84
85        // Simple negotiation: if client supports it, we enable it
86        if let Some(exts) = &self.client_extensions {
87            if exts.contains("permessage-deflate") {
88                // We currently use a simple negotiation strategy
89                // TODO: Parse parameters and negotiate window bits
90                let mut header_val = String::from("permessage-deflate");
91
92                // Add server/client_no_context_takeover to reduce memory usage at cost of compression ratio
93                // This is a common default for many servers
94                header_val.push_str("; server_no_context_takeover");
95                header_val.push_str("; client_no_context_takeover");
96
97                if config.window_bits < 15 {
98                    header_val
99                        .push_str(&format!("; server_max_window_bits={}", config.window_bits));
100                }
101                if config.client_window_bits < 15 {
102                    header_val.push_str(&format!(
103                        "; client_max_window_bits={}",
104                        config.client_window_bits
105                    ));
106                }
107
108                if let Ok(val) = header::HeaderValue::from_str(&header_val) {
109                    self.response
110                        .headers_mut()
111                        .insert("Sec-WebSocket-Extensions", val);
112                }
113            }
114        }
115        self
116    }
117
118    /// Set the callback to handle the upgraded WebSocket connection
119    ///
120    /// # Example
121    ///
122    /// ```rust,ignore
123    /// ws.on_upgrade(|socket| async move {
124    ///     let (mut sender, mut receiver) = socket.split();
125    ///     while let Some(msg) = receiver.next().await {
126    ///         // Handle messages...
127    ///     }
128    /// })
129    /// ```
130    pub fn on_upgrade<F, Fut>(mut self, callback: F) -> Self
131    where
132        F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
133        Fut: Future<Output = ()> + Send + 'static,
134    {
135        self.on_upgrade = Some(Box::new(move |stream| Box::pin(callback(stream))));
136        self
137    }
138
139    /// Add a protocol to the response
140    pub fn protocol(mut self, protocol: &str) -> Self {
141        // Rebuild response to keep headers clean (or just insert)
142        // More efficient to just insert
143        self.response.headers_mut().insert(
144            "Sec-WebSocket-Protocol",
145            header::HeaderValue::from_str(protocol).unwrap(),
146        );
147        self
148    }
149
150    /// Get the underlying response (for implementing IntoResponse)
151    #[allow(dead_code)]
152    pub(crate) fn into_response_inner(self) -> Response<ResponseBody> {
153        self.response
154    }
155
156    /// Get the on_upgrade callback
157    #[allow(dead_code)]
158    pub(crate) fn take_callback(&mut self) -> Option<UpgradeCallback> {
159        self.on_upgrade.take()
160    }
161}
162
163impl IntoResponse for WebSocketUpgrade {
164    fn into_response(mut self) -> rustapi_core::Response {
165        // If we have the upgrade future and a callback, spawn the upgrade task
166        if let (Some(on_upgrade), Some(callback)) =
167            (self.on_upgrade_fut.take(), self.on_upgrade.take())
168        {
169            let heartbeat = self.heartbeat;
170
171            // TODO: Apply compression config to WebSocketConfig if/when supported by from_raw_socket
172            // Currently tungstenite negotiation logic in handshake is separate from stream config
173
174            tokio::spawn(async move {
175                match on_upgrade.await {
176                    Ok(upgraded) => {
177                        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
178                            TokioIo::new(upgraded),
179                            Role::Server,
180                            None,
181                        )
182                        .await;
183
184                        let socket = if let Some(hb_config) = heartbeat {
185                            WebSocketStream::new_managed(ws_stream, hb_config)
186                        } else {
187                            WebSocketStream::new(ws_stream)
188                        };
189
190                        callback(socket).await;
191                    }
192                    Err(e) => {
193                        tracing::error!("WebSocket upgrade failed: {:?}", e);
194                        // Also try to print the source if available
195                        if let Some(source) = std::error::Error::source(&e) {
196                            tracing::error!("Cause: {:?}", source);
197                        }
198                    }
199                }
200            });
201        }
202
203        self.response
204    }
205}
206
207impl ResponseModifier for WebSocketUpgrade {
208    fn update_response(op: &mut Operation) {
209        op.responses.insert(
210            "101".to_string(),
211            ResponseSpec {
212                description: "WebSocket upgrade successful".to_string(),
213                content: None,
214            },
215        );
216    }
217}
218
219/// Generate the Sec-WebSocket-Accept key from the client's Sec-WebSocket-Key
220fn generate_accept_key(key: &str) -> String {
221    use base64::Engine;
222    use sha1::{Digest, Sha1};
223
224    const GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
225
226    let mut hasher = Sha1::new();
227    hasher.update(key.as_bytes());
228    hasher.update(GUID.as_bytes());
229    let hash = hasher.finalize();
230
231    base64::engine::general_purpose::STANDARD.encode(hash)
232}
233
234/// Validate that a request is a valid WebSocket upgrade request
235pub(crate) fn validate_upgrade_request(
236    method: &http::Method,
237    headers: &http::HeaderMap,
238) -> Result<String, WebSocketError> {
239    // Must be GET
240    if method != http::Method::GET {
241        return Err(WebSocketError::invalid_upgrade("Method must be GET"));
242    }
243
244    // Must have Upgrade: websocket header
245    let upgrade = headers
246        .get(header::UPGRADE)
247        .and_then(|v| v.to_str().ok())
248        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Upgrade header"))?;
249
250    if !upgrade.eq_ignore_ascii_case("websocket") {
251        return Err(WebSocketError::invalid_upgrade(
252            "Upgrade header must be 'websocket'",
253        ));
254    }
255
256    // Must have Connection: Upgrade header
257    let connection = headers
258        .get(header::CONNECTION)
259        .and_then(|v| v.to_str().ok())
260        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Connection header"))?;
261
262    let has_upgrade = connection
263        .split(',')
264        .any(|s| s.trim().eq_ignore_ascii_case("upgrade"));
265
266    if !has_upgrade {
267        return Err(WebSocketError::invalid_upgrade(
268            "Connection header must contain 'Upgrade'",
269        ));
270    }
271
272    // Must have Sec-WebSocket-Key header
273    let sec_key = headers
274        .get("Sec-WebSocket-Key")
275        .and_then(|v| v.to_str().ok())
276        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Key header"))?;
277
278    // Must have Sec-WebSocket-Version: 13
279    let version = headers
280        .get("Sec-WebSocket-Version")
281        .and_then(|v| v.to_str().ok())
282        .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Version header"))?;
283
284    if version != "13" {
285        return Err(WebSocketError::invalid_upgrade(
286            "Sec-WebSocket-Version must be 13",
287        ));
288    }
289
290    Ok(sec_key.to_string())
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_accept_key_generation() {
299        // Example from RFC 6455
300        let key = "dGhlIHNhbXBsZSBub25jZQ==";
301        let accept = generate_accept_key(key);
302        assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
303    }
304}