Skip to main content

jerrycan_core/
cors.rs

1//! CORS (spec §v2.2). Lives in core because preflight must be answered BEFORE
2//! routing (an `OPTIONS` to a method-mismatched route is rejected 405 before
3//! any middleware runs), so CORS is a pre-routing + response-decoration concern
4//! integrated into `route_policy`/dispatch in later tasks — not a `Middleware`.
5
6use crate::response::{JcBody, Response};
7use http::{HeaderValue, Method, StatusCode, header};
8use std::time::Duration;
9
10/// Which origins may make cross-origin requests.
11#[derive(Clone, Debug)]
12pub enum CorsOrigins {
13    /// Any origin (`Access-Control-Allow-Origin: *`). Invalid with credentials —
14    /// `App::build` refuses the combination.
15    Any,
16    /// An exact-match allowlist of origin strings (scheme + host + optional port).
17    List(Vec<String>),
18}
19
20impl CorsOrigins {
21    pub fn any() -> Self {
22        Self::Any
23    }
24    pub fn list<I, S>(origins: I) -> Self
25    where
26        I: IntoIterator<Item = S>,
27        S: Into<String>,
28    {
29        Self::List(origins.into_iter().map(Into::into).collect())
30    }
31}
32
33/// CORS policy. Build with `CorsConfig::new(origins)`, chain options, install
34/// with `App::cors(config)`.
35#[derive(Clone, Debug)]
36pub struct CorsConfig {
37    origins: CorsOrigins,
38    methods: Vec<http::Method>, // empty => reflect the route's real methods on preflight
39    headers: Vec<String>,       // empty => reflect Access-Control-Request-Headers
40    expose: Vec<String>,
41    allow_credentials: bool,
42    max_age: Option<Duration>,
43}
44
45impl CorsConfig {
46    pub fn new(origins: CorsOrigins) -> Self {
47        Self {
48            origins,
49            methods: Vec::new(),
50            headers: Vec::new(),
51            expose: Vec::new(),
52            allow_credentials: false,
53            max_age: None,
54        }
55    }
56    pub fn allow_credentials(mut self, yes: bool) -> Self {
57        self.allow_credentials = yes;
58        self
59    }
60    pub fn max_age(mut self, d: Duration) -> Self {
61        self.max_age = Some(d);
62        self
63    }
64    pub fn allow_methods<I: IntoIterator<Item = http::Method>>(mut self, m: I) -> Self {
65        self.methods = m.into_iter().collect();
66        self
67    }
68    pub fn allow_headers<I: IntoIterator<Item = S>, S: Into<String>>(mut self, h: I) -> Self {
69        self.headers = h.into_iter().map(Into::into).collect();
70        self
71    }
72    pub fn expose_headers<I: IntoIterator<Item = S>, S: Into<String>>(mut self, h: I) -> Self {
73        self.expose = h.into_iter().map(Into::into).collect();
74        self
75    }
76
77    /// Public reader (the builder method `allow_credentials(bool)` can't share the name).
78    pub fn allow_credentials_enabled(&self) -> bool {
79        self.allow_credentials
80    }
81
82    /// True if `origin` is permitted. `Any` matches everything; `List` is exact.
83    pub fn allows_origin(&self, origin: &str) -> bool {
84        match &self.origins {
85            CorsOrigins::Any => true,
86            CorsOrigins::List(list) => list.iter().any(|o| o == origin),
87        }
88    }
89
90    /// Configured `allow_methods` (empty => reflect the route's real methods).
91    pub(crate) fn cfg_methods(&self) -> &[http::Method] {
92        &self.methods
93    }
94    /// Configured `allow_headers` (empty => reflect `Access-Control-Request-Headers`).
95    pub(crate) fn cfg_headers(&self) -> &[String] {
96        &self.headers
97    }
98    /// Configured `Access-Control-Max-Age`, if set.
99    pub(crate) fn cfg_max_age(&self) -> Option<std::time::Duration> {
100        self.max_age
101    }
102    /// Whether `Access-Control-Allow-Credentials: true` is emitted.
103    pub(crate) fn credentials(&self) -> bool {
104        self.allow_credentials
105    }
106    /// Configured `expose_headers` (empty => no `Access-Control-Expose-Headers`).
107    pub(crate) fn cfg_expose(&self) -> &[String] {
108        &self.expose
109    }
110
111    /// Validate at build time: `*` + credentials is forbidden by the Fetch spec
112    /// and is a footgun, so it is a build error, not a runtime surprise.
113    pub(crate) fn validate(&self) -> crate::Result<()> {
114        if self.allow_credentials && matches!(self.origins, CorsOrigins::Any) {
115            return Err(crate::Error::internal(
116                "CORS misconfiguration: allow_credentials(true) cannot be combined with CorsOrigins::any() — list explicit origins",
117            ));
118        }
119        Ok(())
120    }
121}
122
123/// Is this request a CORS preflight? (`OPTIONS` + `Origin` +
124/// `Access-Control-Request-Method`). All three are required by the Fetch spec;
125/// a bare `OPTIONS` (no origin/no ACRM) is a normal request, not a preflight.
126pub(crate) fn is_preflight(parts: &http::request::Parts) -> bool {
127    parts.method == Method::OPTIONS
128        && parts.headers.contains_key(header::ORIGIN)
129        && parts
130            .headers
131            .contains_key(header::ACCESS_CONTROL_REQUEST_METHOD)
132}
133
134/// Build the CORS preflight `204` for an ALLOWED origin. `allowed_methods` are
135/// the route's real methods (used when the config doesn't pin `allow_methods`).
136/// Returns a bare `204` carrying ONLY CORS headers — deliberately no security
137/// headers, since `cache-control: no-store` would fight `Access-Control-Max-Age`.
138pub(crate) fn preflight_response(
139    config: &CorsConfig,
140    origin: &str,
141    request_headers: Option<&str>,
142    allowed_methods: &[Method],
143) -> Response {
144    let mut r = http::Response::new(JcBody::empty());
145    *r.status_mut() = StatusCode::NO_CONTENT;
146    let h = r.headers_mut();
147    if let Ok(v) = HeaderValue::from_str(origin) {
148        h.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, v);
149        h.insert(header::VARY, HeaderValue::from_static("Origin"));
150    }
151    if config.credentials() {
152        h.insert(
153            header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
154            HeaderValue::from_static("true"),
155        );
156    }
157    let methods = if config.cfg_methods().is_empty() {
158        allowed_methods
159    } else {
160        config.cfg_methods()
161    };
162    let methods_joined = methods
163        .iter()
164        .map(Method::as_str)
165        .collect::<Vec<_>>()
166        .join(", ");
167    if let Ok(v) = HeaderValue::from_str(&methods_joined) {
168        h.insert(header::ACCESS_CONTROL_ALLOW_METHODS, v);
169    }
170    let allow_headers = if config.cfg_headers().is_empty() {
171        request_headers.map(str::to_string)
172    } else {
173        Some(config.cfg_headers().join(", "))
174    };
175    if let Some(hdrs) = allow_headers
176        && let Ok(v) = HeaderValue::from_str(&hdrs)
177    {
178        h.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, v);
179    }
180    if let Some(age) = config.cfg_max_age()
181        && let Ok(v) = HeaderValue::from_str(&age.as_secs().to_string())
182    {
183        h.insert(header::ACCESS_CONTROL_MAX_AGE, v);
184    }
185    r
186}
187
188/// Decorate an actual (non-preflight) response with CORS headers for an allowed
189/// origin. Insert-if-absent so handler-set values win; APPEND `Vary: Origin`
190/// (don't clobber a content-negotiation Vary). No-op for a same-origin request
191/// (no Origin) or a disallowed origin. `origin` is the request's Origin header
192/// value (already extracted); the matching/echo is fallible and skips on bad bytes.
193pub(crate) fn apply_cors(res: &mut Response, origin: Option<&HeaderValue>, config: &CorsConfig) {
194    let Some(origin) = origin.and_then(|v| v.to_str().ok()) else {
195        return;
196    };
197    if !config.allows_origin(origin) {
198        return;
199    }
200    let Ok(origin_val) = HeaderValue::from_str(origin) else {
201        return;
202    };
203    let h = res.headers_mut();
204    if !h.contains_key(header::ACCESS_CONTROL_ALLOW_ORIGIN) {
205        h.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin_val);
206    }
207    if config.credentials() && !h.contains_key(header::ACCESS_CONTROL_ALLOW_CREDENTIALS) {
208        h.insert(
209            header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
210            HeaderValue::from_static("true"),
211        );
212    }
213    if !config.cfg_expose().is_empty()
214        && !h.contains_key(header::ACCESS_CONTROL_EXPOSE_HEADERS)
215        && let Ok(v) = HeaderValue::from_str(&config.cfg_expose().join(", "))
216    {
217        h.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, v);
218    }
219    // Vary: Origin — append unless already present (caches must not serve a
220    // wrong-origin response). Check existing Vary values case-insensitively.
221    let has_origin_vary = h.get_all(header::VARY).iter().any(|v| {
222        v.to_str()
223            .map(|s| {
224                s.split(',')
225                    .any(|p| p.trim().eq_ignore_ascii_case("origin"))
226            })
227            .unwrap_or(false)
228    });
229    if !has_origin_vary {
230        h.append(header::VARY, HeaderValue::from_static("Origin"));
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn config_builder_shapes_origins_and_credentials() {
240        let c = CorsConfig::new(CorsOrigins::list(["https://app.example"]))
241            .allow_credentials(true)
242            .max_age(std::time::Duration::from_secs(600));
243        assert!(c.allows_origin("https://app.example"));
244        assert!(!c.allows_origin("https://evil.example"));
245        assert!(c.allow_credentials_enabled());
246    }
247
248    #[test]
249    fn any_origin_allows_everything() {
250        let c = CorsConfig::new(CorsOrigins::any());
251        assert!(c.allows_origin("https://whatever.example"));
252    }
253}