use super::{HttpHeader, HeaderHelper};
use super::{write_header, filter_header};
use super::handshake_check;
use super::MAX_ALLOW_HEADERS;
use super::{HTTP_METHOD, HTTP_VERSION, HTTP_LINE_BREAK, HTTP_HEADER_SP};
use super::static_headers::*;
use crate::bleed::Writer;
use crate::error::HandshakeError;
pub struct Request<'h, 'b: 'h, const N: usize = MAX_ALLOW_HEADERS> {
pub path: &'b [u8],
pub host: &'b [u8],
pub sec_key: &'b [u8],
pub other_headers: &'h mut [HttpHeader<'b>],
}
impl<'h, 'b: 'h, const N: usize> HeaderHelper for Request<'h, 'b, N> {
const SIZE: usize = N;
}
impl<'h, 'b: 'h> Request<'h, 'b> {
#[inline]
pub const fn new(path: &'b [u8], host: &'b [u8], sec_key: &'b [u8]) -> Self {
Self {
path,
host,
sec_key,
other_headers: &mut [],
}
}
#[inline]
pub const fn new_with_headers(
path: &'b [u8],
host: &'b [u8],
sec_key: &'b [u8],
other_headers: &'h mut [HttpHeader<'b>],
) -> Self {
Self {
path,
host,
sec_key,
other_headers,
}
}
#[inline]
pub const fn new_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self {
Self {
path: &[],
host: &[],
sec_key: &[],
other_headers,
}
}
}
impl<'h, 'b: 'h, const N: usize> Request<'h, 'b, N> {
#[inline]
pub const fn new_custom_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self {
Self {
path: &[],
host: &[],
sec_key: &[],
other_headers,
}
}
pub fn encode(&self, buf: &mut [u8]) -> Result<usize, HandshakeError> {
debug_assert!(buf.len() > 80);
let mut w = Writer::new(buf);
unsafe {
w.write_unchecked(HTTP_METHOD);
w.write_byte_unchecked(0x20);
w.write_unchecked(self.path);
w.write_byte_unchecked(0x20);
w.write_unchecked(HTTP_VERSION);
w.write_unchecked(HTTP_LINE_BREAK);
}
write_header!(w, HEADER_HOST_NAME, self.host);
write_header!(w, HEADER_UPGRADE_NAME, HEADER_UPGRADE_VALUE);
write_header!(w, HEADER_CONNECTION_NAME, HEADER_CONNECTION_VALUE);
write_header!(w, HEADER_SEC_WEBSOCKET_KEY_NAME, self.sec_key);
write_header!(
w,
HEADER_SEC_WEBSOCKET_VERSION_NAME,
HEADER_SEC_WEBSOCKET_VERSION_VALUE
);
for hdr in self.other_headers.iter() {
write_header!(w, hdr)
}
w.write_or_err(HTTP_LINE_BREAK, || HandshakeError::NotEnoughCapacity)?;
Ok(w.pos())
}
pub fn decode(&mut self, buf: &'b [u8]) -> Result<usize, HandshakeError> {
debug_assert!(self.other_headers.len() >= <Self as HeaderHelper>::SIZE);
let mut headers = [httparse::EMPTY_HEADER; N];
let mut request = httparse::Request::new(&mut headers);
let decode_n = match request.parse(buf)? {
httparse::Status::Complete(n) => n,
httparse::Status::Partial => return Err(HandshakeError::NotEnoughData),
};
if request.method.unwrap().as_bytes() != HTTP_METHOD {
return Err(HandshakeError::HttpMethod);
}
if request.version.unwrap() != 1_u8 {
return Err(HandshakeError::HttpVersion);
}
let headers = request.headers;
let mut required_headers = [
HEADER_HOST,
HEADER_UPGRADE,
HEADER_CONNECTION,
HEADER_SEC_WEBSOCKET_KEY,
HEADER_SEC_WEBSOCKET_VERSION,
];
filter_header(headers, &mut required_headers, self.other_headers);
let [host_hdr, upgrade_hdr, connection_hdr, sec_key_hdr, sec_version_hdr] =
required_headers;
if !required_headers.iter().all(|h| !h.value.is_empty()) {
handshake_check!(host_hdr, HandshakeError::HttpHost);
handshake_check!(upgrade_hdr, HandshakeError::Upgrade);
handshake_check!(connection_hdr, HandshakeError::Connection);
handshake_check!(sec_key_hdr, HandshakeError::SecWebSocketKey);
handshake_check!(sec_version_hdr, HandshakeError::SecWebSocketVersion);
}
handshake_check!(upgrade_hdr, HEADER_UPGRADE_VALUE, HandshakeError::Upgrade);
handshake_check!(
connection_hdr,
HEADER_CONNECTION_VALUE,
HandshakeError::Connection
);
handshake_check!(
sec_version_hdr,
HEADER_SEC_WEBSOCKET_VERSION_VALUE,
HandshakeError::SecWebSocketVersion
);
self.path = request.path.unwrap().as_bytes();
self.host = host_hdr.value;
self.sec_key = sec_key_hdr.value;
let other_header_len = headers.len() - required_headers.len();
let other_headers: &'h mut [HttpHeader<'b>] =
unsafe { &mut *(self.other_headers as *mut _) };
self.other_headers = unsafe { other_headers.get_unchecked_mut(0..other_header_len) };
Ok(decode_n)
}
}
#[cfg(test)]
mod test {
use super::*;
use super::super::HttpHeader;
use super::super::test::{make_headers, TEMPLATE_HEADERS};
use rand::prelude::*;
#[test]
fn client_handshake() {
for i in 0..64 {
let hdr_len: usize = thread_rng().gen_range(1..128);
let headers = format!(
"GET / HTTP/1.1\r\n{}\r\n",
make_headers(i, hdr_len, TEMPLATE_HEADERS)
);
let mut other_headers = HttpHeader::new_custom_storage::<1024>();
let mut request = Request::<1024>::new_custom_storage(&mut other_headers);
let decode_n = request.decode(headers.as_bytes()).unwrap();
assert_eq!(decode_n, headers.len());
assert_eq!(request.path, b"/");
assert_eq!(request.host, b"www.example.com");
assert_eq!(request.sec_key, b"dGhlIHNhbXBsZSBub25jZQ==");
macro_rules! match_other {
($name: expr, $value: expr) => {{
request
.other_headers
.iter()
.find(|hdr| hdr.name == $name && hdr.value == $value)
.unwrap();
}};
}
match_other!(b"sec-websocket-accept", b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
let mut buf: Vec<u8> = vec![0; 0x4000];
let encode_n = request.encode(&mut buf).unwrap();
assert_eq!(encode_n, decode_n);
}
}
#[test]
fn client_handshake2() {
macro_rules! run {
($host: expr, $path: expr, $sec_key: expr) => {{
let headers = format!(
"GET {1} HTTP/1.1\r\n{0}\r\n",
make_headers(
16,
32,
&format!(
"host: {0}\r\n\
sec-websocket-key: {1}\r\n\
upgrade: websocket\r\n\
connection: upgrade\r\n\
sec-websocket-version: 13",
$host, $sec_key
)
),
$path
);
let mut other_headers = HttpHeader::new_storage();
let mut request = Request::new_storage(&mut other_headers);
let decode_n = request.decode(headers.as_bytes()).unwrap();
assert_eq!(decode_n, headers.len());
assert_eq!(request.host, $host.as_bytes());
assert_eq!(request.path, $path.as_bytes());
assert_eq!(request.sec_key, $sec_key.as_bytes());
let mut buf: Vec<u8> = vec![0; 0x4000];
let encode_n = request.encode(&mut buf).unwrap();
assert_eq!(encode_n, decode_n);
}};
}
run!("host", "/path", "key");
run!("www.abc.com", "/path/to", "xxxxxx");
run!("wwww.www.ww.w", "/path/to/to/path", "xxxxxxyyyy");
}
}