Skip to main content

edgeguard/
cors.rs

1//! CORS (Cross-Origin Resource Sharing).
2//!
3//! A drop-in front door is frequently deployed in front of an app whose browser frontend lives
4//! on a *different* origin (a separate static host, a preview deployment, `localhost:5173` during
5//! development). Browsers block those cross-origin `fetch`/`XHR` calls unless the server answers
6//! with the right `Access-Control-*` headers, so EdgeGuard grows a small, explicit CORS policy.
7//!
8//! Two responsibilities, both driven by [`CorsPolicy`] (held in the hot-swappable [`Runtime`],
9//! `None` when `cors.enabled = false`):
10//!   1. **Preflight** — answer a browser's `OPTIONS` preflight (`Origin` +
11//!      `Access-Control-Request-Method`) directly with `204` + the allow headers. This happens
12//!      *before* authentication in the request pipeline, because a preflight carries no
13//!      credentials; gating it behind auth would make every cross-origin call fail.
14//!   2. **Decoration** — add `Access-Control-Allow-Origin` (and friends) to the *actual*
15//!      response so the browser exposes it to the calling page.
16//!
17//! Security note: a wildcard origin (`"*"`) cannot be combined with `allow_credentials = true`
18//! — the Fetch spec forbids it and browsers ignore the combination — so [`CorsPolicy::build`]
19//! rejects it at startup/reload rather than emitting a policy that silently doesn't work.
20
21use anyhow::{Context, Result};
22use axum::{
23    body::Body,
24    http::{header, HeaderMap, HeaderValue, Response, StatusCode},
25};
26
27use crate::config::{parse_duration, CorsCfg};
28
29/// A compiled CORS policy. Built once from [`CorsCfg`]; the string header values are
30/// precomputed so the request path only does cheap lookups and inserts.
31pub struct CorsPolicy {
32    /// `allow_origins` contained `"*"`. With credentials this is rejected at build, so when this
33    /// is true credentials are necessarily off and we can emit the cacheable literal `*`.
34    any_origin: bool,
35    /// Explicit allowed origins, lowercased for a case-insensitive compare.
36    origins: Vec<String>,
37    /// Precomputed `Access-Control-Allow-Methods` value.
38    allow_methods: HeaderValue,
39    /// Precomputed `Access-Control-Allow-Headers`; `None` => reflect the request's
40    /// `Access-Control-Request-Headers`.
41    allow_headers: Option<HeaderValue>,
42    /// Precomputed `Access-Control-Expose-Headers`; `None` => don't send it.
43    expose_headers: Option<HeaderValue>,
44    allow_credentials: bool,
45    /// `Access-Control-Max-Age` in seconds; `None` => omit the header.
46    max_age: Option<HeaderValue>,
47}
48
49const DEFAULT_METHODS: &str = "GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD";
50
51impl CorsPolicy {
52    /// Compile the policy, or `Ok(None)` when CORS is disabled. Fails fast on an incoherent
53    /// policy (credentialed wildcard, enabled-but-no-origins, bad `max_age`) so the mistake
54    /// surfaces at startup/reload like any other bad config — not as silently-missing CORS
55    /// headers at request time.
56    pub fn build(cfg: &CorsCfg) -> Result<Option<CorsPolicy>> {
57        if !cfg.enabled {
58            return Ok(None);
59        }
60        anyhow::ensure!(
61            !cfg.allow_origins.is_empty(),
62            "cors.enabled = true requires at least one cors.allow_origins entry (use [\"*\"] for any)"
63        );
64        let any_origin = cfg.allow_origins.iter().any(|o| o.trim() == "*");
65        anyhow::ensure!(
66            !(any_origin && cfg.allow_credentials),
67            "cors.allow_credentials = true cannot be combined with a \"*\" origin (the Fetch spec \
68             forbids credentialed wildcard CORS); list explicit origins instead"
69        );
70
71        let origins = cfg
72            .allow_origins
73            .iter()
74            .map(|o| o.trim())
75            .filter(|o| *o != "*")
76            .map(|o| o.to_ascii_lowercase())
77            .collect();
78
79        let methods = if cfg.allow_methods.is_empty() {
80            DEFAULT_METHODS.to_string()
81        } else {
82            cfg.allow_methods.join(", ")
83        };
84        let allow_methods =
85            HeaderValue::from_str(&methods).context("cors.allow_methods has an invalid value")?;
86
87        let allow_headers = if cfg.allow_headers.is_empty() {
88            None
89        } else {
90            Some(
91                HeaderValue::from_str(&cfg.allow_headers.join(", "))
92                    .context("cors.allow_headers has an invalid value")?,
93            )
94        };
95        let expose_headers = if cfg.expose_headers.is_empty() {
96            None
97        } else {
98            Some(
99                HeaderValue::from_str(&cfg.expose_headers.join(", "))
100                    .context("cors.expose_headers has an invalid value")?,
101            )
102        };
103
104        let secs = parse_duration(&cfg.max_age)
105            .context("cors.max_age")?
106            .as_secs();
107        let max_age = (secs > 0)
108            .then(|| HeaderValue::from_str(&secs.to_string()).expect("digits are a valid header"));
109
110        Ok(Some(CorsPolicy {
111            any_origin,
112            origins,
113            allow_methods,
114            allow_headers,
115            expose_headers,
116            allow_credentials: cfg.allow_credentials,
117            max_age,
118        }))
119    }
120
121    /// The `Access-Control-Allow-Origin` value to send for a request from `origin`, or `None`
122    /// when the origin isn't allowed (the browser then blocks the page from reading the
123    /// response, which is the desired outcome).
124    fn allow_origin_value(&self, origin: &str) -> Option<HeaderValue> {
125        let origin = origin.trim();
126        if self.any_origin {
127            return Some(HeaderValue::from_static("*"));
128        }
129        let lower = origin.to_ascii_lowercase();
130        if self.origins.contains(&lower) {
131            HeaderValue::from_str(origin).ok()
132        } else {
133            None
134        }
135    }
136
137    /// Common to preflight and actual responses: set `Allow-Origin`, the credentials flag, and —
138    /// when the allowed origin echoes the request (an explicit list, not the constant `*`) — a
139    /// `Vary: Origin` so a shared cache can't serve one origin's headers to another.
140    fn set_origin(&self, h: &mut HeaderMap, allow: HeaderValue) {
141        h.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, allow);
142        if self.allow_credentials {
143            h.insert(
144                header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
145                HeaderValue::from_static("true"),
146            );
147        }
148        if !self.any_origin {
149            append_vary_origin(h);
150        }
151    }
152
153    /// If `headers` describe a CORS **preflight** (an `OPTIONS` with `Origin` +
154    /// `Access-Control-Request-Method`), build the `204` response to answer it with. Returns
155    /// `None` when it isn't a preflight, so the caller falls through to normal handling. When the
156    /// origin isn't allowed we still return a `204`, just without the CORS headers — the browser
157    /// then refuses the cross-origin call.
158    ///
159    /// The caller must only invoke this for `OPTIONS` requests; the `Access-Control-Request-Method`
160    /// presence check distinguishes a real preflight from a plain `OPTIONS`.
161    pub fn preflight_response(&self, headers: &HeaderMap) -> Option<Response<Body>> {
162        let origin = headers.get(header::ORIGIN)?.to_str().ok()?;
163        headers.get(header::ACCESS_CONTROL_REQUEST_METHOD)?;
164
165        let mut resp = Response::new(Body::empty());
166        *resp.status_mut() = StatusCode::NO_CONTENT;
167
168        if let Some(allow) = self.allow_origin_value(origin) {
169            let h = resp.headers_mut();
170            self.set_origin(h, allow);
171            h.insert(
172                header::ACCESS_CONTROL_ALLOW_METHODS,
173                self.allow_methods.clone(),
174            );
175            // Advertised request headers: the configured list, or reflect what the browser asked
176            // for (so a permissive default doesn't have to enumerate every header).
177            let allow_headers = self.allow_headers.clone().or_else(|| {
178                headers
179                    .get(header::ACCESS_CONTROL_REQUEST_HEADERS)
180                    .filter(|v| !v.is_empty())
181                    .cloned()
182            });
183            if let Some(v) = allow_headers {
184                h.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, v);
185            }
186            if let Some(age) = &self.max_age {
187                h.insert(header::ACCESS_CONTROL_MAX_AGE, age.clone());
188            }
189        }
190        Some(resp)
191    }
192
193    /// Add the CORS headers to an *actual* (non-preflight) response, based on the request's
194    /// `Origin`. A no-op when the request has no `Origin` (not a cross-origin browser request) or
195    /// the origin isn't allowed.
196    pub fn decorate(&self, req_headers: &HeaderMap, resp: &mut Response<Body>) {
197        if let Some(origin) = req_headers
198            .get(header::ORIGIN)
199            .and_then(|v| v.to_str().ok())
200        {
201            self.decorate_origin(origin, resp);
202        }
203    }
204
205    /// Like [`decorate`](Self::decorate), but given the request `Origin` directly. A no-op when the
206    /// origin isn't allowed. Idempotent, so it's safe to call on a response that may already carry
207    /// CORS headers (e.g. a preflight). Used to decorate **every** response — including
208    /// EdgeGuard-generated `401`/`403`/`429` — so an allowed browser origin sees the real status
209    /// rather than a generic CORS failure.
210    pub fn decorate_origin(&self, origin: &str, resp: &mut Response<Body>) {
211        let Some(allow) = self.allow_origin_value(origin) else {
212            return;
213        };
214        let h = resp.headers_mut();
215        self.set_origin(h, allow);
216        if let Some(expose) = &self.expose_headers {
217            h.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
218        }
219    }
220}
221
222/// Append `Origin` to the response `Vary` header without duplicating it. Multiple `Vary` values
223/// are valid, but de-duping keeps the output tidy and avoids unbounded growth across hops.
224fn append_vary_origin(h: &mut HeaderMap) {
225    let already = h.get_all(header::VARY).iter().any(|v| {
226        v.to_str()
227            .map(|s| {
228                s.split(',')
229                    .any(|t| t.trim().eq_ignore_ascii_case("origin"))
230            })
231            .unwrap_or(false)
232    });
233    if !already {
234        h.append(header::VARY, HeaderValue::from_static("Origin"));
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use axum::http::HeaderName;
242
243    fn policy(cfg: CorsCfg) -> CorsPolicy {
244        CorsPolicy::build(&cfg).unwrap().unwrap()
245    }
246
247    fn req(origin: &str, extra: &[(&'static str, &str)]) -> HeaderMap {
248        let mut h = HeaderMap::new();
249        h.insert(header::ORIGIN, HeaderValue::from_str(origin).unwrap());
250        for (n, v) in extra {
251            h.insert(
252                HeaderName::from_static(n),
253                HeaderValue::from_str(v).unwrap(),
254            );
255        }
256        h
257    }
258
259    #[test]
260    fn disabled_builds_to_none() {
261        assert!(CorsPolicy::build(&CorsCfg::default()).unwrap().is_none());
262    }
263
264    #[test]
265    fn enabled_without_origins_is_rejected() {
266        let cfg = CorsCfg {
267            enabled: true,
268            ..Default::default()
269        };
270        assert!(CorsPolicy::build(&cfg).is_err());
271    }
272
273    #[test]
274    fn credentialed_wildcard_is_rejected() {
275        let cfg = CorsCfg {
276            enabled: true,
277            allow_origins: vec!["*".into()],
278            allow_credentials: true,
279            ..Default::default()
280        };
281        assert!(CorsPolicy::build(&cfg).is_err());
282    }
283
284    #[test]
285    fn wildcard_returns_star_and_no_vary() {
286        let p = policy(CorsCfg {
287            enabled: true,
288            allow_origins: vec!["*".into()],
289            ..Default::default()
290        });
291        let mut resp = Response::new(Body::empty());
292        p.decorate(&req("https://anything.example", &[]), &mut resp);
293        assert_eq!(
294            resp.headers()
295                .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
296                .unwrap(),
297            "*"
298        );
299        assert!(resp.headers().get(header::VARY).is_none());
300    }
301
302    #[test]
303    fn explicit_origin_echoes_allowed_and_blocks_others() {
304        let p = policy(CorsCfg {
305            enabled: true,
306            allow_origins: vec!["https://app.example.com".into()],
307            allow_credentials: true,
308            ..Default::default()
309        });
310        // Allowed origin: echoed back, credentials flag set, Vary: Origin present.
311        let mut ok = Response::new(Body::empty());
312        p.decorate(&req("https://app.example.com", &[]), &mut ok);
313        assert_eq!(
314            ok.headers()
315                .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
316                .unwrap(),
317            "https://app.example.com"
318        );
319        assert_eq!(
320            ok.headers()
321                .get(header::ACCESS_CONTROL_ALLOW_CREDENTIALS)
322                .unwrap(),
323            "true"
324        );
325        assert_eq!(ok.headers().get(header::VARY).unwrap(), "Origin");
326
327        // Disallowed origin: no CORS headers, so the browser blocks it.
328        let mut bad = Response::new(Body::empty());
329        p.decorate(&req("https://evil.example", &[]), &mut bad);
330        assert!(bad
331            .headers()
332            .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
333            .is_none());
334    }
335
336    #[test]
337    fn preflight_reflects_requested_headers_when_unset() {
338        let p = policy(CorsCfg {
339            enabled: true,
340            allow_origins: vec!["https://app.example.com".into()],
341            ..Default::default()
342        });
343        let h = req(
344            "https://app.example.com",
345            &[
346                ("access-control-request-method", "POST"),
347                ("access-control-request-headers", "x-custom, content-type"),
348            ],
349        );
350        let resp = p.preflight_response(&h).expect("is a preflight");
351        assert_eq!(resp.status(), StatusCode::NO_CONTENT);
352        assert_eq!(
353            resp.headers()
354                .get(header::ACCESS_CONTROL_ALLOW_METHODS)
355                .unwrap(),
356            DEFAULT_METHODS
357        );
358        assert_eq!(
359            resp.headers()
360                .get(header::ACCESS_CONTROL_ALLOW_HEADERS)
361                .unwrap(),
362            "x-custom, content-type"
363        );
364        assert_eq!(
365            resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap(),
366            "600"
367        );
368    }
369
370    #[test]
371    fn plain_options_is_not_a_preflight() {
372        let p = policy(CorsCfg {
373            enabled: true,
374            allow_origins: vec!["*".into()],
375            ..Default::default()
376        });
377        // No Access-Control-Request-Method => not a preflight, fall through to normal handling.
378        assert!(p
379            .preflight_response(&req("https://app.example.com", &[]))
380            .is_none());
381    }
382}