use crate::error::{Error, ErrorKind, ResultExt};
use std::{
borrow::Cow, collections::HashSet, convert::Infallible, fmt, str::FromStr, sync::LazyLock,
};
pub static DEFAULT_ALLOWED_HEADER_NAMES: LazyLock<HashSet<Cow<'static, str>>> =
LazyLock::new(|| {
[
"accept",
"cache-control",
"connection",
"content-length",
"content-type",
"date",
"etag",
"expires",
"if-match",
"if-modified-since",
"if-none-match",
"if-unmodified-since",
"last-modified",
"ms-cv",
"pragma",
"request-id",
"retry-after",
"server",
"traceparent",
"tracestate",
"transfer-encoding",
"user-agent",
"www-authenticate",
"x-ms-request-id",
"x-ms-client-request-id",
"x-ms-return-client-request-id",
]
.iter()
.map(|s| Cow::Borrowed(*s))
.collect()
});
pub trait AsHeaders {
type Error: std::error::Error + Send + Sync + 'static;
type Iter: Iterator<Item = (HeaderName, HeaderValue)>;
fn as_headers(&self) -> Result<Self::Iter, Self::Error>;
}
impl<T> AsHeaders for T
where
T: Header,
{
type Error = Infallible;
type Iter = std::vec::IntoIter<(HeaderName, HeaderValue)>;
fn as_headers(&self) -> Result<Self::Iter, Self::Error> {
Ok(vec![(self.name(), self.value())].into_iter())
}
}
impl<T> AsHeaders for Option<T>
where
T: AsHeaders<Iter = std::vec::IntoIter<(HeaderName, HeaderValue)>>,
{
type Error = T::Error;
type Iter = T::Iter;
fn as_headers(&self) -> Result<Self::Iter, T::Error> {
match self {
Some(h) => h.as_headers(),
None => Ok(vec![].into_iter()),
}
}
}
pub trait FromHeaders: Sized {
type Error: std::error::Error + Send + Sync + 'static;
fn header_names() -> &'static [&'static str];
fn from_headers(headers: &Headers) -> Result<Option<Self>, Self::Error>;
}
pub trait Header {
fn name(&self) -> HeaderName;
fn value(&self) -> HeaderValue;
}
#[derive(Clone, PartialEq, Eq, Default)]
pub struct Headers(std::collections::HashMap<HeaderName, HeaderValue>);
impl Headers {
pub fn new() -> Self {
Self::default()
}
pub fn get<H: FromHeaders>(&self) -> crate::Result<H> {
match H::from_headers(self) {
Ok(Some(x)) => Ok(x),
Ok(None) => Err(crate::Error::with_message_fn(
ErrorKind::DataConversion,
|| {
let required_headers = H::header_names();
format!(
"required header(s) not found: {}",
required_headers.join(", ")
)
},
)),
Err(e) => Err(crate::Error::new(ErrorKind::DataConversion, e)),
}
}
pub fn get_optional<H: FromHeaders>(&self) -> Result<Option<H>, H::Error> {
H::from_headers(self)
}
pub fn get_optional_string(&self, key: &HeaderName) -> Option<String> {
self.get_as(key).ok()
}
pub fn get_str(&self, key: &HeaderName) -> crate::Result<&str> {
self.get_with(key, |s| crate::Result::Ok(s.as_str()))
}
pub fn get_optional_str(&self, key: &HeaderName) -> Option<&str> {
self.get_str(key).ok()
}
pub fn get_as<V, E>(&self, key: &HeaderName) -> crate::Result<V>
where
V: FromStr<Err = E>,
E: std::error::Error + Send + Sync + 'static,
{
self.get_with(key, |s| s.as_str().parse())
}
pub fn get_optional_as<V, E>(&self, key: &HeaderName) -> crate::Result<Option<V>>
where
V: FromStr<Err = E>,
E: std::error::Error + Send + Sync + 'static,
{
self.get_optional_with(key, |s| s.as_str().parse())
}
pub fn get_with<'a, V, F, E>(&'a self, key: &HeaderName, parser: F) -> crate::Result<V>
where
F: FnOnce(&'a HeaderValue) -> Result<V, E>,
E: std::error::Error + Send + Sync + 'static,
{
self.get_optional_with(key, parser)?.ok_or_else(|| {
Error::with_message_fn(ErrorKind::DataConversion, || {
format!("header not found {}", key.as_str())
})
})
}
pub fn get_optional_with<'a, V, F, E>(
&'a self,
key: &HeaderName,
parser: F,
) -> crate::Result<Option<V>>
where
F: FnOnce(&'a HeaderValue) -> Result<V, E>,
E: std::error::Error + Send + Sync + 'static,
{
self.0
.get(key)
.map(|v: &HeaderValue| {
parser(v).with_context_fn(ErrorKind::DataConversion, || {
let ty = std::any::type_name::<V>();
format!("unable to parse header '{key:?}: {v:?}' into {ty}",)
})
})
.transpose()
}
pub fn insert<K, V>(&mut self, key: K, value: V)
where
K: Into<HeaderName>,
V: Into<HeaderValue>,
{
self.0.insert(key.into(), value.into());
}
pub fn add<H>(&mut self, header: H) -> Result<(), H::Error>
where
H: AsHeaders,
{
for (key, value) in header.as_headers()? {
self.insert(key, value);
}
Ok(())
}
pub fn iter(&self) -> impl Iterator<Item = (&HeaderName, &HeaderValue)> {
self.0.iter()
}
pub fn remove<K>(&mut self, key: K) -> Option<HeaderValue>
where
K: Into<HeaderName>,
{
self.0.remove(&key.into())
}
}
impl fmt::Debug for Headers {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map()
.entries(self.0.iter().map(|(k, v)| {
(
k.as_str(),
if DEFAULT_ALLOWED_HEADER_NAMES.contains(k.as_str()) {
v.as_str()
} else {
super::REDACTED_PATTERN
},
)
}))
.finish()
}
}
impl IntoIterator for Headers {
type Item = (HeaderName, HeaderValue);
type IntoIter = std::collections::hash_map::IntoIter<HeaderName, HeaderValue>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl From<std::collections::HashMap<HeaderName, HeaderValue>> for Headers {
fn from(c: std::collections::HashMap<HeaderName, HeaderValue>) -> Self {
Self(c)
}
}
#[derive(Clone, Debug, Eq, PartialOrd, Ord)]
pub struct HeaderName {
name: Cow<'static, str>,
pub(crate) is_standard: bool,
}
impl HeaderName {
pub const fn from_static(s: &'static str) -> Self {
ensure_no_uppercase(s);
Self {
name: Cow::Borrowed(s),
is_standard: false,
}
}
pub const fn from_static_standard(s: &'static str) -> Self {
ensure_no_uppercase(s);
Self {
name: Cow::Borrowed(s),
is_standard: true,
}
}
fn from_cow<C>(c: C) -> Self
where
C: Into<Cow<'static, str>>,
{
let c = c.into();
assert!(
c.chars().all(|c| c.is_lowercase() || !c.is_alphabetic()),
"header names must be lowercase: {c}"
);
Self {
name: c,
is_standard: false,
}
}
pub fn as_str(&self) -> &str {
self.name.as_ref()
}
pub fn is_standard(&self) -> bool {
self.is_standard
}
}
impl PartialEq for HeaderName {
fn eq(&self, other: &Self) -> bool {
self.name.eq_ignore_ascii_case(&other.name)
}
}
impl PartialEq<&str> for HeaderName {
fn eq(&self, other: &&str) -> bool {
self.name.eq_ignore_ascii_case(other)
}
}
impl std::hash::Hash for HeaderName {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
std::hash::Hash::hash(&self.name, state);
}
}
const fn ensure_no_uppercase(s: &str) {
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
let byte = bytes[i];
assert!(
!(byte >= 65u8 && byte <= 90u8),
"header names must not contain uppercase letters"
);
i += 1;
}
}
impl From<&'static str> for HeaderName {
fn from(s: &'static str) -> Self {
Self::from_cow(s)
}
}
impl From<String> for HeaderName {
fn from(s: String) -> Self {
Self::from_cow(s.to_lowercase())
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct HeaderValue(Cow<'static, str>);
impl HeaderValue {
pub const fn from_static(s: &'static str) -> Self {
Self(Cow::Borrowed(s))
}
pub fn from_cow<C>(c: C) -> Self
where
C: Into<Cow<'static, str>>,
{
Self(c.into())
}
pub fn as_str(&self) -> &str {
self.0.as_ref()
}
}
impl fmt::Debug for HeaderValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("HeaderValue")
}
}
impl From<&'static str> for HeaderValue {
fn from(s: &'static str) -> Self {
Self::from_cow(s)
}
}
impl From<String> for HeaderValue {
fn from(s: String) -> Self {
Self::from_cow(s)
}
}
impl From<&String> for HeaderValue {
fn from(s: &String) -> Self {
s.clone().into()
}
}
#[cfg(test)]
mod tests {
use crate::error::ErrorKind;
use url::Url;
use super::{FromHeaders, HeaderName, Headers};
#[derive(Debug)]
struct ContentLocationForTest(Url);
impl FromHeaders for ContentLocationForTest {
type Error = url::ParseError;
fn header_names() -> &'static [&'static str] {
&["content-location"]
}
fn from_headers(headers: &super::Headers) -> Result<Option<Self>, Self::Error> {
let Some(loc) = headers.get_optional_str(&HeaderName::from("content-location")) else {
return Ok(None);
};
Ok(Some(ContentLocationForTest(loc.parse()?)))
}
}
#[test]
pub fn headers_get_optional_returns_ok_some_if_header_present_and_valid() {
let mut headers = Headers::new();
headers.insert("content-location", "https://example.com");
let content_location: ContentLocationForTest = headers.get_optional().unwrap().unwrap();
assert_eq!("https://example.com/", content_location.0.as_str())
}
#[test]
pub fn headers_get_optional_returns_ok_none_if_header_not_present() {
let headers = Headers::new();
let content_location: Option<ContentLocationForTest> = headers.get_optional().unwrap();
assert!(content_location.is_none())
}
#[test]
pub fn headers_get_optional_returns_err_if_conversion_fails() {
let mut headers = Headers::new();
headers.insert("content-location", "not a URL");
let err = headers
.get_optional::<ContentLocationForTest>()
.unwrap_err();
assert_eq!(url::ParseError::RelativeUrlWithoutBase, err)
}
#[test]
pub fn headers_get_returns_ok_if_header_present_and_valid() {
let mut headers = Headers::new();
headers.insert("content-location", "https://example.com");
let content_location: ContentLocationForTest = headers.get().unwrap();
assert_eq!("https://example.com/", content_location.0.as_str())
}
#[test]
pub fn headers_get_returns_err_if_header_not_present() {
let headers = Headers::new();
let err = headers.get::<ContentLocationForTest>().unwrap_err();
assert_eq!(&ErrorKind::DataConversion, err.kind());
assert_eq!(
"required header(s) not found: content-location",
format!("{}", err)
);
}
#[test]
pub fn headers_get_returns_err_if_header_requiring_multiple_headers_not_present() {
#[derive(Debug)]
struct HasTwoHeaders;
impl FromHeaders for HasTwoHeaders {
type Error = std::convert::Infallible;
fn header_names() -> &'static [&'static str] {
&["header-a", "header-b"]
}
fn from_headers(_: &Headers) -> Result<Option<Self>, Self::Error> {
Ok(None)
}
}
let headers = Headers::new();
let err = headers.get::<HasTwoHeaders>().unwrap_err();
assert_eq!(&ErrorKind::DataConversion, err.kind());
assert_eq!(
"required header(s) not found: header-a, header-b",
format!("{}", err)
);
}
#[test]
pub fn headers_get_returns_err_if_conversion_fails() {
let mut headers = Headers::new();
headers.insert("content-location", "not a URL");
let err = headers.get::<ContentLocationForTest>().unwrap_err();
assert_eq!(&ErrorKind::DataConversion, err.kind());
let inner: Box<url::ParseError> = err.into_inner().unwrap().downcast().unwrap();
assert_eq!(Box::new(url::ParseError::RelativeUrlWithoutBase), inner)
}
#[test]
pub fn headers_remove_existing_header_returns_value() {
let mut headers = Headers::new();
headers.insert("test-header", "test-value");
assert_eq!(
headers.get_optional_str(&HeaderName::from("test-header")),
Some("test-value")
);
let removed_value = headers.remove("test-header");
assert!(removed_value.is_some());
assert_eq!(removed_value.unwrap().as_str(), "test-value");
assert_eq!(
headers.get_optional_str(&HeaderName::from("test-header")),
None
);
}
#[test]
pub fn headers_remove_nonexistent_header_returns_none() {
let mut headers = Headers::new();
let removed_value = headers.remove("nonexistent-header");
assert_eq!(removed_value, None);
}
#[test]
pub fn headers_remove_works_with_different_key_types() {
let mut headers = Headers::new();
headers.insert("test-header", "test-value");
let removed_value = headers.remove("test-header");
assert!(removed_value.is_some());
assert_eq!(removed_value.unwrap().as_str(), "test-value");
headers.insert("test-header", "test-value");
let removed_value = headers.remove(HeaderName::from("test-header"));
assert!(removed_value.is_some());
assert_eq!(removed_value.unwrap().as_str(), "test-value");
headers.insert("test-header", "test-value");
let removed_value = headers.remove("test-header".to_string());
assert!(removed_value.is_some());
assert_eq!(removed_value.unwrap().as_str(), "test-value");
}
}