use super::{HttpHeader, HeaderHelper};
use super::{write_header, filter_header};
use super::handshake_check;
use super::MAX_ALLOW_HEADERS;
use super::{HTTP_STATUS_LINE, HTTP_LINE_BREAK, HTTP_HEADER_SP};
use super::static_headers::*;
use crate::bleed::Writer;
use crate::error::HandshakeError;
pub struct Response<'h, 'b: 'h, const N: usize = MAX_ALLOW_HEADERS> {
pub sec_accept: &'b [u8],
pub other_headers: &'h mut [HttpHeader<'b>],
}
impl<'h, 'b: 'h, const N: usize> HeaderHelper for Response<'h, 'b, N> {
const SIZE: usize = N;
}
impl<'h, 'b: 'h> Response<'h, 'b> {
#[inline]
pub const fn new(sec_accept: &'b [u8]) -> Self {
Self {
sec_accept,
other_headers: &mut [],
}
}
#[inline]
pub const fn new_with_headers(
sec_accept: &'b [u8],
other_headers: &'h mut [HttpHeader<'b>],
) -> Self {
Self {
sec_accept,
other_headers,
}
}
#[inline]
pub const fn new_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self {
Self {
sec_accept: &[],
other_headers,
}
}
}
impl<'h, 'b: 'h, const N: usize> Response<'h, 'b, N> {
#[inline]
pub const fn new_custom_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self {
Self {
sec_accept: &[],
other_headers,
}
}
pub fn encode(&self, buf: &mut [u8]) -> Result<usize, HandshakeError> {
debug_assert!(buf.len() > 80);
let mut w = Writer::new(buf);
unsafe {
w.write_unchecked(HTTP_STATUS_LINE);
w.write_unchecked(HTTP_LINE_BREAK);
}
write_header!(w, HEADER_UPGRADE_NAME, HEADER_UPGRADE_VALUE);
write_header!(w, HEADER_CONNECTION_NAME, HEADER_CONNECTION_VALUE);
write_header!(w, HEADER_SEC_WEBSOCKET_ACCEPT_NAME, self.sec_accept);
for hdr in self.other_headers.iter() {
write_header!(w, hdr)
}
w.write_or_err(HTTP_LINE_BREAK, || HandshakeError::NotEnoughCapacity)?;
Ok(w.pos())
}
pub fn decode(&mut self, buf: &'b [u8]) -> Result<usize, HandshakeError> {
debug_assert!(self.other_headers.len() >= <Self as HeaderHelper>::SIZE);
let mut headers = [httparse::EMPTY_HEADER; N];
let mut response = httparse::Response::new(&mut headers);
let decode_n = match response.parse(buf)? {
httparse::Status::Complete(n) => n,
httparse::Status::Partial => return Err(HandshakeError::NotEnoughData),
};
if response.version.unwrap() != 1_u8 {
return Err(HandshakeError::HttpVersion);
}
if response.code.unwrap() != 101_u16 {
return Err(HandshakeError::HttpSatusCode);
}
let headers = response.headers;
let mut required_headers = [
HEADER_UPGRADE,
HEADER_CONNECTION,
HEADER_SEC_WEBSOCKET_ACCEPT,
];
filter_header(headers, &mut required_headers, self.other_headers);
let [upgrade_hdr, connection_hdr, sec_accept_hdr] = required_headers;
if !required_headers.iter().all(|h| !h.value.is_empty()) {
handshake_check!(upgrade_hdr, HandshakeError::Upgrade);
handshake_check!(connection_hdr, HandshakeError::Connection);
handshake_check!(sec_accept_hdr, HandshakeError::SecWebSocketAccept);
}
handshake_check!(upgrade_hdr, HEADER_UPGRADE_VALUE, HandshakeError::Upgrade);
handshake_check!(
connection_hdr,
HEADER_CONNECTION_VALUE,
HandshakeError::Connection
);
self.sec_accept = sec_accept_hdr.value;
let other_header_len = headers.len() - required_headers.len();
let other_headers: &'h mut [HttpHeader<'b>] =
unsafe { &mut *(self.other_headers as *mut _) };
self.other_headers = unsafe { other_headers.get_unchecked_mut(0..other_header_len) };
Ok(decode_n)
}
}
#[cfg(test)]
mod test {
use super::*;
use super::super::HttpHeader;
use super::super::test::{make_headers, TEMPLATE_HEADERS};
use {rand::rng, rand::prelude::*};
#[test]
fn server_handshake() {
for i in 0..64 {
let hdr_len: usize = rng().random_range(1..128);
let headers = format!(
"HTTP/1.1 101 Switching Protocols\r\n{}\r\n",
make_headers(i, hdr_len, TEMPLATE_HEADERS)
);
let mut other_headers = HttpHeader::new_custom_storage::<1024>();
let mut response = Response::<1024>::new_custom_storage(&mut other_headers);
let decode_n = response.decode(headers.as_bytes()).unwrap();
assert_eq!(decode_n, headers.len());
assert_eq!(response.sec_accept, b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
macro_rules! match_other {
($name: expr, $value: expr) => {{
response
.other_headers
.iter()
.find(|hdr| hdr.name == $name && hdr.value == $value)
.unwrap();
}};
}
match_other!(b"host", b"www.example.com");
match_other!(b"sec-websocket-version", b"13");
match_other!(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==");
let mut buf: Vec<u8> = vec![0; 0x4000];
let encode_n = response.encode(&mut buf).unwrap();
assert_eq!(encode_n, decode_n);
}
}
#[test]
fn server_handshake2() {
macro_rules! run {
($sec_accept: expr) => {{
let headers = format!(
"HTTP/1.1 101 Switching Protocols\r\n{}\r\n",
make_headers(
16,
32,
&format!(
"upgrade: websocket\r\n\
connection: upgrade\r\n\
sec-websocket-accept: {}",
$sec_accept
)
)
);
let mut other_headers = HttpHeader::new_storage();
let mut response = Response::new_storage(&mut other_headers);
let decode_n = response.decode(headers.as_bytes()).unwrap();
assert_eq!(decode_n, headers.len());
assert_eq!(response.sec_accept, $sec_accept.as_bytes());
let mut buf: Vec<u8> = vec![0; 0x4000];
let encode_n = response.encode(&mut buf).unwrap();
assert_eq!(encode_n, decode_n);
}};
}
run!("aaa");
run!("bbbbbbbbbb");
run!("xxxxxxxxx==");
}
}