Skip to main content

actus_server/
websocket.rs

1//! WebSocket support (RFC 6455). Behind the `websocket` feature.
2//!
3//! A route handler that wants to serve a WebSocket validates the request as
4//! usual (origin, auth, subprotocol — whatever it needs), then returns
5//! [`upgrade`]`(...)` instead of an ordinary reply. The server completes the
6//! handshake (`101 Switching Protocols`), upgrades the connection, and runs
7//! the closure you supplied on the resulting [`WebSocket`] (a `Stream` of
8//! incoming [`Message`]s and a `Sink` for outgoing ones).
9//!
10//! ```ignore
11//! use actus::prelude::*;
12//! use futures_util::{SinkExt, StreamExt};
13//!
14//! pub async fn echo(&self, _params: &Params) -> Reply {
15//!     Ok(actus::ws::upgrade(|mut socket| async move {
16//!         while let Some(Ok(msg)) = socket.next().await {
17//!             if msg.is_text() || msg.is_binary() {
18//!                 let _ = socket.send(msg).await;
19//!             }
20//!         }
21//!     }))
22//! }
23//! ```
24//!
25//! Mount it like any other route (`GET "echo" => echo()`). If the request
26//! reaching such a handler isn't actually a WebSocket handshake, the server
27//! responds `426 Upgrade Required` instead of attempting the upgrade.
28//!
29//! ## Timing of the upgrade capture
30//!
31//! The server captures `OnUpgrade` (and derives `Sec-WebSocket-Accept`) for
32//! any request whose method/headers look like an RFC 6455 handshake —
33//! *before* middleware or routing runs. That means a handshake-shaped request
34//! that ends up short-circuited by middleware, rejected on auth, or routed to
35//! a non-WS handler still pays the small cost of capturing the upgrade
36//! future; the captured `OnUpgrade` is then dropped harmlessly (hyper handles
37//! a dropped upgrade future). This shape is deliberate: it keeps the upgrade
38//! capture out of the handler's hot path and avoids threading a "give me the
39//! upgrade now" callback through the reply machinery.
40
41use actus_reply::ReplyData;
42use futures_util::future::BoxFuture;
43use http::{HeaderMap, HeaderValue, Method, header};
44use hyper::upgrade::{OnUpgrade, Upgraded};
45use hyper_util::rt::TokioIo;
46use std::future::Future;
47use tokio_tungstenite::WebSocketStream;
48use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
49use tokio_tungstenite::tungstenite::protocol::Role;
50
51/// Re-export of `tungstenite`'s message type — text / binary / ping / pong /
52/// close frames.
53pub use tokio_tungstenite::tungstenite::Message;
54
55/// A server-side WebSocket connection: a [`Stream`](futures_util::Stream) of
56/// incoming [`Message`]s and a [`Sink`](futures_util::Sink) for outgoing ones.
57/// Obtained inside the closure passed to [`upgrade`].
58pub type WebSocket = WebSocketStream<TokioIo<Upgraded>>;
59
60/// The boxed one-shot task carried by `ReplyData::Upgrade` — run once the
61/// connection has been upgraded.
62pub(crate) type UpgradeTask = Box<dyn FnOnce(WebSocket) -> BoxFuture<'static, ()> + Send>;
63
64/// Return this from a route handler to turn the response into a WebSocket
65/// upgrade. `handler` runs on the upgraded connection after the handshake.
66/// See the [module docs](self).
67pub fn upgrade<F, Fut>(handler: F) -> ReplyData
68where
69    F: FnOnce(WebSocket) -> Fut + Send + 'static,
70    Fut: Future<Output = ()> + Send + 'static,
71{
72    let task: UpgradeTask = Box::new(move |ws| -> BoxFuture<'static, ()> { Box::pin(handler(ws)) });
73    ReplyData::Upgrade(Box::new(task))
74}
75
76/// True iff `(method, headers)` carry an RFC 6455 WebSocket handshake.
77pub(crate) fn is_upgrade_request(method: &Method, headers: &HeaderMap) -> bool {
78    fn list_contains(headers: &HeaderMap, name: header::HeaderName, needle: &str) -> bool {
79        headers
80            .get(name)
81            .and_then(|v| v.to_str().ok())
82            .is_some_and(|v| v.split(',').any(|t| t.trim().eq_ignore_ascii_case(needle)))
83    }
84    *method == Method::GET
85        && list_contains(headers, header::CONNECTION, "upgrade")
86        && list_contains(headers, header::UPGRADE, "websocket")
87        && headers.contains_key(header::SEC_WEBSOCKET_KEY)
88        && headers
89            .get(header::SEC_WEBSOCKET_VERSION)
90            .and_then(|v| v.to_str().ok())
91            == Some("13")
92}
93
94/// The `Sec-WebSocket-Accept` value derived from the request's
95/// `Sec-WebSocket-Key`, ready to put on the `101` response.
96pub(crate) fn accept_key(request_headers: &HeaderMap) -> Option<HeaderValue> {
97    let key = request_headers.get(header::SEC_WEBSOCKET_KEY)?;
98    HeaderValue::from_str(&derive_accept_key(key.as_bytes())).ok()
99}
100
101/// Await the connection upgrade and run `task` on the WebSocket. Spawned by
102/// the server after it has sent the `101` response.
103pub(crate) async fn run_upgrade(on_upgrade: OnUpgrade, task: UpgradeTask) {
104    match on_upgrade.await {
105        Ok(upgraded) => {
106            let socket =
107                WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, None).await;
108            task(socket).await;
109        }
110        Err(e) => tracing::warn!("websocket upgrade failed: {}", e),
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    fn headers(pairs: &[(header::HeaderName, &str)]) -> HeaderMap {
119        let mut h = HeaderMap::new();
120        for (n, v) in pairs {
121            h.insert(n.clone(), HeaderValue::from_str(v).unwrap());
122        }
123        h
124    }
125
126    #[test]
127    fn recognizes_a_valid_handshake() {
128        let h = headers(&[
129            (header::CONNECTION, "keep-alive, Upgrade"),
130            (header::UPGRADE, "websocket"),
131            (header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ=="),
132            (header::SEC_WEBSOCKET_VERSION, "13"),
133        ]);
134        assert!(is_upgrade_request(&Method::GET, &h));
135        // wrong method / missing bits → not a handshake
136        assert!(!is_upgrade_request(&Method::POST, &h));
137        assert!(!is_upgrade_request(
138            &Method::GET,
139            &headers(&[(header::UPGRADE, "websocket")])
140        ));
141    }
142
143    #[test]
144    fn derives_the_rfc6455_accept_key() {
145        // RFC 6455 §1.3: key "dGhlIHNhbXBsZSBub25jZQ==" → accept
146        // "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".
147        let h = headers(&[(header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==")]);
148        assert_eq!(accept_key(&h).unwrap(), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
149        assert!(accept_key(&HeaderMap::new()).is_none());
150    }
151}