api_tools/server/axum/layers/cors.rs
1//! CORS layer for Axum
2
3use axum::http::{HeaderName, HeaderValue, Method};
4use tower_http::cors::{AllowOrigin, Any, CorsLayer};
5
6/// CORS configuration
7///
8/// # Example
9///
10/// ```rust
11/// use axum::http::{header, HeaderName, HeaderValue, Method};
12/// use api_tools::server::axum::layers::cors::CorsConfig;
13///
14/// let cors_config = CorsConfig {
15/// allow_origin: "*",
16/// allow_methods: vec![Method::GET, Method::POST, Method::PUT, Method::PATCH, Method::DELETE],
17/// allow_headers: vec![header::AUTHORIZATION, header::ACCEPT, header::CONTENT_TYPE, header::ORIGIN],
18/// };
19/// ```
20pub struct CorsConfig<'a> {
21 pub allow_origin: &'a str,
22 pub allow_methods: Vec<Method>,
23 pub allow_headers: Vec<HeaderName>,
24}
25
26/// CORS layer
27///
28/// This function creates a CORS layer for Axum with the specified configuration.
29///
30/// # Example
31///
32/// ```rust
33/// use axum::http::{header, HeaderName, HeaderValue, Method};
34/// use api_tools::server::axum::layers::cors::{cors, CorsConfig};
35///
36/// let cors_config = CorsConfig {
37/// allow_origin: "*",
38/// allow_methods: vec![Method::GET, Method::POST, Method::PUT, Method::PATCH, Method::DELETE],
39/// allow_headers: vec![header::AUTHORIZATION, header::ACCEPT, header::CONTENT_TYPE, header::ORIGIN],
40/// };
41///
42/// let layer = cors(cors_config);
43/// ```
44pub fn cors(config: CorsConfig) -> CorsLayer {
45 let allow_origin = config.allow_origin;
46
47 let layer = CorsLayer::new()
48 .allow_methods(config.allow_methods)
49 .allow_headers(config.allow_headers);
50
51 if allow_origin == "*" {
52 layer.allow_origin(Any)
53 } else {
54 let origins = allow_origin
55 .split(',')
56 .filter(|url| *url != "*" && !url.is_empty())
57 .filter_map(|url| url.parse().ok())
58 .collect::<Vec<HeaderValue>>();
59
60 if origins.is_empty() {
61 layer.allow_origin(Any)
62 } else {
63 layer
64 .allow_origin(AllowOrigin::predicate(move |origin: &HeaderValue, _| {
65 origins.contains(origin)
66 }))
67 .allow_credentials(true)
68 }
69 }
70}