#![allow(dead_code, unused_variables)]
use core::fmt;
use std::io::Cursor;
use xitca_unsafe_collection::bytes::BytesStr;
use crate::{
bytes::{BufMut, Bytes, BytesMut, buf::Limit},
http::{
HeaderMap, Method, StatusCode, Uri,
header::{self, HeaderName},
uri,
},
};
use super::{
super::{error::Error, hpack},
head::{Head, Kind},
priority::Priority,
stream_id::StreamId,
};
type EncodeBuf<'a> = Limit<&'a mut BytesMut>;
#[derive(Eq, PartialEq)]
pub struct Headers<P = Pseudo> {
stream_id: StreamId,
header_block: HeaderBlock<P>,
flags: HeadersFlag,
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct HeadersFlag(u8);
#[derive(Debug)]
pub struct Continuation {
stream_id: StreamId,
header_block: EncodingHeaderBlock,
}
#[derive(Debug, Default, Eq, PartialEq)]
pub struct Pseudo {
pub method: Option<Method>,
pub scheme: Option<BytesStr>,
pub authority: Option<BytesStr>,
pub path: Option<BytesStr>,
pub protocol: Option<BytesStr>,
pub status: Option<StatusCode>,
}
#[derive(Debug, Default, Eq, PartialEq)]
pub struct ResponsePseudo {
status: StatusCode,
}
#[derive(Debug, PartialEq, Eq)]
struct HeaderBlock<P = Pseudo> {
fields: HeaderMap,
is_over_size: bool,
pseudo: P,
}
#[derive(Debug)]
struct EncodingHeaderBlock {
hpack: Bytes,
}
const END_STREAM: u8 = 0x1;
const END_HEADERS: u8 = 0x4;
const PADDED: u8 = 0x8;
const PRIORITY: u8 = 0x20;
const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
impl<P> Headers<P>
where
P: _Pseudo + Default,
{
pub fn new(stream_id: StreamId, pseudo: P, fields: HeaderMap) -> Self {
Self {
stream_id,
header_block: HeaderBlock {
fields,
is_over_size: false,
pseudo,
},
flags: HeadersFlag::default(),
}
}
pub fn load_hpack(
&mut self,
src: &mut BytesMut,
max_header_list_size: usize,
decoder: &mut hpack::Decoder,
) -> Result<(), Error> {
self.header_block.load(src, max_header_list_size, decoder)
}
pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
let flags = HeadersFlag::load(head.flag());
let mut pad = 0;
if head.stream_id().is_zero() {
return Err(Error::InvalidStreamId);
}
if flags.is_padded() {
if src.is_empty() {
return Err(Error::TooMuchPadding);
}
pad = src[0] as usize;
let _ = src.split_to(1);
}
if flags.is_priority() {
Priority::_load(head, &mut src)?;
}
if pad > 0 {
if pad > src.len() {
return Err(Error::TooMuchPadding);
}
let len = src.len() - pad;
src.truncate(len);
}
let headers = Headers {
stream_id: head.stream_id(),
header_block: HeaderBlock {
fields: HeaderMap::new(),
is_over_size: false,
pseudo: P::default(),
},
flags,
};
Ok((headers, src))
}
pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
debug_assert!(self.flags.is_end_headers());
let head = self.head();
self.header_block.into_encoding(encoder).encode(&head, dst, |_| {})
}
pub fn stream_id(&self) -> StreamId {
self.stream_id
}
pub fn is_end_headers(&self) -> bool {
self.flags.is_end_headers()
}
pub fn set_end_headers(&mut self) {
self.flags.set_end_headers();
}
pub fn is_end_stream(&self) -> bool {
self.flags.is_end_stream()
}
pub fn set_end_stream(&mut self) {
self.flags.set_end_stream()
}
pub fn is_over_size(&self) -> bool {
self.header_block.is_over_size
}
pub fn into_parts(self) -> (P, HeaderMap) {
(self.header_block.pseudo, self.header_block.fields)
}
pub fn fields(&self) -> &HeaderMap {
&self.header_block.fields
}
pub fn into_fields(self) -> HeaderMap {
self.header_block.fields
}
fn head(&self) -> Head {
Head::new(Kind::Headers, self.flags.into(), self.stream_id)
}
}
impl Headers<ResponsePseudo> {
pub(crate) fn is_informational(&self) -> bool {
self.header_block.pseudo.status.is_informational()
}
}
impl Headers<()> {
pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
let mut flags = HeadersFlag::default();
flags.set_end_stream();
Self {
stream_id,
header_block: HeaderBlock {
fields,
is_over_size: false,
pseudo: (),
},
flags,
}
}
}
impl fmt::Debug for Headers {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut builder = f.debug_struct("Headers");
builder.field("stream_id", &self.stream_id).field("flags", &self.flags);
if let Some(ref protocol) = self.header_block.pseudo.protocol {
builder.field("protocol", protocol);
}
builder.finish()
}
}
pub fn parse_u64(src: &[u8]) -> Result<u64, ()> {
if src.len() > 19 {
return Err(());
}
let mut ret = 0;
for d in src {
if !d.is_ascii_digit() {
return Err(());
}
ret *= 10;
ret += (d - b'0') as u64;
}
Ok(ret)
}
impl Continuation {
fn head(&self) -> Head {
Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
}
pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
let head = self.head();
self.header_block.encode(&head, dst, |_| {})
}
}
impl Pseudo {
pub fn request(method: Method, uri: Uri, protocol: Option<BytesStr>) -> Self {
let parts = uri::Parts::from(uri);
let (scheme, path) = if method == Method::CONNECT && protocol.is_none() {
(None, None)
} else {
let path = parts
.path_and_query
.map(|v| BytesStr::from(v.as_str()))
.unwrap_or(BytesStr::from_static(""));
let path = if !path.is_empty() {
path
} else if method == Method::OPTIONS {
BytesStr::from_static("*")
} else {
BytesStr::from_static("/")
};
(parts.scheme, Some(path))
};
let mut pseudo = Pseudo {
method: Some(method),
scheme: None,
authority: None,
path,
protocol,
status: None,
};
if let Some(scheme) = scheme {
pseudo.set_scheme(scheme);
}
if let Some(authority) = parts.authority {
pseudo.set_authority(BytesStr::from(authority.as_str()));
}
pseudo
}
pub fn response(status: StatusCode) -> ResponsePseudo {
ResponsePseudo { status }
}
pub fn set_scheme(&mut self, scheme: uri::Scheme) {
let bytes_str = match scheme.as_str() {
"http" => BytesStr::from_static("http"),
"https" => BytesStr::from_static("https"),
s => BytesStr::from(s),
};
self.scheme = Some(bytes_str);
}
pub fn set_authority(&mut self, authority: BytesStr) {
self.authority = Some(authority);
}
}
impl EncodingHeaderBlock {
fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
where
F: FnOnce(&mut EncodeBuf<'_>),
{
let head_pos = dst.get_ref().len();
head.encode(0, dst);
let payload_pos = dst.get_ref().len();
f(dst);
let continuation = if self.hpack.len() > dst.remaining_mut() {
dst.put_slice(&self.hpack.split_to(dst.remaining_mut()));
Some(Continuation {
stream_id: head.stream_id(),
header_block: self,
})
} else {
dst.put_slice(&self.hpack);
None
};
let payload_len = (dst.get_ref().len() - payload_pos) as u64;
let payload_len_be = payload_len.to_be_bytes();
assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
(dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
if continuation.is_some() {
debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
dst.get_mut()[head_pos + 4] -= END_HEADERS;
}
continuation
}
}
impl HeadersFlag {
pub fn empty() -> HeadersFlag {
HeadersFlag(0)
}
pub fn load(bits: u8) -> HeadersFlag {
HeadersFlag(bits & ALL)
}
pub fn is_end_stream(&self) -> bool {
self.0 & END_STREAM == END_STREAM
}
pub fn set_end_stream(&mut self) {
self.0 |= END_STREAM;
}
pub fn is_end_headers(&self) -> bool {
self.0 & END_HEADERS == END_HEADERS
}
pub fn set_end_headers(&mut self) {
self.0 |= END_HEADERS;
}
pub fn is_padded(&self) -> bool {
self.0 & PADDED == PADDED
}
pub fn is_priority(&self) -> bool {
self.0 & PRIORITY == PRIORITY
}
}
impl Default for HeadersFlag {
fn default() -> Self {
HeadersFlag(END_HEADERS)
}
}
impl From<HeadersFlag> for u8 {
fn from(src: HeadersFlag) -> u8 {
src.0
}
}
impl fmt::Debug for HeadersFlag {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("HeadersFlag")
.field("end_headers", &self.is_end_headers())
.field("end_stream", &self.is_end_stream())
.field("padded", &self.is_padded())
.finish()
}
}
impl<P> HeaderBlock<P>
where
P: _Pseudo,
{
fn load(
&mut self,
src: &mut BytesMut,
max_header_list_size: usize,
decoder: &mut hpack::Decoder,
) -> Result<(), Error> {
let mut reg = !self.fields.is_empty();
let mut malformed = false;
let mut headers_size = self.calculate_header_list_size();
let mut cursor = Cursor::new(src);
let res = decoder.decode(&mut cursor, |header| {
use hpack::Header::*;
match header {
Field { name, value } => {
if name == header::CONNECTION
|| name == header::TRANSFER_ENCODING
|| name == header::UPGRADE
|| name == "keep-alive"
|| name == "proxy-connection"
|| (name == header::TE && value != "trailers")
{
malformed = true;
} else {
reg = true;
headers_size += decoded_header_size(name.as_str().len(), value.len());
if headers_size < max_header_list_size {
self.fields.append(name, value);
} else if !self.is_over_size {
self.is_over_size = true;
}
}
}
header => self.pseudo.parse(
header,
max_header_list_size,
&mut self.is_over_size,
reg,
&mut malformed,
&mut headers_size,
),
}
});
res?;
if malformed {
return Err(Error::MalformedMessage);
}
Ok(())
}
fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
let mut hpack = BytesMut::new();
let headers = self.pseudo.into_iter().chain(
self.fields
.into_iter()
.map(|(name, value)| hpack::Header::Field { name, value }),
);
encoder.encode(headers, &mut hpack);
EncodingHeaderBlock { hpack: hpack.freeze() }
}
fn calculate_header_list_size(&self) -> usize {
self.pseudo.as_header_size()
+ self
.fields
.iter()
.map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
.sum::<usize>()
}
}
pub trait _Pseudo {
fn into_iter(self) -> impl Iterator<Item = hpack::Header<Option<HeaderName>>> + Send;
fn parse(
&mut self,
header: hpack::Header,
max_header_list_size: usize,
is_over_size: &mut bool,
reg: bool,
malformed: &mut bool,
headers_size: &mut usize,
);
fn as_header_size(&self) -> usize;
}
impl _Pseudo for Pseudo {
fn into_iter(self) -> impl Iterator<Item = hpack::Header<Option<HeaderName>>> {
struct Iter(Pseudo);
impl Iterator for Iter {
type Item = hpack::Header<Option<HeaderName>>;
fn next(&mut self) -> Option<Self::Item> {
use super::super::hpack::Header::*;
if let Some(method) = self.0.method.take() {
return Some(Method(method));
}
if let Some(scheme) = self.0.scheme.take() {
return Some(Scheme(scheme));
}
if let Some(authority) = self.0.authority.take() {
return Some(Authority(authority));
}
if let Some(path) = self.0.path.take() {
return Some(Path(path));
}
if let Some(protocol) = self.0.protocol.take() {
return Some(Protocol(protocol));
}
if let Some(status) = self.0.status.take() {
return Some(Status(status));
}
None
}
}
Iter(self)
}
fn parse(
&mut self,
header: hpack::Header,
max_header_list_size: usize,
is_over_size: &mut bool,
reg: bool,
malformed: &mut bool,
headers_size: &mut usize,
) {
use hpack::Header::*;
macro_rules! set_pseudo {
($field:ident, $val:expr) => {{
if reg {
*malformed = true;
} else if self.$field.is_some() {
*malformed = true;
} else {
let __val = $val;
*headers_size += decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
if *headers_size < max_header_list_size {
self.$field = Some(__val);
} else if !*is_over_size {
*is_over_size = true;
}
}
}};
}
match header {
Authority(v) => set_pseudo!(authority, v),
Method(v) => set_pseudo!(method, v),
Scheme(v) => set_pseudo!(scheme, v),
Path(v) => set_pseudo!(path, v),
Protocol(v) => set_pseudo!(protocol, v),
Status(_) => *malformed = true,
_ => unreachable!(),
}
}
fn as_header_size(&self) -> usize {
macro_rules! pseudo_size {
($name:ident) => {{
self.$name
.as_ref()
.map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
.unwrap_or(0)
}};
}
pseudo_size!(method)
+ pseudo_size!(scheme)
+ pseudo_size!(status)
+ pseudo_size!(authority)
+ pseudo_size!(path)
+ pseudo_size!(protocol)
}
}
impl _Pseudo for ResponsePseudo {
fn into_iter(self) -> impl Iterator<Item = hpack::Header<Option<HeaderName>>> {
core::iter::once_with(move || hpack::Header::Status(self.status))
}
fn parse(
&mut self,
header: hpack::Header,
max_header_list_size: usize,
is_over_size: &mut bool,
reg: bool,
malformed: &mut bool,
headers_size: &mut usize,
) {
use hpack::Header;
match header {
Header::Status(status) => {
if reg {
*malformed = true;
} else {
*headers_size += self.as_header_size();
if *headers_size < max_header_list_size {
self.status = status;
} else if !*is_over_size {
*is_over_size = true;
}
}
}
_ => unreachable!(),
}
}
fn as_header_size(&self) -> usize {
decoded_header_size("status".len() + 1, self.status.as_str().len())
}
}
impl _Pseudo for () {
fn into_iter(self) -> impl Iterator<Item = hpack::Header<Option<HeaderName>>> {
core::iter::empty()
}
fn parse(&mut self, _: hpack::Header, _: usize, _: &mut bool, _: bool, _: &mut bool, _: &mut usize) {}
fn as_header_size(&self) -> usize {
0
}
}
fn decoded_header_size(name: usize, value: usize) -> usize {
name + value + 32
}
#[cfg(test)]
mod test {
use core::iter::FromIterator;
use crate::{
h2::proto::{
frame,
hpack::{Encoder, huffman},
},
http::HeaderValue,
};
use super::*;
#[test]
fn test_nameless_header_at_resume() {
let mut encoder = Encoder::default();
let mut dst = BytesMut::new();
let headers = Headers::new(
StreamId::ZERO,
(),
HeaderMap::from_iter(vec![
(HeaderName::from_static("hello"), HeaderValue::from_static("world")),
(HeaderName::from_static("hello"), HeaderValue::from_static("zomg")),
(HeaderName::from_static("hello"), HeaderValue::from_static("sup")),
]),
);
let continuation = headers
.encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
.unwrap();
assert_eq!(17, dst.len());
assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
assert_eq!("hello", huff_decode(&dst[11..15]));
assert_eq!(0x80 | 4, dst[15]);
let mut world = dst[16..17].to_owned();
dst.clear();
assert!(
continuation
.encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
.is_none()
);
world.extend_from_slice(&dst[9..12]);
assert_eq!("world", huff_decode(&world));
assert_eq!(24, dst.len());
assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
assert_eq!("zomg", huff_decode(&dst[15..18]));
assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
assert_eq!("sup", huff_decode(&dst[21..]));
}
fn huff_decode(src: &[u8]) -> BytesMut {
let mut buf = BytesMut::new();
huffman::decode(src, &mut buf).unwrap()
}
#[test]
fn test_connect_request_pseudo_headers_omits_path_and_scheme() {
assert_eq!(
Pseudo::request(Method::CONNECT, Uri::from_static("https://example.com:8443"), None),
Pseudo {
method: Method::CONNECT.into(),
authority: BytesStr::from_static("example.com:8443").into(),
..Default::default()
}
);
assert_eq!(
Pseudo::request(Method::CONNECT, Uri::from_static("https://example.com/test"), None),
Pseudo {
method: Method::CONNECT.into(),
authority: BytesStr::from_static("example.com").into(),
..Default::default()
}
);
assert_eq!(
Pseudo::request(Method::CONNECT, Uri::from_static("example.com:8443"), None),
Pseudo {
method: Method::CONNECT.into(),
authority: BytesStr::from_static("example.com:8443").into(),
..Default::default()
}
);
}
#[test]
fn test_extended_connect_request_pseudo_headers_includes_path_and_scheme() {
assert_eq!(
Pseudo::request(
Method::CONNECT,
Uri::from_static("https://example.com:8443"),
Some(BytesStr::from_static("the-bread-protocol"))
),
Pseudo {
method: Method::CONNECT.into(),
authority: BytesStr::from_static("example.com:8443").into(),
scheme: BytesStr::from_static("https").into(),
path: BytesStr::from_static("/").into(),
protocol: Some(BytesStr::from_static("the-bread-protocol")),
..Default::default()
}
);
assert_eq!(
Pseudo::request(
Method::CONNECT,
Uri::from_static("https://example.com:8443/test"),
Some(BytesStr::from_static("the-bread-protocol"))
),
Pseudo {
method: Method::CONNECT.into(),
authority: BytesStr::from_static("example.com:8443").into(),
scheme: BytesStr::from_static("https").into(),
path: BytesStr::from_static("/test").into(),
protocol: Some(BytesStr::from_static("the-bread-protocol")),
..Default::default()
}
);
assert_eq!(
Pseudo::request(
Method::CONNECT,
Uri::from_static("http://example.com/a/b/c"),
Some(BytesStr::from_static("the-bread-protocol"))
),
Pseudo {
method: Method::CONNECT.into(),
authority: BytesStr::from_static("example.com").into(),
scheme: BytesStr::from_static("http").into(),
path: BytesStr::from_static("/a/b/c").into(),
protocol: Some(BytesStr::from_static("the-bread-protocol")),
..Default::default()
}
);
}
#[test]
fn test_options_request_with_empty_path_has_asterisk_as_pseudo_path() {
assert_eq!(
Pseudo::request(Method::OPTIONS, Uri::from_static("example.com:8080"), None,),
Pseudo {
method: Method::OPTIONS.into(),
authority: BytesStr::from_static("example.com:8080").into(),
path: BytesStr::from_static("*").into(),
..Default::default()
}
);
}
#[test]
fn test_non_option_and_non_connect_requests_include_path_and_scheme() {
let methods = [
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::HEAD,
Method::PATCH,
Method::TRACE,
];
for method in methods {
assert_eq!(
Pseudo::request(method.clone(), Uri::from_static("http://example.com:8080"), None,),
Pseudo {
method: method.clone().into(),
authority: BytesStr::from_static("example.com:8080").into(),
scheme: BytesStr::from_static("http").into(),
path: BytesStr::from_static("/").into(),
..Default::default()
}
);
assert_eq!(
Pseudo::request(method.clone(), Uri::from_static("https://example.com/a/b/c"), None,),
Pseudo {
method: method.into(),
authority: BytesStr::from_static("example.com").into(),
scheme: BytesStr::from_static("https").into(),
path: BytesStr::from_static("/a/b/c").into(),
..Default::default()
}
);
}
}
}