rustapi_ws/
extractor.rs

1//! WebSocket extractor
2
3use crate::upgrade::{validate_upgrade_request, WebSocketUpgrade};
4use hyper::upgrade::OnUpgrade;
5use rustapi_core::{ApiError, FromRequest, Request, Result};
6use rustapi_openapi::{Operation, OperationModifier};
7
8/// WebSocket extractor for upgrading HTTP connections to WebSocket
9///
10/// Use this extractor in your handler to initiate a WebSocket upgrade.
11/// The extractor validates the upgrade request and returns a `WebSocket`
12/// that can be used to set up the connection handler.
13///
14/// # Example
15///
16/// ```rust,ignore
17/// use rustapi_ws::{WebSocket, Message};
18///
19/// async fn ws_handler(ws: WebSocket) -> impl IntoResponse {
20///     ws.on_upgrade(|socket| async move {
21///         let (mut sender, mut receiver) = socket.split();
22///         
23///         while let Some(Ok(msg)) = receiver.next().await {
24///             match msg {
25///                 Message::Text(text) => {
26///                     // Echo back
27///                     let _ = sender.send(Message::text(format!("Echo: {}", text))).await;
28///                 }
29///                 Message::Close(_) => break,
30///                 _ => {}
31///             }
32///         }
33///     })
34/// }
35/// ```
36pub struct WebSocket {
37    sec_key: String,
38    protocols: Vec<String>,
39    extensions: Option<String>,
40    on_upgrade: Option<OnUpgrade>,
41}
42
43impl WebSocket {
44    /// Create a WebSocket upgrade response with a handler
45    ///
46    /// The provided callback will be called with the established WebSocket
47    /// stream once the upgrade is complete.
48    pub fn on_upgrade<F, Fut>(mut self, callback: F) -> WebSocketUpgrade
49    where
50        F: FnOnce(crate::WebSocketStream) -> Fut + Send + 'static,
51        Fut: std::future::Future<Output = ()> + Send + 'static,
52    {
53        let upgrade = WebSocketUpgrade::new(self.sec_key, self.extensions, self.on_upgrade.take());
54
55        // If protocols were requested, select the first one
56        let upgrade = if let Some(protocol) = self.protocols.first() {
57            upgrade.protocol(protocol)
58        } else {
59            upgrade
60        };
61
62        upgrade.on_upgrade(callback)
63    }
64
65    /// Get the requested protocols
66    pub fn protocols(&self) -> &[String] {
67        &self.protocols
68    }
69
70    /// Check if a specific protocol was requested
71    pub fn has_protocol(&self, protocol: &str) -> bool {
72        self.protocols.iter().any(|p| p == protocol)
73    }
74}
75
76impl FromRequest for WebSocket {
77    async fn from_request(req: &mut Request) -> Result<Self> {
78        let headers = req.headers();
79        let method = req.method();
80
81        // Validate the upgrade request
82        // Note: we clone sec_key to avoid keeping borrow of headers
83        let sec_key = validate_upgrade_request(method, headers)
84            .map_err(ApiError::from)?
85            .to_string();
86
87        // Parse requested protocols
88        let protocols = headers
89            .get("Sec-WebSocket-Protocol")
90            .and_then(|v| v.to_str().ok())
91            .map(|s| s.split(',').map(|p| p.trim().to_string()).collect())
92            .unwrap_or_default();
93
94        // Get extensions
95        let extensions = headers
96            .get("Sec-WebSocket-Extensions")
97            .and_then(|v| v.to_str().ok())
98            .map(|s| s.to_string());
99
100        // Capture OnUpgrade future
101        let on_upgrade = req.extensions_mut().remove::<OnUpgrade>();
102
103        // IMPORTANT: Consume the request body to ensure hyper allows the upgrade.
104        if let Some(stream) = req.take_stream() {
105            use http_body_util::BodyExt;
106            let _ = stream.collect().await;
107        }
108
109        Ok(Self {
110            sec_key,
111            protocols,
112            extensions,
113            on_upgrade,
114        })
115    }
116}
117
118impl OperationModifier for WebSocket {
119    fn update_operation(_op: &mut Operation) {
120        // WebSocket endpoints don't have regular request body parameters
121        // The upgrade is indicated by the response
122    }
123}