use std::str::FromStr;
use base64::{Engine, engine::general_purpose::STANDARD};
use bytes::{Buf, BytesMut};
use http::{HeaderMap, header::SET_COOKIE};
use httparse::Request;
use tokio_util::codec::Decoder;
use crate::{sha::digest, upgrade::Error};
const SWITCHING_PROTOCOLS_BODY: &[u8] = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool {
if needle.is_empty() {
return true;
}
while haystack.len() >= needle.len() {
if haystack[..needle.len()].eq_ignore_ascii_case(needle) {
return true;
}
haystack = &haystack[1..];
}
false
}
struct ClientRequest {
ws_accept: [u8; 20],
}
impl ClientRequest {
pub fn parse<'a, F>(header: F) -> Result<Self, Error>
where
F: Fn(&'static str) -> Option<&'a str> + 'a,
{
let find_header = |name| header(name).ok_or(super::Error::MissingHeader(name));
let check_header = |name, expected, err| {
let actual = find_header(name)?;
if actual.eq_ignore_ascii_case(expected) {
Ok(())
} else {
Err(err)
}
};
let check_header_contains = |name, expected: &str, err| {
let actual = find_header(name)?;
if contains_ignore_ascii_case(actual.as_bytes(), expected.as_bytes()) {
Ok(())
} else {
Err(err)
}
};
check_header("Upgrade", "websocket", Error::UpgradeNotWebSocket)?;
check_header_contains("Connection", "Upgrade", Error::ConnectionNotUpgrade)?;
check_header(
"Sec-WebSocket-Version",
"13",
Error::UnsupportedWebSocketVersion,
)?;
let key = find_header("Sec-WebSocket-Key")?;
let ws_accept = digest(key.as_bytes());
Ok(Self { ws_accept })
}
#[must_use]
pub fn ws_accept(&self) -> String {
STANDARD.encode(self.ws_accept)
}
}
pub struct Codec<'a> {
pub response_headers: &'a HeaderMap,
}
impl Decoder for Codec<'_> {
type Error = crate::Error;
type Item = (http::Request<()>, Vec<u8>);
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut request = Request::new(&mut headers);
let status = request.parse(src).map_err(Error::Parsing)?;
if !status.is_complete() {
return Ok(None);
}
let request_len = status.unwrap();
let mut builder = http::request::Builder::new();
if let Some(m) = request.method {
let method =
http::method::Method::from_bytes(m.as_bytes()).expect("httparse method is valid");
builder = builder.method(method);
}
if let Some(uri) = request.path {
builder = builder.uri(uri);
}
match request.version {
Some(0) => builder = builder.version(http::Version::HTTP_10),
Some(1) => builder = builder.version(http::Version::HTTP_11),
_ => Err(Error::Parsing(httparse::Error::Version))?,
}
let mut header_map = http::HeaderMap::with_capacity(request.headers.len());
for header in request.headers {
let name = http::HeaderName::from_str(header.name)
.map_err(|_| Error::Parsing(httparse::Error::HeaderName))?;
let value = http::HeaderValue::from_bytes(header.value)
.map_err(|_| Error::Parsing(httparse::Error::HeaderValue))?;
header_map.insert(name, value);
}
let mut request = builder
.body(())
.expect("httparse sees the request as valid");
*request.headers_mut() = header_map;
let ws_accept =
ClientRequest::parse(|name| request.headers().get(name).and_then(|h| h.to_str().ok()))?
.ws_accept();
src.advance(request_len);
let mut resp = Vec::with_capacity(SWITCHING_PROTOCOLS_BODY.len() + ws_accept.len() + 4);
resp.extend_from_slice(SWITCHING_PROTOCOLS_BODY);
resp.extend_from_slice(ws_accept.as_bytes());
resp.extend_from_slice(b"\r\n");
for name in self.response_headers.keys() {
let values = self.response_headers.get_all(name).iter();
if name == SET_COOKIE {
for value in values {
resp.extend_from_slice(name.as_str().as_bytes());
resp.extend_from_slice(b": ");
resp.extend_from_slice(value.as_bytes());
resp.extend_from_slice(b"\r\n");
}
} else {
resp.extend_from_slice(name.as_str().as_bytes());
resp.extend_from_slice(b": ");
let mut values = values.peekable();
while let Some(value) = values.next() {
resp.extend_from_slice(value.as_bytes());
if values.peek().is_some() {
resp.push(b',');
}
}
resp.extend_from_slice(b"\r\n");
}
}
resp.extend_from_slice(b"\r\n");
Ok(Some((request, resp)))
}
}