Skip to main content

axum_api_kit/
cors.rs

1use axum::http::{header, HeaderValue, Method};
2use tower_http::cors::{AllowOrigin, CorsLayer};
3
4/// A permissive CORS layer (any origin, method, and header).
5///
6/// Convenient for local development; prefer [`cors_allowing`] in production so credentialed
7/// requests and a known origin allow-list are handled correctly. Requires the `cors` feature.
8pub fn cors_permissive() -> CorsLayer {
9    CorsLayer::permissive()
10}
11
12/// Build a CORS layer that allows the given origins.
13///
14/// Allows the common REST methods (`GET`, `POST`, `PUT`, `PATCH`, `DELETE`, `OPTIONS`), the
15/// `content-type` and `authorization` request headers, and credentials. Origins that fail to
16/// parse as header values are skipped. Requires the `cors` feature.
17///
18/// # Example
19///
20/// ```rust,no_run
21/// use axum::{routing::get, Router};
22/// use axum_api_kit::cors_allowing;
23///
24/// let app: Router = Router::new()
25///     .route("/", get(|| async { "ok" }))
26///     .layer(cors_allowing(["https://app.example.com"]));
27/// ```
28pub fn cors_allowing<I, S>(origins: I) -> CorsLayer
29where
30    I: IntoIterator<Item = S>,
31    S: AsRef<str>,
32{
33    let origins: Vec<HeaderValue> = origins
34        .into_iter()
35        .filter_map(|origin| HeaderValue::from_str(origin.as_ref()).ok())
36        .collect();
37
38    CorsLayer::new()
39        .allow_origin(AllowOrigin::list(origins))
40        .allow_methods([
41            Method::GET,
42            Method::POST,
43            Method::PUT,
44            Method::PATCH,
45            Method::DELETE,
46            Method::OPTIONS,
47        ])
48        .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION])
49        .allow_credentials(true)
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55    use axum::{body::Body, http::Request, routing::get, Router};
56    use tower::ServiceExt;
57
58    #[tokio::test]
59    async fn allows_listed_origin() {
60        let app = Router::new()
61            .route("/", get(|| async { "ok" }))
62            .layer(cors_allowing(["https://app.example.com"]));
63
64        let res = app
65            .oneshot(
66                Request::builder()
67                    .uri("/")
68                    .header("origin", "https://app.example.com")
69                    .body(Body::empty())
70                    .unwrap(),
71            )
72            .await
73            .unwrap();
74
75        assert_eq!(
76            res.headers().get("access-control-allow-origin").unwrap(),
77            "https://app.example.com"
78        );
79    }
80
81    #[tokio::test]
82    async fn omits_header_for_unlisted_origin() {
83        let app = Router::new()
84            .route("/", get(|| async { "ok" }))
85            .layer(cors_allowing(["https://app.example.com"]));
86
87        let res = app
88            .oneshot(
89                Request::builder()
90                    .uri("/")
91                    .header("origin", "https://evil.example.com")
92                    .body(Body::empty())
93                    .unwrap(),
94            )
95            .await
96            .unwrap();
97
98        assert!(res.headers().get("access-control-allow-origin").is_none());
99    }
100}