1#[cfg(all(not(feature = "std"), feature = "alloc"))]
32use alloc::{string::String, vec::Vec};
33use core::fmt;
34#[cfg(feature = "serde")]
35use serde::{Deserialize, Serialize};
36
37#[derive(Debug, Clone, PartialEq, Eq)]
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
47#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
49pub enum CorsOrigin {
50 Any,
52 Origin(String),
54}
55
56impl fmt::Display for CorsOrigin {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 match self {
59 Self::Any => f.write_str("*"),
60 Self::Origin(url) => f.write_str(url),
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Default)]
74#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
75#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
76#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
77#[non_exhaustive]
78pub struct CorsHeaders {
79 pub allow_origin: Option<CorsOrigin>,
81 pub allow_methods: Option<Vec<String>>,
83 pub allow_headers: Option<Vec<String>>,
85 pub expose_headers: Option<Vec<String>>,
87 pub max_age: Option<u64>,
89 pub allow_credentials: Option<bool>,
91}
92
93impl CorsHeaders {
94 #[must_use]
96 pub fn new() -> Self {
97 Self::default()
98 }
99
100 #[must_use]
113 pub fn allow_origin(mut self, origin: CorsOrigin) -> Self {
114 self.allow_origin = Some(origin);
115 self
116 }
117
118 #[must_use]
128 pub fn allow_methods<I>(mut self, methods: I) -> Self
129 where
130 I: IntoIterator,
131 I::Item: Into<String>,
132 {
133 self.allow_methods = Some(methods.into_iter().map(Into::into).collect());
134 self
135 }
136
137 #[must_use]
139 pub fn allow_headers<I>(mut self, headers: I) -> Self
140 where
141 I: IntoIterator,
142 I::Item: Into<String>,
143 {
144 self.allow_headers = Some(headers.into_iter().map(Into::into).collect());
145 self
146 }
147
148 #[must_use]
150 pub fn expose_headers<I>(mut self, headers: I) -> Self
151 where
152 I: IntoIterator,
153 I::Item: Into<String>,
154 {
155 self.expose_headers = Some(headers.into_iter().map(Into::into).collect());
156 self
157 }
158
159 #[must_use]
161 pub fn max_age(mut self, seconds: u64) -> Self {
162 self.max_age = Some(seconds);
163 self
164 }
165
166 #[must_use]
172 pub fn allow_credentials(mut self, allow: bool) -> Self {
173 self.allow_credentials = Some(allow);
174 self
175 }
176
177 #[must_use]
198 pub fn preflight<M, H>(origin: CorsOrigin, methods: M, headers: H) -> Self
199 where
200 M: IntoIterator,
201 M::Item: Into<String>,
202 H: IntoIterator,
203 H::Item: Into<String>,
204 {
205 Self::new()
206 .allow_origin(origin)
207 .allow_methods(methods)
208 .allow_headers(headers)
209 }
210
211 #[must_use]
219 pub fn allow_methods_header(&self) -> Option<String> {
220 self.allow_methods.as_ref().map(|m| m.join(", "))
221 }
222
223 #[must_use]
227 pub fn allow_headers_header(&self) -> Option<String> {
228 self.allow_headers.as_ref().map(|h| h.join(", "))
229 }
230
231 #[must_use]
235 pub fn expose_headers_header(&self) -> Option<String> {
236 self.expose_headers.as_ref().map(|h| h.join(", "))
237 }
238}
239
240#[cfg(feature = "axum")]
245mod axum_support {
246 use super::CorsHeaders;
247 use axum::http::HeaderValue;
248 use axum::response::{IntoResponseParts, ResponseParts};
249
250 impl IntoResponseParts for CorsHeaders {
251 type Error = std::convert::Infallible;
252
253 fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
254 let headers = res.headers_mut();
255
256 if let Some(origin) = &self.allow_origin
257 && let Ok(v) = HeaderValue::from_str(&origin.to_string())
258 {
259 headers.insert("access-control-allow-origin", v);
260 }
261 if let Some(methods) = &self.allow_methods
262 && let Ok(v) = HeaderValue::from_str(&methods.join(", "))
263 {
264 headers.insert("access-control-allow-methods", v);
265 }
266 if let Some(hdrs) = &self.allow_headers
267 && let Ok(v) = HeaderValue::from_str(&hdrs.join(", "))
268 {
269 headers.insert("access-control-allow-headers", v);
270 }
271 if let Some(expose) = &self.expose_headers
272 && let Ok(v) = HeaderValue::from_str(&expose.join(", "))
273 {
274 headers.insert("access-control-expose-headers", v);
275 }
276 if let Some(max_age) = self.max_age
277 && let Ok(v) = HeaderValue::from_str(&max_age.to_string())
278 {
279 headers.insert("access-control-max-age", v);
280 }
281 if let Some(creds) = self.allow_credentials {
282 let val = if creds { "true" } else { "false" };
283 let v = HeaderValue::from_static(val);
284 headers.insert("access-control-allow-credentials", v);
285 }
286
287 Ok(res)
288 }
289 }
290}
291
292#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn default_all_none() {
302 let cors = CorsHeaders::new();
303 assert!(cors.allow_origin.is_none());
304 assert!(cors.allow_methods.is_none());
305 assert!(cors.allow_headers.is_none());
306 assert!(cors.expose_headers.is_none());
307 assert!(cors.max_age.is_none());
308 assert!(cors.allow_credentials.is_none());
309 }
310
311 #[test]
312 fn builder_allow_origin_any() {
313 let cors = CorsHeaders::new().allow_origin(CorsOrigin::Any);
314 assert_eq!(cors.allow_origin.unwrap().to_string(), "*");
315 }
316
317 #[test]
318 fn builder_allow_origin_specific() {
319 let cors =
320 CorsHeaders::new().allow_origin(CorsOrigin::Origin("https://example.com".into()));
321 assert_eq!(
322 cors.allow_origin.unwrap().to_string(),
323 "https://example.com"
324 );
325 }
326
327 #[test]
328 fn builder_allow_methods() {
329 let cors = CorsHeaders::new().allow_methods(["GET", "POST", "DELETE"]);
330 let methods = cors.allow_methods.unwrap();
331 assert!(methods.contains(&"GET".to_string()));
332 assert!(methods.contains(&"POST".to_string()));
333 assert_eq!(methods.len(), 3);
334 }
335
336 #[test]
337 fn builder_allow_headers() {
338 let cors = CorsHeaders::new().allow_headers(["Content-Type", "Authorization"]);
339 let hdrs = cors.allow_headers.unwrap();
340 assert!(hdrs.contains(&"Content-Type".to_string()));
341 }
342
343 #[test]
344 fn builder_expose_headers() {
345 let cors = CorsHeaders::new().expose_headers(["X-Request-Id"]);
346 assert_eq!(cors.expose_headers_header().unwrap(), "X-Request-Id");
347 }
348
349 #[test]
350 fn builder_max_age() {
351 let cors = CorsHeaders::new().max_age(3600);
352 assert_eq!(cors.max_age, Some(3600));
353 }
354
355 #[test]
356 fn builder_allow_credentials() {
357 let cors = CorsHeaders::new().allow_credentials(true);
358 assert_eq!(cors.allow_credentials, Some(true));
359 }
360
361 #[test]
362 fn header_value_accessors() {
363 let cors = CorsHeaders::new()
364 .allow_methods(["GET", "POST"])
365 .allow_headers(["Content-Type"]);
366 assert_eq!(cors.allow_methods_header().unwrap(), "GET, POST");
367 assert_eq!(cors.allow_headers_header().unwrap(), "Content-Type");
368 assert!(cors.expose_headers_header().is_none());
369 }
370
371 #[test]
372 fn preflight_constructor() {
373 let p = CorsHeaders::preflight(
374 CorsOrigin::Origin("https://app.example.com".into()),
375 ["GET", "POST"],
376 ["Content-Type", "Authorization"],
377 );
378 assert!(p.allow_origin.is_some());
379 assert_eq!(p.allow_methods.as_ref().unwrap().len(), 2);
380 assert_eq!(p.allow_headers.as_ref().unwrap().len(), 2);
381 assert!(p.allow_credentials.is_none());
382 }
383
384 #[test]
385 fn cors_origin_display() {
386 assert_eq!(CorsOrigin::Any.to_string(), "*");
387 assert_eq!(
388 CorsOrigin::Origin("https://x.com".into()).to_string(),
389 "https://x.com"
390 );
391 }
392
393 #[cfg(feature = "axum")]
394 #[test]
395 fn into_response_parts_sets_headers() {
396 use axum::response::IntoResponse;
397
398 let cors = CorsHeaders::new()
399 .allow_origin(CorsOrigin::Any)
400 .allow_methods(["GET"])
401 .max_age(600);
402
403 let response = (cors, axum::http::StatusCode::NO_CONTENT).into_response();
404 let headers = response.headers();
405
406 assert_eq!(
407 headers
408 .get("access-control-allow-origin")
409 .unwrap()
410 .to_str()
411 .unwrap(),
412 "*"
413 );
414 assert_eq!(
415 headers
416 .get("access-control-max-age")
417 .unwrap()
418 .to_str()
419 .unwrap(),
420 "600"
421 );
422 }
423}