use bytes::BytesMut;
use crate::{Parsing, extension::Extension};
use crate::connection::{self, Mode};
use futures::prelude::*;
use http::StatusCode;
use sha1::Sha1;
use smallvec::SmallVec;
use std::str;
use super::{
Error,
KEY,
MAX_NUM_HEADERS,
SEC_WEBSOCKET_EXTENSIONS,
SEC_WEBSOCKET_PROTOCOL,
append_extensions,
configure_extensions,
expect_ascii_header,
with_first_header
};
const BLOCK_SIZE: usize = 8 * 1024;
const SOKETTO_VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Debug)]
pub struct Server<'a, T> {
socket: T,
protocols: SmallVec<[&'a str; 4]>,
extensions: SmallVec<[Box<dyn Extension + Send>; 4]>,
buffer: crate::Buffer
}
impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
pub fn new(socket: T) -> Self {
Server {
socket,
protocols: SmallVec::new(),
extensions: SmallVec::new(),
buffer: crate::Buffer::new()
}
}
pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self {
self.buffer = crate::Buffer::from(b);
self
}
pub fn take_buffer(&mut self) -> BytesMut {
self.buffer.take().into_bytes()
}
pub fn add_protocol(&mut self, p: &'a str) -> &mut Self {
self.protocols.push(p);
self
}
pub fn add_extension(&mut self, e: Box<dyn Extension + Send>) -> &mut Self {
self.extensions.push(e);
self
}
pub fn drain_extensions(&mut self) -> impl Iterator<Item = Box<dyn Extension + Send>> + '_ {
self.extensions.drain(..)
}
pub async fn receive_request(&mut self) -> Result<ClientRequest<'a>, Error> {
self.buffer.clear();
loop {
if self.buffer.remaining_mut() < BLOCK_SIZE {
self.buffer.reserve(BLOCK_SIZE)
}
self.buffer.read_from(&mut self.socket).await?;
if let Parsing::Done { value, offset } = self.decode_request()? {
self.buffer.split_to(offset);
return Ok(value)
}
}
}
pub async fn send_response(&mut self, r: &Response<'_>) -> Result<(), Error> {
self.buffer.clear();
self.encode_response(r);
self.socket.write_all(self.buffer.as_ref()).await?;
self.socket.flush().await?;
self.buffer.clear();
Ok(())
}
pub fn into_builder(mut self) -> connection::Builder<T> {
let mut builder = connection::Builder::new(self.socket, Mode::Server);
builder.set_buffer(self.buffer.into_bytes());
builder.add_extensions(self.extensions.drain(..));
builder
}
pub fn into_inner(self) -> T {
self.socket
}
fn decode_request(&mut self) -> Result<Parsing<ClientRequest<'a>>, Error> {
let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS];
let mut request = httparse::Request::new(&mut header_buf);
let offset = match request.parse(self.buffer.as_ref()) {
Ok(httparse::Status::Complete(off)) => off,
Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())),
Err(e) => return Err(Error::Http(Box::new(e)))
};
if request.method != Some("GET") {
return Err(Error::InvalidRequestMethod)
}
if request.version != Some(1) {
return Err(Error::UnsupportedHttpVersion)
}
with_first_header(&request.headers, "Host", |_h| Ok(()))?;
expect_ascii_header(request.headers, "Upgrade", "websocket")?;
expect_ascii_header(request.headers, "Connection", "upgrade")?;
expect_ascii_header(request.headers, "Sec-WebSocket-Version", "13")?;
let ws_key = with_first_header(&request.headers, "Sec-WebSocket-Key", |k| {
Ok(Vec::from(k))
})?;
for h in request.headers.iter()
.filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS))
{
configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)?
}
let mut protocols = SmallVec::new();
for p in request.headers.iter()
.filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL))
{
if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == p.value) {
protocols.push(p)
}
}
Ok(Parsing::Done { value: ClientRequest { ws_key, protocols }, offset })
}
fn encode_response(&mut self, response: &Response<'_>) {
match response {
Response::Accept { key, protocol } => {
let mut key_buf = [0; 32];
let accept_value = {
let mut digest = Sha1::new();
digest.update(key);
digest.update(KEY);
let d = digest.digest().bytes();
let n = base64::encode_config_slice(&d, base64::STANDARD, &mut key_buf);
&key_buf[.. n]
};
self.buffer.extend_from_slice(b"HTTP/1.1 101 Switching Protocols");
self.buffer.extend_from_slice(b"\r\nServer: soketto-");
self.buffer.extend_from_slice(SOKETTO_VERSION.as_bytes());
self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: upgrade");
self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Accept: ");
self.buffer.extend_from_slice(accept_value);
if let Some(p) = protocol {
self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: ");
self.buffer.extend_from_slice(p.as_bytes())
}
append_extensions(self.extensions.iter().filter(|e| e.is_enabled()), &mut self.buffer);
self.buffer.extend_from_slice(b"\r\n\r\n")
}
Response::Reject { status_code } => {
self.buffer.extend_from_slice(b"HTTP/1.1 ");
let s = StatusCode::from_u16(*status_code)
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
self.buffer.extend_from_slice(s.as_str().as_bytes());
self.buffer.extend_from_slice(b" ");
self.buffer.extend_from_slice(s.canonical_reason().unwrap_or("N/A").as_bytes());
self.buffer.extend_from_slice(b"\r\n\r\n")
}
}
}
}
#[derive(Debug)]
pub struct ClientRequest<'a> {
ws_key: Vec<u8>,
protocols: SmallVec<[&'a str; 4]>
}
impl<'a> ClientRequest<'a> {
pub fn key(&self) -> &[u8] {
&self.ws_key
}
pub fn into_key(self) -> Vec<u8> {
self.ws_key
}
pub fn protocols(&self) -> impl Iterator<Item = &str> {
self.protocols.iter().cloned()
}
}
#[derive(Debug)]
pub enum Response<'a> {
Accept {
key: &'a [u8],
protocol: Option<&'a str>
},
Reject {
status_code: u16
}
}