ferro_rs/middleware/
cors.rs1use crate::http::{HttpResponse, Request, Response};
29use crate::middleware::{Middleware, Next};
30use async_trait::async_trait;
31
32pub struct Cors {
38 origins: Origins,
39 methods: Vec<String>,
40 headers: Vec<String>,
41 max_age: u32,
42}
43
44enum Origins {
45 Any,
46 List(Vec<String>),
47}
48
49impl Cors {
50 pub fn permissive() -> Self {
54 Self {
55 origins: Origins::Any,
56 methods: vec!["GET".into(), "POST".into(), "OPTIONS".into()],
57 headers: vec!["Content-Type".into(), "Accept".into()],
58 max_age: 86400,
59 }
60 }
61
62 pub fn new() -> Self {
66 Self {
67 origins: Origins::List(Vec::new()),
68 methods: vec!["GET".into(), "POST".into(), "OPTIONS".into()],
69 headers: vec!["Content-Type".into(), "Accept".into()],
70 max_age: 86400,
71 }
72 }
73
74 pub fn allow_origins<I, S>(mut self, origins: I) -> Self
82 where
83 I: IntoIterator<Item = S>,
84 S: Into<String>,
85 {
86 self.origins = Origins::List(origins.into_iter().map(Into::into).collect());
87 self
88 }
89
90 pub fn allow_methods<I, S>(mut self, methods: I) -> Self
92 where
93 I: IntoIterator<Item = S>,
94 S: Into<String>,
95 {
96 self.methods = methods.into_iter().map(Into::into).collect();
97 self
98 }
99
100 pub fn allow_headers<I, S>(mut self, headers: I) -> Self
102 where
103 I: IntoIterator<Item = S>,
104 S: Into<String>,
105 {
106 self.headers = headers.into_iter().map(Into::into).collect();
107 self
108 }
109
110 pub fn max_age(mut self, seconds: u32) -> Self {
112 self.max_age = seconds;
113 self
114 }
115
116 fn allowed_origin(&self, request_origin: Option<&str>) -> Option<String> {
118 match &self.origins {
119 Origins::Any => Some("*".into()),
120 Origins::List(list) => {
121 let origin = request_origin?;
122 if list.iter().any(|o| o == origin) {
123 Some(origin.to_string())
124 } else {
125 None
126 }
127 }
128 }
129 }
130
131 fn apply(&self, response: HttpResponse, origin: &str) -> HttpResponse {
133 response
134 .header("Access-Control-Allow-Origin", origin)
135 .header("Access-Control-Allow-Methods", self.methods.join(", "))
136 .header("Access-Control-Allow-Headers", self.headers.join(", "))
137 .header("Access-Control-Max-Age", self.max_age.to_string())
138 }
139}
140
141impl Default for Cors {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147#[async_trait]
148impl Middleware for Cors {
149 async fn handle(&self, request: Request, next: Next) -> Response {
150 let request_origin = request.header("Origin").map(|s| s.to_string());
151 let origin = self.allowed_origin(request_origin.as_deref());
152
153 if request.method() == "OPTIONS" {
155 let response = HttpResponse::new().status(204);
156 return match origin {
157 Some(ref o) => Ok(self.apply(response, o)),
158 None => Ok(response),
159 };
160 }
161
162 let response = next(request).await;
163
164 match origin {
165 Some(ref o) => match response {
166 Ok(r) => Ok(self.apply(r, o)),
167 Err(r) => Err(self.apply(r, o)),
168 },
169 None => response,
170 }
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[test]
179 fn test_permissive_allows_any_origin() {
180 let cors = Cors::permissive();
181 assert!(matches!(cors.origins, Origins::Any));
182 assert_eq!(
183 cors.allowed_origin(Some("https://example.com")),
184 Some("*".into())
185 );
186 assert_eq!(cors.allowed_origin(None), Some("*".into()));
187 }
188
189 #[test]
190 fn test_allow_origins_list() {
191 let cors = Cors::new().allow_origins(vec!["https://a.com", "https://b.com"]);
192 assert_eq!(
193 cors.allowed_origin(Some("https://a.com")),
194 Some("https://a.com".into())
195 );
196 assert_eq!(
197 cors.allowed_origin(Some("https://b.com")),
198 Some("https://b.com".into())
199 );
200 assert_eq!(cors.allowed_origin(Some("https://c.com")), None);
201 assert_eq!(cors.allowed_origin(None), None);
202 }
203
204 #[test]
205 fn test_builder_methods() {
206 let cors = Cors::new()
207 .allow_origins(vec!["https://x.com"])
208 .allow_methods(vec!["GET", "POST", "PUT"])
209 .allow_headers(vec!["Authorization", "Content-Type"])
210 .max_age(3600);
211
212 assert_eq!(cors.methods, vec!["GET", "POST", "PUT"]);
213 assert_eq!(cors.headers, vec!["Authorization", "Content-Type"]);
214 assert_eq!(cors.max_age, 3600);
215 }
216
217 #[test]
218 fn test_default_is_restrictive() {
219 let cors = Cors::default();
220 assert_eq!(cors.allowed_origin(Some("https://any.com")), None);
222 }
223}