#![warn(missing_docs)]
#![doc(html_logo_url = "https://actix.rs/img/logo.png")]
#![doc(html_favicon_url = "https://actix.rs/favicon.ico")]
#![cfg_attr(docsrs, feature(doc_cfg))]
pub use actix_http::ws::{CloseCode, CloseReason, Item, Message, ProtocolError};
use actix_http::{
body::{BodyStream, MessageBody},
ws::handshake,
};
use actix_web::{http::header, web, HttpRequest, HttpResponse};
use tokio::sync::{mpsc::channel, oneshot};
mod aggregated;
pub mod codec;
mod session;
mod stream;
pub use self::{
aggregated::{AggregatedMessage, AggregatedMessageStream},
session::{Closed, Session},
stream::{MessageStream, StreamingBody},
};
pub fn handle(
req: &HttpRequest,
body: web::Payload,
) -> Result<(HttpResponse, Session, MessageStream), actix_web::Error> {
handle_with_protocols(req, body, &[])
}
pub fn handle_with_protocols(
req: &HttpRequest,
body: web::Payload,
protocols: &[&str],
) -> Result<(HttpResponse, Session, MessageStream), actix_web::Error> {
let mut response = handshake_with_protocols(req, protocols)?;
let (tx, rx) = channel(32);
let (connection_closed_tx, connection_closed_rx) = oneshot::channel();
Ok((
response
.message_body(
BodyStream::new(
StreamingBody::new(rx).with_connection_close_signal(connection_closed_tx),
)
.boxed(),
)?
.into(),
Session::new(tx),
MessageStream::new(body.into_inner()).with_connection_close_signal(connection_closed_rx),
))
}
fn handshake_with_protocols(
req: &HttpRequest,
protocols: &[&str],
) -> Result<actix_http::ResponseBuilder, actix_http::ws::HandshakeError> {
let mut response = handshake(req.head())?;
if let Some(protocol) = select_protocol(req, protocols) {
response.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocol));
}
Ok(response)
}
fn select_protocol<'a>(req: &'a HttpRequest, protocols: &[&str]) -> Option<&'a str> {
for requested_protocols in req.headers().get_all(header::SEC_WEBSOCKET_PROTOCOL) {
let Ok(requested_protocols) = requested_protocols.to_str() else {
continue;
};
for requested_protocol in requested_protocols.split(',').map(str::trim) {
if requested_protocol.is_empty() {
continue;
}
if protocols
.iter()
.any(|supported_protocol| *supported_protocol == requested_protocol)
{
return Some(requested_protocol);
}
}
}
None
}
#[cfg(test)]
mod tests {
use actix_web::{
http::header::{self, HeaderValue},
test::TestRequest,
HttpRequest,
};
use super::handshake_with_protocols;
fn ws_request(protocols: Option<&'static str>) -> HttpRequest {
let mut req = TestRequest::default()
.insert_header((header::UPGRADE, HeaderValue::from_static("websocket")))
.insert_header((header::CONNECTION, HeaderValue::from_static("upgrade")))
.insert_header((
header::SEC_WEBSOCKET_VERSION,
HeaderValue::from_static("13"),
))
.insert_header((
header::SEC_WEBSOCKET_KEY,
HeaderValue::from_static("x3JJHMbDL1EzLkh9GBhXDw=="),
));
if let Some(protocols) = protocols {
req = req.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocols));
}
req.to_http_request()
}
#[test]
fn handshake_selects_first_supported_client_protocol() {
let req = ws_request(Some("p1,p2,p3"));
let response = handshake_with_protocols(&req, &["p3", "p2"])
.unwrap()
.finish();
assert_eq!(
response.headers().get(header::SEC_WEBSOCKET_PROTOCOL),
Some(&HeaderValue::from_static("p2")),
);
}
#[test]
fn handshake_omits_protocol_header_without_overlap() {
let req = ws_request(Some("p1,p2,p3"));
let response = handshake_with_protocols(&req, &["graphql"])
.unwrap()
.finish();
assert!(response
.headers()
.get(header::SEC_WEBSOCKET_PROTOCOL)
.is_none());
}
#[test]
fn handshake_supports_multiple_protocol_headers() {
let req = TestRequest::default()
.insert_header((header::UPGRADE, HeaderValue::from_static("websocket")))
.insert_header((header::CONNECTION, HeaderValue::from_static("upgrade")))
.insert_header((
header::SEC_WEBSOCKET_VERSION,
HeaderValue::from_static("13"),
))
.insert_header((
header::SEC_WEBSOCKET_KEY,
HeaderValue::from_static("x3JJHMbDL1EzLkh9GBhXDw=="),
))
.append_header((header::SEC_WEBSOCKET_PROTOCOL, "p1"))
.append_header((header::SEC_WEBSOCKET_PROTOCOL, "p2"))
.to_http_request();
let response = handshake_with_protocols(&req, &["p2"]).unwrap().finish();
assert_eq!(
response.headers().get(header::SEC_WEBSOCKET_PROTOCOL),
Some(&HeaderValue::from_static("p2")),
);
}
}