use bytes::BytesMut;
use crate::{Parsing, extension::Extension};
use crate::connection::{self, Mode};
use futures::prelude::*;
use sha1::Sha1;
use smallvec::SmallVec;
use std::str;
use super::{
Error,
KEY,
MAX_NUM_HEADERS,
SEC_WEBSOCKET_EXTENSIONS,
SEC_WEBSOCKET_PROTOCOL,
append_extensions,
configure_extensions,
expect_ascii_header,
with_first_header
};
const BLOCK_SIZE: usize = 8 * 1024;
#[derive(Debug)]
pub struct Client<'a, T> {
socket: T,
host: &'a str,
resource: &'a str,
origin: Option<&'a str>,
nonce: [u8; 32],
nonce_offset: usize,
protocols: SmallVec<[&'a str; 4]>,
extensions: SmallVec<[Box<dyn Extension + Send>; 4]>,
buffer: crate::Buffer
}
impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> {
pub fn new(socket: T, host: &'a str, resource: &'a str) -> Self {
Client {
socket,
host,
resource,
origin: None,
nonce: [0; 32],
nonce_offset: 0,
protocols: SmallVec::new(),
extensions: SmallVec::new(),
buffer: crate::Buffer::new()
}
}
pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self {
self.buffer = crate::Buffer::from(b);
self
}
pub fn take_buffer(&mut self) -> BytesMut {
self.buffer.take().into_bytes()
}
pub fn set_origin(&mut self, o: &'a str) -> &mut Self {
self.origin = Some(o);
self
}
pub fn add_protocol(&mut self, p: &'a str) -> &mut Self {
self.protocols.push(p);
self
}
pub fn add_extension(&mut self, e: Box<dyn Extension + Send>) -> &mut Self {
self.extensions.push(e);
self
}
pub fn drain_extensions(&mut self) -> impl Iterator<Item = Box<dyn Extension + Send>> + '_ {
self.extensions.drain(..)
}
pub async fn handshake(&mut self) -> Result<ServerResponse, Error> {
self.buffer.clear();
self.encode_request();
self.socket.write_all(self.buffer.as_ref()).await?;
self.socket.flush().await?;
self.buffer.clear();
loop {
if self.buffer.remaining_mut() < BLOCK_SIZE {
self.buffer.reserve(BLOCK_SIZE)
}
self.buffer.read_from(&mut self.socket).await?;
if let Parsing::Done { value, offset } = self.decode_response()? {
self.buffer.split_to(offset);
return Ok(value)
}
}
}
pub fn into_builder(mut self) -> connection::Builder<T> {
let mut builder = connection::Builder::new(self.socket, Mode::Client);
builder.set_buffer(self.buffer.into_bytes());
builder.add_extensions(self.extensions.drain(..));
builder
}
pub fn into_inner(self) -> T {
self.socket
}
fn encode_request(&mut self) {
let nonce: [u8; 16] = rand::random();
self.nonce_offset = base64::encode_config_slice(&nonce, base64::STANDARD, &mut self.nonce);
self.buffer.extend_from_slice(b"GET ");
self.buffer.extend_from_slice(self.resource.as_bytes());
self.buffer.extend_from_slice(b" HTTP/1.1");
self.buffer.extend_from_slice(b"\r\nHost: ");
self.buffer.extend_from_slice(self.host.as_bytes());
self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: upgrade");
self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: ");
self.buffer.extend_from_slice(&self.nonce[.. self.nonce_offset]);
if let Some(o) = &self.origin {
self.buffer.extend_from_slice(b"\r\nOrigin: ");
self.buffer.extend_from_slice(o.as_bytes())
}
if let Some((last, prefix)) = self.protocols.split_last() {
self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: ");
for p in prefix {
self.buffer.extend_from_slice(p.as_bytes());
self.buffer.extend_from_slice(b",")
}
self.buffer.extend_from_slice(last.as_bytes())
}
append_extensions(&self.extensions, &mut self.buffer);
self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n\r\n")
}
fn decode_response(&mut self) -> Result<Parsing<ServerResponse>, Error> {
let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS];
let mut response = httparse::Response::new(&mut header_buf);
let offset = match response.parse(self.buffer.as_ref()) {
Ok(httparse::Status::Complete(off)) => off,
Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())),
Err(e) => return Err(Error::Http(Box::new(e)))
};
if response.version != Some(1) {
return Err(Error::UnsupportedHttpVersion)
}
match response.code {
Some(101) => (),
Some(code@(301 ..= 303)) | Some(code@307) | Some(code@308) => { let location = with_first_header(response.headers, "Location", |loc| {
Ok(String::from(std::str::from_utf8(loc)?))
})?;
let response = ServerResponse::Redirect { status_code: code, location };
return Ok(Parsing::Done { value: response, offset })
}
other => {
let response = ServerResponse::Rejected { status_code: other.unwrap_or(0) };
return Ok(Parsing::Done { value: response, offset })
}
}
expect_ascii_header(response.headers, "Upgrade", "websocket")?;
expect_ascii_header(response.headers, "Connection", "upgrade")?;
let nonce = &self.nonce[.. self.nonce_offset];
with_first_header(&response.headers, "Sec-WebSocket-Accept", |theirs| {
let mut digest = Sha1::new();
digest.update(nonce);
digest.update(KEY);
let ours = base64::encode(&digest.digest().bytes());
if ours.as_bytes() != theirs {
return Err(Error::InvalidSecWebSocketAccept)
}
Ok(())
})?;
for h in response.headers.iter()
.filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS))
{
configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)?
}
let mut selected_proto = None;
if let Some(tp) = response.headers.iter()
.find(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL))
{
if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == tp.value) {
selected_proto = Some(String::from(p))
} else {
return Err(Error::UnsolicitedProtocol)
}
}
let response = ServerResponse::Accepted { protocol: selected_proto };
Ok(Parsing::Done { value: response, offset })
}
}
#[derive(Debug)]
pub enum ServerResponse {
Accepted {
protocol: Option<String>
},
Redirect {
status_code: u16,
location: String
},
Rejected {
status_code: u16
}
}