use serde::{Deserialize, Deserializer, Serialize};
use std::collections::HashMap;
#[derive(Debug, Default, Clone, Serialize)]
pub struct Headers {
inner: HashMap<String, String>,
}
impl<'de> Deserialize<'de> for Headers {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum HeadersRepr {
Transparent(HashMap<String, String>),
Wrapped { inner: HashMap<String, String> },
}
let raw = match HeadersRepr::deserialize(deserializer)? {
HeadersRepr::Transparent(map) => map,
HeadersRepr::Wrapped { inner } => inner,
};
let mut headers = Headers::new();
for (key, value) in raw {
headers.insert(key, value);
}
Ok(headers)
}
}
impl Headers {
fn normalize_key(key: &str) -> String {
key.to_ascii_lowercase()
}
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) {
let key = key.into();
self.inner.insert(Self::normalize_key(&key), value.into());
}
pub fn merge(&mut self, overlay: Headers) {
for (k, v) in overlay.inner {
self.inner.insert(Self::normalize_key(&k), v);
}
}
pub fn merge_with(&mut self, overlay: &Headers) {
for (k, v) in &overlay.inner {
self.inner.insert(Self::normalize_key(k), v.clone());
}
}
pub fn get(&self, key: &str) -> Option<&String> {
self.inner.get(&Self::normalize_key(key))
}
pub fn remove(&mut self, key: &str) -> Option<String> {
self.inner.remove(&Self::normalize_key(key))
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn to_reqwest_headers(&self) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
for (k, v) in &self.inner {
if let (Ok(name), Ok(value)) = (
reqwest::header::HeaderName::from_bytes(k.as_bytes()),
reqwest::header::HeaderValue::from_str(v),
) {
headers.insert(name, value);
}
}
headers
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &String)> {
self.inner.iter()
}
}
impl From<(String, String)> for Headers {
fn from((key, value): (String, String)) -> Self {
let mut headers = Headers::new();
headers.insert(key, value);
headers
}
}
impl From<(&str, &str)> for Headers {
fn from((key, value): (&str, &str)) -> Self {
let mut headers = Headers::new();
headers.insert(key, value);
headers
}
}
impl From<Vec<(String, String)>> for Headers {
fn from(vec: Vec<(String, String)>) -> Self {
let mut headers = Headers::new();
for (k, v) in vec {
headers.insert(k, v);
}
headers
}
}
impl<const N: usize> From<[(String, String); N]> for Headers {
fn from(arr: [(String, String); N]) -> Self {
let mut headers = Headers::new();
for (k, v) in arr {
headers.insert(k, v);
}
headers
}
}
impl<const N: usize> From<[(&str, &str); N]> for Headers {
fn from(arr: [(&str, &str); N]) -> Self {
let mut headers = Headers::new();
for (k, v) in arr {
headers.insert(k, v);
}
headers
}
}
impl IntoIterator for Headers {
type Item = (String, String);
type IntoIter = std::collections::hash_map::IntoIter<String, String>;
fn into_iter(self) -> Self::IntoIter {
self.inner.into_iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_headers_basic() {
let mut headers = Headers::new();
headers.insert("x-api-key", "test-key");
assert_eq!(headers.get("x-api-key"), Some(&"test-key".to_string()));
}
#[test]
fn test_headers_merge() {
let mut headers1 = Headers::new();
headers1.insert("key1", "value1");
let mut headers2 = Headers::new();
headers2.insert("key2", "value2");
headers1.merge(headers2);
assert_eq!(headers1.len(), 2);
}
#[test]
fn test_headers_from_tuple() {
let headers: Headers = ("x-api-key", "test").into();
assert_eq!(headers.get("x-api-key"), Some(&"test".to_string()));
}
#[test]
fn test_headers_from_array() {
let headers: Headers = [("key1", "val1"), ("key2", "val2")].into();
assert_eq!(headers.len(), 2);
}
#[test]
fn test_headers_are_case_insensitive() {
let mut headers = Headers::new();
headers.insert("ChatGPT-Account-Id", "acct_test_123");
assert_eq!(
headers.get("chatgpt-account-id"),
Some(&"acct_test_123".to_string())
);
assert_eq!(
headers.get("ChatGPT-Account-Id"),
Some(&"acct_test_123".to_string())
);
let removed = headers.remove("CHATGPT-ACCOUNT-ID");
assert_eq!(removed.as_deref(), Some("acct_test_123"));
assert!(headers.get("chatgpt-account-id").is_none());
}
#[test]
fn test_headers_deserialization_remains_case_insensitive() {
let headers: Headers =
serde_json::from_str(r#"{"inner":{"ChatGPT-Account-Id":"acct_test_123"}}"#)
.expect("deserialize headers");
assert_eq!(
headers.get("chatgpt-account-id"),
Some(&"acct_test_123".to_string())
);
}
}