Skip to main content

ferro_rs/middleware/
cors.rs

1//! CORS middleware for Ferro framework
2//!
3//! Adds Cross-Origin Resource Sharing headers to responses and handles
4//! OPTIONS preflight requests. Suitable for public API endpoints consumed
5//! by browser clients on different origins (e.g. custom-domain frontends).
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use ferro::middleware::Cors;
11//!
12//! // Permissive: allow all origins
13//! group!("/api/v1").middleware(Cors::permissive()).routes(|r| {
14//!     r.get("/products", products::index);
15//! });
16//!
17//! // Restrictive: allow specific origins only
18//! group!("/api/v1").middleware(
19//!     Cors::new()
20//!         .allow_origins(vec!["https://example.com", "https://app.example.com"])
21//!         .allow_methods(vec!["GET", "POST"])
22//!         .allow_headers(vec!["Content-Type", "Authorization"]),
23//! ).routes(|r| {
24//!     r.get("/products", products::index);
25//! });
26//! ```
27
28use crate::http::{HttpResponse, Request, Response};
29use crate::middleware::{Middleware, Next};
30use async_trait::async_trait;
31
32/// CORS middleware
33///
34/// Appends CORS response headers and short-circuits OPTIONS preflight requests
35/// with a 204 No Content response. Use [`Cors::permissive()`] for open APIs or
36/// [`Cors::new()`] + builder methods for origin-restricted endpoints.
37pub struct Cors {
38    origins: Origins,
39    methods: Vec<String>,
40    headers: Vec<String>,
41    max_age: u32,
42}
43
44enum Origins {
45    Any,
46    List(Vec<String>),
47}
48
49impl Cors {
50    /// Create a new `Cors` with permissive defaults
51    ///
52    /// Allows all origins (`*`), GET/POST/OPTIONS methods, and Content-Type/Accept headers.
53    pub fn permissive() -> Self {
54        Self {
55            origins: Origins::Any,
56            methods: vec!["GET".into(), "POST".into(), "OPTIONS".into()],
57            headers: vec!["Content-Type".into(), "Accept".into()],
58            max_age: 86400,
59        }
60    }
61
62    /// Create a new `Cors` builder with no allowed origins
63    ///
64    /// Call [`allow_origins`](Self::allow_origins) to configure allowed origins.
65    pub fn new() -> Self {
66        Self {
67            origins: Origins::List(Vec::new()),
68            methods: vec!["GET".into(), "POST".into(), "OPTIONS".into()],
69            headers: vec!["Content-Type".into(), "Accept".into()],
70            max_age: 86400,
71        }
72    }
73
74    /// Set the list of allowed origins
75    ///
76    /// # Example
77    ///
78    /// ```rust,ignore
79    /// Cors::new().allow_origins(vec!["https://example.com", "https://app.example.com"])
80    /// ```
81    pub fn allow_origins<I, S>(mut self, origins: I) -> Self
82    where
83        I: IntoIterator<Item = S>,
84        S: Into<String>,
85    {
86        self.origins = Origins::List(origins.into_iter().map(Into::into).collect());
87        self
88    }
89
90    /// Set the allowed HTTP methods
91    pub fn allow_methods<I, S>(mut self, methods: I) -> Self
92    where
93        I: IntoIterator<Item = S>,
94        S: Into<String>,
95    {
96        self.methods = methods.into_iter().map(Into::into).collect();
97        self
98    }
99
100    /// Set the allowed request headers
101    pub fn allow_headers<I, S>(mut self, headers: I) -> Self
102    where
103        I: IntoIterator<Item = S>,
104        S: Into<String>,
105    {
106        self.headers = headers.into_iter().map(Into::into).collect();
107        self
108    }
109
110    /// Set the preflight cache duration in seconds (default: 86400)
111    pub fn max_age(mut self, seconds: u32) -> Self {
112        self.max_age = seconds;
113        self
114    }
115
116    /// Determine the `Access-Control-Allow-Origin` value for a given request origin
117    fn allowed_origin(&self, request_origin: Option<&str>) -> Option<String> {
118        match &self.origins {
119            Origins::Any => Some("*".into()),
120            Origins::List(list) => {
121                let origin = request_origin?;
122                if list.iter().any(|o| o == origin) {
123                    Some(origin.to_string())
124                } else {
125                    None
126                }
127            }
128        }
129    }
130
131    /// Apply CORS headers to a response
132    fn apply(&self, response: HttpResponse, origin: &str) -> HttpResponse {
133        response
134            .header("Access-Control-Allow-Origin", origin)
135            .header("Access-Control-Allow-Methods", self.methods.join(", "))
136            .header("Access-Control-Allow-Headers", self.headers.join(", "))
137            .header("Access-Control-Max-Age", self.max_age.to_string())
138    }
139}
140
141impl Default for Cors {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147#[async_trait]
148impl Middleware for Cors {
149    async fn handle(&self, request: Request, next: Next) -> Response {
150        let request_origin = request.header("Origin").map(|s| s.to_string());
151        let origin = self.allowed_origin(request_origin.as_deref());
152
153        // Preflight: respond immediately without reaching the handler
154        if request.method() == "OPTIONS" {
155            let response = HttpResponse::new().status(204);
156            return match origin {
157                Some(ref o) => Ok(self.apply(response, o)),
158                None => Ok(response),
159            };
160        }
161
162        let response = next(request).await;
163
164        match origin {
165            Some(ref o) => match response {
166                Ok(r) => Ok(self.apply(r, o)),
167                Err(r) => Err(self.apply(r, o)),
168            },
169            None => response,
170        }
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn test_permissive_allows_any_origin() {
180        let cors = Cors::permissive();
181        assert!(matches!(cors.origins, Origins::Any));
182        assert_eq!(
183            cors.allowed_origin(Some("https://example.com")),
184            Some("*".into())
185        );
186        assert_eq!(cors.allowed_origin(None), Some("*".into()));
187    }
188
189    #[test]
190    fn test_allow_origins_list() {
191        let cors = Cors::new().allow_origins(vec!["https://a.com", "https://b.com"]);
192        assert_eq!(
193            cors.allowed_origin(Some("https://a.com")),
194            Some("https://a.com".into())
195        );
196        assert_eq!(
197            cors.allowed_origin(Some("https://b.com")),
198            Some("https://b.com".into())
199        );
200        assert_eq!(cors.allowed_origin(Some("https://c.com")), None);
201        assert_eq!(cors.allowed_origin(None), None);
202    }
203
204    #[test]
205    fn test_builder_methods() {
206        let cors = Cors::new()
207            .allow_origins(vec!["https://x.com"])
208            .allow_methods(vec!["GET", "POST", "PUT"])
209            .allow_headers(vec!["Authorization", "Content-Type"])
210            .max_age(3600);
211
212        assert_eq!(cors.methods, vec!["GET", "POST", "PUT"]);
213        assert_eq!(cors.headers, vec!["Authorization", "Content-Type"]);
214        assert_eq!(cors.max_age, 3600);
215    }
216
217    #[test]
218    fn test_default_is_restrictive() {
219        let cors = Cors::default();
220        // No origins configured — should not allow anything
221        assert_eq!(cors.allowed_origin(Some("https://any.com")), None);
222    }
223}