use bytes::Bytes;
pub use http::header::*;
use wreq_proto::ext::OnPreserveHeaderCallback;
pub trait IntoHeaderCaseName: sealed::Sealed {
fn into_header_case_name(self) -> HeaderCaseName;
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct OrigHeaderMap(HeaderMap<HeaderCaseName>);
impl OrigHeaderMap {
#[inline]
pub fn new() -> Self {
Self(HeaderMap::default())
}
#[inline]
pub fn with_capacity(size: usize) -> Self {
Self(HeaderMap::with_capacity(size))
}
#[inline]
pub fn insert<N>(&mut self, orig: N) -> bool
where
N: IntoHeaderCaseName,
{
let header_case_name = orig.into_header_case_name();
match &header_case_name.inner {
Repr::Cased(bytes) => HeaderName::from_bytes(bytes)
.map(|header_name| self.0.append(header_name, header_case_name))
.unwrap_or(false),
Repr::Standard(header_name) => self.0.append(header_name.clone(), header_case_name),
}
}
#[inline]
pub fn extend(&mut self, iter: OrigHeaderMap) {
self.0.extend(iter.0);
}
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = (&HeaderName, &HeaderCaseName)> {
self.0.iter()
}
}
impl OnPreserveHeaderCallback for OrigHeaderMap {
fn call(&self, headers: &mut HeaderMap) {
if headers.len() <= 1 || self.0.is_empty() {
return;
}
let mut sorted_headers = HeaderMap::with_capacity(headers.keys_len());
for name in self.0.keys() {
for value in headers.get_all(name) {
sorted_headers.append(name.clone(), value.clone());
}
headers.remove(name);
}
let mut prev_name: Option<HeaderName> = None;
for (name, value) in headers.drain() {
match (name, &prev_name) {
(Some(name), _) => {
prev_name.replace(name.clone());
sorted_headers.insert(name, value);
}
(None, Some(prev_name)) => {
sorted_headers.append(prev_name, value);
}
_ => {}
}
}
std::mem::swap(headers, &mut sorted_headers);
}
fn call_visit(
&self,
headers: &mut HeaderMap,
dst: &mut dyn FnMut(&dyn AsRef<[u8]>, &http::HeaderValue),
) {
for (name, case_name) in self.iter() {
for value in headers.get_all(name) {
dst(case_name, value);
}
headers.remove(name);
}
let mut prev_name: Option<HeaderCaseName> = None;
for (name, value) in headers.drain() {
match (name, &prev_name) {
(Some(name), _) => {
dst(&name, &value);
prev_name.replace(name.into_header_case_name());
}
(None, Some(prev_name)) => {
dst(prev_name, &value);
}
_ => (),
};
}
}
}
impl<'a> IntoIterator for &'a OrigHeaderMap {
type Item = (&'a HeaderName, &'a HeaderCaseName);
type IntoIter = <&'a HeaderMap<HeaderCaseName> as IntoIterator>::IntoIter;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
impl IntoIterator for OrigHeaderMap {
type Item = (Option<HeaderName>, HeaderCaseName);
type IntoIter = <HeaderMap<HeaderCaseName> as IntoIterator>::IntoIter;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl_request_config_value!(OrigHeaderMap);
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct HeaderCaseName {
inner: Repr,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
enum Repr {
Cased(Bytes),
Standard(HeaderName),
}
impl AsRef<[u8]> for HeaderCaseName {
#[inline]
fn as_ref(&self) -> &[u8] {
match &self.inner {
Repr::Standard(name) => name.as_ref(),
Repr::Cased(orig) => orig.as_ref(),
}
}
}
impl IntoHeaderCaseName for &'static str {
#[inline]
fn into_header_case_name(self) -> HeaderCaseName {
Bytes::from_static(self.as_bytes()).into_header_case_name()
}
}
impl IntoHeaderCaseName for String {
#[inline]
fn into_header_case_name(self) -> HeaderCaseName {
Bytes::from(self).into_header_case_name()
}
}
impl IntoHeaderCaseName for Bytes {
#[inline]
fn into_header_case_name(self) -> HeaderCaseName {
HeaderCaseName {
inner: Repr::Cased(self),
}
}
}
impl IntoHeaderCaseName for &HeaderName {
#[inline]
fn into_header_case_name(self) -> HeaderCaseName {
HeaderCaseName {
inner: Repr::Standard(self.clone()),
}
}
}
impl IntoHeaderCaseName for HeaderName {
#[inline]
fn into_header_case_name(self) -> HeaderCaseName {
HeaderCaseName {
inner: Repr::Standard(self),
}
}
}
impl IntoHeaderCaseName for HeaderCaseName {
#[inline]
fn into_header_case_name(self) -> HeaderCaseName {
self
}
}
impl IntoHeaderCaseName for &HeaderCaseName {
#[inline]
fn into_header_case_name(self) -> HeaderCaseName {
self.clone()
}
}
mod sealed {
use bytes::Bytes;
use http::HeaderName;
use crate::header::HeaderCaseName;
pub trait Sealed {}
impl Sealed for &'static str {}
impl Sealed for String {}
impl Sealed for Bytes {}
impl Sealed for &HeaderName {}
impl Sealed for HeaderName {}
impl Sealed for &HeaderCaseName {}
impl Sealed for HeaderCaseName {}
}
#[cfg(test)]
mod test {
use http::{HeaderMap, HeaderName, HeaderValue};
use wreq_proto::ext::OnPreserveHeaderCallback;
use super::OrigHeaderMap;
#[inline]
pub(crate) fn get_all<'a>(
orig_headers: &'a OrigHeaderMap,
name: &HeaderName,
) -> impl Iterator<Item = impl AsRef<[u8]> + 'a> + 'a {
orig_headers.0.get_all(name).into_iter()
}
#[test]
fn test_header_order() {
let mut headers = OrigHeaderMap::new();
headers.insert("X-Test");
headers.insert("X-Another");
headers.insert("x-test2");
let mut iter = headers.iter();
assert_eq!(iter.next().unwrap().1.as_ref(), b"X-Test");
assert_eq!(iter.next().unwrap().1.as_ref(), b"X-Another");
assert_eq!(iter.next().unwrap().1.as_ref(), b"x-test2");
}
#[test]
fn test_extend_preserves_order() {
use super::OrigHeaderMap;
let mut map1 = OrigHeaderMap::new();
map1.insert("A-Header");
map1.insert("B-Header");
let mut map2 = OrigHeaderMap::new();
map2.insert("C-Header");
map2.insert("D-Header");
map1.extend(map2);
let names: Vec<_> = map1.iter().map(|(_, orig)| orig.as_ref()).collect();
assert_eq!(
names,
vec![b"A-Header", b"B-Header", b"C-Header", b"D-Header"]
);
}
#[test]
fn test_header_case() {
let mut headers = OrigHeaderMap::new();
headers.insert("X-Test");
headers.insert("x-test");
let all_x_test: Vec<_> = get_all(&headers, &"X-Test".parse().unwrap()).collect();
assert_eq!(all_x_test.len(), 2);
assert!(all_x_test.iter().any(|v| v.as_ref() == b"X-Test"));
assert!(all_x_test.iter().any(|v| v.as_ref() == b"x-test"));
}
#[test]
fn test_header_multiple_cases() {
let mut headers = OrigHeaderMap::new();
headers.insert("X-test");
headers.insert("x-test");
headers.insert("X-test");
let all_x_test: Vec<_> = get_all(&headers, &"x-test".parse().unwrap()).collect();
assert_eq!(all_x_test.len(), 3);
assert!(all_x_test.iter().any(|v| v.as_ref() == b"X-test"));
assert!(all_x_test.iter().any(|v| v.as_ref() == b"x-test"));
assert!(all_x_test.iter().any(|v| v.as_ref() == b"X-test"));
}
#[test]
fn test_sort_headers_preserves_multiple_cookie_values() {
let mut orig_headers = OrigHeaderMap::new();
orig_headers.insert("Cookie");
orig_headers.insert("User-Agent");
orig_headers.insert("Accept");
let mut headers = HeaderMap::new();
headers.append("cookie", HeaderValue::from_static("session=abc123"));
headers.append("cookie", HeaderValue::from_static("theme=dark"));
headers.append("cookie", HeaderValue::from_static("lang=en"));
headers.insert("user-agent", HeaderValue::from_static("Mozilla/5.0"));
headers.insert("accept", HeaderValue::from_static("text/html"));
headers.insert("host", HeaderValue::from_static("example.com"));
let original_cookies: Vec<_> = headers
.get_all("cookie")
.iter()
.map(|v| v.to_str().unwrap().to_string())
.collect();
orig_headers.call(&mut headers);
let sorted_cookies: Vec<_> = headers
.get_all("cookie")
.iter()
.map(|v| v.to_str().unwrap().to_string())
.collect();
assert_eq!(
original_cookies.len(),
sorted_cookies.len(),
"Cookie count should be preserved"
);
assert_eq!(original_cookies.len(), 3, "Should have 3 cookie values");
for original_cookie in &original_cookies {
assert!(
sorted_cookies.contains(original_cookie),
"Cookie '{original_cookie}' should be preserved"
);
}
let header_names: Vec<_> = headers.keys().collect();
assert_eq!(
header_names[0].as_str(),
"cookie",
"Cookie should be first header"
);
assert_eq!(
headers.len(),
6,
"Should have 6 total header values (3 cookies + 3 others)"
);
assert!(headers.contains_key("user-agent"));
assert!(headers.contains_key("accept"));
assert!(headers.contains_key("host"));
}
#[test]
fn test_sort_headers_multiple_values_different_headers() {
let mut orig_headers = OrigHeaderMap::new();
orig_headers.insert("Accept");
orig_headers.insert("Cookie");
let mut headers = HeaderMap::new();
headers.append("accept", HeaderValue::from_static("text/html"));
headers.append("accept", HeaderValue::from_static("application/json"));
headers.append("cookie", HeaderValue::from_static("a=1"));
headers.append("cookie", HeaderValue::from_static("b=2"));
headers.insert("host", HeaderValue::from_static("example.com"));
let total_before = headers.len();
orig_headers.call(&mut headers);
assert_eq!(
headers.len(),
total_before,
"Total header count should be preserved"
);
assert_eq!(
headers.get_all("accept").iter().count(),
2,
"Accept headers should be preserved"
);
assert_eq!(
headers.get_all("cookie").iter().count(),
2,
"Cookie headers should be preserved"
);
assert_eq!(
headers.get_all("host").iter().count(),
1,
"Host header should be preserved"
);
}
}