1use axum::body::Body;
18use axum::extract::Request;
19use axum::http::{HeaderValue, Method};
20use axum::middleware::Next;
21use axum::response::Response;
22
23#[derive(Clone, Debug)]
25pub struct CorsConfig {
26 pub allow_origins: Vec<String>,
28 pub allow_methods: String,
29 pub allow_headers: String,
30 pub allow_credentials: bool,
31 pub max_age_secs: u32,
32}
33
34impl Default for CorsConfig {
35 fn default() -> Self {
36 Self {
37 allow_origins: vec![],
38 allow_methods: "GET, POST, PUT, PATCH, DELETE, OPTIONS".into(),
39 allow_headers: "content-type, authorization, x-request-id, x-tenant-id, \
40 idempotency-key, traceparent"
41 .into(),
42 allow_credentials: false,
43 max_age_secs: 600,
44 }
45 }
46}
47
48impl CorsConfig {
49 pub fn for_origins<I, S>(origins: I) -> Self
52 where
53 I: IntoIterator<Item = S>,
54 S: Into<String>,
55 {
56 Self {
57 allow_origins: origins.into_iter().map(Into::into).collect(),
58 allow_credentials: true,
59 ..Default::default()
60 }
61 }
62
63 fn allowed_origin(&self, origin: &str) -> Option<String> {
65 let any = self.allow_origins.iter().any(|o| o == "*");
66 if any && !self.allow_credentials {
67 return Some("*".to_owned());
68 }
69 if any || self.allow_origins.iter().any(|o| o == origin) {
70 return Some(origin.to_owned());
71 }
72 None
73 }
74}
75
76pub(crate) async fn apply_cors(cfg: &'static CorsConfig, req: Request, next: Next) -> Response {
77 let origin = req
78 .headers()
79 .get("origin")
80 .and_then(|v| v.to_str().ok())
81 .map(str::to_owned);
82
83 let allowed = origin.as_deref().and_then(|o| cfg.allowed_origin(o));
84
85 let is_preflight = req.method() == Method::OPTIONS
87 && req.headers().contains_key("access-control-request-method");
88 if is_preflight {
89 let Some(echo) = allowed else {
90 return Response::builder()
92 .status(403)
93 .body(Body::empty())
94 .expect("static preflight denial");
95 };
96 let mut resp = Response::builder()
97 .status(204)
98 .body(Body::empty())
99 .expect("static preflight response");
100 set_cors_headers(resp.headers_mut(), cfg, &echo);
101 resp.headers_mut().insert(
102 "access-control-max-age",
103 HeaderValue::from_str(&cfg.max_age_secs.to_string()).expect("numeric"),
104 );
105 return resp;
106 }
107
108 let mut resp = next.run(req).await;
109 if let Some(echo) = allowed {
110 set_cors_headers(resp.headers_mut(), cfg, &echo);
111 }
112 resp
113}
114
115fn set_cors_headers(headers: &mut axum::http::HeaderMap, cfg: &CorsConfig, origin: &str) {
116 if let Ok(v) = HeaderValue::from_str(origin) {
117 headers.insert("access-control-allow-origin", v);
118 }
119 if let Ok(v) = HeaderValue::from_str(&cfg.allow_methods) {
120 headers.insert("access-control-allow-methods", v);
121 }
122 if let Ok(v) = HeaderValue::from_str(&cfg.allow_headers) {
123 headers.insert("access-control-allow-headers", v);
124 }
125 if cfg.allow_credentials {
126 headers.insert(
127 "access-control-allow-credentials",
128 HeaderValue::from_static("true"),
129 );
130 }
131 headers.append("vary", HeaderValue::from_static("origin"));
132}