use std::borrow::Cow;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt;
use std::fmt::{Display, Formatter};
use std::str::FromStr;
use http::header::{HeaderMap, HeaderName, HeaderValue, ValueIter};
use aws_smithy_types::date_time::Format;
use aws_smithy_types::primitive::Parse;
use aws_smithy_types::DateTime;
#[derive(Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct ParseError {
message: Option<Cow<'static, str>>,
}
impl ParseError {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self { message: None }
}
pub fn new_with_message(message: impl Into<Cow<'static, str>>) -> Self {
Self {
message: Some(message.into()),
}
}
}
impl Display for ParseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Output failed to parse in headers")?;
if let Some(message) = &self.message {
write!(f, ". {}", message)?;
}
Ok(())
}
}
impl Error for ParseError {}
pub fn many_dates(
values: ValueIter<HeaderValue>,
format: Format,
) -> Result<Vec<DateTime>, ParseError> {
let mut out = vec![];
for header in values {
let mut header = header
.to_str()
.map_err(|_| ParseError::new_with_message("header was not valid utf-8 string"))?;
while !header.is_empty() {
let (v, next) = DateTime::read(header, format, ',').map_err(|err| {
ParseError::new_with_message(format!("header could not be parsed as date: {}", err))
})?;
out.push(v);
header = next;
}
}
Ok(out)
}
pub fn headers_for_prefix<'a>(
headers: &'a http::HeaderMap,
key: &'a str,
) -> impl Iterator<Item = (&'a str, &'a HeaderName)> {
let lower_key = key.to_ascii_lowercase();
headers
.keys()
.filter(move |k| k.as_str().starts_with(&lower_key))
.map(move |h| (&h.as_str()[key.len()..], h))
}
pub fn read_many_from_str<T: FromStr>(
values: ValueIter<HeaderValue>,
) -> Result<Vec<T>, ParseError> {
read_many(values, |v: &str| {
v.parse()
.map_err(|_err| ParseError::new_with_message("failed during FromString conversion"))
})
}
pub fn read_many_primitive<T: Parse>(values: ValueIter<HeaderValue>) -> Result<Vec<T>, ParseError> {
read_many(values, |v: &str| {
T::parse_smithy_primitive(v).map_err(|primitive| {
ParseError::new_with_message(format!(
"failed reading a list of primitives: {}",
primitive
))
})
})
}
fn read_many<T>(
values: ValueIter<HeaderValue>,
f: impl Fn(&str) -> Result<T, ParseError>,
) -> Result<Vec<T>, ParseError> {
let mut out = vec![];
for header in values {
let mut header = header.as_bytes();
while !header.is_empty() {
let (v, next) = read_one(header, &f)?;
out.push(v);
header = next;
}
}
Ok(out)
}
pub fn one_or_none<T: FromStr>(
mut values: ValueIter<HeaderValue>,
) -> Result<Option<T>, ParseError> {
let first = match values.next() {
Some(v) => v,
None => return Ok(None),
};
let value = std::str::from_utf8(first.as_bytes())
.map_err(|_| ParseError::new_with_message("invalid utf-8"))?;
match values.next() {
None => T::from_str(value.trim())
.map_err(|_| ParseError::new())
.map(Some),
Some(_) => Err(ParseError::new_with_message(
"expected a single value but found multiple",
)),
}
}
pub fn set_request_header_if_absent<V>(
request: http::request::Builder,
key: HeaderName,
value: V,
) -> http::request::Builder
where
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
if !request
.headers_ref()
.map(|map| map.contains_key(&key))
.unwrap_or(false)
{
request.header(key, value)
} else {
request
}
}
pub fn set_response_header_if_absent<V>(
response: http::response::Builder,
key: HeaderName,
value: V,
) -> http::response::Builder
where
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
if !response
.headers_ref()
.map(|map| map.contains_key(&key))
.unwrap_or(false)
{
response.header(key, value)
} else {
response
}
}
mod parse_multi_header {
use super::ParseError;
use std::borrow::Cow;
fn trim(s: Cow<'_, str>) -> Cow<'_, str> {
match s {
Cow::Owned(s) => Cow::Owned(s.trim().into()),
Cow::Borrowed(s) => Cow::Borrowed(s.trim()),
}
}
fn replace<'a>(value: Cow<'a, str>, pattern: &str, replacement: &str) -> Cow<'a, str> {
if value.contains(pattern) {
Cow::Owned(value.replace(pattern, replacement))
} else {
value
}
}
pub(crate) fn read_value(input: &[u8]) -> Result<(Cow<'_, str>, &[u8]), ParseError> {
for (index, &byte) in input.iter().enumerate() {
let current_slice = &input[index..];
match byte {
b' ' | b'\t' => { }
b'"' => return read_quoted_value(¤t_slice[1..]),
_ => {
let (value, rest) = read_unquoted_value(current_slice)?;
return Ok((trim(value), rest));
}
}
}
Ok((Cow::Borrowed(""), &[]))
}
fn read_unquoted_value(input: &[u8]) -> Result<(Cow<'_, str>, &[u8]), ParseError> {
let next_delim = input.iter().position(|&b| b == b',').unwrap_or(input.len());
let (first, next) = input.split_at(next_delim);
let first = std::str::from_utf8(first)
.map_err(|_| ParseError::new_with_message("header was not valid utf8"))?;
Ok((Cow::Borrowed(first), then_comma(next).unwrap()))
}
fn read_quoted_value(input: &[u8]) -> Result<(Cow<'_, str>, &[u8]), ParseError> {
for index in 0..input.len() {
match input[index] {
b'"' if index == 0 || input[index - 1] != b'\\' => {
let mut inner =
Cow::Borrowed(std::str::from_utf8(&input[0..index]).map_err(|_| {
ParseError::new_with_message("header was not valid utf8")
})?);
inner = replace(inner, "\\\"", "\"");
inner = replace(inner, "\\\\", "\\");
let rest = then_comma(&input[(index + 1)..])?;
return Ok((inner, rest));
}
_ => {}
}
}
Err(ParseError::new_with_message(
"header value had quoted value without end quote",
))
}
fn then_comma(s: &[u8]) -> Result<&[u8], ParseError> {
if s.is_empty() {
Ok(s)
} else if s.starts_with(b",") {
Ok(&s[1..])
} else {
Err(ParseError::new_with_message("expected delimiter `,`"))
}
}
}
fn read_one<'a, T>(
s: &'a [u8],
f: &impl Fn(&str) -> Result<T, ParseError>,
) -> Result<(T, &'a [u8]), ParseError> {
let (value, rest) = parse_multi_header::read_value(s)?;
Ok((f(&value)?, rest))
}
pub fn quote_header_value<'a>(value: impl Into<Cow<'a, str>>) -> Cow<'a, str> {
let value = value.into();
if value.trim().len() != value.len()
|| value.contains('"')
|| value.contains(',')
|| value.contains('(')
|| value.contains(')')
{
Cow::Owned(format!(
"\"{}\"",
value.replace('\\', "\\\\").replace('"', "\\\"")
))
} else {
value
}
}
pub fn append_merge_header_maps(
mut lhs: HeaderMap<HeaderValue>,
rhs: HeaderMap<HeaderValue>,
) -> HeaderMap<HeaderValue> {
let mut last_header_name_seen = None;
for (header_name, header_value) in rhs.into_iter() {
match (&mut last_header_name_seen, header_name) {
(_, Some(header_name)) => {
lhs.append(header_name.clone(), header_value);
last_header_name_seen = Some(header_name);
}
(Some(header_name), None) => {
lhs.append(header_name.clone(), header_value);
}
(None, None) => unreachable!(),
};
}
lhs
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use aws_smithy_types::{date_time::Format, DateTime};
use http::header::{HeaderMap, HeaderName, HeaderValue};
use crate::header::{
append_merge_header_maps, headers_for_prefix, many_dates, read_many_from_str,
read_many_primitive, set_request_header_if_absent, set_response_header_if_absent,
ParseError,
};
use super::quote_header_value;
#[test]
fn put_on_request_if_absent() {
let builder = http::Request::builder().header("foo", "bar");
let builder = set_request_header_if_absent(builder, HeaderName::from_static("foo"), "baz");
let builder =
set_request_header_if_absent(builder, HeaderName::from_static("other"), "value");
let req = builder.body(()).expect("valid request");
assert_eq!(
req.headers().get_all("foo").iter().collect::<Vec<_>>(),
vec!["bar"]
);
assert_eq!(
req.headers().get_all("other").iter().collect::<Vec<_>>(),
vec!["value"]
);
}
#[test]
fn put_on_response_if_absent() {
let builder = http::Response::builder().header("foo", "bar");
let builder = set_response_header_if_absent(builder, HeaderName::from_static("foo"), "baz");
let builder =
set_response_header_if_absent(builder, HeaderName::from_static("other"), "value");
let response = builder.body(()).expect("valid response");
assert_eq!(
response.headers().get_all("foo").iter().collect::<Vec<_>>(),
vec!["bar"]
);
assert_eq!(
response
.headers()
.get_all("other")
.iter()
.collect::<Vec<_>>(),
vec!["value"]
);
}
#[test]
fn parse_floats() {
let test_request = http::Request::builder()
.header("X-Float-Multi", "0.0,Infinity,-Infinity,5555.5")
.header("X-Float-Error", "notafloat")
.body(())
.unwrap();
assert_eq!(
read_many_primitive::<f32>(test_request.headers().get_all("X-Float-Multi").iter())
.expect("valid"),
vec![0.0, f32::INFINITY, f32::NEG_INFINITY, 5555.5]
);
assert_eq!(
read_many_primitive::<f32>(test_request.headers().get_all("X-Float-Error").iter())
.expect_err("invalid"),
ParseError::new_with_message(
"failed reading a list of primitives: failed to parse input as f32"
)
)
}
#[test]
fn test_many_dates() {
let test_request = http::Request::builder()
.header("Empty", "")
.header("SingleHttpDate", "Wed, 21 Oct 2015 07:28:00 GMT")
.header(
"MultipleHttpDates",
"Wed, 21 Oct 2015 07:28:00 GMT,Thu, 22 Oct 2015 07:28:00 GMT",
)
.header("SingleEpochSeconds", "1234.5678")
.header("MultipleEpochSeconds", "1234.5678,9012.3456")
.body(())
.unwrap();
let read = |name: &str, format: Format| {
many_dates(test_request.headers().get_all(name).iter(), format)
};
let read_valid = |name: &str, format: Format| read(name, format).expect("valid");
assert_eq!(
read_valid("Empty", Format::DateTime),
Vec::<DateTime>::new()
);
assert_eq!(
read_valid("SingleHttpDate", Format::HttpDate),
vec![DateTime::from_secs_and_nanos(1445412480, 0)]
);
assert_eq!(
read_valid("MultipleHttpDates", Format::HttpDate),
vec![
DateTime::from_secs_and_nanos(1445412480, 0),
DateTime::from_secs_and_nanos(1445498880, 0)
]
);
assert_eq!(
read_valid("SingleEpochSeconds", Format::EpochSeconds),
vec![DateTime::from_secs_and_nanos(1234, 567_800_000)]
);
assert_eq!(
read_valid("MultipleEpochSeconds", Format::EpochSeconds),
vec![
DateTime::from_secs_and_nanos(1234, 567_800_000),
DateTime::from_secs_and_nanos(9012, 345_600_000)
]
);
}
#[test]
fn read_many_strings() {
let test_request = http::Request::builder()
.header("Empty", "")
.header("Foo", " foo")
.header("FooTrailing", "foo ")
.header("FooInQuotes", "\" foo \"")
.header("CommaInQuotes", "\"foo,bar\",baz")
.header("CommaInQuotesTrailing", "\"foo,bar\",baz ")
.header("QuoteInQuotes", "\"foo\\\",bar\",\"\\\"asdf\\\"\",baz")
.header(
"QuoteInQuotesWithSpaces",
"\"foo\\\",bar\", \"\\\"asdf\\\"\", baz",
)
.header("JunkFollowingQuotes", "\"\\\"asdf\\\"\"baz")
.header("EmptyQuotes", "\"\",baz")
.header("EscapedSlashesInQuotes", "foo, \"(foo\\\\bar)\"")
.body(())
.unwrap();
let read =
|name: &str| read_many_from_str::<String>(test_request.headers().get_all(name).iter());
let read_valid = |name: &str| read(name).expect("valid");
assert_eq!(read_valid("Empty"), Vec::<String>::new());
assert_eq!(read_valid("Foo"), vec!["foo"]);
assert_eq!(read_valid("FooTrailing"), vec!["foo"]);
assert_eq!(read_valid("FooInQuotes"), vec![" foo "]);
assert_eq!(read_valid("CommaInQuotes"), vec!["foo,bar", "baz"]);
assert_eq!(read_valid("CommaInQuotesTrailing"), vec!["foo,bar", "baz"]);
assert_eq!(
read_valid("QuoteInQuotes"),
vec!["foo\",bar", "\"asdf\"", "baz"]
);
assert_eq!(
read_valid("QuoteInQuotesWithSpaces"),
vec!["foo\",bar", "\"asdf\"", "baz"]
);
assert!(read("JunkFollowingQuotes").is_err());
assert_eq!(read_valid("EmptyQuotes"), vec!["", "baz"]);
assert_eq!(
read_valid("EscapedSlashesInQuotes"),
vec!["foo", "(foo\\bar)"]
);
}
#[test]
fn read_many_bools() {
let test_request = http::Request::builder()
.header("X-Bool-Multi", "true,false")
.header("X-Bool-Multi", "true")
.header("X-Bool", "true")
.header("X-Bool-Invalid", "truth,falsy")
.header("X-Bool-Single", "true,false,true,true")
.header("X-Bool-Quoted", "true,\"false\",true,true")
.body(())
.unwrap();
assert_eq!(
read_many_primitive::<bool>(test_request.headers().get_all("X-Bool-Multi").iter())
.expect("valid"),
vec![true, false, true]
);
assert_eq!(
read_many_primitive::<bool>(test_request.headers().get_all("X-Bool").iter()).unwrap(),
vec![true]
);
assert_eq!(
read_many_primitive::<bool>(test_request.headers().get_all("X-Bool-Single").iter())
.unwrap(),
vec![true, false, true, true]
);
assert_eq!(
read_many_primitive::<bool>(test_request.headers().get_all("X-Bool-Quoted").iter())
.unwrap(),
vec![true, false, true, true]
);
read_many_primitive::<bool>(test_request.headers().get_all("X-Bool-Invalid").iter())
.expect_err("invalid");
}
#[test]
fn check_read_many_i16() {
let test_request = http::Request::builder()
.header("X-Multi", "123,456")
.header("X-Multi", "789")
.header("X-Num", "777")
.header("X-Num-Invalid", "12ef3")
.header("X-Num-Single", "1,2,3,-4,5")
.header("X-Num-Quoted", "1, \"2\",3,\"-4\",5")
.body(())
.unwrap();
assert_eq!(
read_many_primitive::<i16>(test_request.headers().get_all("X-Multi").iter())
.expect("valid"),
vec![123, 456, 789]
);
assert_eq!(
read_many_primitive::<i16>(test_request.headers().get_all("X-Num").iter()).unwrap(),
vec![777]
);
assert_eq!(
read_many_primitive::<i16>(test_request.headers().get_all("X-Num-Single").iter())
.unwrap(),
vec![1, 2, 3, -4, 5]
);
assert_eq!(
read_many_primitive::<i16>(test_request.headers().get_all("X-Num-Quoted").iter())
.unwrap(),
vec![1, 2, 3, -4, 5]
);
read_many_primitive::<i16>(test_request.headers().get_all("X-Num-Invalid").iter())
.expect_err("invalid");
}
#[test]
fn test_prefix_headers() {
let test_request = http::Request::builder()
.header("X-Prefix-A", "123,456")
.header("X-Prefix-B", "789")
.header("X-Prefix-C", "777")
.header("X-Prefix-C", "777")
.body(())
.unwrap();
let resp: Result<HashMap<String, Vec<i16>>, ParseError> =
headers_for_prefix(test_request.headers(), "X-Prefix-")
.map(|(key, header_name)| {
let values = test_request.headers().get_all(header_name);
read_many_primitive(values.iter()).map(|v| (key.to_string(), v))
})
.collect();
let resp = resp.expect("valid");
assert_eq!(resp.get("a"), Some(&vec![123_i16, 456_i16]));
}
#[test]
fn test_quote_header_value() {
assert_eq!("", "e_header_value(""));
assert_eq!("foo", "e_header_value("foo"));
assert_eq!("\" foo\"", "e_header_value(" foo"));
assert_eq!("foo bar", "e_header_value("foo bar"));
assert_eq!("\"foo,bar\"", "e_header_value("foo,bar"));
assert_eq!("\",\"", "e_header_value(","));
assert_eq!("\"\\\"foo\\\"\"", "e_header_value("\"foo\""));
assert_eq!("\"\\\"f\\\\oo\\\"\"", "e_header_value("\"f\\oo\""));
assert_eq!("\"(\"", "e_header_value("("));
assert_eq!("\")\"", "e_header_value(")"));
}
#[test]
fn test_append_merge_header_maps_with_shared_key() {
let header_name = HeaderName::from_static("some_key");
let left_header_value = HeaderValue::from_static("lhs value");
let right_header_value = HeaderValue::from_static("rhs value");
let mut left_hand_side_headers = HeaderMap::new();
left_hand_side_headers.insert(header_name.clone(), left_header_value.clone());
let mut right_hand_side_headers = HeaderMap::new();
right_hand_side_headers.insert(header_name.clone(), right_header_value.clone());
let merged_header_map =
append_merge_header_maps(left_hand_side_headers, right_hand_side_headers);
let actual_merged_values: Vec<_> = merged_header_map
.get_all(header_name.clone())
.into_iter()
.collect();
let expected_merged_values = vec![left_header_value, right_header_value];
assert_eq!(actual_merged_values, expected_merged_values);
}
#[test]
fn test_append_merge_header_maps_with_multiple_values_in_left_hand_map() {
let header_name = HeaderName::from_static("some_key");
let left_header_value_1 = HeaderValue::from_static("lhs value 1");
let left_header_value_2 = HeaderValue::from_static("lhs_value 2");
let right_header_value = HeaderValue::from_static("rhs value");
let mut left_hand_side_headers = HeaderMap::new();
left_hand_side_headers.insert(header_name.clone(), left_header_value_1.clone());
left_hand_side_headers.append(header_name.clone(), left_header_value_2.clone());
let mut right_hand_side_headers = HeaderMap::new();
right_hand_side_headers.insert(header_name.clone(), right_header_value.clone());
let merged_header_map =
append_merge_header_maps(left_hand_side_headers, right_hand_side_headers);
let actual_merged_values: Vec<_> = merged_header_map
.get_all(header_name.clone())
.into_iter()
.collect();
let expected_merged_values =
vec![left_header_value_1, left_header_value_2, right_header_value];
assert_eq!(actual_merged_values, expected_merged_values);
}
#[test]
fn test_append_merge_header_maps_with_empty_left_hand_map() {
let header_name = HeaderName::from_static("some_key");
let right_header_value_1 = HeaderValue::from_static("rhs value 1");
let right_header_value_2 = HeaderValue::from_static("rhs_value 2");
let left_hand_side_headers = HeaderMap::new();
let mut right_hand_side_headers = HeaderMap::new();
right_hand_side_headers.insert(header_name.clone(), right_header_value_1.clone());
right_hand_side_headers.append(header_name.clone(), right_header_value_2.clone());
let merged_header_map =
append_merge_header_maps(left_hand_side_headers, right_hand_side_headers);
let actual_merged_values: Vec<_> = merged_header_map
.get_all(header_name.clone())
.into_iter()
.collect();
let expected_merged_values = vec![right_header_value_1, right_header_value_2];
assert_eq!(actual_merged_values, expected_merged_values);
}
}