1use actus_controller::Verb;
27use http::{HeaderMap, HeaderName, HeaderValue, Method, header};
28use std::time::Duration;
29
30#[derive(Clone, Debug)]
31enum OriginRule {
32 Any,
35 List(Vec<String>),
37}
38
39#[derive(Clone, Debug)]
40enum HeaderRule {
41 MirrorRequest,
44 List(Vec<HeaderName>),
46}
47
48#[derive(Clone, Debug)]
50pub struct CorsLayer {
51 origins: OriginRule,
52 methods: Vec<Verb>,
53 headers: HeaderRule,
54 expose: Vec<HeaderName>,
55 credentials: bool,
56 max_age: Option<Duration>,
57}
58
59impl Default for CorsLayer {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl CorsLayer {
66 pub fn new() -> Self {
71 Self {
72 origins: OriginRule::List(Vec::new()),
73 methods: vec![Verb::GET, Verb::POST],
74 headers: HeaderRule::List(Vec::new()),
75 expose: Vec::new(),
76 credentials: false,
77 max_age: None,
78 }
79 }
80
81 pub fn permissive() -> Self {
85 Self {
86 origins: OriginRule::Any,
87 methods: vec![Verb::GET, Verb::POST, Verb::PUT, Verb::DELETE, Verb::PATCH],
88 headers: HeaderRule::MirrorRequest,
89 expose: Vec::new(),
90 credentials: false,
91 max_age: Some(Duration::from_secs(86_400)),
92 }
93 }
94
95 pub fn allow_any_origin(mut self) -> Self {
98 self.origins = OriginRule::Any;
99 self
100 }
101
102 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
106 match &mut self.origins {
107 OriginRule::List(list) => list.push(origin.into()),
108 OriginRule::Any => self.origins = OriginRule::List(vec![origin.into()]),
109 }
110 self
111 }
112
113 pub fn allow_methods(mut self, methods: impl IntoIterator<Item = Verb>) -> Self {
116 self.methods = methods.into_iter().collect();
117 self
118 }
119
120 pub fn allow_any_header(mut self) -> Self {
122 self.headers = HeaderRule::MirrorRequest;
123 self
124 }
125
126 pub fn allow_headers(mut self, headers: impl IntoIterator<Item = HeaderName>) -> Self {
128 self.headers = HeaderRule::List(headers.into_iter().collect());
129 self
130 }
131
132 pub fn expose_headers(mut self, headers: impl IntoIterator<Item = HeaderName>) -> Self {
135 self.expose = headers.into_iter().collect();
136 self
137 }
138
139 pub fn allow_credentials(mut self, yes: bool) -> Self {
144 self.credentials = yes;
145 self
146 }
147
148 pub fn max_age(mut self, age: Duration) -> Self {
150 self.max_age = Some(age);
151 self
152 }
153
154 fn allow_origin_value(&self, origin: &str) -> Option<HeaderValue> {
159 let allowed = match &self.origins {
160 OriginRule::Any => true,
161 OriginRule::List(list) => list.iter().any(|o| o == origin),
162 };
163 if allowed {
164 HeaderValue::from_str(origin).ok()
165 } else {
166 None
167 }
168 }
169
170 fn allow_methods_value(&self) -> HeaderValue {
171 let joined = self
172 .methods
173 .iter()
174 .map(Verb::as_str)
175 .collect::<Vec<_>>()
176 .join(", ");
177 HeaderValue::from_str(&joined).unwrap_or_else(|_| HeaderValue::from_static("GET, POST"))
178 }
179
180 fn allow_headers_value(&self, requested: Option<&HeaderValue>) -> Option<HeaderValue> {
181 match &self.headers {
182 HeaderRule::MirrorRequest => requested.cloned(),
183 HeaderRule::List(list) if list.is_empty() => None,
184 HeaderRule::List(list) => {
185 let joined = list
186 .iter()
187 .map(HeaderName::as_str)
188 .collect::<Vec<_>>()
189 .join(", ");
190 HeaderValue::from_str(&joined).ok()
191 }
192 }
193 }
194
195 fn expose_headers_value(&self) -> Option<HeaderValue> {
196 if self.expose.is_empty() {
197 return None;
198 }
199 let joined = self
200 .expose
201 .iter()
202 .map(HeaderName::as_str)
203 .collect::<Vec<_>>()
204 .join(", ");
205 HeaderValue::from_str(&joined).ok()
206 }
207
208 pub(crate) fn is_preflight(method: &Method, headers: &HeaderMap) -> bool {
211 *method == Method::OPTIONS
212 && headers.contains_key(header::ORIGIN)
213 && headers.contains_key(header::ACCESS_CONTROL_REQUEST_METHOD)
214 }
215
216 fn preflight_headers(&self, request_headers: &HeaderMap) -> Vec<(HeaderName, HeaderValue)> {
217 let mut out = Vec::new();
218 let Some(origin) = request_headers
219 .get(header::ORIGIN)
220 .and_then(|v| v.to_str().ok())
221 else {
222 return out;
223 };
224 let Some(allow_origin) = self.allow_origin_value(origin) else {
225 return out;
226 };
227 out.push((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin));
228 out.push((header::VARY, HeaderValue::from_static("Origin")));
229 out.push((
230 header::ACCESS_CONTROL_ALLOW_METHODS,
231 self.allow_methods_value(),
232 ));
233 if let Some(h) =
234 self.allow_headers_value(request_headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS))
235 {
236 out.push((header::ACCESS_CONTROL_ALLOW_HEADERS, h));
237 }
238 if self.credentials {
239 out.push((
240 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
241 HeaderValue::from_static("true"),
242 ));
243 }
244 if let Some(age) = self.max_age
245 && let Ok(v) = HeaderValue::from_str(&age.as_secs().to_string())
246 {
247 out.push((header::ACCESS_CONTROL_MAX_AGE, v));
248 }
249 out
250 }
251
252 fn response_headers(&self, request_headers: &HeaderMap) -> Vec<(HeaderName, HeaderValue)> {
253 let mut out = Vec::new();
254 let Some(origin) = request_headers
255 .get(header::ORIGIN)
256 .and_then(|v| v.to_str().ok())
257 else {
258 return out;
259 };
260 let Some(allow_origin) = self.allow_origin_value(origin) else {
261 return out;
262 };
263 out.push((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin));
264 out.push((header::VARY, HeaderValue::from_static("Origin")));
265 if self.credentials {
266 out.push((
267 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
268 HeaderValue::from_static("true"),
269 ));
270 }
271 if let Some(v) = self.expose_headers_value() {
272 out.push((header::ACCESS_CONTROL_EXPOSE_HEADERS, v));
273 }
274 out
275 }
276
277 pub(crate) fn apply(&self, request_headers: &HeaderMap, into: &mut HeaderMap, preflight: bool) {
285 let pairs = if preflight {
286 self.preflight_headers(request_headers)
287 } else {
288 self.response_headers(request_headers)
289 };
290 for (name, value) in pairs {
291 if name == header::VARY {
292 into.append(name, value);
293 } else {
294 into.insert(name, value);
295 }
296 }
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 fn headers(pairs: &[(HeaderName, &str)]) -> HeaderMap {
305 let mut h = HeaderMap::new();
306 for (name, value) in pairs {
307 h.insert(name.clone(), HeaderValue::from_str(value).unwrap());
308 }
309 h
310 }
311
312 fn names(pairs: &[(HeaderName, HeaderValue)], name: &HeaderName) -> Vec<String> {
313 pairs
314 .iter()
315 .filter(|(n, _)| n == name)
316 .map(|(_, v)| v.to_str().unwrap().to_string())
317 .collect()
318 }
319
320 #[test]
321 fn no_origin_header_is_a_noop() {
322 assert!(
323 CorsLayer::permissive()
324 .response_headers(&HeaderMap::new())
325 .is_empty()
326 );
327 assert!(
328 CorsLayer::permissive()
329 .preflight_headers(&HeaderMap::new())
330 .is_empty()
331 );
332 }
333
334 #[test]
335 fn permissive_echoes_any_origin_with_vary() {
336 let out = CorsLayer::permissive()
337 .response_headers(&headers(&[(header::ORIGIN, "https://x.example")]));
338 assert_eq!(
339 names(&out, &header::ACCESS_CONTROL_ALLOW_ORIGIN),
340 ["https://x.example"]
341 );
342 assert_eq!(names(&out, &header::VARY), ["Origin"]);
343 assert!(names(&out, &header::ACCESS_CONTROL_ALLOW_CREDENTIALS).is_empty());
345 }
346
347 #[test]
348 fn allow_list_rejects_unlisted_origin() {
349 let cors = CorsLayer::new().allow_origin("https://app.example");
350 assert!(
351 cors.response_headers(&headers(&[(header::ORIGIN, "https://evil.example")]))
352 .is_empty()
353 );
354 assert_eq!(
355 names(
356 &cors.response_headers(&headers(&[(header::ORIGIN, "https://app.example")])),
357 &header::ACCESS_CONTROL_ALLOW_ORIGIN
358 ),
359 ["https://app.example"]
360 );
361 }
362
363 #[test]
364 fn preflight_advertises_methods_mirrored_headers_and_max_age() {
365 let out = CorsLayer::permissive().preflight_headers(&headers(&[
366 (header::ORIGIN, "https://x.example"),
367 (header::ACCESS_CONTROL_REQUEST_METHOD, "POST"),
368 (
369 header::ACCESS_CONTROL_REQUEST_HEADERS,
370 "content-type, authorization",
371 ),
372 ]));
373 let methods = &names(&out, &header::ACCESS_CONTROL_ALLOW_METHODS)[0];
374 assert!(methods.contains("POST") && methods.contains("DELETE"));
375 assert_eq!(
376 names(&out, &header::ACCESS_CONTROL_ALLOW_HEADERS),
377 ["content-type, authorization"]
378 );
379 assert_eq!(names(&out, &header::ACCESS_CONTROL_MAX_AGE), ["86400"]);
380 }
381
382 #[test]
383 fn credentials_never_sends_star() {
384 let cors = CorsLayer::permissive().allow_credentials(true);
385 let out = cors.response_headers(&headers(&[(header::ORIGIN, "https://x.example")]));
386 assert_eq!(
387 names(&out, &header::ACCESS_CONTROL_ALLOW_ORIGIN),
388 ["https://x.example"]
389 );
390 assert_eq!(
391 names(&out, &header::ACCESS_CONTROL_ALLOW_CREDENTIALS),
392 ["true"]
393 );
394 }
395
396 #[test]
397 fn apply_appends_vary_but_replaces_acao() {
398 let cors = CorsLayer::permissive();
399 let mut into = HeaderMap::new();
400 into.insert(header::VARY, HeaderValue::from_static("Accept-Encoding"));
401 cors.apply(
402 &headers(&[(header::ORIGIN, "https://x.example")]),
403 &mut into,
404 false,
405 );
406 let vary: Vec<_> = into
407 .get_all(header::VARY)
408 .iter()
409 .map(|v| v.to_str().unwrap().to_string())
410 .collect();
411 assert_eq!(vary, ["Accept-Encoding", "Origin"]);
412 assert_eq!(
413 into.get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(),
414 "https://x.example"
415 );
416 }
417}