Skip to main content

actus_server/
cors.rs

1//! Cross-Origin Resource Sharing (CORS).
2//!
3//! Build a [`CorsLayer`] and hand it to [`crate::Server::with_cors`]. The
4//! server then answers preflight (`OPTIONS`) requests itself and stamps the
5//! `Access-Control-*` headers onto every cross-origin response — including
6//! error responses, so the browser can read 4xx/5xx bodies.
7//!
8//! ```ignore
9//! use actus::prelude::*;
10//! use std::time::Duration;
11//!
12//! // Development: anything goes.
13//! Server::new(router).with_cors(CorsLayer::permissive());
14//!
15//! // Production: pin it down.
16//! Server::new(router).with_cors(
17//!     CorsLayer::new()
18//!         .allow_origin("https://app.example.com")
19//!         .allow_methods([Verb::GET, Verb::POST, Verb::DELETE])
20//!         .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
21//!         .allow_credentials(true)
22//!         .max_age(Duration::from_secs(3600)),
23//! );
24//! ```
25
26use actus_controller::Verb;
27use http::{HeaderMap, HeaderName, HeaderValue, Method, header};
28use std::time::Duration;
29
30#[derive(Clone, Debug)]
31enum OriginRule {
32    /// Any origin is allowed; responses echo the request's concrete `Origin`
33    /// (never the literal `*`, so this stays valid alongside credentials).
34    Any,
35    /// An explicit allow-list — exact, case-sensitive match on `Origin`.
36    List(Vec<String>),
37}
38
39#[derive(Clone, Debug)]
40enum HeaderRule {
41    /// Mirror the browser's `Access-Control-Request-Headers` verbatim (always
42    /// valid, including with credentials); send nothing if it asked for none.
43    MirrorRequest,
44    /// An explicit allow-list.
45    List(Vec<HeaderName>),
46}
47
48/// A CORS policy. See the [module docs](self).
49#[derive(Clone, Debug)]
50pub struct CorsLayer {
51    origins: OriginRule,
52    methods: Vec<Verb>,
53    headers: HeaderRule,
54    expose: Vec<HeaderName>,
55    credentials: bool,
56    max_age: Option<Duration>,
57}
58
59impl Default for CorsLayer {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl CorsLayer {
66    /// A *closed* policy — no origin allowed yet. Add some with
67    /// [`allow_origin`](Self::allow_origin) / [`allow_any_origin`](Self::allow_any_origin)
68    /// (and methods/headers/credentials as your API needs). Defaults: methods
69    /// GET + POST, no extra headers, no credentials, no preflight cache.
70    pub fn new() -> Self {
71        Self {
72            origins: OriginRule::List(Vec::new()),
73            methods: vec![Verb::GET, Verb::POST],
74            headers: HeaderRule::List(Vec::new()),
75            expose: Vec::new(),
76            credentials: false,
77            max_age: None,
78        }
79    }
80
81    /// A *wide-open* policy — any origin, the common verbs
82    /// (GET/POST/PUT/DELETE/PATCH), any request header, no credentials, a
83    /// one-day preflight cache. Handy for development; tighten it for production.
84    pub fn permissive() -> Self {
85        Self {
86            origins: OriginRule::Any,
87            methods: vec![Verb::GET, Verb::POST, Verb::PUT, Verb::DELETE, Verb::PATCH],
88            headers: HeaderRule::MirrorRequest,
89            expose: Vec::new(),
90            credentials: false,
91            max_age: Some(Duration::from_secs(86_400)),
92        }
93    }
94
95    /// Allow any origin. The response still echoes the concrete `Origin`, so
96    /// this composes correctly with [`allow_credentials`](Self::allow_credentials).
97    pub fn allow_any_origin(mut self) -> Self {
98        self.origins = OriginRule::Any;
99        self
100    }
101
102    /// Add an allowed origin (exact match, e.g. `"https://app.example.com"`).
103    /// Calling this after [`allow_any_origin`](Self::allow_any_origin) narrows
104    /// back to a list.
105    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
106        match &mut self.origins {
107            OriginRule::List(list) => list.push(origin.into()),
108            OriginRule::Any => self.origins = OriginRule::List(vec![origin.into()]),
109        }
110        self
111    }
112
113    /// Set the methods advertised in preflight responses
114    /// (`Access-Control-Allow-Methods`). Default: GET, POST.
115    pub fn allow_methods(mut self, methods: impl IntoIterator<Item = Verb>) -> Self {
116        self.methods = methods.into_iter().collect();
117        self
118    }
119
120    /// Allow any request header — mirrors `Access-Control-Request-Headers`.
121    pub fn allow_any_header(mut self) -> Self {
122        self.headers = HeaderRule::MirrorRequest;
123        self
124    }
125
126    /// Set the allowed request headers explicitly (`Access-Control-Allow-Headers`).
127    pub fn allow_headers(mut self, headers: impl IntoIterator<Item = HeaderName>) -> Self {
128        self.headers = HeaderRule::List(headers.into_iter().collect());
129        self
130    }
131
132    /// Response headers JS may read beyond the CORS-safelisted ones
133    /// (`Access-Control-Expose-Headers`).
134    pub fn expose_headers(mut self, headers: impl IntoIterator<Item = HeaderName>) -> Self {
135        self.expose = headers.into_iter().collect();
136        self
137    }
138
139    /// Whether credentialed requests (cookies, HTTP auth) are permitted
140    /// (`Access-Control-Allow-Credentials`). When on, the wildcard `*` token
141    /// is never sent — the policy always echoes concrete origin/header values,
142    /// which is what the spec requires.
143    pub fn allow_credentials(mut self, yes: bool) -> Self {
144        self.credentials = yes;
145        self
146    }
147
148    /// How long a browser may cache a preflight result (`Access-Control-Max-Age`).
149    pub fn max_age(mut self, age: Duration) -> Self {
150        self.max_age = Some(age);
151        self
152    }
153
154    // -------- internal: used by `Server` --------
155
156    /// `Some(value)` to send as `Access-Control-Allow-Origin` for `origin`
157    /// (always the concrete origin), or `None` if it isn't allowed.
158    fn allow_origin_value(&self, origin: &str) -> Option<HeaderValue> {
159        let allowed = match &self.origins {
160            OriginRule::Any => true,
161            OriginRule::List(list) => list.iter().any(|o| o == origin),
162        };
163        if allowed {
164            HeaderValue::from_str(origin).ok()
165        } else {
166            None
167        }
168    }
169
170    fn allow_methods_value(&self) -> HeaderValue {
171        let joined = self
172            .methods
173            .iter()
174            .map(Verb::as_str)
175            .collect::<Vec<_>>()
176            .join(", ");
177        HeaderValue::from_str(&joined).unwrap_or_else(|_| HeaderValue::from_static("GET, POST"))
178    }
179
180    fn allow_headers_value(&self, requested: Option<&HeaderValue>) -> Option<HeaderValue> {
181        match &self.headers {
182            HeaderRule::MirrorRequest => requested.cloned(),
183            HeaderRule::List(list) if list.is_empty() => None,
184            HeaderRule::List(list) => {
185                let joined = list
186                    .iter()
187                    .map(HeaderName::as_str)
188                    .collect::<Vec<_>>()
189                    .join(", ");
190                HeaderValue::from_str(&joined).ok()
191            }
192        }
193    }
194
195    fn expose_headers_value(&self) -> Option<HeaderValue> {
196        if self.expose.is_empty() {
197            return None;
198        }
199        let joined = self
200            .expose
201            .iter()
202            .map(HeaderName::as_str)
203            .collect::<Vec<_>>()
204            .join(", ");
205        HeaderValue::from_str(&joined).ok()
206    }
207
208    /// True iff `(method, headers)` look like a CORS preflight: `OPTIONS` with
209    /// both `Origin` and `Access-Control-Request-Method`.
210    pub(crate) fn is_preflight(method: &Method, headers: &HeaderMap) -> bool {
211        *method == Method::OPTIONS
212            && headers.contains_key(header::ORIGIN)
213            && headers.contains_key(header::ACCESS_CONTROL_REQUEST_METHOD)
214    }
215
216    fn preflight_headers(&self, request_headers: &HeaderMap) -> Vec<(HeaderName, HeaderValue)> {
217        let mut out = Vec::new();
218        let Some(origin) = request_headers
219            .get(header::ORIGIN)
220            .and_then(|v| v.to_str().ok())
221        else {
222            return out;
223        };
224        let Some(allow_origin) = self.allow_origin_value(origin) else {
225            return out;
226        };
227        out.push((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin));
228        out.push((header::VARY, HeaderValue::from_static("Origin")));
229        out.push((
230            header::ACCESS_CONTROL_ALLOW_METHODS,
231            self.allow_methods_value(),
232        ));
233        if let Some(h) =
234            self.allow_headers_value(request_headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS))
235        {
236            out.push((header::ACCESS_CONTROL_ALLOW_HEADERS, h));
237        }
238        if self.credentials {
239            out.push((
240                header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
241                HeaderValue::from_static("true"),
242            ));
243        }
244        if let Some(age) = self.max_age
245            && let Ok(v) = HeaderValue::from_str(&age.as_secs().to_string())
246        {
247            out.push((header::ACCESS_CONTROL_MAX_AGE, v));
248        }
249        out
250    }
251
252    fn response_headers(&self, request_headers: &HeaderMap) -> Vec<(HeaderName, HeaderValue)> {
253        let mut out = Vec::new();
254        let Some(origin) = request_headers
255            .get(header::ORIGIN)
256            .and_then(|v| v.to_str().ok())
257        else {
258            return out;
259        };
260        let Some(allow_origin) = self.allow_origin_value(origin) else {
261            return out;
262        };
263        out.push((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin));
264        out.push((header::VARY, HeaderValue::from_static("Origin")));
265        if self.credentials {
266            out.push((
267                header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
268                HeaderValue::from_static("true"),
269            ));
270        }
271        if let Some(v) = self.expose_headers_value() {
272            out.push((header::ACCESS_CONTROL_EXPOSE_HEADERS, v));
273        }
274        out
275    }
276
277    /// Add the CORS headers to `into`. With `preflight = true` this is the
278    /// response to a preflight `OPTIONS` (otherwise an empty `204`); with
279    /// `false` it's an ordinary response. No-op when the request had no
280    /// `Origin`, or it isn't allowed (the browser then blocks the call).
281    ///
282    /// `Vary` is *appended* (so an existing `Vary: Accept-Encoding` isn't
283    /// clobbered); the `Access-Control-*` headers are set outright.
284    pub(crate) fn apply(&self, request_headers: &HeaderMap, into: &mut HeaderMap, preflight: bool) {
285        let pairs = if preflight {
286            self.preflight_headers(request_headers)
287        } else {
288            self.response_headers(request_headers)
289        };
290        for (name, value) in pairs {
291            if name == header::VARY {
292                into.append(name, value);
293            } else {
294                into.insert(name, value);
295            }
296        }
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    fn headers(pairs: &[(HeaderName, &str)]) -> HeaderMap {
305        let mut h = HeaderMap::new();
306        for (name, value) in pairs {
307            h.insert(name.clone(), HeaderValue::from_str(value).unwrap());
308        }
309        h
310    }
311
312    fn names(pairs: &[(HeaderName, HeaderValue)], name: &HeaderName) -> Vec<String> {
313        pairs
314            .iter()
315            .filter(|(n, _)| n == name)
316            .map(|(_, v)| v.to_str().unwrap().to_string())
317            .collect()
318    }
319
320    #[test]
321    fn no_origin_header_is_a_noop() {
322        assert!(
323            CorsLayer::permissive()
324                .response_headers(&HeaderMap::new())
325                .is_empty()
326        );
327        assert!(
328            CorsLayer::permissive()
329                .preflight_headers(&HeaderMap::new())
330                .is_empty()
331        );
332    }
333
334    #[test]
335    fn permissive_echoes_any_origin_with_vary() {
336        let out = CorsLayer::permissive()
337            .response_headers(&headers(&[(header::ORIGIN, "https://x.example")]));
338        assert_eq!(
339            names(&out, &header::ACCESS_CONTROL_ALLOW_ORIGIN),
340            ["https://x.example"]
341        );
342        assert_eq!(names(&out, &header::VARY), ["Origin"]);
343        // no credentials by default
344        assert!(names(&out, &header::ACCESS_CONTROL_ALLOW_CREDENTIALS).is_empty());
345    }
346
347    #[test]
348    fn allow_list_rejects_unlisted_origin() {
349        let cors = CorsLayer::new().allow_origin("https://app.example");
350        assert!(
351            cors.response_headers(&headers(&[(header::ORIGIN, "https://evil.example")]))
352                .is_empty()
353        );
354        assert_eq!(
355            names(
356                &cors.response_headers(&headers(&[(header::ORIGIN, "https://app.example")])),
357                &header::ACCESS_CONTROL_ALLOW_ORIGIN
358            ),
359            ["https://app.example"]
360        );
361    }
362
363    #[test]
364    fn preflight_advertises_methods_mirrored_headers_and_max_age() {
365        let out = CorsLayer::permissive().preflight_headers(&headers(&[
366            (header::ORIGIN, "https://x.example"),
367            (header::ACCESS_CONTROL_REQUEST_METHOD, "POST"),
368            (
369                header::ACCESS_CONTROL_REQUEST_HEADERS,
370                "content-type, authorization",
371            ),
372        ]));
373        let methods = &names(&out, &header::ACCESS_CONTROL_ALLOW_METHODS)[0];
374        assert!(methods.contains("POST") && methods.contains("DELETE"));
375        assert_eq!(
376            names(&out, &header::ACCESS_CONTROL_ALLOW_HEADERS),
377            ["content-type, authorization"]
378        );
379        assert_eq!(names(&out, &header::ACCESS_CONTROL_MAX_AGE), ["86400"]);
380    }
381
382    #[test]
383    fn credentials_never_sends_star() {
384        let cors = CorsLayer::permissive().allow_credentials(true);
385        let out = cors.response_headers(&headers(&[(header::ORIGIN, "https://x.example")]));
386        assert_eq!(
387            names(&out, &header::ACCESS_CONTROL_ALLOW_ORIGIN),
388            ["https://x.example"]
389        );
390        assert_eq!(
391            names(&out, &header::ACCESS_CONTROL_ALLOW_CREDENTIALS),
392            ["true"]
393        );
394    }
395
396    #[test]
397    fn apply_appends_vary_but_replaces_acao() {
398        let cors = CorsLayer::permissive();
399        let mut into = HeaderMap::new();
400        into.insert(header::VARY, HeaderValue::from_static("Accept-Encoding"));
401        cors.apply(
402            &headers(&[(header::ORIGIN, "https://x.example")]),
403            &mut into,
404            false,
405        );
406        let vary: Vec<_> = into
407            .get_all(header::VARY)
408            .iter()
409            .map(|v| v.to_str().unwrap().to_string())
410            .collect();
411        assert_eq!(vary, ["Accept-Encoding", "Origin"]);
412        assert_eq!(
413            into.get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(),
414            "https://x.example"
415        );
416    }
417}