use std::hint::unreachable_unchecked;
use base64::{engine::general_purpose::STANDARD, Engine};
use bytes::{Buf, BytesMut};
use httparse::{Header, Response};
use tokio_util::codec::{Decoder, Encoder};
use crate::{sha::digest, upgrade::Error};
const SWITCHING_PROTOCOLS: u16 = 101;
fn header<'a, 'header: 'a>(
headers: &'a [Header<'header>],
name: &'static str,
) -> Result<&'header [u8], Error> {
let header = headers
.iter()
.find(|header| header.name.eq_ignore_ascii_case(name))
.ok_or(Error::MissingHeader(name))?;
Ok(header.value)
}
pub struct Codec {
ws_accept: [u8; 20],
}
impl Codec {
#[must_use]
pub fn new(key: &[u8]) -> Self {
Self {
ws_accept: digest(key),
}
}
}
impl Decoder for Codec {
type Error = crate::Error;
type Item = ();
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut headers = [httparse::EMPTY_HEADER; 25];
let mut response = Response::new(&mut headers);
let status = response.parse(src).map_err(Error::Parsing)?;
if !status.is_complete() {
return Ok(None);
}
let response_len = status.unwrap();
let code = response.code.unwrap();
if code != SWITCHING_PROTOCOLS {
return Err(crate::Error::Upgrade(Error::DidNotSwitchProtocols(code)));
}
let ws_accept_header = header(response.headers, "Sec-WebSocket-Accept")?;
let mut ws_accept = [0; 20];
STANDARD
.decode_slice_unchecked(ws_accept_header, &mut ws_accept)
.map_err(|_| Error::WrongWebsocketAccept)?;
if self.ws_accept != ws_accept {
return Err(crate::Error::Upgrade(Error::WrongWebsocketAccept));
}
src.advance(response_len);
Ok(Some(()))
}
}
impl Encoder<()> for Codec {
type Error = crate::Error;
fn encode(&mut self, _item: (), _dst: &mut BytesMut) -> Result<(), Self::Error> {
unsafe { unreachable_unchecked() }
}
}