use bytes::BytesMut;
use crate::extension::{Param, Extension};
use http::StatusCode;
use rand::Rng;
use sha1::Sha1;
use smallvec::SmallVec;
use std::{borrow::{Borrow, Cow}, io, fmt, str};
use tokio_codec::{Decoder, Encoder};
const SOKETTO_VERSION: &str = env!("CARGO_PKG_VERSION");
const KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
const MAX_NUM_HEADERS: usize = 32;
const SEC_WEBSOCKET_EXTENSIONS: &str = "Sec-WebSocket-Extensions";
const SEC_WEBSOCKET_PROTOCOL: &str = "Sec-WebSocket-Protocol";
#[derive(Debug)]
pub struct Client<'a> {
host: Cow<'a, str>,
resource: Cow<'a, str>,
origin: Option<Cow<'a, str>>,
nonce: String,
protocols: SmallVec<[Cow<'a, str>; 4]>,
extensions: SmallVec<[Box<dyn Extension + Send>; 4]>
}
impl<'a> Client<'a> {
pub fn new<H, R>(host: H, resource: R) -> Self
where
H: Into<Cow<'a, str>>,
R: Into<Cow<'a, str>>
{
let mut buf = [0; 16];
rand::thread_rng().fill(&mut buf);
let nonce = base64::encode(&buf);
Client {
host: host.into(),
resource: resource.into(),
origin: None,
nonce,
protocols: SmallVec::new(),
extensions: SmallVec::new()
}
}
pub fn ws_key(&self) -> &str {
&self.nonce
}
pub fn set_origin(&mut self, o: impl Into<Cow<'a, str>>) -> &mut Self {
self.origin = Some(o.into());
self
}
pub fn add_protocol(&mut self, p: impl Into<Cow<'a, str>>) -> &mut Self {
self.protocols.push(p.into());
self
}
pub fn add_extension(&mut self, e: Box<dyn Extension + Send>) -> &mut Self {
self.extensions.push(e);
self
}
pub fn drain_extensions(&mut self) -> impl Iterator<Item = Box<dyn Extension + Send>> {
self.extensions.drain()
}
}
impl<'a> Encoder for Client<'a> {
type Item = ();
type Error = Error;
fn encode(&mut self, _: Self::Item, buf: &mut BytesMut) -> Result<(), Self::Error> {
buf.extend_from_slice(b"GET ");
buf.extend_from_slice(self.resource.as_bytes());
buf.extend_from_slice(b" HTTP/1.1");
buf.extend_from_slice(b"\r\nHost: ");
buf.extend_from_slice(self.host.as_bytes());
buf.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: upgrade");
buf.extend_from_slice(b"\r\nSec-WebSocket-Key: ");
buf.extend_from_slice(self.nonce.as_bytes());
if let Some(o) = &self.origin {
buf.extend_from_slice(b"\r\nOrigin: ");
buf.extend_from_slice(o.as_bytes())
}
if let Some((last, prefix)) = self.protocols.split_last() {
buf.extend_from_slice(b"\r\nSec-WebSocket-Protocol: ");
for p in prefix {
buf.extend_from_slice(p.as_bytes());
buf.extend_from_slice(b",")
}
buf.extend_from_slice(last.as_bytes())
}
append_extensions(&self.extensions, buf);
buf.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n\r\n");
Ok(())
}
}
#[derive(Debug)]
pub enum Response<'a> {
Accepted(Accepted<'a>),
Redirect(Redirect)
}
#[derive(Debug)]
pub struct Accepted<'a> {
protocol: Option<Cow<'a, str>>,
}
impl<'a> Accepted<'a> {
pub fn protocol(&self) -> Option<&str> {
self.protocol.as_ref().map(|p| p.as_ref())
}
}
#[derive(Debug)]
pub struct Redirect {
status_code: u16,
location: String,
}
impl fmt::Display for Redirect {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "redirect: code = {}, location = \"{}\"", self.status_code, self.location)
}
}
impl Redirect {
pub fn status_code(&self) -> u16 {
self.status_code
}
pub fn location(&self) -> &str {
&self.location
}
}
impl<'a> Decoder for Client<'a> {
type Item = Response<'a>;
type Error = Error;
fn decode(&mut self, bytes: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS];
let mut response = httparse::Response::new(&mut header_buf);
let offset = match response.parse(bytes) {
Ok(httparse::Status::Complete(off)) => off,
Ok(httparse::Status::Partial) => return Ok(None),
Err(e) => return Err(Error::Http(Box::new(e)))
};
if response.version != Some(1) {
return Err(Error::UnsupportedHttpVersion)
}
match response.code {
Some(101) => (),
Some(code@(301 ..= 303)) | Some(code@307) | Some(code@308) => { let location = with_header(&response.headers, "Location", |loc| {
Ok(String::from(std::str::from_utf8(loc)?))
})?;
bytes.split_to(offset); let response = Redirect { status_code: code, location };
return Ok(Some(Response::Redirect(response)))
}
other => return Err(Error::UnexpectedStatusCode(other.unwrap_or(0)))
}
expect_ascii_header(&response.headers, "Upgrade", "websocket")?;
expect_ascii_header(&response.headers, "Connection", "upgrade")?;
let nonce: &str = self.nonce.borrow();
with_header(&response.headers, "Sec-WebSocket-Accept", move |theirs| {
let mut digest = Sha1::new();
digest.update(nonce.as_bytes());
digest.update(KEY);
let ours = base64::encode(&digest.digest().bytes());
if ours.as_bytes() != theirs {
return Err(Error::InvalidSecWebSocketAccept)
}
Ok(())
})?;
for h in response.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) {
configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)?
}
let their_proto = response.headers
.iter()
.find(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL));
let mut selected_proto = None;
if let Some(tp) = their_proto {
if let Some(p) = self.protocols.iter().find(|x| x.as_bytes() == tp.value) {
selected_proto = Some(p.clone())
} else {
return Err(Error::UnsolicitedProtocol)
}
}
bytes.split_to(offset);
let response = Accepted { protocol: selected_proto };
Ok(Some(Response::Accepted(response)))
}
}
#[derive(Debug, Default)]
pub struct Server<'a> {
protocols: SmallVec<[Cow<'a, str>; 4]>,
extensions: SmallVec<[Box<dyn Extension + Send>; 4]>
}
impl<'a> Server<'a> {
pub fn new() -> Self {
Server::default()
}
pub fn add_protocol(&mut self, p: impl Into<Cow<'a, str>>) -> &mut Self {
self.protocols.push(p.into());
self
}
pub fn add_extension(&mut self, e: Box<dyn Extension + Send>) -> &mut Self {
self.extensions.push(e);
self
}
pub fn drain_extensions(&mut self) -> impl Iterator<Item = Box<dyn Extension + Send>> {
self.extensions.drain()
}
}
#[derive(Debug)]
pub struct Request<'a> {
ws_key: SmallVec<[u8; 32]>,
protocols: SmallVec<[Cow<'a, str>; 4]>
}
impl<'a> Request<'a> {
pub fn key(&self) -> &[u8] {
&self.ws_key
}
pub fn protocols(&self) -> impl Iterator<Item = &str> {
self.protocols.iter().map(|p| p.as_ref())
}
}
impl<'a> Decoder for Server<'a> {
type Item = Request<'a>;
type Error = Error;
fn decode(&mut self, bytes: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS];
let mut request = httparse::Request::new(&mut header_buf);
let offset = match request.parse(bytes) {
Ok(httparse::Status::Complete(off)) => off,
Ok(httparse::Status::Partial) => return Ok(None),
Err(e) => return Err(Error::Http(Box::new(e)))
};
if request.method != Some("GET") {
return Err(Error::InvalidRequestMethod)
}
if request.version != Some(1) {
return Err(Error::UnsupportedHttpVersion)
}
with_header(&request.headers, "Host", |_h| Ok(()))?;
expect_ascii_header(&request.headers, "Upgrade", "websocket")?;
expect_ascii_header(&request.headers, "Connection", "upgrade")?;
expect_ascii_header(&request.headers, "Sec-WebSocket-Version", "13")?;
let ws_key = with_header(&request.headers, "Sec-WebSocket-Key", |k| {
Ok(SmallVec::from(k))
})?;
for h in request.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) {
configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)?
}
let mut protocols = SmallVec::new();
for p in request.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) {
if let Some(x) = self.protocols.iter().find(|x| x.as_bytes() == p.value) {
protocols.push(x.clone())
}
}
bytes.split_to(offset);
Ok(Some(Request { ws_key, protocols }))
}
}
#[derive(Debug)]
pub struct Accept<'a> {
key: Cow<'a, [u8]>,
protocol: Option<Cow<'a, str>>
}
impl<'a> Accept<'a> {
pub fn new(key: impl Into<Cow<'a, [u8]>>) -> Self {
Accept {
key: key.into(),
protocol: None
}
}
pub fn set_protocol(&mut self, p: impl Into<Cow<'a, str>>) -> &mut Self {
self.protocol = Some(p.into());
self
}
}
#[derive(Debug)]
pub struct Reject {
code: u16
}
impl Reject {
pub fn new(code: u16) -> Self {
Reject { code }
}
}
impl<'a> Encoder for Server<'a> {
type Item = Result<Accept<'a>, Reject>;
type Error = Error;
fn encode(&mut self, answer: Self::Item, buf: &mut BytesMut) -> Result<(), Self::Error> {
match answer {
Ok(accept) => {
let mut key_buf = [0; 32];
let accept_value = {
let mut digest = Sha1::new();
digest.update(accept.key.borrow());
digest.update(KEY);
let d = digest.digest().bytes();
let n = base64::encode_config_slice(&d, base64::STANDARD, &mut key_buf);
&key_buf[.. n]
};
buf.extend_from_slice(b"HTTP/1.1 101 Switching Protocols");
buf.extend_from_slice(b"\r\nServer: soketto-");
buf.extend_from_slice(SOKETTO_VERSION.as_bytes());
buf.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: upgrade");
buf.extend_from_slice(b"\r\nSec-WebSocket-Accept: ");
buf.extend_from_slice(accept_value);
if let Some(p) = accept.protocol {
buf.extend_from_slice(b"\r\nSec-WebSocket-Protocol: ");
buf.extend_from_slice(p.as_bytes())
}
append_extensions(self.extensions.iter().filter(|e| e.is_enabled()), buf);
buf.extend_from_slice(b"\r\n\r\n")
}
Err(reject) => {
buf.extend_from_slice(b"HTTP/1.1 ");
let s = StatusCode::from_u16(reject.code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
buf.extend_from_slice(s.as_str().as_bytes());
buf.extend_from_slice(b" ");
buf.extend_from_slice(s.canonical_reason().unwrap_or("N/A").as_bytes());
buf.extend_from_slice(b"\r\n\r\n")
}
}
Ok(())
}
}
fn expect_ascii_header(headers: &[httparse::Header], name: &str, ours: &str) -> Result<(), Error> {
with_header(headers, name, move |theirs| {
let s = str::from_utf8(theirs)?;
if s.eq_ignore_ascii_case(ours) {
Ok(())
} else {
Err(Error::UnexpectedHeader(name.into()))
}
})
}
fn with_header<F, R>(headers: &[httparse::Header], name: &str, f: F) -> Result<R, Error>
where
F: Fn(&[u8]) -> Result<R, Error>
{
if let Some(h) = headers.iter().find(move |h| h.name.eq_ignore_ascii_case(name)) {
f(h.value)
} else {
Err(Error::HeaderNotFound(name.into()))
}
}
fn configure_extensions(extensions: &mut [Box<dyn Extension + Send>], line: &str) -> Result<(), Error> {
for e in line.split(',') {
let mut ext_parts = e.split(';');
if let Some(name) = ext_parts.next() {
let name = name.trim();
if let Some(ext) = extensions.iter_mut().find(|x| x.name().eq_ignore_ascii_case(name)) {
let mut params = SmallVec::<[Param; 4]>::new();
for p in ext_parts {
let mut key_value = p.split('=');
if let Some(key) = key_value.next().map(str::trim) {
let val = key_value.next().map(|v| v.trim().trim_matches('"'));
let mut p = Param::new(key);
p.set_value(val);
params.push(p)
}
}
ext.configure(¶ms).map_err(Error::Extension)?
}
}
}
Ok(())
}
fn append_extensions<'a, I>(extensions: I, buf: &mut BytesMut)
where
I: IntoIterator<Item = &'a Box<dyn Extension + Send>>
{
let mut iter = extensions.into_iter().peekable();
if iter.peek().is_some() {
buf.extend_from_slice(b"\r\nSec-WebSocket-Extensions: ")
}
while let Some(e) = iter.next() {
buf.extend_from_slice(e.name().as_bytes());
for p in e.params() {
buf.extend_from_slice(b";");
buf.extend_from_slice(p.name().as_bytes());
if let Some(v) = p.value() {
buf.extend_from_slice(b"=");
buf.extend_from_slice(v.as_bytes())
}
}
if iter.peek().is_some() {
buf.extend_from_slice(b", ")
}
}
}
#[derive(Debug)]
pub enum Error {
Io(io::Error),
UnsupportedHttpVersion,
InvalidRequestMethod,
UnexpectedStatusCode(u16),
HeaderNotFound(String),
UnexpectedHeader(String),
InvalidSecWebSocketAccept,
UnsolicitedExtension,
UnsolicitedProtocol,
Extension(crate::BoxError),
Http(crate::BoxError),
Utf8(std::str::Utf8Error),
#[doc(hidden)]
__Nonexhaustive
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::Io(e) => write!(f, "i/o error: {}", e),
Error::Http(e) => write!(f, "http parser error: {}", e),
Error::HeaderNotFound(n) => write!(f, "header {} not found", n),
Error::UnexpectedHeader(n) => write!(f, "header {} had unexpected value", n),
Error::Utf8(e) => write!(f, "utf-8 decoding error: {}", e),
Error::UnexpectedStatusCode(c) => write!(f, "unexpected response status: {}", c),
Error::Extension(e) => write!(f, "extension error: {}", e),
Error::UnsupportedHttpVersion => f.write_str("http version was not 1.1"),
Error::InvalidRequestMethod => f.write_str("handshake not a GET request"),
Error::InvalidSecWebSocketAccept => f.write_str("websocket key mismatch"),
Error::UnsolicitedExtension => f.write_str("unsolicited extension returned"),
Error::UnsolicitedProtocol => f.write_str("unsolicited protocol returned"),
Error::__Nonexhaustive => f.write_str("__Nonexhaustive")
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Io(e) => Some(e),
Error::Utf8(e) => Some(e),
Error::Http(e) => Some(&**e),
Error::Extension(e) => Some(&**e),
Error::HeaderNotFound(_)
| Error::UnexpectedHeader(_)
| Error::UnexpectedStatusCode(_)
| Error::UnsupportedHttpVersion
| Error::InvalidRequestMethod
| Error::InvalidSecWebSocketAccept
| Error::UnsolicitedExtension
| Error::UnsolicitedProtocol
| Error::__Nonexhaustive => None
}
}
}
impl From<io::Error> for Error {
fn from(e: io::Error) -> Self {
Error::Io(e)
}
}
impl From<str::Utf8Error> for Error {
fn from(e: str::Utf8Error) -> Self {
Error::Utf8(e)
}
}