Skip to main content

actix_ws/
lib.rs

1//! WebSockets for Actix Web, without actors.
2//!
3//! For usage, see documentation on [`handle()`] and [`handle_with_protocols()`].
4
5#![warn(missing_docs)]
6#![doc(html_logo_url = "https://actix.rs/img/logo.png")]
7#![doc(html_favicon_url = "https://actix.rs/favicon.ico")]
8#![cfg_attr(docsrs, feature(doc_cfg))]
9
10pub use actix_http::ws::{CloseCode, CloseReason, Item, Message, ProtocolError};
11use actix_http::{
12    body::{BodyStream, MessageBody},
13    ws::handshake,
14};
15use actix_web::{http::header, web, HttpRequest, HttpResponse};
16use tokio::sync::{mpsc::channel, oneshot};
17
18mod aggregated;
19pub mod codec;
20mod session;
21mod stream;
22
23pub use self::{
24    aggregated::{AggregatedMessage, AggregatedMessageStream},
25    session::{Closed, Session},
26    stream::{MessageStream, StreamingBody},
27};
28
29/// Begin handling websocket traffic
30///
31/// To negotiate sub-protocols via `Sec-WebSocket-Protocol`, use [`handle_with_protocols`].
32///
33/// ```no_run
34/// use std::io;
35/// use actix_web::{middleware::Logger, web, App, HttpRequest, HttpServer, Responder};
36/// use actix_ws::Message;
37///
38/// async fn ws(req: HttpRequest, body: web::Payload) -> actix_web::Result<impl Responder> {
39///     let (response, mut session, mut msg_stream) = actix_ws::handle(&req, body)?;
40///
41///     actix_web::rt::spawn(async move {
42///         while let Some(Ok(msg)) = msg_stream.recv().await {
43///             match msg {
44///                 Message::Ping(bytes) => {
45///                     if session.pong(&bytes).await.is_err() {
46///                         return;
47///                     }
48///                 }
49///
50///                 Message::Text(msg) => println!("Got text: {msg}"),
51///                 _ => break,
52///             }
53///         }
54///
55///         let _ = session.close(None).await;
56///     });
57///
58///     Ok(response)
59/// }
60///
61/// #[tokio::main(flavor = "current_thread")]
62/// async fn main() -> io::Result<()> {
63///     HttpServer::new(move || {
64///         App::new()
65///             .route("/ws", web::get().to(ws))
66///             .wrap(Logger::default())
67///     })
68///     .bind(("127.0.0.1", 8080))?
69///     .run()
70///     .await
71/// }
72/// ```
73pub fn handle(
74    req: &HttpRequest,
75    body: web::Payload,
76) -> Result<(HttpResponse, Session, MessageStream), actix_web::Error> {
77    handle_with_protocols(req, body, &[])
78}
79
80/// Begin handling websocket traffic with optional sub-protocol negotiation.
81///
82/// The first protocol offered by the client in the `Sec-WebSocket-Protocol` header that also
83/// appears in `protocols` is returned in the handshake response.
84///
85/// If there is no overlap, no `Sec-WebSocket-Protocol` header is set in the response.
86pub fn handle_with_protocols(
87    req: &HttpRequest,
88    body: web::Payload,
89    protocols: &[&str],
90) -> Result<(HttpResponse, Session, MessageStream), actix_web::Error> {
91    let mut response = handshake_with_protocols(req, protocols)?;
92    let (tx, rx) = channel(32);
93    let (connection_closed_tx, connection_closed_rx) = oneshot::channel();
94
95    Ok((
96        response
97            .message_body(
98                BodyStream::new(
99                    StreamingBody::new(rx).with_connection_close_signal(connection_closed_tx),
100                )
101                .boxed(),
102            )?
103            .into(),
104        Session::new(tx),
105        MessageStream::new(body.into_inner()).with_connection_close_signal(connection_closed_rx),
106    ))
107}
108
109fn handshake_with_protocols(
110    req: &HttpRequest,
111    protocols: &[&str],
112) -> Result<actix_http::ResponseBuilder, actix_http::ws::HandshakeError> {
113    let mut response = handshake(req.head())?;
114
115    if let Some(protocol) = select_protocol(req, protocols) {
116        response.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocol));
117    }
118
119    Ok(response)
120}
121
122fn select_protocol<'a>(req: &'a HttpRequest, protocols: &[&str]) -> Option<&'a str> {
123    for requested_protocols in req.headers().get_all(header::SEC_WEBSOCKET_PROTOCOL) {
124        let Ok(requested_protocols) = requested_protocols.to_str() else {
125            continue;
126        };
127
128        for requested_protocol in requested_protocols.split(',').map(str::trim) {
129            if requested_protocol.is_empty() {
130                continue;
131            }
132
133            if protocols
134                .iter()
135                .any(|supported_protocol| *supported_protocol == requested_protocol)
136            {
137                return Some(requested_protocol);
138            }
139        }
140    }
141
142    None
143}
144
145#[cfg(test)]
146mod tests {
147    use actix_web::{
148        http::header::{self, HeaderValue},
149        test::TestRequest,
150        HttpRequest,
151    };
152
153    use super::handshake_with_protocols;
154
155    fn ws_request(protocols: Option<&'static str>) -> HttpRequest {
156        let mut req = TestRequest::default()
157            .insert_header((header::UPGRADE, HeaderValue::from_static("websocket")))
158            .insert_header((header::CONNECTION, HeaderValue::from_static("upgrade")))
159            .insert_header((
160                header::SEC_WEBSOCKET_VERSION,
161                HeaderValue::from_static("13"),
162            ))
163            .insert_header((
164                header::SEC_WEBSOCKET_KEY,
165                HeaderValue::from_static("x3JJHMbDL1EzLkh9GBhXDw=="),
166            ));
167
168        if let Some(protocols) = protocols {
169            req = req.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocols));
170        }
171
172        req.to_http_request()
173    }
174
175    #[test]
176    fn handshake_selects_first_supported_client_protocol() {
177        let req = ws_request(Some("p1,p2,p3"));
178
179        let response = handshake_with_protocols(&req, &["p3", "p2"])
180            .unwrap()
181            .finish();
182
183        assert_eq!(
184            response.headers().get(header::SEC_WEBSOCKET_PROTOCOL),
185            Some(&HeaderValue::from_static("p2")),
186        );
187    }
188
189    #[test]
190    fn handshake_omits_protocol_header_without_overlap() {
191        let req = ws_request(Some("p1,p2,p3"));
192
193        let response = handshake_with_protocols(&req, &["graphql"])
194            .unwrap()
195            .finish();
196
197        assert!(response
198            .headers()
199            .get(header::SEC_WEBSOCKET_PROTOCOL)
200            .is_none());
201    }
202
203    #[test]
204    fn handshake_supports_multiple_protocol_headers() {
205        let req = TestRequest::default()
206            .insert_header((header::UPGRADE, HeaderValue::from_static("websocket")))
207            .insert_header((header::CONNECTION, HeaderValue::from_static("upgrade")))
208            .insert_header((
209                header::SEC_WEBSOCKET_VERSION,
210                HeaderValue::from_static("13"),
211            ))
212            .insert_header((
213                header::SEC_WEBSOCKET_KEY,
214                HeaderValue::from_static("x3JJHMbDL1EzLkh9GBhXDw=="),
215            ))
216            .append_header((header::SEC_WEBSOCKET_PROTOCOL, "p1"))
217            .append_header((header::SEC_WEBSOCKET_PROTOCOL, "p2"))
218            .to_http_request();
219
220        let response = handshake_with_protocols(&req, &["p2"]).unwrap().finish();
221
222        assert_eq!(
223            response.headers().get(header::SEC_WEBSOCKET_PROTOCOL),
224            Some(&HeaderValue::from_static("p2")),
225        );
226    }
227}