use std::{
convert::TryFrom,
fmt,
iter::{IntoIterator, Iterator},
str::FromStr,
};
use http::{
header::{self, HeaderName, HeaderValue},
uri::{self, Authority, Parts, PathAndQuery, Scheme, Uri},
HeaderMap, Method, StatusCode,
};
use crate::qpack::HeaderField;
#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq, Clone))]
pub struct Header {
pseudo: Pseudo,
fields: HeaderMap,
}
#[allow(clippy::len_without_is_empty)]
impl Header {
pub fn request(method: Method, uri: Uri, fields: HeaderMap) -> Result<Self, HeaderError> {
match (uri.authority(), fields.get("host")) {
(None, None) => Err(HeaderError::MissingAuthority),
(Some(a), Some(h)) if a.as_str() != h => Err(HeaderError::ContradictedAuthority),
_ => Ok(Self {
pseudo: Pseudo::request(method, uri),
fields,
}),
}
}
pub fn response(status: StatusCode, fields: HeaderMap) -> Self {
Self {
pseudo: Pseudo::response(status),
fields,
}
}
pub fn trailer(fields: HeaderMap) -> Self {
Self {
pseudo: Pseudo::default(),
fields,
}
}
pub fn into_request_parts(self) -> Result<(Method, Uri, HeaderMap), HeaderError> {
let mut uri = Uri::builder();
if let Some(path) = self.pseudo.path {
uri = uri.path_and_query(path.as_str().as_bytes());
}
if let Some(scheme) = self.pseudo.scheme {
uri = uri.scheme(scheme.as_str().as_bytes());
}
match (self.pseudo.authority, self.fields.get("host")) {
(None, None) => return Err(HeaderError::MissingAuthority),
(Some(a), None) => uri = uri.authority(a.as_str().as_bytes()),
(None, Some(h)) => uri = uri.authority(h.as_bytes()),
(Some(a), Some(h)) if a.as_str() != h => {
return Err(HeaderError::ContradictedAuthority)
}
(Some(_), Some(h)) => uri = uri.authority(h.as_bytes()),
}
Ok((
self.pseudo.method.ok_or(HeaderError::MissingMethod)?,
uri.build().map_err(HeaderError::InvalidRequest)?,
self.fields,
))
}
pub fn into_response_parts(self) -> Result<(StatusCode, HeaderMap), HeaderError> {
Ok((
self.pseudo.status.ok_or(HeaderError::MissingStatus)?,
self.fields,
))
}
pub fn into_fields(self) -> HeaderMap {
self.fields
}
pub fn len(&self) -> usize {
self.pseudo.len() + self.fields.len()
}
pub fn size(&self) -> usize {
self.pseudo.len() + self.fields.len()
}
#[cfg(test)]
pub(crate) fn authory_mut(&mut self) -> &mut Option<Authority> {
&mut self.pseudo.authority
}
}
impl IntoIterator for Header {
type Item = HeaderField;
type IntoIter = HeaderIter;
fn into_iter(self) -> Self::IntoIter {
HeaderIter {
pseudo: Some(self.pseudo),
last_header_name: None,
fields: self.fields.into_iter(),
}
}
}
pub struct HeaderIter {
pseudo: Option<Pseudo>,
last_header_name: Option<HeaderName>,
fields: header::IntoIter<HeaderValue>,
}
impl Iterator for HeaderIter {
type Item = HeaderField;
fn next(&mut self) -> Option<Self::Item> {
if let Some(ref mut pseudo) = self.pseudo {
if let Some(method) = pseudo.method.take() {
return Some((":method", method.as_str()).into());
}
if let Some(scheme) = pseudo.scheme.take() {
return Some((":scheme", scheme.as_str().as_bytes()).into());
}
if let Some(authority) = pseudo.authority.take() {
return Some((":authority", authority.as_str().as_bytes()).into());
}
if let Some(path) = pseudo.path.take() {
return Some((":path", path.as_str().as_bytes()).into());
}
if let Some(status) = pseudo.status.take() {
return Some((":status", status.as_str()).into());
}
}
self.pseudo = None;
for (new_header_name, header_value) in self.fields.by_ref() {
if let Some(new) = new_header_name {
self.last_header_name = Some(new);
}
if let (Some(ref n), v) = (&self.last_header_name, header_value) {
return Some((n.as_str(), v.as_bytes()).into());
}
}
None
}
}
impl TryFrom<Vec<HeaderField>> for Header {
type Error = HeaderError;
fn try_from(headers: Vec<HeaderField>) -> Result<Self, Self::Error> {
let mut fields = HeaderMap::with_capacity(headers.len());
let mut pseudo = Pseudo::default();
for field in headers.into_iter() {
let (name, value) = field.into_inner();
match Field::parse(name, value)? {
Field::Method(m) => {
pseudo.method = Some(m);
pseudo.len += 1;
}
Field::Scheme(s) => {
pseudo.scheme = Some(s);
pseudo.len += 1;
}
Field::Authority(a) => {
pseudo.authority = Some(a);
pseudo.len += 1;
}
Field::Path(p) => {
pseudo.path = Some(p);
pseudo.len += 1;
}
Field::Status(s) => {
pseudo.status = Some(s);
pseudo.len += 1;
}
Field::Header((n, v)) => {
fields.append(n, v);
}
}
}
Ok(Header { pseudo, fields })
}
}
enum Field {
Method(Method),
Scheme(Scheme),
Authority(Authority),
Path(PathAndQuery),
Status(StatusCode),
Header((HeaderName, HeaderValue)),
}
impl Field {
fn parse<N, V>(name: N, value: V) -> Result<Self, HeaderError>
where
N: AsRef<[u8]>,
V: AsRef<[u8]>,
{
let name = name.as_ref();
if name.is_empty() {
return Err(HeaderError::InvalidHeaderName("name is empty".into()));
}
if name[0] != b':' {
return Ok(Field::Header((
HeaderName::from_bytes(name).map_err(|_| HeaderError::invalid_name(name))?,
HeaderValue::from_bytes(value.as_ref())
.map_err(|_| HeaderError::invalid_value(name, value))?,
)));
}
Ok(match name {
b":scheme" => Field::Scheme(try_value(name, value)?),
b":authority" => Field::Authority(try_value(name, value)?),
b":path" => Field::Path(try_value(name, value)?),
b":method" => Field::Method(
Method::from_bytes(value.as_ref())
.map_err(|_| HeaderError::invalid_value(name, value))?,
),
b":status" => Field::Status(
StatusCode::from_bytes(value.as_ref())
.map_err(|_| HeaderError::invalid_value(name, value))?,
),
_ => return Err(HeaderError::invalid_name(name)),
})
}
}
fn try_value<N, V, R>(name: N, value: V) -> Result<R, HeaderError>
where
N: AsRef<[u8]>,
V: AsRef<[u8]>,
R: FromStr,
{
let (name, value) = (name.as_ref(), value.as_ref());
let s = std::str::from_utf8(value).map_err(|_| HeaderError::invalid_value(name, value))?;
R::from_str(s).map_err(|_| HeaderError::invalid_value(name, value))
}
#[derive(Debug, Default)]
#[cfg_attr(test, derive(PartialEq, Clone))]
struct Pseudo {
method: Option<Method>,
scheme: Option<Scheme>,
authority: Option<Authority>,
path: Option<PathAndQuery>,
status: Option<StatusCode>,
len: usize,
}
#[allow(clippy::len_without_is_empty)]
impl Pseudo {
fn request(method: Method, uri: Uri) -> Self {
let Parts {
scheme,
authority,
path_and_query,
..
} = uri::Parts::from(uri);
let path = path_and_query.map_or_else(
|| PathAndQuery::from_static("/"),
|path| {
if path.path().is_empty() && method != Method::OPTIONS {
PathAndQuery::from_static("/")
} else {
path
}
},
);
let len = 3 + if authority.is_some() { 1 } else { 0 };
Self {
method: Some(method),
scheme: scheme.or(Some(Scheme::HTTPS)),
authority,
path: Some(path),
status: None,
len,
}
}
fn response(status: StatusCode) -> Self {
Pseudo {
method: None,
scheme: None,
authority: None,
path: None,
status: Some(status),
len: 1,
}
}
fn len(&self) -> usize {
self.len
}
}
#[derive(Debug)]
pub enum HeaderError {
InvalidHeaderName(String),
InvalidHeaderValue(String),
InvalidRequest(http::Error),
MissingMethod,
MissingStatus,
MissingAuthority,
ContradictedAuthority,
}
impl HeaderError {
fn invalid_name<N>(name: N) -> Self
where
N: AsRef<[u8]>,
{
HeaderError::InvalidHeaderName(format!("{:?}", name.as_ref()))
}
fn invalid_value<N, V>(name: N, value: V) -> Self
where
N: AsRef<[u8]>,
V: AsRef<[u8]>,
{
HeaderError::InvalidHeaderValue(format!(
"{:?} {:?}",
String::from_utf8_lossy(name.as_ref()),
value.as_ref()
))
}
}
impl std::error::Error for HeaderError {}
impl fmt::Display for HeaderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HeaderError::InvalidHeaderName(h) => write!(f, "invalid header name: {}", h),
HeaderError::InvalidHeaderValue(v) => write!(f, "invalid header value: {}", v),
HeaderError::InvalidRequest(r) => write!(f, "invalid request: {}", r),
HeaderError::MissingMethod => write!(f, "missing method in request headers"),
HeaderError::MissingStatus => write!(f, "missing status in response headers"),
HeaderError::MissingAuthority => write!(f, "missing authority"),
HeaderError::ContradictedAuthority => {
write!(f, "uri and authority field are in contradiction")
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
#[test]
fn request_has_no_authority_nor_host() {
let headers = Header::try_from(vec![(b":method", Method::GET.as_str()).into()]).unwrap();
assert!(headers.pseudo.authority.is_none());
assert_matches!(
headers.into_request_parts(),
Err(HeaderError::MissingAuthority)
);
}
#[test]
fn request_has_authority() {
let headers = Header::try_from(vec![
(b":method", Method::GET.as_str()).into(),
(b":authority", b"test.com").into(),
])
.unwrap();
assert_matches!(headers.into_request_parts(), Ok(_));
}
#[test]
fn request_has_host() {
let headers = Header::try_from(vec![
(b":method", Method::GET.as_str()).into(),
(b"host", b"test.com").into(),
])
.unwrap();
assert!(headers.pseudo.authority.is_none());
assert_matches!(headers.into_request_parts(), Ok(_));
}
#[test]
fn request_has_same_host_and_authority() {
let headers = Header::try_from(vec![
(b":method", Method::GET.as_str()).into(),
(b":authority", b"test.com").into(),
(b"host", b"test.com").into(),
])
.unwrap();
assert_matches!(headers.into_request_parts(), Ok(_));
}
#[test]
fn request_has_different_host_and_authority() {
let headers = Header::try_from(vec![
(b":method", Method::GET.as_str()).into(),
(b":authority", b"authority.com").into(),
(b"host", b"host.com").into(),
])
.unwrap();
assert_matches!(
headers.into_request_parts(),
Err(HeaderError::ContradictedAuthority)
);
}
#[test]
fn preserves_duplicate_headers() {
let headers = Header::try_from(vec![
(b":method", Method::GET.as_str()).into(),
(b":authority", b"test.com").into(),
(b"set-cookie", b"foo=foo").into(),
(b"set-cookie", b"bar=bar").into(),
(b"other-header", b"other-header-value").into(),
])
.unwrap();
assert_eq!(
headers
.clone()
.into_iter()
.filter(|h| h.name.as_ref() == b"set-cookie")
.collect::<Vec<_>>(),
vec![
HeaderField {
name: std::borrow::Cow::Borrowed(b"set-cookie"),
value: std::borrow::Cow::Borrowed(b"foo=foo")
},
HeaderField {
name: std::borrow::Cow::Borrowed(b"set-cookie"),
value: std::borrow::Cow::Borrowed(b"bar=bar")
}
]
);
assert_eq!(
headers
.into_iter()
.filter(|h| h.name.as_ref() == b"other-header")
.collect::<Vec<_>>(),
vec![HeaderField {
name: std::borrow::Cow::Borrowed(b"other-header"),
value: std::borrow::Cow::Borrowed(b"other-header-value")
},]
);
}
}