1#![warn(missing_docs)]
6#![doc(html_logo_url = "https://actix.rs/img/logo.png")]
7#![doc(html_favicon_url = "https://actix.rs/favicon.ico")]
8#![cfg_attr(docsrs, feature(doc_cfg))]
9
10pub use actix_http::ws::{CloseCode, CloseReason, Item, Message, ProtocolError};
11use actix_http::{
12 body::{BodyStream, MessageBody},
13 ws::handshake,
14};
15use actix_web::{http::header, web, HttpRequest, HttpResponse};
16use tokio::sync::{mpsc::channel, oneshot};
17
18mod aggregated;
19pub mod codec;
20mod session;
21mod stream;
22
23pub use self::{
24 aggregated::{AggregatedMessage, AggregatedMessageStream},
25 session::{Closed, Session},
26 stream::{MessageStream, StreamingBody},
27};
28
29pub fn handle(
74 req: &HttpRequest,
75 body: web::Payload,
76) -> Result<(HttpResponse, Session, MessageStream), actix_web::Error> {
77 handle_with_protocols(req, body, &[])
78}
79
80pub fn handle_with_protocols(
87 req: &HttpRequest,
88 body: web::Payload,
89 protocols: &[&str],
90) -> Result<(HttpResponse, Session, MessageStream), actix_web::Error> {
91 let mut response = handshake_with_protocols(req, protocols)?;
92 let (tx, rx) = channel(32);
93 let (connection_closed_tx, connection_closed_rx) = oneshot::channel();
94
95 Ok((
96 response
97 .message_body(
98 BodyStream::new(
99 StreamingBody::new(rx).with_connection_close_signal(connection_closed_tx),
100 )
101 .boxed(),
102 )?
103 .into(),
104 Session::new(tx),
105 MessageStream::new(body.into_inner()).with_connection_close_signal(connection_closed_rx),
106 ))
107}
108
109fn handshake_with_protocols(
110 req: &HttpRequest,
111 protocols: &[&str],
112) -> Result<actix_http::ResponseBuilder, actix_http::ws::HandshakeError> {
113 let mut response = handshake(req.head())?;
114
115 if let Some(protocol) = select_protocol(req, protocols) {
116 response.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocol));
117 }
118
119 Ok(response)
120}
121
122fn select_protocol<'a>(req: &'a HttpRequest, protocols: &[&str]) -> Option<&'a str> {
123 for requested_protocols in req.headers().get_all(header::SEC_WEBSOCKET_PROTOCOL) {
124 let Ok(requested_protocols) = requested_protocols.to_str() else {
125 continue;
126 };
127
128 for requested_protocol in requested_protocols.split(',').map(str::trim) {
129 if requested_protocol.is_empty() {
130 continue;
131 }
132
133 if protocols
134 .iter()
135 .any(|supported_protocol| *supported_protocol == requested_protocol)
136 {
137 return Some(requested_protocol);
138 }
139 }
140 }
141
142 None
143}
144
145#[cfg(test)]
146mod tests {
147 use actix_web::{
148 http::header::{self, HeaderValue},
149 test::TestRequest,
150 HttpRequest,
151 };
152
153 use super::handshake_with_protocols;
154
155 fn ws_request(protocols: Option<&'static str>) -> HttpRequest {
156 let mut req = TestRequest::default()
157 .insert_header((header::UPGRADE, HeaderValue::from_static("websocket")))
158 .insert_header((header::CONNECTION, HeaderValue::from_static("upgrade")))
159 .insert_header((
160 header::SEC_WEBSOCKET_VERSION,
161 HeaderValue::from_static("13"),
162 ))
163 .insert_header((
164 header::SEC_WEBSOCKET_KEY,
165 HeaderValue::from_static("x3JJHMbDL1EzLkh9GBhXDw=="),
166 ));
167
168 if let Some(protocols) = protocols {
169 req = req.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocols));
170 }
171
172 req.to_http_request()
173 }
174
175 #[test]
176 fn handshake_selects_first_supported_client_protocol() {
177 let req = ws_request(Some("p1,p2,p3"));
178
179 let response = handshake_with_protocols(&req, &["p3", "p2"])
180 .unwrap()
181 .finish();
182
183 assert_eq!(
184 response.headers().get(header::SEC_WEBSOCKET_PROTOCOL),
185 Some(&HeaderValue::from_static("p2")),
186 );
187 }
188
189 #[test]
190 fn handshake_omits_protocol_header_without_overlap() {
191 let req = ws_request(Some("p1,p2,p3"));
192
193 let response = handshake_with_protocols(&req, &["graphql"])
194 .unwrap()
195 .finish();
196
197 assert!(response
198 .headers()
199 .get(header::SEC_WEBSOCKET_PROTOCOL)
200 .is_none());
201 }
202
203 #[test]
204 fn handshake_supports_multiple_protocol_headers() {
205 let req = TestRequest::default()
206 .insert_header((header::UPGRADE, HeaderValue::from_static("websocket")))
207 .insert_header((header::CONNECTION, HeaderValue::from_static("upgrade")))
208 .insert_header((
209 header::SEC_WEBSOCKET_VERSION,
210 HeaderValue::from_static("13"),
211 ))
212 .insert_header((
213 header::SEC_WEBSOCKET_KEY,
214 HeaderValue::from_static("x3JJHMbDL1EzLkh9GBhXDw=="),
215 ))
216 .append_header((header::SEC_WEBSOCKET_PROTOCOL, "p1"))
217 .append_header((header::SEC_WEBSOCKET_PROTOCOL, "p2"))
218 .to_http_request();
219
220 let response = handshake_with_protocols(&req, &["p2"]).unwrap().finish();
221
222 assert_eq!(
223 response.headers().get(header::SEC_WEBSOCKET_PROTOCOL),
224 Some(&HeaderValue::from_static("p2")),
225 );
226 }
227}