use std::{str::FromStr, time::Duration};
use axum::http::{header, HeaderName, HeaderValue};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tower_http::cors::{Any, CorsLayer};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum CorsError {
#[error("Invalid CORS origin")]
InvalidOrigin(header::InvalidHeaderValue),
#[error("Invalid CORS header")]
InvalidHeader(header::InvalidHeaderName),
}
#[derive(Clone, Debug, Default, PartialEq)]
pub(crate) enum AnyOr<T> {
#[default]
Any,
Some(Vec<T>),
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub struct CorsConfig {
origins: AnyOr<String>,
#[serde(default)]
credentials: bool,
#[serde(default)]
headers: Option<AnyOr<String>>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "humantime_serde"
)]
max_age: Option<Duration>,
}
impl CorsConfig {
pub fn make_layer(&self) -> Result<CorsLayer, CorsError> {
let mut layer = CorsLayer::new();
layer = match &self.origins {
AnyOr::Any => layer.allow_origin(Any),
AnyOr::Some(origins) => {
let headers = origins
.iter()
.map(|origin| HeaderValue::from_str(origin))
.collect::<Result<Vec<_>, _>>()
.map_err(CorsError::InvalidOrigin)?;
layer.allow_origin(headers)
}
};
layer = match self.credentials {
true => layer.allow_credentials(true),
false => layer,
};
layer = match &self.headers {
None => layer,
Some(AnyOr::Any) => layer.allow_headers(Any),
Some(AnyOr::Some(headers)) => {
let headers = headers
.iter()
.map(|origin| HeaderName::from_str(origin))
.collect::<Result<Vec<_>, _>>()
.map_err(CorsError::InvalidHeader)?;
layer.allow_headers(headers)
}
};
if let Some(max_age) = self.max_age {
layer = layer.max_age(max_age);
}
Ok(layer)
}
}
mod serde_impls {
use std::{fmt, marker::PhantomData};
use serde::{de, ser::SerializeSeq, Deserializer, Serializer};
use super::*;
#[doc(hidden)]
struct AnyOrVisitor<T> {
marker: PhantomData<T>,
}
impl<'de, T: Deserialize<'de>> de::Visitor<'de> for AnyOrVisitor<T> {
type Value = AnyOr<T>;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("'any' string or a sequence")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
let mut vec = match seq.size_hint() {
Some(hint) => Vec::with_capacity(hint),
None => Vec::new(),
};
while let Some(el) = seq.next_element()? {
vec.push(el);
}
Ok(AnyOr::Some(vec))
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
match v.to_ascii_lowercase().as_str() {
"any" => Ok(AnyOr::Any),
other => Err(E::custom(format_args!(
"expecting string 'any', got '{other}'"
))),
}
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(AnyOr::Any)
}
}
impl<'de, T: Deserialize<'de>> Deserialize<'de> for AnyOr<T> {
fn deserialize<D>(deser: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let visitor = AnyOrVisitor {
marker: PhantomData,
};
deser.deserialize_any(visitor)
}
}
impl<T: Serialize> Serialize for AnyOr<T> {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Self::Any => ser.serialize_str("any"),
Self::Some(vals) => {
let mut seq = ser.serialize_seq(Some(vals.len()))?;
for val in vals {
seq.serialize_element(val)?;
}
seq.end()
}
}
}
}
}
#[cfg(test)]
mod tests {
use serde_json::{from_str, json, to_value};
use super::*;
#[derive(Debug, Default, Deserialize, Serialize, PartialEq)]
struct TestData {
#[serde(default)]
string_param: AnyOr<String>,
#[serde(default)]
int_param: AnyOr<i32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
opt_int_param: Option<AnyOr<i32>>,
}
#[test]
fn anyor_de_all_any() {
let serialized = r#"{
"string_param": "any",
"int_param": "Any",
"opt_int_param": "ANY"
}"#;
let deserialized: TestData = from_str(serialized).unwrap();
assert_eq!(
deserialized,
TestData {
string_param: AnyOr::Any,
int_param: AnyOr::Any,
opt_int_param: Some(AnyOr::Any),
}
);
}
#[test]
fn anyor_de_all_or() {
let serialized = r#"{
"string_param": ["v1", "v2"],
"int_param": [11, 22],
"opt_int_param": [33, 44]
}"#;
let deserialized: TestData = from_str(serialized).unwrap();
assert_eq!(
deserialized,
TestData {
string_param: AnyOr::Some(vec!["v1".into(), "v2".into()]),
int_param: AnyOr::Some(vec![11, 22]),
opt_int_param: Some(AnyOr::Some(vec![33, 44])),
}
);
}
#[test]
fn anyor_de_default() {
let serialized = "{}";
let deserialized: TestData = from_str(serialized).unwrap();
assert_eq!(
deserialized,
TestData {
string_param: AnyOr::Any,
int_param: AnyOr::Any,
opt_int_param: None,
}
);
}
#[test]
fn anyor_de_invalid_any() {
let serialized = r#"{
"string_param": "whatever",
"int_param": [11, 22],
"opt_int_param": [33, 44]
}"#;
assert!(from_str::<TestData>(serialized).is_err());
}
#[test]
fn anyor_ser_all_any() {
let deserialized = TestData {
string_param: AnyOr::Any,
int_param: AnyOr::Any,
opt_int_param: Some(AnyOr::Any),
};
let serialized = to_value(deserialized).unwrap();
assert_eq!(
serialized,
json!({
"string_param": "any",
"int_param": "any",
"opt_int_param": "any"
})
);
}
#[test]
fn anyor_ser_all_or() {
let deserialized = TestData {
string_param: AnyOr::Some(vec!["v1".into(), "v2".into()]),
int_param: AnyOr::Some(vec![11, 22]),
opt_int_param: Some(AnyOr::Some(vec![33, 44])),
};
let serialized = to_value(deserialized).unwrap();
assert_eq!(
serialized,
json!({
"string_param": ["v1", "v2"],
"int_param": [11, 22],
"opt_int_param": [33, 44]
})
);
}
#[test]
fn anyor_ser_default() {
let deserialized = TestData {
string_param: AnyOr::default(),
int_param: AnyOr::default(),
opt_int_param: Option::default(),
};
let serialized = to_value(deserialized).unwrap();
assert_eq!(
serialized,
json!({
"string_param": "any",
"int_param": "any"
})
);
}
}