Skip to main content

arcly_http_core/web/
cors.rs

1//! Config-driven CORS, applied as one stateless layer at launch.
2//!
3//! Off by default (`LaunchConfig::cors = None`) — zero per-request cost when
4//! disabled because the layer simply isn't mounted. When enabled, the per-
5//! request work is a header lookup against a boot-frozen origin list.
6//!
7//! Semantics follow the WHATWG fetch spec pragmatically:
8//! - Preflight (`OPTIONS` + `Access-Control-Request-Method`) from an allowed
9//!   origin short-circuits with `204` and the allow-headers — the request
10//!   never reaches routing, guards, or body read.
11//! - Actual requests from an allowed origin get `Access-Control-Allow-*`
12//!   response headers plus `Vary: Origin`.
13//! - `"*"` in `allow_origins` allows any origin. Combined with
14//!   `allow_credentials` the *specific* origin is echoed back (the spec
15//!   forbids `*` with credentials).
16
17use axum::body::Body;
18use axum::extract::Request;
19use axum::http::{HeaderValue, Method};
20use axum::middleware::Next;
21use axum::response::Response;
22
23/// CORS policy, frozen at launch.
24#[derive(Clone, Debug)]
25#[non_exhaustive]
26pub struct CorsConfig {
27    /// Exact origins (scheme + host + port), or `"*"` for any.
28    pub allow_origins: Vec<String>,
29    pub allow_methods: String,
30    pub allow_headers: String,
31    pub allow_credentials: bool,
32    pub max_age_secs: u32,
33}
34
35impl Default for CorsConfig {
36    fn default() -> Self {
37        Self {
38            allow_origins: vec![],
39            allow_methods: "GET, POST, PUT, PATCH, DELETE, OPTIONS".into(),
40            allow_headers: "content-type, authorization, x-request-id, x-tenant-id, \
41                            idempotency-key, traceparent"
42                .into(),
43            allow_credentials: false,
44            max_age_secs: 600,
45        }
46    }
47}
48
49impl CorsConfig {
50    /// Convenience: allow exactly these origins with credentials enabled —
51    /// the common SPA setup.
52    pub fn for_origins<I, S>(origins: I) -> Self
53    where
54        I: IntoIterator<Item = S>,
55        S: Into<String>,
56    {
57        Self {
58            allow_origins: origins.into_iter().map(Into::into).collect(),
59            allow_credentials: true,
60            ..Default::default()
61        }
62    }
63
64    pub fn allow_origins<I, S>(mut self, origins: I) -> Self
65    where
66        I: IntoIterator<Item = S>,
67        S: Into<String>,
68    {
69        self.allow_origins = origins.into_iter().map(Into::into).collect();
70        self
71    }
72    pub fn allow_methods(mut self, v: impl Into<String>) -> Self {
73        self.allow_methods = v.into();
74        self
75    }
76    pub fn allow_headers(mut self, v: impl Into<String>) -> Self {
77        self.allow_headers = v.into();
78        self
79    }
80    pub fn allow_credentials(mut self, v: bool) -> Self {
81        self.allow_credentials = v;
82        self
83    }
84    pub fn max_age_secs(mut self, v: u32) -> Self {
85        self.max_age_secs = v;
86        self
87    }
88
89    /// The origin value to echo for this request, if allowed.
90    fn allowed_origin(&self, origin: &str) -> Option<String> {
91        let any = self.allow_origins.iter().any(|o| o == "*");
92        if any && !self.allow_credentials {
93            return Some("*".to_owned());
94        }
95        if any || self.allow_origins.iter().any(|o| o == origin) {
96            return Some(origin.to_owned());
97        }
98        None
99    }
100}
101
102#[doc(hidden)]
103pub async fn apply_cors(cfg: std::sync::Arc<CorsConfig>, req: Request, next: Next) -> Response {
104    let origin = req
105        .headers()
106        .get("origin")
107        .and_then(|v| v.to_str().ok())
108        .map(str::to_owned);
109
110    let allowed = origin.as_deref().and_then(|o| cfg.allowed_origin(o));
111
112    // Preflight: answer before routing/auth/body — but only for allowed origins.
113    let is_preflight = req.method() == Method::OPTIONS
114        && req.headers().contains_key("access-control-request-method");
115    if is_preflight {
116        let Some(echo) = allowed else {
117            // Disallowed origin: no CORS headers; the browser blocks it.
118            return Response::builder()
119                .status(403)
120                .body(Body::empty())
121                .expect("static preflight denial");
122        };
123        let mut resp = Response::builder()
124            .status(204)
125            .body(Body::empty())
126            .expect("static preflight response");
127        set_cors_headers(resp.headers_mut(), &cfg, &echo);
128        resp.headers_mut().insert(
129            "access-control-max-age",
130            HeaderValue::from_str(&cfg.max_age_secs.to_string()).expect("numeric"),
131        );
132        return resp;
133    }
134
135    let mut resp = next.run(req).await;
136    if let Some(echo) = allowed {
137        set_cors_headers(resp.headers_mut(), &cfg, &echo);
138    }
139    resp
140}
141
142fn set_cors_headers(headers: &mut axum::http::HeaderMap, cfg: &CorsConfig, origin: &str) {
143    if let Ok(v) = HeaderValue::from_str(origin) {
144        headers.insert("access-control-allow-origin", v);
145    }
146    if let Ok(v) = HeaderValue::from_str(&cfg.allow_methods) {
147        headers.insert("access-control-allow-methods", v);
148    }
149    if let Ok(v) = HeaderValue::from_str(&cfg.allow_headers) {
150        headers.insert("access-control-allow-headers", v);
151    }
152    if cfg.allow_credentials {
153        headers.insert(
154            "access-control-allow-credentials",
155            HeaderValue::from_static("true"),
156        );
157    }
158    headers.append("vary", HeaderValue::from_static("origin"));
159}