Skip to main content

api_bones/
cors.rs

1//! Typed helpers for CORS response headers (`Access-Control-*`).
2//!
3//! [`CorsHeaders`] models the complete set of CORS response headers defined in
4//! the Fetch specification. A fluent builder API makes it easy to construct
5//! both simple and preflight responses.
6//!
7//! # Example
8//!
9//! ```rust
10//! use api_bones::cors::{CorsHeaders, CorsOrigin};
11//!
12//! // Simple CORS response.
13//! let cors = CorsHeaders::new()
14//!     .allow_origin(CorsOrigin::Any)
15//!     .allow_methods(["GET", "POST"])
16//!     .allow_headers(["Content-Type", "Authorization"])
17//!     .max_age(86_400);
18//!
19//! assert_eq!(cors.allow_origin.as_ref().unwrap().to_string(), "*");
20//! assert_eq!(cors.max_age, Some(86_400));
21//!
22//! // Preflight response helper.
23//! let preflight = CorsHeaders::preflight(
24//!     CorsOrigin::Origin("https://example.com".into()),
25//!     ["GET", "POST", "DELETE"],
26//!     ["Content-Type"],
27//! );
28//! assert!(preflight.allow_credentials.is_none() || preflight.allow_credentials == Some(false));
29//! ```
30
31#[cfg(all(not(feature = "std"), feature = "alloc"))]
32use alloc::{string::String, vec::Vec};
33use core::fmt;
34#[cfg(feature = "serde")]
35use serde::{Deserialize, Serialize};
36
37// ---------------------------------------------------------------------------
38// CorsOrigin
39// ---------------------------------------------------------------------------
40
41/// The value of the `Access-Control-Allow-Origin` header.
42///
43/// - [`CorsOrigin::Any`] — `*`
44/// - [`CorsOrigin::Origin(url)`] — a specific origin URL
45#[derive(Debug, Clone, PartialEq, Eq)]
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
47#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
49pub enum CorsOrigin {
50    /// `Access-Control-Allow-Origin: *`
51    Any,
52    /// `Access-Control-Allow-Origin: <url>`
53    Origin(String),
54}
55
56impl fmt::Display for CorsOrigin {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        match self {
59            Self::Any => f.write_str("*"),
60            Self::Origin(url) => f.write_str(url),
61        }
62    }
63}
64
65// ---------------------------------------------------------------------------
66// CorsHeaders
67// ---------------------------------------------------------------------------
68
69/// Structured CORS response headers.
70///
71/// All fields are `Option` so that headers can be omitted when not needed.
72/// Use the builder methods to set individual fields.
73#[derive(Debug, Clone, PartialEq, Eq, Default)]
74#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
75#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
76#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
77#[non_exhaustive]
78pub struct CorsHeaders {
79    /// `Access-Control-Allow-Origin`
80    pub allow_origin: Option<CorsOrigin>,
81    /// `Access-Control-Allow-Methods`
82    pub allow_methods: Option<Vec<String>>,
83    /// `Access-Control-Allow-Headers`
84    pub allow_headers: Option<Vec<String>>,
85    /// `Access-Control-Expose-Headers`
86    pub expose_headers: Option<Vec<String>>,
87    /// `Access-Control-Max-Age` (seconds)
88    pub max_age: Option<u64>,
89    /// `Access-Control-Allow-Credentials`
90    pub allow_credentials: Option<bool>,
91}
92
93impl CorsHeaders {
94    /// Create a new, empty `CorsHeaders`.
95    #[must_use]
96    pub fn new() -> Self {
97        Self::default()
98    }
99
100    // -----------------------------------------------------------------------
101    // Builder methods
102    // -----------------------------------------------------------------------
103
104    /// Set `Access-Control-Allow-Origin`.
105    ///
106    /// ```
107    /// use api_bones::cors::{CorsHeaders, CorsOrigin};
108    ///
109    /// let cors = CorsHeaders::new().allow_origin(CorsOrigin::Any);
110    /// assert_eq!(cors.allow_origin.unwrap().to_string(), "*");
111    /// ```
112    #[must_use]
113    pub fn allow_origin(mut self, origin: CorsOrigin) -> Self {
114        self.allow_origin = Some(origin);
115        self
116    }
117
118    /// Set `Access-Control-Allow-Methods` from an iterator of method strings.
119    ///
120    /// ```
121    /// use api_bones::cors::CorsHeaders;
122    ///
123    /// let cors = CorsHeaders::new().allow_methods(["GET", "POST"]);
124    /// let methods = cors.allow_methods.unwrap();
125    /// assert!(methods.contains(&"GET".to_string()));
126    /// ```
127    #[must_use]
128    pub fn allow_methods<I>(mut self, methods: I) -> Self
129    where
130        I: IntoIterator,
131        I::Item: Into<String>,
132    {
133        self.allow_methods = Some(methods.into_iter().map(Into::into).collect());
134        self
135    }
136
137    /// Set `Access-Control-Allow-Headers` from an iterator of header names.
138    #[must_use]
139    pub fn allow_headers<I>(mut self, headers: I) -> Self
140    where
141        I: IntoIterator,
142        I::Item: Into<String>,
143    {
144        self.allow_headers = Some(headers.into_iter().map(Into::into).collect());
145        self
146    }
147
148    /// Set `Access-Control-Expose-Headers` from an iterator of header names.
149    #[must_use]
150    pub fn expose_headers<I>(mut self, headers: I) -> Self
151    where
152        I: IntoIterator,
153        I::Item: Into<String>,
154    {
155        self.expose_headers = Some(headers.into_iter().map(Into::into).collect());
156        self
157    }
158
159    /// Set `Access-Control-Max-Age` (seconds).
160    #[must_use]
161    pub fn max_age(mut self, seconds: u64) -> Self {
162        self.max_age = Some(seconds);
163        self
164    }
165
166    /// Set `Access-Control-Allow-Credentials`.
167    ///
168    /// Note: per the spec, `Allow-Credentials: true` is incompatible with
169    /// `Allow-Origin: *`. This is not enforced at the type level but callers
170    /// should be careful.
171    #[must_use]
172    pub fn allow_credentials(mut self, allow: bool) -> Self {
173        self.allow_credentials = Some(allow);
174        self
175    }
176
177    // -----------------------------------------------------------------------
178    // Convenience constructors
179    // -----------------------------------------------------------------------
180
181    /// Build a preflight (`OPTIONS`) response with sensible defaults.
182    ///
183    /// Sets `Allow-Origin`, `Allow-Methods`, and `Allow-Headers`. Does not set
184    /// `Allow-Credentials` (default: absent, treated as `false` by browsers).
185    ///
186    /// ```
187    /// use api_bones::cors::{CorsHeaders, CorsOrigin};
188    ///
189    /// let preflight = CorsHeaders::preflight(
190    ///     CorsOrigin::Origin("https://example.com".into()),
191    ///     ["GET", "POST"],
192    ///     ["Content-Type"],
193    /// );
194    /// assert!(preflight.allow_methods.is_some());
195    /// assert!(preflight.allow_headers.is_some());
196    /// ```
197    #[must_use]
198    pub fn preflight<M, H>(origin: CorsOrigin, methods: M, headers: H) -> Self
199    where
200        M: IntoIterator,
201        M::Item: Into<String>,
202        H: IntoIterator,
203        H::Item: Into<String>,
204    {
205        Self::new()
206            .allow_origin(origin)
207            .allow_methods(methods)
208            .allow_headers(headers)
209    }
210
211    // -----------------------------------------------------------------------
212    // Header value accessors
213    // -----------------------------------------------------------------------
214
215    /// Render the `Access-Control-Allow-Methods` value as a comma-separated string.
216    ///
217    /// Returns `None` if the field is not set.
218    #[must_use]
219    pub fn allow_methods_header(&self) -> Option<String> {
220        self.allow_methods.as_ref().map(|m| m.join(", "))
221    }
222
223    /// Render the `Access-Control-Allow-Headers` value as a comma-separated string.
224    ///
225    /// Returns `None` if the field is not set.
226    #[must_use]
227    pub fn allow_headers_header(&self) -> Option<String> {
228        self.allow_headers.as_ref().map(|h| h.join(", "))
229    }
230
231    /// Render the `Access-Control-Expose-Headers` value as a comma-separated string.
232    ///
233    /// Returns `None` if the field is not set.
234    #[must_use]
235    pub fn expose_headers_header(&self) -> Option<String> {
236        self.expose_headers.as_ref().map(|h| h.join(", "))
237    }
238}
239
240// ---------------------------------------------------------------------------
241// Axum integration
242// ---------------------------------------------------------------------------
243
244#[cfg(feature = "axum")]
245mod axum_support {
246    use super::CorsHeaders;
247    use axum::http::HeaderValue;
248    use axum::response::{IntoResponseParts, ResponseParts};
249
250    impl IntoResponseParts for CorsHeaders {
251        type Error = std::convert::Infallible;
252
253        fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
254            let headers = res.headers_mut();
255
256            if let Some(origin) = &self.allow_origin
257                && let Ok(v) = HeaderValue::from_str(&origin.to_string())
258            {
259                headers.insert("access-control-allow-origin", v);
260            }
261            if let Some(methods) = &self.allow_methods
262                && let Ok(v) = HeaderValue::from_str(&methods.join(", "))
263            {
264                headers.insert("access-control-allow-methods", v);
265            }
266            if let Some(hdrs) = &self.allow_headers
267                && let Ok(v) = HeaderValue::from_str(&hdrs.join(", "))
268            {
269                headers.insert("access-control-allow-headers", v);
270            }
271            if let Some(expose) = &self.expose_headers
272                && let Ok(v) = HeaderValue::from_str(&expose.join(", "))
273            {
274                headers.insert("access-control-expose-headers", v);
275            }
276            if let Some(max_age) = self.max_age
277                && let Ok(v) = HeaderValue::from_str(&max_age.to_string())
278            {
279                headers.insert("access-control-max-age", v);
280            }
281            if let Some(creds) = self.allow_credentials {
282                let val = if creds { "true" } else { "false" };
283                let v = HeaderValue::from_static(val);
284                headers.insert("access-control-allow-credentials", v);
285            }
286
287            Ok(res)
288        }
289    }
290}
291
292// ---------------------------------------------------------------------------
293// Tests
294// ---------------------------------------------------------------------------
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn default_all_none() {
302        let cors = CorsHeaders::new();
303        assert!(cors.allow_origin.is_none());
304        assert!(cors.allow_methods.is_none());
305        assert!(cors.allow_headers.is_none());
306        assert!(cors.expose_headers.is_none());
307        assert!(cors.max_age.is_none());
308        assert!(cors.allow_credentials.is_none());
309    }
310
311    #[test]
312    fn builder_allow_origin_any() {
313        let cors = CorsHeaders::new().allow_origin(CorsOrigin::Any);
314        assert_eq!(cors.allow_origin.unwrap().to_string(), "*");
315    }
316
317    #[test]
318    fn builder_allow_origin_specific() {
319        let cors =
320            CorsHeaders::new().allow_origin(CorsOrigin::Origin("https://example.com".into()));
321        assert_eq!(
322            cors.allow_origin.unwrap().to_string(),
323            "https://example.com"
324        );
325    }
326
327    #[test]
328    fn builder_allow_methods() {
329        let cors = CorsHeaders::new().allow_methods(["GET", "POST", "DELETE"]);
330        let methods = cors.allow_methods.unwrap();
331        assert!(methods.contains(&"GET".to_string()));
332        assert!(methods.contains(&"POST".to_string()));
333        assert_eq!(methods.len(), 3);
334    }
335
336    #[test]
337    fn builder_allow_headers() {
338        let cors = CorsHeaders::new().allow_headers(["Content-Type", "Authorization"]);
339        let hdrs = cors.allow_headers.unwrap();
340        assert!(hdrs.contains(&"Content-Type".to_string()));
341    }
342
343    #[test]
344    fn builder_expose_headers() {
345        let cors = CorsHeaders::new().expose_headers(["X-Request-Id"]);
346        assert_eq!(cors.expose_headers_header().unwrap(), "X-Request-Id");
347    }
348
349    #[test]
350    fn builder_max_age() {
351        let cors = CorsHeaders::new().max_age(3600);
352        assert_eq!(cors.max_age, Some(3600));
353    }
354
355    #[test]
356    fn builder_allow_credentials() {
357        let cors = CorsHeaders::new().allow_credentials(true);
358        assert_eq!(cors.allow_credentials, Some(true));
359    }
360
361    #[test]
362    fn header_value_accessors() {
363        let cors = CorsHeaders::new()
364            .allow_methods(["GET", "POST"])
365            .allow_headers(["Content-Type"]);
366        assert_eq!(cors.allow_methods_header().unwrap(), "GET, POST");
367        assert_eq!(cors.allow_headers_header().unwrap(), "Content-Type");
368        assert!(cors.expose_headers_header().is_none());
369    }
370
371    #[test]
372    fn preflight_constructor() {
373        let p = CorsHeaders::preflight(
374            CorsOrigin::Origin("https://app.example.com".into()),
375            ["GET", "POST"],
376            ["Content-Type", "Authorization"],
377        );
378        assert!(p.allow_origin.is_some());
379        assert_eq!(p.allow_methods.as_ref().unwrap().len(), 2);
380        assert_eq!(p.allow_headers.as_ref().unwrap().len(), 2);
381        assert!(p.allow_credentials.is_none());
382    }
383
384    #[test]
385    fn cors_origin_display() {
386        assert_eq!(CorsOrigin::Any.to_string(), "*");
387        assert_eq!(
388            CorsOrigin::Origin("https://x.com".into()).to_string(),
389            "https://x.com"
390        );
391    }
392
393    #[cfg(feature = "axum")]
394    #[test]
395    fn into_response_parts_sets_headers() {
396        use axum::response::IntoResponse;
397
398        let cors = CorsHeaders::new()
399            .allow_origin(CorsOrigin::Any)
400            .allow_methods(["GET"])
401            .max_age(600);
402
403        let response = (cors, axum::http::StatusCode::NO_CONTENT).into_response();
404        let headers = response.headers();
405
406        assert_eq!(
407            headers
408                .get("access-control-allow-origin")
409                .unwrap()
410                .to_str()
411                .unwrap(),
412            "*"
413        );
414        assert_eq!(
415            headers
416                .get("access-control-max-age")
417                .unwrap()
418                .to_str()
419                .unwrap(),
420            "600"
421        );
422    }
423}