rouille_maint_in/websocket/mod.rs
1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10//! Support for websockets.
11//!
12//! Using websockets is done with the following steps:
13//!
14//! - The websocket client (usually the browser through some Javascript) must send a request to the
15//! server to initiate the process. Examples for how to do this in Javascript are out of scope
16//! of this documentation but should be easy to find on the web.
17//! - The server written with rouille must answer that request with the `start()` function defined
18//! in this module. This function returns an error if the request is not a websocket
19//! initialization request.
20//! - The `start()` function also returns a `Receiver<Websocket>` object. Once that `Receiver`
21//! contains a value, the connection has been initiated.
22//! - You can then use the `Websocket` object to communicate with the client through the `Read`
23//! and `Write` traits.
24//!
25//! # Subprotocols
26//!
27//! The websocket connection will produce either text or binary messages. But these messages do not
28//! have a meaning per se, and must also be interpreted in some way. The way messages are
29//! interpreted during a websocket connection is called a *subprotocol*.
30//!
31//! When you call `start()` you have to indicate which subprotocol the connection is going to use.
32//! This subprotocol must match one of the subprotocols that were passed by the client during its
33//! request, otherwise `start()` will return an error. It is also possible to pass `None`, in which
34//! case the subprotocol is unknown to both the client and the server.
35//!
36//! There are usually three ways to handle subprotocols on the server-side:
37//!
38//! - You don't really care about subprotocols because you use websockets for your own needs. You
39//! can just pass `None` to `start()`. The connection will thus never fail unless the client
40//! decides to.
41//! - Your route only handles one subprotocol. Just pass this subprotocol to `start()` and you will
42//! get an error (which you can handle for example with `try_or_400!`) if it's not supported by
43//! the client.
44//! - Your route supports multiple subprotocols. This is the most complex situation as you will
45//! have to enumerate the protocols with `requested_protocols()` and choose one.
46//!
47//! # Example
48//!
49//! ```
50//! # #[macro_use] extern crate rouille_maint_in as rouille;
51//! use std::sync::Mutex;
52//! use std::sync::mpsc::Receiver;
53//!
54//! use rouille::Request;
55//! use rouille::Response;
56//! use rouille::websocket;
57//! # fn main() {}
58//!
59//! fn handle_request(request: &Request, websockets: &Mutex<Vec<Receiver<websocket::Websocket>>>)
60//! -> Response
61//! {
62//! let (response, websocket) = try_or_400!(websocket::start(request, Some("my-subprotocol")));
63//! websockets.lock().unwrap().push(websocket);
64//! response
65//! }
66//! ```
67
68pub use self::websocket::Message;
69pub use self::websocket::SendError;
70pub use self::websocket::Websocket;
71
72use base64;
73use std::borrow::Cow;
74use std::error;
75use std::fmt;
76use std::sync::mpsc;
77use std::vec::IntoIter as VecIntoIter;
78use sha1::Sha1;
79
80use crate::Request;
81use crate::Response;
82
83mod low_level;
84mod websocket;
85
86/// Error that can happen when attempting to start websocket.
87#[derive(Debug)]
88pub enum WebsocketError {
89 /// The request does not match a websocket request.
90 ///
91 /// The conditions are:
92 /// - The method must be `GET`.
93 /// - The HTTP version must be at least 1.1.
94 /// - The request must include `Host`.
95 /// - The `Connection` header must include `websocket`.
96 /// - The `Sec-WebSocket-Version` header must be `13`.
97 /// - Must have a `Sec-WebSocket-Key` header.
98 InvalidWebsocketRequest,
99
100 /// The subprotocol passed to the function was not requested by the client.
101 WrongSubprotocol,
102}
103
104impl error::Error for WebsocketError {}
105
106impl fmt::Display for WebsocketError {
107 #[inline]
108 fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
109 let description = match *self {
110 WebsocketError::InvalidWebsocketRequest => {
111 "the request does not match a websocket request"
112 },
113 WebsocketError::WrongSubprotocol => {
114 "the subprotocol passed to the function was not requested by the client"
115 },
116 };
117
118 write!(fmt, "{}", description)
119 }
120}
121
122/// Builds a `Response` that initiates the websocket protocol.
123pub fn start<S>(request: &Request, subprotocol: Option<S>)
124 -> Result<(Response, mpsc::Receiver<Websocket>), WebsocketError>
125 where S: Into<Cow<'static, str>>
126{
127 let subprotocol = subprotocol.map(|s| s.into());
128
129 if request.method() != "GET" {
130 return Err(WebsocketError::InvalidWebsocketRequest);
131 }
132
133 // TODO:
134 /*if request.http_version() < &HTTPVersion(1, 1) {
135 return Err(WebsocketError::InvalidWebsocketRequest);
136 }*/
137
138 match request.header("Connection") {
139 Some(ref h) if h.to_ascii_lowercase().contains("upgrade") => (),
140 _ => return Err(WebsocketError::InvalidWebsocketRequest),
141 }
142
143 match request.header("Upgrade") {
144 Some(ref h) if h.to_ascii_lowercase().contains("websocket") => (),
145 _ => return Err(WebsocketError::InvalidWebsocketRequest),
146 }
147
148 // TODO: there are some version shanigans to handle
149 // see https://tools.ietf.org/html/rfc6455#section-4.4
150 match request.header("Sec-WebSocket-Version") {
151 Some(h) if h == "13" => (),
152 _ => return Err(WebsocketError::InvalidWebsocketRequest),
153 }
154
155 if let Some(ref sp) = subprotocol {
156 if !requested_protocols(request).any(|p| &p == sp) {
157 return Err(WebsocketError::WrongSubprotocol);
158 }
159 }
160
161 let key = {
162 let in_key = match request.header("Sec-WebSocket-Key") {
163 Some(h) => h,
164 None => return Err(WebsocketError::InvalidWebsocketRequest),
165 };
166
167 convert_key(&in_key)
168 };
169
170 let (tx, rx) = mpsc::channel();
171
172 let mut response = Response::text("");
173 response.status_code = 101;
174 response.headers.push(("Upgrade".into(), "websocket".into()));
175 if let Some(sp) = subprotocol {
176 response.headers.push(("Sec-Websocket-Protocol".into(), sp));
177 }
178 response.headers.push(("Sec-Websocket-Accept".into(), key.into()));
179 response.upgrade = Some(Box::new(tx) as Box<_>);
180 Ok((response, rx))
181}
182
183/// Returns a list of the websocket protocols requested by the client.
184///
185/// # Example
186///
187/// ```
188/// # use rouille_maint_in as rouille;
189/// use rouille::websocket;
190///
191/// # let request: rouille::Request = return;
192/// for protocol in websocket::requested_protocols(&request) {
193/// // ...
194/// }
195/// ```
196// TODO: return references to the request
197pub fn requested_protocols(request: &Request) -> RequestedProtocolsIter {
198 match request.header("Sec-WebSocket-Protocol") {
199 None => RequestedProtocolsIter { iter: Vec::new().into_iter() },
200 Some(h) => {
201 let iter = h.split(',')
202 .map(|s| s.trim())
203 .filter(|s| !s.is_empty())
204 .map(|s| s.to_owned())
205 .collect::<Vec<_>>().into_iter();
206 RequestedProtocolsIter { iter }
207 }
208 }
209}
210
211/// Iterator to the list of protocols requested by the user.
212pub struct RequestedProtocolsIter {
213 iter: VecIntoIter<String>,
214}
215
216impl Iterator for RequestedProtocolsIter {
217 type Item = String;
218
219 #[inline]
220 fn next(&mut self) -> Option<String> {
221 self.iter.next()
222 }
223
224 #[inline]
225 fn size_hint(&self) -> (usize, Option<usize>) {
226 self.iter.size_hint()
227 }
228}
229
230impl ExactSizeIterator for RequestedProtocolsIter {
231}
232
233/// Turns a `Sec-WebSocket-Key` into a `Sec-WebSocket-Accept`.
234fn convert_key(input: &str) -> String {
235 let mut sha1 = Sha1::new();
236 sha1.update(input.as_bytes());
237 sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
238
239 base64::encode_config(&sha1.digest().bytes(), base64::STANDARD)
240}