use http::{header::HeaderName, HeaderMap, HeaderValue};
use serde::{
de::{self, Deserializer, Error as DeError, MapAccess, Unexpected, Visitor},
ser::{Error as SerError, SerializeMap, Serializer},
};
use std::{borrow::Cow, fmt};
pub(crate) fn serialize_multi_value_headers<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(headers.keys_len()))?;
for key in headers.keys() {
let mut map_values = Vec::new();
for value in headers.get_all(key) {
map_values.push(String::from_utf8(value.as_bytes().to_vec()).map_err(S::Error::custom)?)
}
map.serialize_entry(key.as_str(), &map_values)?;
}
map.end()
}
pub(crate) fn serialize_headers<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(headers.keys_len()))?;
for key in headers.keys() {
let map_value = String::from_utf8(headers[key].as_bytes().to_vec()).map_err(S::Error::custom)?;
map.serialize_entry(key.as_str(), &map_value)?;
}
map.end()
}
#[cfg(feature = "vpc_lattice")]
pub(crate) fn serialize_comma_separated_headers<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(headers.keys_len()))?;
for key in headers.keys() {
let values: Vec<&str> = headers.get_all(key).iter().filter_map(|v| v.to_str().ok()).collect();
if !values.is_empty() {
let combined_value = values.join(", ");
map.serialize_entry(key.as_str(), &combined_value)?;
}
}
map.end()
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum OneOrMore<'a> {
One(Cow<'a, str>),
Strings(Vec<Cow<'a, str>>),
Bytes(Vec<Cow<'a, [u8]>>),
}
struct HeaderMapVisitor {
is_human_readable: bool,
split_comma_separated: bool,
}
impl<'de> Visitor<'de> for HeaderMapVisitor {
type Value = HeaderMap;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("lots of things can go wrong with HeaderMap")
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: DeError,
{
Ok(HeaderMap::default())
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: DeError,
{
Ok(HeaderMap::default())
}
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(self)
}
fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut map = HeaderMap::with_capacity(access.size_hint().unwrap_or(0));
if !self.is_human_readable {
while let Some((key, arr)) = access.next_entry::<Cow<'_, str>, Vec<Cow<'_, [u8]>>>()? {
let key = HeaderName::from_bytes(key.as_bytes())
.map_err(|_| de::Error::invalid_value(Unexpected::Str(&key), &self))?;
for val in arr {
let val = HeaderValue::from_bytes(&val)
.map_err(|_| de::Error::invalid_value(Unexpected::Bytes(&val), &self))?;
map.append(&key, val);
}
}
} else {
while let Some((key, val)) = access.next_entry::<Cow<'_, str>, OneOrMore<'_>>()? {
let key = HeaderName::from_bytes(key.as_bytes())
.map_err(|_| de::Error::invalid_value(Unexpected::Str(&key), &self))?;
match val {
OneOrMore::One(val) => {
if self.split_comma_separated && val.contains(',') {
split_and_append_header(&mut map, &key, &val, &self)?;
} else {
let val = val
.parse()
.map_err(|_| de::Error::invalid_value(Unexpected::Str(&val), &self))?;
map.insert(key, val);
}
}
OneOrMore::Strings(arr) => {
for val in arr {
if self.split_comma_separated && val.contains(',') {
split_and_append_header(&mut map, &key, &val, &self)?;
} else {
let val = val
.parse()
.map_err(|_| de::Error::invalid_value(Unexpected::Str(&val), &self))?;
map.append(&key, val);
}
}
}
OneOrMore::Bytes(arr) => {
for val in arr {
let val = HeaderValue::from_bytes(&val)
.map_err(|_| de::Error::invalid_value(Unexpected::Bytes(&val), &self))?;
map.append(&key, val);
}
}
};
}
}
Ok(map)
}
}
fn split_and_append_header<E>(
map: &mut HeaderMap,
key: &HeaderName,
value: &str,
visitor: &HeaderMapVisitor,
) -> Result<(), E>
where
E: DeError,
{
for split_val in value.split(',') {
let trimmed_val = split_val.trim();
if !trimmed_val.is_empty() {
let header_val = trimmed_val
.parse()
.map_err(|_| de::Error::invalid_value(Unexpected::Str(trimmed_val), visitor))?;
map.append(key, header_val);
}
}
Ok(())
}
pub(crate) fn deserialize_headers<'de, D>(de: D) -> Result<HeaderMap, D::Error>
where
D: Deserializer<'de>,
{
let is_human_readable = de.is_human_readable();
de.deserialize_option(HeaderMapVisitor {
is_human_readable,
split_comma_separated: false,
})
}
#[cfg(feature = "vpc_lattice")]
pub(crate) fn deserialize_comma_separated_headers<'de, D>(de: D) -> Result<HeaderMap, D::Error>
where
D: Deserializer<'de>,
{
let is_human_readable = de.is_human_readable();
de.deserialize_option(HeaderMapVisitor {
is_human_readable,
split_comma_separated: true,
})
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[test]
fn test_deserialize_missing_http_headers() {
#[derive(Deserialize)]
struct Test {
#[serde(deserialize_with = "deserialize_headers", default)]
pub headers: HeaderMap,
}
let data = serde_json::json!({
"not_headers": {}
});
let expected = HeaderMap::new();
let decoded: Test = serde_json::from_value(data).unwrap();
assert_eq!(expected, decoded.headers);
}
#[test]
fn test_serialize_headers() {
#[derive(Deserialize, Serialize)]
struct Test {
#[serde(deserialize_with = "deserialize_headers", default)]
#[serde(serialize_with = "serialize_multi_value_headers")]
headers: HeaderMap,
}
let data = serde_json::json!({
"headers": {
"Accept": ["*/*"]
}
});
let decoded: Test = serde_json::from_value(data).unwrap();
assert_eq!(&"*/*", decoded.headers.get("Accept").unwrap());
let recoded = serde_json::to_value(decoded).unwrap();
let decoded: Test = serde_json::from_value(recoded).unwrap();
assert_eq!(&"*/*", decoded.headers.get("Accept").unwrap());
}
#[test]
fn test_null_headers() {
#[derive(Deserialize)]
struct Test {
#[serde(deserialize_with = "deserialize_headers")]
headers: HeaderMap,
}
let data = serde_json::json!({ "headers": null });
let decoded: Test = serde_json::from_value(data).unwrap();
assert!(decoded.headers.is_empty());
}
#[test]
fn test_serialize_utf8_headers() {
#[derive(Deserialize, Serialize)]
struct Test {
#[serde(deserialize_with = "deserialize_headers", default)]
#[serde(serialize_with = "serialize_headers")]
pub headers: HeaderMap,
#[serde(deserialize_with = "deserialize_headers", default)]
#[serde(serialize_with = "serialize_multi_value_headers")]
pub multi_value_headers: HeaderMap,
}
let content_disposition =
"inline; filename=\"Schillers schönste Szenenanweisungen -Kabale und Liebe.mp4.avif\"";
let data = serde_json::json!({
"headers": {
"Content-Disposition": content_disposition
},
"multi_value_headers": {
"Content-Disposition": content_disposition
}
});
let decoded: Test = serde_json::from_value(data).unwrap();
assert_eq!(content_disposition, decoded.headers.get("Content-Disposition").unwrap());
assert_eq!(
content_disposition,
decoded.multi_value_headers.get("Content-Disposition").unwrap()
);
let recoded = serde_json::to_value(decoded).unwrap();
let decoded: Test = serde_json::from_value(recoded).unwrap();
assert_eq!(content_disposition, decoded.headers.get("Content-Disposition").unwrap());
assert_eq!(
content_disposition,
decoded.multi_value_headers.get("Content-Disposition").unwrap()
);
}
}