use http::{
HeaderMap, Method, Request as HttpRequest, Response as HttpResponse, StatusCode, Version,
};
use httparse::{Status, EMPTY_HEADER};
use std::{
io::{Read, Write},
marker::PhantomData,
result::Result as StdResult,
};
use crate::{
error::{Error, ProtocolError, Result},
handshake::{
core::{derive_accept_key, HandshakeRole, MidHandshake, ProcessingResult},
headers::{FromHttparse, MAX_HEADERS},
machine::{HandshakeMachine, StageResult, TryParse},
},
protocol::{
config::WebSocketConfig,
websocket::{OperationMode, WebSocket},
},
};
pub type Request = HttpRequest<()>;
pub type Response = HttpResponse<()>;
pub type ErrorResponse = HttpResponse<Option<String>>;
fn create_parts<T>(req: &HttpRequest<T>) -> Result<http::response::Builder> {
if req.method() != Method::GET {
return Err(Error::Protocol(ProtocolError::InvalidHttpMethod));
}
if req.version() < Version::HTTP_11 {
return Err(Error::Protocol(ProtocolError::InvalidHttpVersion));
}
let headers = req.headers();
if !headers
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|v| v.split([',', ' ']).any(|s| s.eq_ignore_ascii_case("Upgrade")))
.unwrap_or(false)
{
return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
}
if !headers
.get("Upgrade")
.and_then(|h| h.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol(ProtocolError::MissingUpgradeHeader));
}
if !headers.get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
return Err(Error::Protocol(ProtocolError::MissingVersionHeader));
}
let key =
headers.get("Sec-WebSocket-Key").ok_or(Error::Protocol(ProtocolError::MissingKeyHeader))?;
let builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.version(req.version())
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Accept", derive_accept_key(key.as_bytes()));
Ok(builder)
}
pub fn create_response(req: &Request) -> Result<Response> {
Ok(create_parts(req)?.body(())?)
}
pub fn create_response_with_body<T1, T2>(
req: &HttpRequest<T1>,
generate_body: impl FnOnce() -> T2,
) -> Result<HttpResponse<T2>> {
Ok(create_parts(req)?.body(generate_body())?)
}
pub fn write_response<T>(mut w: impl Write, res: &HttpResponse<T>) -> Result<()> {
writeln!(w, "{:?} {}\r", res.version(), res.status())?;
for (k, v) in res.headers() {
writeln!(w, "{}: {}\r", k, v.to_str()?)?;
}
writeln!(w, "\r")?;
Ok(())
}
impl TryParse for Request {
fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>> {
let mut header_buf = [EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Request::new(&mut header_buf);
Ok(match req.parse(data)? {
Status::Complete(n) => Some((n, Request::from_httparse(req)?)),
Status::Partial => None,
})
}
}
impl<'b: 'h, 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
if raw.method != Some("GET") {
return Err(Error::Protocol(ProtocolError::InvalidHttpMethod));
}
if raw.version != Some(1) {
return Err(Error::Protocol(ProtocolError::InvalidHttpVersion));
}
let mut req = Request::new(());
*req.method_mut() = Method::GET;
*req.uri_mut() = raw.path.expect("Bug: no path in header").parse()?;
*req.version_mut() = Version::HTTP_11;
*req.headers_mut() = HeaderMap::from_httparse(raw.headers)?;
Ok(req)
}
}
pub trait Callback: Sized {
fn on_request(self, req: &Request, res: Response) -> StdResult<Response, ErrorResponse>;
}
impl<F> Callback for F
where
F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>,
{
fn on_request(self, req: &Request, res: Response) -> StdResult<Response, ErrorResponse> {
self(req, res)
}
}
#[derive(Clone, Copy, Debug)]
pub struct NoCallback;
impl Callback for NoCallback {
fn on_request(self, _req: &Request, res: Response) -> StdResult<Response, ErrorResponse> {
Ok(res)
}
}
#[allow(missing_copy_implementations)]
#[derive(Debug)]
pub struct ServerHandshake<S, C> {
callback: Option<C>,
config: Option<WebSocketConfig>,
error_response: Option<ErrorResponse>,
_marker: PhantomData<S>,
}
impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
MidHandshake {
machine: HandshakeMachine::start_read(stream),
role: ServerHandshake {
callback: Some(callback),
config,
error_response: None,
_marker: PhantomData,
},
}
}
}
impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
type IncomingData = Request;
type InternalStream = S;
type FinalResult = WebSocket<S>;
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
match finish {
StageResult::DoneReading { result, stream, tail } => {
if !tail.is_empty() {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
}
let response = create_response(&result)?;
let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response)
} else {
Ok(response)
};
match callback_result {
Ok(resp) => {
let mut output = vec![];
write_response(&mut output, &resp)?;
Ok(ProcessingResult::Continue(HandshakeMachine::start_write(
stream, output,
)))
}
Err(resp) => {
if resp.status().is_success() {
return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
}
self.error_response = Some(resp);
let resp_ref = self.error_response.as_ref().unwrap();
let mut output = vec![];
write_response(&mut output, resp_ref)?;
if let Some(body) = resp_ref.body() {
output.extend_from_slice(body.as_bytes());
}
Ok(ProcessingResult::Continue(HandshakeMachine::start_write(
stream, output,
)))
}
}
}
StageResult::DoneWriting(stream) => {
if let Some(err) = self.error_response.take() {
let (parts, body) = err.into_parts();
return Err(Error::Http(HttpResponse::from_parts(
parts,
body.map(|s| s.into_bytes()),
)));
}
Ok(ProcessingResult::Done(WebSocket::new(
stream,
OperationMode::Server,
self.config,
)))
}
}
}
}