#[cfg(test)]
mod tests;
macro_rules! check_headers {
($headers:expr, $($header:expr),*) => {{
let rslt = check_headers(
[
(KnownHeaderName::Connection, Some(b"upgrade")),
(KnownHeaderName::Upgrade, Some(b"websocket")),
$($header,)*
],
$headers
)?;
drop(check_header_value(rslt[0]));
drop(check_header_value(rslt[1]));
rslt
}};
}
use crate::{
codec::{Base64Alphabet, base64_encode, base64_encoded_len},
crypto::{Hash, Sha1DigestGlobal},
http::{GenericHeader as _, GenericRequest as _, HttpError, KnownHeaderName, Method},
misc::{LeaseMut, SuffixWriterFbvm, UriRef, bytes_split1},
rng::Rng,
stream::Stream,
web_socket::{
Compression, WebSocket, WebSocketAcceptor, WebSocketBuffer, WebSocketConnector, WebSocketError,
compression::NegotiatedCompression,
},
};
use httparse::{EMPTY_HEADER, Header, Request, Response, Status};
const MAX_READ_HEADER_LEN: usize = 64;
const MAX_READ_LEN: usize = 2 * 1024;
const NO_MASKING: &str = "no-masking";
const UPGRADE: &str = "Upgrade";
const VERSION: &str = "13";
const WEBSOCKET: &str = "websocket";
impl<C, E, R, RNG, WB> WebSocketAcceptor<C, R, RNG, WB>
where
C: Compression<false>,
E: From<crate::Error>,
R: FnOnce(&Request<'_, '_>) -> Result<(), E>,
RNG: Rng,
WB: LeaseMut<WebSocketBuffer>,
{
#[inline]
pub async fn accept<S>(
mut self,
mut stream: S,
) -> Result<WebSocket<C::NegotiatedCompression, RNG, S, WB, false>, E>
where
S: Stream,
{
self.wsb.lease_mut().clear();
let nb = &mut self.wsb.lease_mut().network_buffer;
nb.reserve(MAX_READ_LEN)?;
let mut read = 0;
loop {
let read_buffer = nb.all_mut().get_mut(read..).unwrap_or_default();
let local_read = stream.read(read_buffer).await?;
if local_read == 0 {
return Err(crate::Error::UnexpectedStreamReadEOF.into());
}
read = read.wrapping_add(local_read);
let mut req_buffer = [EMPTY_HEADER; MAX_READ_HEADER_LEN];
let mut req = Request::new(&mut req_buffer);
match req.parse(nb.following()).map_err(From::from)? {
Status::Complete(_) => {
(self.req)(&req)?;
if !req.method().trim_ascii().eq_ignore_ascii_case(b"get") {
return Err(
crate::Error::from(HttpError::UnexpectedHttpMethod { expected: Method::Get }).into(),
);
}
let mut key_buffer = [0; 30];
let [_, _, c, d, e] = check_headers!(
req.headers,
(KnownHeaderName::SecWebsocketExtensions, None),
(KnownHeaderName::SecWebsocketKey, None),
(KnownHeaderName::SecWebsocketVersion, Some(VERSION.as_bytes()))
);
self.no_masking &= check_header_value(c).is_ok_and(has_no_masking);
let key = check_header_value(d)?;
let _ = check_header_value(e)?;
let nc = self.compression.negotiate(req.headers.iter())?;
let swa = derived_key(&mut key_buffer, key);
let mut headers_buffer = [EMPTY_HEADER; 3];
headers_buffer[0] = Header { name: "Connection", value: UPGRADE.as_bytes() };
headers_buffer[1] = Header { name: "Sec-WebSocket-Accept", value: swa };
headers_buffer[2] = Header { name: "Upgrade", value: WEBSOCKET.as_bytes() };
let mut res = Response::new(&mut headers_buffer);
res.code = Some(101);
res.version = Some(req.version().into());
{
let mut sw = nb.suffix_writer();
build_res(&mut sw, res.headers, &nc, self.no_masking)?;
stream.write_all(sw.curr_bytes()).await?;
}
nb.clear();
return Ok(WebSocket::new(nc, self.no_masking, self.rng, stream, self.wsb));
}
Status::Partial => {}
}
}
}
}
impl<'headers, C, E, H, R, RNG, WB> WebSocketConnector<C, H, R, RNG, WB>
where
C: Compression<true>,
E: From<crate::Error>,
H: IntoIterator<Item = (&'headers str, &'headers str)>,
R: FnOnce(&Response<'_, '_>) -> Result<(), E>,
RNG: Rng,
WB: LeaseMut<WebSocketBuffer>,
{
#[inline]
pub async fn connect<S>(
mut self,
mut stream: S,
uri: &UriRef<'_>,
) -> Result<WebSocket<C::NegotiatedCompression, RNG, S, WB, true>, E>
where
S: Stream,
{
self.wsb.lease_mut().clear();
let key_buffer = &mut [0; 26];
let key = {
let nb = &mut self.wsb.lease_mut().network_buffer;
nb.reserve(MAX_READ_LEN)?;
{
let mut sw = nb.suffix_writer();
let key = build_req(
&self.compression,
&mut sw,
self.headers,
key_buffer,
self.no_masking,
&mut self.rng,
uri,
)?;
stream.write_all(sw.curr_bytes()).await?;
key
}
};
let mut read = 0;
let (nc, len) = loop {
let nb = &mut self.wsb.lease_mut().network_buffer;
let local_read = stream.read(nb.all_mut().get_mut(read..).unwrap_or_default()).await?;
if local_read == 0 {
return Err(crate::Error::UnexpectedStreamReadEOF.into());
}
read = read.wrapping_add(local_read);
let mut httparse_headers = [EMPTY_HEADER; MAX_READ_HEADER_LEN];
let mut res = Response::new(&mut httparse_headers);
let len = match res.parse(nb.all().get(..read).unwrap_or_default()).map_err(From::from)? {
Status::Complete(len) => len,
Status::Partial => continue,
};
if res.code != Some(101) {
return Err(
crate::Error::from(WebSocketError::MissingSwitchingProtocols { found: res.code }).into(),
);
}
(self.res)(&res)?;
let [_, _, c, d] = check_headers!(
res.headers,
(KnownHeaderName::SecWebsocketAccept, Some(derived_key(&mut [0; 30], key))),
(KnownHeaderName::SecWebsocketExtensions, None)
);
drop(check_header_value(c));
self.no_masking &= check_header_value(d).is_ok_and(has_no_masking);
break (self.compression.negotiate(res.headers.iter())?, len);
};
self.wsb.lease_mut().network_buffer.set_indices(0, len, read.wrapping_sub(len))?;
Ok(WebSocket::new(nc, self.no_masking, self.rng, stream, self.wsb))
}
}
fn base64_from_array<'output, const I: usize, const O: usize>(
input: &[u8; I],
output: &'output mut [u8; O],
) -> &'output [u8] {
const {
let rslt = if let Some(elem) = base64_encoded_len(I, false) { elem } else { 0 };
assert!(O >= rslt);
}
base64_encode(Base64Alphabet::Standard, input, output).map(|el| el.as_bytes()).unwrap_or_default()
}
fn build_req<'headers, 'kb, C>(
compression: &C,
sw: &mut SuffixWriterFbvm<'_>,
headers: impl IntoIterator<Item = (&'headers str, &'headers str)>,
key_buffer: &'kb mut [u8; 26],
no_masking: bool,
rng: &mut impl Rng,
uri: &UriRef<'_>,
) -> crate::Result<&'kb [u8]>
where
C: Compression<true>,
{
let key = gen_key(key_buffer, rng);
sw.extend_from_slices_group_rn(&[
b"GET ",
uri.relative_reference_slash().as_bytes(),
b" HTTP/1.1",
])?;
sw.extend_from_slice_rn(b"Connection: Upgrade")?;
match uri.port() {
Some(80 | 443) => {
sw.extend_from_slices_group_rn(&[b"Host: ", uri.hostname().as_bytes()])?;
}
_ => sw.extend_from_slices_group_rn(&[b"Host: ", uri.host().as_bytes()])?,
}
sw.extend_from_slices_group_rn(&[b"Sec-WebSocket-Key: ", key])?;
if no_masking {
sw.extend_from_slice_rn(b"Sec-WebSocket-Extensions: no-masking")?;
}
sw.extend_from_slice_rn(b"Sec-WebSocket-Version: 13")?;
sw.extend_from_slice_rn(b"Upgrade: websocket")?;
for (name, value) in headers {
sw.extend_from_slices_group_rn(&[name.as_bytes(), b": ", value.as_bytes()])?;
}
compression.write_req_headers(sw)?;
sw.extend_from_slice_rn(b"")?;
Ok(key)
}
fn build_res<NC>(
sw: &mut SuffixWriterFbvm<'_>,
headers: &[Header<'_>],
nc: &NC,
no_masking: bool,
) -> crate::Result<()>
where
NC: NegotiatedCompression,
{
sw.extend_from_slice_rn(b"HTTP/1.1 101 Switching Protocols")?;
for header in headers {
sw.extend_from_slices_group_rn(&[header.name(), b": ", header.value()])?;
}
if no_masking {
sw.extend_from_slices_group_rn(&[
KnownHeaderName::SecWebsocketExtensions.into(),
b": ",
NO_MASKING.as_bytes(),
])?;
}
nc.write_res_headers(sw)?;
sw.extend_from_slice_rn(b"")?;
Ok(())
}
fn check_header_value((name, value): (KnownHeaderName, Option<&[u8]>)) -> crate::Result<&[u8]> {
let Some(elem) = value else {
return Err(crate::Error::from(HttpError::MissingHeader(name)));
};
Ok(elem)
}
fn check_headers<'headers, const N: usize>(
array: [(KnownHeaderName, Option<&[u8]>); N],
headers: &'headers [Header<'_>],
) -> crate::Result<[(KnownHeaderName, Option<&'headers [u8]>); N]> {
let mut rslt = [(KnownHeaderName::Accept, None); N];
for header in headers {
let trimmed_name = header.name().trim_ascii();
let trimmed_value = header.value().trim_ascii();
for ((name, value_opt), rslt_elem) in array.into_iter().zip(&mut rslt) {
let has_name = rslt_elem.1.is_none() && trimmed_name.eq_ignore_ascii_case(name.into());
if has_name {
if let Some(value) = value_opt {
for sub_value in bytes_split1(trimmed_value, b',') {
if sub_value.trim_ascii().eq_ignore_ascii_case(value) {
*rslt_elem = (name, Some(sub_value));
break;
}
}
if rslt_elem.1.is_some() {
break;
}
} else {
*rslt_elem = (name, Some(trimmed_value));
}
}
}
}
Ok(rslt)
}
fn derived_key<'buffer>(buffer: &'buffer mut [u8; 30], key: &[u8]) -> &'buffer [u8] {
let array = Sha1DigestGlobal::digest([key, b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"]);
base64_from_array(&array, buffer)
}
fn gen_key<'buffer>(buffer: &'buffer mut [u8; 26], rng: &mut impl Rng) -> &'buffer [u8] {
base64_from_array(&rng.u8_16(), buffer)
}
const fn has_no_masking(el: &[u8]) -> bool {
el.eq_ignore_ascii_case(NO_MASKING.as_bytes())
}