use axum::body::Body;
use axum::extract::Request;
use axum::http::{HeaderValue, Method};
use axum::middleware::Next;
use axum::response::Response;
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct CorsConfig {
pub allow_origins: Vec<String>,
pub allow_methods: String,
pub allow_headers: String,
pub allow_credentials: bool,
pub max_age_secs: u32,
}
impl Default for CorsConfig {
fn default() -> Self {
Self {
allow_origins: vec![],
allow_methods: "GET, POST, PUT, PATCH, DELETE, OPTIONS".into(),
allow_headers: "content-type, authorization, x-request-id, x-tenant-id, \
idempotency-key, traceparent"
.into(),
allow_credentials: false,
max_age_secs: 600,
}
}
}
impl CorsConfig {
pub fn for_origins<I, S>(origins: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self {
allow_origins: origins.into_iter().map(Into::into).collect(),
allow_credentials: true,
..Default::default()
}
}
pub fn allow_origins<I, S>(mut self, origins: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allow_origins = origins.into_iter().map(Into::into).collect();
self
}
pub fn allow_methods(mut self, v: impl Into<String>) -> Self {
self.allow_methods = v.into();
self
}
pub fn allow_headers(mut self, v: impl Into<String>) -> Self {
self.allow_headers = v.into();
self
}
pub fn allow_credentials(mut self, v: bool) -> Self {
self.allow_credentials = v;
self
}
pub fn max_age_secs(mut self, v: u32) -> Self {
self.max_age_secs = v;
self
}
fn allowed_origin(&self, origin: &str) -> Option<String> {
let any = self.allow_origins.iter().any(|o| o == "*");
if any && !self.allow_credentials {
return Some("*".to_owned());
}
if any || self.allow_origins.iter().any(|o| o == origin) {
return Some(origin.to_owned());
}
None
}
}
pub(crate) async fn apply_cors(cfg: &'static CorsConfig, req: Request, next: Next) -> Response {
let origin = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
let allowed = origin.as_deref().and_then(|o| cfg.allowed_origin(o));
let is_preflight = req.method() == Method::OPTIONS
&& req.headers().contains_key("access-control-request-method");
if is_preflight {
let Some(echo) = allowed else {
return Response::builder()
.status(403)
.body(Body::empty())
.expect("static preflight denial");
};
let mut resp = Response::builder()
.status(204)
.body(Body::empty())
.expect("static preflight response");
set_cors_headers(resp.headers_mut(), cfg, &echo);
resp.headers_mut().insert(
"access-control-max-age",
HeaderValue::from_str(&cfg.max_age_secs.to_string()).expect("numeric"),
);
return resp;
}
let mut resp = next.run(req).await;
if let Some(echo) = allowed {
set_cors_headers(resp.headers_mut(), cfg, &echo);
}
resp
}
fn set_cors_headers(headers: &mut axum::http::HeaderMap, cfg: &CorsConfig, origin: &str) {
if let Ok(v) = HeaderValue::from_str(origin) {
headers.insert("access-control-allow-origin", v);
}
if let Ok(v) = HeaderValue::from_str(&cfg.allow_methods) {
headers.insert("access-control-allow-methods", v);
}
if let Ok(v) = HeaderValue::from_str(&cfg.allow_headers) {
headers.insert("access-control-allow-headers", v);
}
if cfg.allow_credentials {
headers.insert(
"access-control-allow-credentials",
HeaderValue::from_static("true"),
);
}
headers.append("vary", HeaderValue::from_static("origin"));
}