rouille_ng/websocket/
mod.rs

1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10//! Support for websockets.
11//!
12//! Using websockets is done with the following steps:
13//!
14//! - The websocket client (usually the browser through some Javascript) must send a request to the
15//!   server to initiate the process. Examples for how to do this in Javascript are out of scope
16//!   of this documentation but should be easy to find on the web.
17//! - The server written with rouille_ng must answer that request with the `start()` function defined
18//!   in this module. This function returns an error if the request is not a websocket
19//!   initialization request.
20//! - The `start()` function also returns a `Receiver<Websocket>` object. Once that `Receiver`
21//!   contains a value, the connection has been initiated.
22//! - You can then use the `Websocket` object to communicate with the client through the `Read`
23//!   and `Write` traits.
24//!
25//! # Subprotocols
26//!
27//! The websocket connection will produce either text or binary messages. But these messages do not
28//! have a meaning per se, and must also be interpreted in some way. The way messages are
29//! interpreted during a websocket connection is called a *subprotocol*.
30//!
31//! When you call `start()` you have to indicate which subprotocol the connection is going to use.
32//! This subprotocol must match one of the subprotocols that were passed by the client during its
33//! request, otherwise `start()` will return an error. It is also possible to pass `None`, in which
34//! case the subprotocol is unknown to both the client and the server.
35//!
36//! There are usually three ways to handle subprotocols on the server-side:
37//!
38//! - You don't really care about subprotocols because you use websockets for your own needs. You
39//!   can just pass `None` to `start()`. The connection will thus never fail unless the client
40//!   decides to.
41//! - Your route only handles one subprotocol. Just pass this subprotocol to `start()` and you will
42//!   get an error (which you can handle for example with `try_or_400!`) if it's not supported by
43//!   the client.
44//! - Your route supports multiple subprotocols. This is the most complex situation as you will
45//!   have to enumerate the protocols with `requested_protocols()` and choose one.
46//!
47//! # Example
48//!
49//! ```
50//! # #[macro_use] extern crate rouille_ng;
51//! use std::sync::Mutex;
52//! use std::sync::mpsc::Receiver;
53//!
54//! use rouille_ng::Request;
55//! use rouille_ng::Response;
56//! use rouille_ng::websocket;
57//! # fn main() {}
58//!
59//! fn handle_request(request: &Request, websockets: &Mutex<Vec<Receiver<websocket::Websocket>>>)
60//!                   -> Response
61//! {
62//!     let (response, websocket) = try_or_400!(websocket::start(request, Some("my-subprotocol")));
63//!     websockets.lock().unwrap().push(websocket);
64//!     response
65//! }
66//! ```
67
68pub use self::websocket::Message;
69pub use self::websocket::SendError;
70pub use self::websocket::Websocket;
71
72use base64;
73use sha1::Sha1;
74use std::borrow::Cow;
75use std::error;
76use std::fmt;
77use std::sync::mpsc;
78use std::vec::IntoIter as VecIntoIter;
79
80use Request;
81use Response;
82
83mod low_level;
84mod websocket;
85
86/// Error that can happen when attempting to start websocket.
87#[derive(Debug)]
88pub enum WebsocketError {
89    /// The request does not match a websocket request.
90    ///
91    /// The conditions are:
92    /// - The method must be `GET`.
93    /// - The HTTP version must be at least 1.1.
94    /// - The request must include `Host`.
95    /// - The `Connection` header must include `websocket`.
96    /// - The `Sec-WebSocket-Version` header must be `13`.
97    /// - Must have a `Sec-WebSocket-Key` header.
98    InvalidWebsocketRequest,
99
100    /// The subprotocol passed to the function was not requested by the client.
101    WrongSubprotocol,
102}
103
104impl error::Error for WebsocketError {}
105
106impl fmt::Display for WebsocketError {
107    #[inline]
108    fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
109        let description = match *self {
110            WebsocketError::InvalidWebsocketRequest => {
111                "the request does not match a websocket request"
112            }
113            WebsocketError::WrongSubprotocol => {
114                "the subprotocol passed to the function was not requested by the client"
115            }
116        };
117
118        write!(fmt, "{}", description)
119    }
120}
121
122/// Builds a `Response` that initiates the websocket protocol.
123pub fn start<S>(
124    request: &Request,
125    subprotocol: Option<S>,
126) -> Result<(Response, mpsc::Receiver<Websocket>), WebsocketError>
127where
128    S: Into<Cow<'static, str>>,
129{
130    let subprotocol = subprotocol.map(|s| s.into());
131
132    if request.method() != "GET" {
133        return Err(WebsocketError::InvalidWebsocketRequest);
134    }
135
136    // TODO:
137    /*if request.http_version() < &HTTPVersion(1, 1) {
138        return Err(WebsocketError::InvalidWebsocketRequest);
139    }*/
140
141    match request.header("Connection") {
142        Some(ref h) if h.to_ascii_lowercase().contains("upgrade") => (),
143        _ => return Err(WebsocketError::InvalidWebsocketRequest),
144    }
145
146    match request.header("Upgrade") {
147        Some(ref h) if h.to_ascii_lowercase().contains("websocket") => (),
148        _ => return Err(WebsocketError::InvalidWebsocketRequest),
149    }
150
151    // TODO: there are some version shanigans to handle
152    // see https://tools.ietf.org/html/rfc6455#section-4.4
153    match request.header("Sec-WebSocket-Version") {
154        Some(h) if h == "13" => (),
155        _ => return Err(WebsocketError::InvalidWebsocketRequest),
156    }
157
158    if let Some(ref sp) = subprotocol {
159        if !requested_protocols(request).any(|p| &p == sp) {
160            return Err(WebsocketError::WrongSubprotocol);
161        }
162    }
163
164    let key = {
165        let in_key = match request.header("Sec-WebSocket-Key") {
166            Some(h) => h,
167            None => return Err(WebsocketError::InvalidWebsocketRequest),
168        };
169
170        convert_key(&in_key)
171    };
172
173    let (tx, rx) = mpsc::channel();
174
175    let mut response = Response::text("");
176    response.status_code = 101;
177    response
178        .headers
179        .push(("Upgrade".into(), "websocket".into()));
180    if let Some(sp) = subprotocol {
181        response.headers.push(("Sec-Websocket-Protocol".into(), sp));
182    }
183    response
184        .headers
185        .push(("Sec-Websocket-Accept".into(), key.into()));
186    response.upgrade = Some(Box::new(tx) as Box<_>);
187    Ok((response, rx))
188}
189
190/// Returns a list of the websocket protocols requested by the client.
191///
192/// # Example
193///
194/// ```
195/// use rouille_ng::websocket;
196///
197/// # let request: rouille_ng::Request = return;
198/// for protocol in websocket::requested_protocols(&request) {
199///     // ...
200/// }
201/// ```
202// TODO: return references to the request
203pub fn requested_protocols(request: &Request) -> RequestedProtocolsIter {
204    match request.header("Sec-WebSocket-Protocol") {
205        None => RequestedProtocolsIter {
206            iter: Vec::new().into_iter(),
207        },
208        Some(h) => {
209            let iter = h
210                .split(',')
211                .map(|s| s.trim())
212                .filter(|s| !s.is_empty())
213                .map(|s| s.to_owned())
214                .collect::<Vec<_>>()
215                .into_iter();
216            RequestedProtocolsIter { iter }
217        }
218    }
219}
220
221/// Iterator to the list of protocols requested by the user.
222pub struct RequestedProtocolsIter {
223    iter: VecIntoIter<String>,
224}
225
226impl Iterator for RequestedProtocolsIter {
227    type Item = String;
228
229    #[inline]
230    fn next(&mut self) -> Option<String> {
231        self.iter.next()
232    }
233
234    #[inline]
235    fn size_hint(&self) -> (usize, Option<usize>) {
236        self.iter.size_hint()
237    }
238}
239
240impl ExactSizeIterator for RequestedProtocolsIter {}
241
242/// Turns a `Sec-WebSocket-Key` into a `Sec-WebSocket-Accept`.
243fn convert_key(input: &str) -> String {
244    let mut sha1 = Sha1::new();
245    sha1.update(input.as_bytes());
246    sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
247
248    base64::encode_config(&sha1.digest().bytes(), base64::STANDARD)
249}