api_tools/server/axum/layers/
cors.rs

1//! CORS layer for Axum
2
3use axum::http::{HeaderName, HeaderValue, Method};
4use tower_http::cors::{AllowOrigin, Any, CorsLayer};
5
6/// CORS configuration
7///
8/// # Example
9///
10/// ```rust
11/// use axum::http::{header, HeaderName, HeaderValue, Method};
12/// use api_tools::server::axum::layers::cors::CorsConfig;
13///
14/// let cors_config = CorsConfig {
15///     allow_origin: "*",
16///     allow_methods: vec![Method::GET, Method::POST, Method::PUT, Method::PATCH, Method::DELETE],
17///     allow_headers: vec![header::AUTHORIZATION, header::ACCEPT, header::CONTENT_TYPE, header::ORIGIN],
18/// };
19/// ```
20pub struct CorsConfig<'a> {
21    pub allow_origin: &'a str,
22    pub allow_methods: Vec<Method>,
23    pub allow_headers: Vec<HeaderName>,
24}
25
26/// CORS layer
27///
28/// This function creates a CORS layer for Axum with the specified configuration.
29///
30/// # Example
31///
32/// ```rust
33/// use axum::http::{header, HeaderName, HeaderValue, Method};
34/// use api_tools::server::axum::layers::cors::{cors, CorsConfig};
35///
36/// let cors_config = CorsConfig {
37///     allow_origin: "*",
38///     allow_methods: vec![Method::GET, Method::POST, Method::PUT, Method::PATCH, Method::DELETE],
39///     allow_headers: vec![header::AUTHORIZATION, header::ACCEPT, header::CONTENT_TYPE, header::ORIGIN],
40/// };
41///
42/// let layer = cors(cors_config);
43/// ```
44pub fn cors(config: CorsConfig) -> CorsLayer {
45    let allow_origin = config.allow_origin;
46
47    let layer = CorsLayer::new()
48        .allow_methods(config.allow_methods)
49        .allow_headers(config.allow_headers);
50
51    if allow_origin == "*" {
52        layer.allow_origin(Any)
53    } else {
54        let origins = allow_origin
55            .split(',')
56            .filter(|url| *url != "*" && !url.is_empty())
57            .filter_map(|url| url.parse().ok())
58            .collect::<Vec<HeaderValue>>();
59
60        if origins.is_empty() {
61            layer.allow_origin(Any)
62        } else {
63            layer
64                .allow_origin(AllowOrigin::predicate(move |origin: &HeaderValue, _| {
65                    origins.contains(origin)
66                }))
67                .allow_credentials(true)
68        }
69    }
70}