actori_http/ws/
mod.rs

1//! WebSocket protocol support.
2//!
3//! To setup a `WebSocket`, first do web socket handshake then on success
4//! convert `Payload` into a `WsStream` stream and then use `WsWriter` to
5//! communicate with the peer.
6use std::io;
7
8use derive_more::{Display, From};
9use http::{header, Method, StatusCode};
10
11use crate::error::ResponseError;
12use crate::message::RequestHead;
13use crate::response::{Response, ResponseBuilder};
14
15mod codec;
16mod dispatcher;
17mod frame;
18mod mask;
19mod proto;
20
21pub use self::codec::{Codec, Frame, Item, Message};
22pub use self::dispatcher::Dispatcher;
23pub use self::frame::Parser;
24pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
25
26/// Websocket protocol errors
27#[derive(Debug, Display, From)]
28pub enum ProtocolError {
29    /// Received an unmasked frame from client
30    #[display(fmt = "Received an unmasked frame from client")]
31    UnmaskedFrame,
32    /// Received a masked frame from server
33    #[display(fmt = "Received a masked frame from server")]
34    MaskedFrame,
35    /// Encountered invalid opcode
36    #[display(fmt = "Invalid opcode: {}", _0)]
37    InvalidOpcode(u8),
38    /// Invalid control frame length
39    #[display(fmt = "Invalid control frame length: {}", _0)]
40    InvalidLength(usize),
41    /// Bad web socket op code
42    #[display(fmt = "Bad web socket op code")]
43    BadOpCode,
44    /// A payload reached size limit.
45    #[display(fmt = "A payload reached size limit.")]
46    Overflow,
47    /// Continuation is not started
48    #[display(fmt = "Continuation is not started.")]
49    ContinuationNotStarted,
50    /// Received new continuation but it is already started
51    #[display(fmt = "Received new continuation but it is already started")]
52    ContinuationStarted,
53    /// Unknown continuation fragment
54    #[display(fmt = "Unknown continuation fragment.")]
55    ContinuationFragment(OpCode),
56    /// Io error
57    #[display(fmt = "io error: {}", _0)]
58    Io(io::Error),
59}
60
61impl ResponseError for ProtocolError {}
62
63/// Websocket handshake errors
64#[derive(PartialEq, Debug, Display)]
65pub enum HandshakeError {
66    /// Only get method is allowed
67    #[display(fmt = "Method not allowed")]
68    GetMethodRequired,
69    /// Upgrade header if not set to websocket
70    #[display(fmt = "Websocket upgrade is expected")]
71    NoWebsocketUpgrade,
72    /// Connection header is not set to upgrade
73    #[display(fmt = "Connection upgrade is expected")]
74    NoConnectionUpgrade,
75    /// Websocket version header is not set
76    #[display(fmt = "Websocket version header is required")]
77    NoVersionHeader,
78    /// Unsupported websocket version
79    #[display(fmt = "Unsupported version")]
80    UnsupportedVersion,
81    /// Websocket key is not set or wrong
82    #[display(fmt = "Unknown websocket key")]
83    BadWebsocketKey,
84}
85
86impl ResponseError for HandshakeError {
87    fn error_response(&self) -> Response {
88        match *self {
89            HandshakeError::GetMethodRequired => Response::MethodNotAllowed()
90                .header(header::ALLOW, "GET")
91                .finish(),
92            HandshakeError::NoWebsocketUpgrade => Response::BadRequest()
93                .reason("No WebSocket UPGRADE header found")
94                .finish(),
95            HandshakeError::NoConnectionUpgrade => Response::BadRequest()
96                .reason("No CONNECTION upgrade")
97                .finish(),
98            HandshakeError::NoVersionHeader => Response::BadRequest()
99                .reason("Websocket version header is required")
100                .finish(),
101            HandshakeError::UnsupportedVersion => Response::BadRequest()
102                .reason("Unsupported version")
103                .finish(),
104            HandshakeError::BadWebsocketKey => {
105                Response::BadRequest().reason("Handshake error").finish()
106            }
107        }
108    }
109}
110
111/// Verify `WebSocket` handshake request and create handshake reponse.
112// /// `protocols` is a sequence of known protocols. On successful handshake,
113// /// the returned response headers contain the first protocol in this list
114// /// which the server also knows.
115pub fn handshake(req: &RequestHead) -> Result<ResponseBuilder, HandshakeError> {
116    verify_handshake(req)?;
117    Ok(handshake_response(req))
118}
119
120/// Verify `WebSocket` handshake request.
121// /// `protocols` is a sequence of known protocols. On successful handshake,
122// /// the returned response headers contain the first protocol in this list
123// /// which the server also knows.
124pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> {
125    // WebSocket accepts only GET
126    if req.method != Method::GET {
127        return Err(HandshakeError::GetMethodRequired);
128    }
129
130    // Check for "UPGRADE" to websocket header
131    let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) {
132        if let Ok(s) = hdr.to_str() {
133            s.to_ascii_lowercase().contains("websocket")
134        } else {
135            false
136        }
137    } else {
138        false
139    };
140    if !has_hdr {
141        return Err(HandshakeError::NoWebsocketUpgrade);
142    }
143
144    // Upgrade connection
145    if !req.upgrade() {
146        return Err(HandshakeError::NoConnectionUpgrade);
147    }
148
149    // check supported version
150    if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) {
151        return Err(HandshakeError::NoVersionHeader);
152    }
153    let supported_ver = {
154        if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
155            hdr == "13" || hdr == "8" || hdr == "7"
156        } else {
157            false
158        }
159    };
160    if !supported_ver {
161        return Err(HandshakeError::UnsupportedVersion);
162    }
163
164    // check client handshake for validity
165    if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
166        return Err(HandshakeError::BadWebsocketKey);
167    }
168    Ok(())
169}
170
171/// Create websocket's handshake response
172///
173/// This function returns handshake `Response`, ready to send to peer.
174pub fn handshake_response(req: &RequestHead) -> ResponseBuilder {
175    let key = {
176        let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
177        proto::hash_key(key.as_ref())
178    };
179
180    Response::build(StatusCode::SWITCHING_PROTOCOLS)
181        .upgrade("websocket")
182        .header(header::TRANSFER_ENCODING, "chunked")
183        .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
184        .take()
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::test::TestRequest;
191    use http::{header, Method};
192
193    #[test]
194    fn test_handshake() {
195        let req = TestRequest::default().method(Method::POST).finish();
196        assert_eq!(
197            HandshakeError::GetMethodRequired,
198            verify_handshake(req.head()).err().unwrap()
199        );
200
201        let req = TestRequest::default().finish();
202        assert_eq!(
203            HandshakeError::NoWebsocketUpgrade,
204            verify_handshake(req.head()).err().unwrap()
205        );
206
207        let req = TestRequest::default()
208            .header(header::UPGRADE, header::HeaderValue::from_static("test"))
209            .finish();
210        assert_eq!(
211            HandshakeError::NoWebsocketUpgrade,
212            verify_handshake(req.head()).err().unwrap()
213        );
214
215        let req = TestRequest::default()
216            .header(
217                header::UPGRADE,
218                header::HeaderValue::from_static("websocket"),
219            )
220            .finish();
221        assert_eq!(
222            HandshakeError::NoConnectionUpgrade,
223            verify_handshake(req.head()).err().unwrap()
224        );
225
226        let req = TestRequest::default()
227            .header(
228                header::UPGRADE,
229                header::HeaderValue::from_static("websocket"),
230            )
231            .header(
232                header::CONNECTION,
233                header::HeaderValue::from_static("upgrade"),
234            )
235            .finish();
236        assert_eq!(
237            HandshakeError::NoVersionHeader,
238            verify_handshake(req.head()).err().unwrap()
239        );
240
241        let req = TestRequest::default()
242            .header(
243                header::UPGRADE,
244                header::HeaderValue::from_static("websocket"),
245            )
246            .header(
247                header::CONNECTION,
248                header::HeaderValue::from_static("upgrade"),
249            )
250            .header(
251                header::SEC_WEBSOCKET_VERSION,
252                header::HeaderValue::from_static("5"),
253            )
254            .finish();
255        assert_eq!(
256            HandshakeError::UnsupportedVersion,
257            verify_handshake(req.head()).err().unwrap()
258        );
259
260        let req = TestRequest::default()
261            .header(
262                header::UPGRADE,
263                header::HeaderValue::from_static("websocket"),
264            )
265            .header(
266                header::CONNECTION,
267                header::HeaderValue::from_static("upgrade"),
268            )
269            .header(
270                header::SEC_WEBSOCKET_VERSION,
271                header::HeaderValue::from_static("13"),
272            )
273            .finish();
274        assert_eq!(
275            HandshakeError::BadWebsocketKey,
276            verify_handshake(req.head()).err().unwrap()
277        );
278
279        let req = TestRequest::default()
280            .header(
281                header::UPGRADE,
282                header::HeaderValue::from_static("websocket"),
283            )
284            .header(
285                header::CONNECTION,
286                header::HeaderValue::from_static("upgrade"),
287            )
288            .header(
289                header::SEC_WEBSOCKET_VERSION,
290                header::HeaderValue::from_static("13"),
291            )
292            .header(
293                header::SEC_WEBSOCKET_KEY,
294                header::HeaderValue::from_static("13"),
295            )
296            .finish();
297        assert_eq!(
298            StatusCode::SWITCHING_PROTOCOLS,
299            handshake_response(req.head()).finish().status()
300        );
301    }
302
303    #[test]
304    fn test_wserror_http_response() {
305        let resp: Response = HandshakeError::GetMethodRequired.error_response();
306        assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
307        let resp: Response = HandshakeError::NoWebsocketUpgrade.error_response();
308        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
309        let resp: Response = HandshakeError::NoConnectionUpgrade.error_response();
310        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
311        let resp: Response = HandshakeError::NoVersionHeader.error_response();
312        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
313        let resp: Response = HandshakeError::UnsupportedVersion.error_response();
314        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
315        let resp: Response = HandshakeError::BadWebsocketKey.error_response();
316        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
317    }
318}