use crate::extensions::DeserializeExt;
use serde::{Deserialize, Deserializer};
use serde_json::Value;
use std::collections::HashSet;
use std::sync::OnceLock;
const SANITIZE_PTR: &str = "/sanitize";
const INCLUDE_PTR: &str = "/include";
const EXCLUDE_PTR: &str = "/exclude";
pub static HEADERS_FILTER: OnceLock<HeadersFilter> = OnceLock::new();
#[derive(Debug, Clone)]
pub enum Filter {
All,
Set(HashSet<String>),
}
#[derive(Debug, Clone)]
pub struct HeadersFilter {
pub include: Filter,
pub exclude: Filter,
pub sanitize: Filter,
}
impl<'de> Deserialize<'de> for HeadersFilter {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let config = Value::deserialize(deserializer)?;
let sanitize: Option<String> = config
.pointer_and_deserialize::<_, D::Error>(SANITIZE_PTR)
.ok();
let include: Option<String> = config
.pointer_and_deserialize::<_, D::Error>(INCLUDE_PTR)
.ok();
let exclude: Option<String> = config
.pointer_and_deserialize::<_, D::Error>(EXCLUDE_PTR)
.ok();
Ok(HeadersFilter {
include: from_str_to_filter(include),
exclude: from_str_to_filter(exclude),
sanitize: from_str_to_filter(sanitize),
})
}
}
fn from_str_to_filter(str: Option<String>) -> Filter {
str.map_or(Filter::Set(HashSet::default()), |str| {
let str = str.trim();
if str == "*" {
Filter::All
} else {
Filter::Set(
str.split(',')
.map(|field| field.trim().to_ascii_lowercase())
.collect::<HashSet<String>>(),
)
}
})
}