use crate::client::sync::Client;
use crate::server::upgrade::{validate, HyperIntoWsError, Request, WsUpgrade};
use crate::stream::sync::{AsTcpStream, Stream};
use std::io;
use std::net::TcpStream;
use hyper::buffer::BufReader;
use hyper::header::Headers;
use hyper::http::h1::parse_request;
use hyper::http::h1::Incoming;
use hyper::net::NetworkStream;
use hyper::status::StatusCode;
const DEFAULT_MAX_DATAFRAME_SIZE : usize = 1024*1024*100;
const DEFAULT_MAX_MESSAGE_SIZE : usize = 1024*1024*200;
#[derive(Debug)]
pub struct Buffer {
pub buf: Vec<u8>,
pub pos: usize,
pub cap: usize,
}
pub struct RequestStreamPair<S: Stream>(pub S, pub Request);
pub type Upgrade<S> = WsUpgrade<S, Option<Buffer>>;
impl<S> WsUpgrade<S, Option<Buffer>>
where
S: Stream,
{
pub fn accept(self) -> Result<Client<S>, (S, io::Error)> {
self.internal_accept(None)
}
pub fn accept_with(self, custom_headers: &Headers) -> Result<Client<S>, (S, io::Error)> {
self.internal_accept(Some(custom_headers))
}
pub fn accept_with_limits(self, max_dataframe_size: usize, max_message_size: usize) -> Result<Client<S>, (S, io::Error)> {
self.internal_accept_with_limits(None, max_dataframe_size, max_message_size)
}
pub fn accept_with_headers_and_limits(self, custom_headers: &Headers, max_dataframe_size: usize, max_message_size: usize) -> Result<Client<S>, (S, io::Error)> {
self.internal_accept_with_limits(Some(custom_headers), max_dataframe_size, max_message_size)
}
fn internal_accept(self, headers: Option<&Headers>) -> Result<Client<S>, (S, io::Error)> {
self.internal_accept_with_limits(headers, DEFAULT_MAX_DATAFRAME_SIZE, DEFAULT_MAX_MESSAGE_SIZE)
}
fn internal_accept_with_limits(mut self, headers: Option<&Headers>, max_dataframe_size: usize, max_message_size: usize) -> Result<Client<S>, (S, io::Error)> {
let status = self.prepare_headers(headers);
if let Err(e) = self.send(status) {
return Err((self.stream, e));
}
let stream = match self.buffer {
Some(Buffer { buf, pos, cap }) => BufReader::from_parts(self.stream, buf, pos, cap),
None => BufReader::new(self.stream),
};
Ok(Client::unchecked_with_limits(stream, self.headers, false, true, max_dataframe_size, max_message_size))
}
pub fn reject(self) -> Result<S, (S, io::Error)> {
self.internal_reject(None)
}
pub fn reject_with(self, headers: &Headers) -> Result<S, (S, io::Error)> {
self.internal_reject(Some(headers))
}
fn internal_reject(mut self, headers: Option<&Headers>) -> Result<S, (S, io::Error)> {
if let Some(custom) = headers {
self.headers.extend(custom.iter());
}
match self.send(StatusCode::BadRequest) {
Ok(()) => Ok(self.stream),
Err(e) => Err((self.stream, e)),
}
}
}
impl<S, B> WsUpgrade<S, B>
where
S: Stream + AsTcpStream,
{
pub fn tcp_stream(&self) -> &TcpStream {
self.stream.as_tcp()
}
}
pub trait IntoWs {
type Stream: Stream;
type Error;
fn into_ws(self) -> Result<Upgrade<Self::Stream>, Self::Error>;
}
impl<S> IntoWs for S
where
S: Stream,
{
type Stream = S;
type Error = (S, Option<Request>, Option<Buffer>, HyperIntoWsError);
fn into_ws(self) -> Result<Upgrade<Self::Stream>, Self::Error> {
let mut reader = BufReader::new(self);
let request = parse_request(&mut reader);
let (stream, buf, pos, cap) = reader.into_parts();
let buffer = Some(Buffer { buf, cap, pos });
let request = match request {
Ok(r) => r,
Err(e) => return Err((stream, None, buffer, e.into())),
};
match validate(&request.subject.0, request.version, &request.headers) {
Ok(_) => Ok(WsUpgrade {
headers: Headers::new(),
stream,
request,
buffer,
}),
Err(e) => Err((stream, Some(request), buffer, e)),
}
}
}
impl<S> IntoWs for RequestStreamPair<S>
where
S: Stream,
{
type Stream = S;
type Error = (S, Request, HyperIntoWsError);
fn into_ws(self) -> Result<Upgrade<Self::Stream>, Self::Error> {
match validate(&self.1.subject.0, self.1.version, &self.1.headers) {
Ok(_) => Ok(WsUpgrade {
headers: Headers::new(),
stream: self.0,
request: self.1,
buffer: None,
}),
Err(e) => Err((self.0, self.1, e)),
}
}
}
pub struct HyperRequest<'a, 'b: 'a>(pub ::hyper::server::Request<'a, 'b>);
impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> {
type Stream = &'a mut &'b mut dyn NetworkStream;
type Error = (::hyper::server::Request<'a, 'b>, HyperIntoWsError);
fn into_ws(self) -> Result<Upgrade<Self::Stream>, Self::Error> {
if let Err(e) = validate(&self.0.method, self.0.version, &self.0.headers) {
return Err((self.0, e));
}
let (_, method, headers, uri, version, reader) = self.0.deconstruct();
let reader = reader.into_inner();
let (buf, pos, cap) = reader.take_buf();
let stream = reader.get_mut();
Ok(Upgrade {
headers: Headers::new(),
stream,
buffer: Some(Buffer { buf, pos, cap }),
request: Incoming {
version,
headers,
subject: (method, uri),
},
})
}
}