1use axum::http::{header, HeaderValue, Method};
2use tower_http::cors::{AllowOrigin, CorsLayer};
3
4pub fn cors_permissive() -> CorsLayer {
9 CorsLayer::permissive()
10}
11
12pub fn cors_allowing<I, S>(origins: I) -> CorsLayer
29where
30 I: IntoIterator<Item = S>,
31 S: AsRef<str>,
32{
33 let origins: Vec<HeaderValue> = origins
34 .into_iter()
35 .filter_map(|origin| HeaderValue::from_str(origin.as_ref()).ok())
36 .collect();
37
38 CorsLayer::new()
39 .allow_origin(AllowOrigin::list(origins))
40 .allow_methods([
41 Method::GET,
42 Method::POST,
43 Method::PUT,
44 Method::PATCH,
45 Method::DELETE,
46 Method::OPTIONS,
47 ])
48 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION])
49 .allow_credentials(true)
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55 use axum::{body::Body, http::Request, routing::get, Router};
56 use tower::ServiceExt;
57
58 #[tokio::test]
59 async fn allows_listed_origin() {
60 let app = Router::new()
61 .route("/", get(|| async { "ok" }))
62 .layer(cors_allowing(["https://app.example.com"]));
63
64 let res = app
65 .oneshot(
66 Request::builder()
67 .uri("/")
68 .header("origin", "https://app.example.com")
69 .body(Body::empty())
70 .unwrap(),
71 )
72 .await
73 .unwrap();
74
75 assert_eq!(
76 res.headers().get("access-control-allow-origin").unwrap(),
77 "https://app.example.com"
78 );
79 }
80
81 #[tokio::test]
82 async fn omits_header_for_unlisted_origin() {
83 let app = Router::new()
84 .route("/", get(|| async { "ok" }))
85 .layer(cors_allowing(["https://app.example.com"]));
86
87 let res = app
88 .oneshot(
89 Request::builder()
90 .uri("/")
91 .header("origin", "https://evil.example.com")
92 .body(Body::empty())
93 .unwrap(),
94 )
95 .await
96 .unwrap();
97
98 assert!(res.headers().get("access-control-allow-origin").is_none());
99 }
100}