pub use self::websocket::Message;
pub use self::websocket::SendError;
pub use self::websocket::Websocket;
use base64;
use sha1::{Digest, Sha1};
use std::borrow::Cow;
use std::error;
use std::fmt;
use std::sync::mpsc;
use std::vec::IntoIter as VecIntoIter;
use Request;
use Response;
mod low_level;
#[allow(clippy::module_inception)]
mod websocket;
#[derive(Debug)]
pub enum WebsocketError {
InvalidWebsocketRequest,
WrongSubprotocol,
}
impl error::Error for WebsocketError {}
impl fmt::Display for WebsocketError {
#[inline]
fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
let description = match *self {
WebsocketError::InvalidWebsocketRequest => {
"the request does not match a websocket request"
}
WebsocketError::WrongSubprotocol => {
"the subprotocol passed to the function was not requested by the client"
}
};
write!(fmt, "{}", description)
}
}
pub fn start<S>(
request: &Request,
subprotocol: Option<S>,
) -> Result<(Response, mpsc::Receiver<Websocket>), WebsocketError>
where
S: Into<Cow<'static, str>>,
{
let subprotocol = subprotocol.map(|s| s.into());
if request.method() != "GET" {
return Err(WebsocketError::InvalidWebsocketRequest);
}
match request.header("Connection") {
Some(h) if h.to_ascii_lowercase().contains("upgrade") => (),
_ => return Err(WebsocketError::InvalidWebsocketRequest),
}
match request.header("Upgrade") {
Some(h) if h.to_ascii_lowercase().contains("websocket") => (),
_ => return Err(WebsocketError::InvalidWebsocketRequest),
}
match request.header("Sec-WebSocket-Version") {
Some(h) if h == "13" => (),
_ => return Err(WebsocketError::InvalidWebsocketRequest),
}
if let Some(ref sp) = subprotocol {
if !requested_protocols(request).any(|p| &p == sp) {
return Err(WebsocketError::WrongSubprotocol);
}
}
let key = {
let in_key = match request.header("Sec-WebSocket-Key") {
Some(h) => h,
None => return Err(WebsocketError::InvalidWebsocketRequest),
};
convert_key(in_key)
};
let (tx, rx) = mpsc::channel();
let mut response = Response::text("");
response.status_code = 101;
response
.headers
.push(("Upgrade".into(), "websocket".into()));
if let Some(sp) = subprotocol {
response.headers.push(("Sec-Websocket-Protocol".into(), sp));
}
response
.headers
.push(("Sec-Websocket-Accept".into(), key.into()));
response.upgrade = Some(Box::new(tx) as Box<_>);
Ok((response, rx))
}
pub fn requested_protocols(request: &Request) -> RequestedProtocolsIter {
match request.header("Sec-WebSocket-Protocol") {
None => RequestedProtocolsIter {
iter: Vec::new().into_iter(),
},
Some(h) => {
let iter = h
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.to_owned())
.collect::<Vec<_>>()
.into_iter();
RequestedProtocolsIter { iter }
}
}
}
pub struct RequestedProtocolsIter {
iter: VecIntoIter<String>,
}
impl Iterator for RequestedProtocolsIter {
type Item = String;
#[inline]
fn next(&mut self) -> Option<String> {
self.iter.next()
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl ExactSizeIterator for RequestedProtocolsIter {}
fn convert_key(input: &str) -> String {
let mut sha1 = Sha1::new();
sha1.update(input.as_bytes());
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
base64::encode_config(&sha1.finalize(), base64::STANDARD)
}