rouille_ng/websocket/mod.rs
1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10//! Support for websockets.
11//!
12//! Using websockets is done with the following steps:
13//!
14//! - The websocket client (usually the browser through some Javascript) must send a request to the
15//! server to initiate the process. Examples for how to do this in Javascript are out of scope
16//! of this documentation but should be easy to find on the web.
17//! - The server written with rouille_ng must answer that request with the `start()` function defined
18//! in this module. This function returns an error if the request is not a websocket
19//! initialization request.
20//! - The `start()` function also returns a `Receiver<Websocket>` object. Once that `Receiver`
21//! contains a value, the connection has been initiated.
22//! - You can then use the `Websocket` object to communicate with the client through the `Read`
23//! and `Write` traits.
24//!
25//! # Subprotocols
26//!
27//! The websocket connection will produce either text or binary messages. But these messages do not
28//! have a meaning per se, and must also be interpreted in some way. The way messages are
29//! interpreted during a websocket connection is called a *subprotocol*.
30//!
31//! When you call `start()` you have to indicate which subprotocol the connection is going to use.
32//! This subprotocol must match one of the subprotocols that were passed by the client during its
33//! request, otherwise `start()` will return an error. It is also possible to pass `None`, in which
34//! case the subprotocol is unknown to both the client and the server.
35//!
36//! There are usually three ways to handle subprotocols on the server-side:
37//!
38//! - You don't really care about subprotocols because you use websockets for your own needs. You
39//! can just pass `None` to `start()`. The connection will thus never fail unless the client
40//! decides to.
41//! - Your route only handles one subprotocol. Just pass this subprotocol to `start()` and you will
42//! get an error (which you can handle for example with `try_or_400!`) if it's not supported by
43//! the client.
44//! - Your route supports multiple subprotocols. This is the most complex situation as you will
45//! have to enumerate the protocols with `requested_protocols()` and choose one.
46//!
47//! # Example
48//!
49//! ```
50//! # #[macro_use] extern crate rouille_ng;
51//! use std::sync::Mutex;
52//! use std::sync::mpsc::Receiver;
53//!
54//! use rouille_ng::Request;
55//! use rouille_ng::Response;
56//! use rouille_ng::websocket;
57//! # fn main() {}
58//!
59//! fn handle_request(request: &Request, websockets: &Mutex<Vec<Receiver<websocket::Websocket>>>)
60//! -> Response
61//! {
62//! let (response, websocket) = try_or_400!(websocket::start(request, Some("my-subprotocol")));
63//! websockets.lock().unwrap().push(websocket);
64//! response
65//! }
66//! ```
67
68pub use self::websocket::Message;
69pub use self::websocket::SendError;
70pub use self::websocket::Websocket;
71
72use base64;
73use sha1::Sha1;
74use std::borrow::Cow;
75use std::error;
76use std::fmt;
77use std::sync::mpsc;
78use std::vec::IntoIter as VecIntoIter;
79
80use Request;
81use Response;
82
83mod low_level;
84mod websocket;
85
86/// Error that can happen when attempting to start websocket.
87#[derive(Debug)]
88pub enum WebsocketError {
89 /// The request does not match a websocket request.
90 ///
91 /// The conditions are:
92 /// - The method must be `GET`.
93 /// - The HTTP version must be at least 1.1.
94 /// - The request must include `Host`.
95 /// - The `Connection` header must include `websocket`.
96 /// - The `Sec-WebSocket-Version` header must be `13`.
97 /// - Must have a `Sec-WebSocket-Key` header.
98 InvalidWebsocketRequest,
99
100 /// The subprotocol passed to the function was not requested by the client.
101 WrongSubprotocol,
102}
103
104impl error::Error for WebsocketError {}
105
106impl fmt::Display for WebsocketError {
107 #[inline]
108 fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
109 let description = match *self {
110 WebsocketError::InvalidWebsocketRequest => {
111 "the request does not match a websocket request"
112 }
113 WebsocketError::WrongSubprotocol => {
114 "the subprotocol passed to the function was not requested by the client"
115 }
116 };
117
118 write!(fmt, "{}", description)
119 }
120}
121
122/// Builds a `Response` that initiates the websocket protocol.
123pub fn start<S>(
124 request: &Request,
125 subprotocol: Option<S>,
126) -> Result<(Response, mpsc::Receiver<Websocket>), WebsocketError>
127where
128 S: Into<Cow<'static, str>>,
129{
130 let subprotocol = subprotocol.map(|s| s.into());
131
132 if request.method() != "GET" {
133 return Err(WebsocketError::InvalidWebsocketRequest);
134 }
135
136 // TODO:
137 /*if request.http_version() < &HTTPVersion(1, 1) {
138 return Err(WebsocketError::InvalidWebsocketRequest);
139 }*/
140
141 match request.header("Connection") {
142 Some(ref h) if h.to_ascii_lowercase().contains("upgrade") => (),
143 _ => return Err(WebsocketError::InvalidWebsocketRequest),
144 }
145
146 match request.header("Upgrade") {
147 Some(ref h) if h.to_ascii_lowercase().contains("websocket") => (),
148 _ => return Err(WebsocketError::InvalidWebsocketRequest),
149 }
150
151 // TODO: there are some version shanigans to handle
152 // see https://tools.ietf.org/html/rfc6455#section-4.4
153 match request.header("Sec-WebSocket-Version") {
154 Some(h) if h == "13" => (),
155 _ => return Err(WebsocketError::InvalidWebsocketRequest),
156 }
157
158 if let Some(ref sp) = subprotocol {
159 if !requested_protocols(request).any(|p| &p == sp) {
160 return Err(WebsocketError::WrongSubprotocol);
161 }
162 }
163
164 let key = {
165 let in_key = match request.header("Sec-WebSocket-Key") {
166 Some(h) => h,
167 None => return Err(WebsocketError::InvalidWebsocketRequest),
168 };
169
170 convert_key(&in_key)
171 };
172
173 let (tx, rx) = mpsc::channel();
174
175 let mut response = Response::text("");
176 response.status_code = 101;
177 response
178 .headers
179 .push(("Upgrade".into(), "websocket".into()));
180 if let Some(sp) = subprotocol {
181 response.headers.push(("Sec-Websocket-Protocol".into(), sp));
182 }
183 response
184 .headers
185 .push(("Sec-Websocket-Accept".into(), key.into()));
186 response.upgrade = Some(Box::new(tx) as Box<_>);
187 Ok((response, rx))
188}
189
190/// Returns a list of the websocket protocols requested by the client.
191///
192/// # Example
193///
194/// ```
195/// use rouille_ng::websocket;
196///
197/// # let request: rouille_ng::Request = return;
198/// for protocol in websocket::requested_protocols(&request) {
199/// // ...
200/// }
201/// ```
202// TODO: return references to the request
203pub fn requested_protocols(request: &Request) -> RequestedProtocolsIter {
204 match request.header("Sec-WebSocket-Protocol") {
205 None => RequestedProtocolsIter {
206 iter: Vec::new().into_iter(),
207 },
208 Some(h) => {
209 let iter = h
210 .split(',')
211 .map(|s| s.trim())
212 .filter(|s| !s.is_empty())
213 .map(|s| s.to_owned())
214 .collect::<Vec<_>>()
215 .into_iter();
216 RequestedProtocolsIter { iter }
217 }
218 }
219}
220
221/// Iterator to the list of protocols requested by the user.
222pub struct RequestedProtocolsIter {
223 iter: VecIntoIter<String>,
224}
225
226impl Iterator for RequestedProtocolsIter {
227 type Item = String;
228
229 #[inline]
230 fn next(&mut self) -> Option<String> {
231 self.iter.next()
232 }
233
234 #[inline]
235 fn size_hint(&self) -> (usize, Option<usize>) {
236 self.iter.size_hint()
237 }
238}
239
240impl ExactSizeIterator for RequestedProtocolsIter {}
241
242/// Turns a `Sec-WebSocket-Key` into a `Sec-WebSocket-Accept`.
243fn convert_key(input: &str) -> String {
244 let mut sha1 = Sha1::new();
245 sha1.update(input.as_bytes());
246 sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
247
248 base64::encode_config(&sha1.digest().bytes(), base64::STANDARD)
249}