modo/middleware/cors.rs
1use http::{HeaderName, HeaderValue, Method};
2use serde::Deserialize;
3use tower_http::cors::{AllowOrigin, CorsLayer};
4
5/// Configuration for CORS middleware.
6///
7/// When `origins` is empty (the default), the layer permits any origin
8/// (`Access-Control-Allow-Origin: *`) and forces `allow_credentials` to
9/// `false` — the CORS spec forbids `*` with credentials.
10///
11/// When one or more origins are specified, only those exact values are
12/// reflected back.
13#[non_exhaustive]
14#[derive(Debug, Clone, Deserialize)]
15#[serde(default)]
16pub struct CorsConfig {
17 /// Allowed origin URLs (e.g. `["https://example.com"]`).
18 /// Empty means allow any origin.
19 pub origins: Vec<String>,
20 /// Allowed HTTP methods.
21 pub methods: Vec<String>,
22 /// Allowed request headers.
23 pub headers: Vec<String>,
24 /// Value for `Access-Control-Max-Age` in seconds.
25 pub max_age_secs: u64,
26 /// Whether to set `Access-Control-Allow-Credentials: true`.
27 /// Ignored when `origins` is empty (forced to `false`).
28 pub allow_credentials: bool,
29}
30
31impl Default for CorsConfig {
32 fn default() -> Self {
33 Self {
34 origins: vec![],
35 methods: vec!["GET", "POST", "PUT", "DELETE", "PATCH"]
36 .into_iter()
37 .map(String::from)
38 .collect(),
39 headers: vec!["Content-Type", "Authorization"]
40 .into_iter()
41 .map(String::from)
42 .collect(),
43 max_age_secs: 86400,
44 allow_credentials: true,
45 }
46 }
47}
48
49/// Returns a [`CorsLayer`] configured from static origin values.
50///
51/// When `config.origins` is empty, any origin is allowed and credentials
52/// are disabled. Otherwise only the listed origins are reflected.
53///
54/// # Example
55///
56/// ```rust,no_run
57/// use modo::middleware::{cors, CorsConfig};
58///
59/// let mut config = CorsConfig::default();
60/// config.origins = vec!["https://example.com".to_string()];
61/// let layer = cors(&config);
62/// ```
63pub fn cors(config: &CorsConfig) -> CorsLayer {
64 let origins: Vec<HeaderValue> = config
65 .origins
66 .iter()
67 .filter_map(|o| o.parse().ok())
68 .collect();
69
70 let methods: Vec<Method> = config
71 .methods
72 .iter()
73 .filter_map(|m| m.parse().ok())
74 .collect();
75
76 let headers: Vec<HeaderName> = config
77 .headers
78 .iter()
79 .filter_map(|h| h.parse().ok())
80 .collect();
81
82 let mut layer = CorsLayer::new()
83 .allow_methods(methods)
84 .allow_headers(headers)
85 .max_age(std::time::Duration::from_secs(config.max_age_secs));
86
87 if origins.is_empty() {
88 // CORS spec forbids Access-Control-Allow-Origin: * with credentials
89 layer = layer
90 .allow_origin(tower_http::cors::Any)
91 .allow_credentials(false);
92 } else {
93 layer = layer.allow_origin(origins);
94 if config.allow_credentials {
95 layer = layer.allow_credentials(true);
96 }
97 }
98
99 layer
100}
101
102/// Returns a [`CorsLayer`] that delegates origin decisions to `predicate`.
103///
104/// Use this when the set of allowed origins is dynamic (e.g. loaded from a
105/// database) or when you need pattern matching such as subdomain wildcards.
106///
107/// # Example
108///
109/// ```rust,no_run
110/// use modo::middleware::{cors_with, subdomains, CorsConfig};
111///
112/// let config = CorsConfig::default();
113/// let layer = cors_with(&config, subdomains("example.com"));
114/// ```
115pub fn cors_with<F>(config: &CorsConfig, predicate: F) -> CorsLayer
116where
117 F: Fn(&HeaderValue, &http::request::Parts) -> bool + Clone + Send + Sync + 'static,
118{
119 let methods: Vec<Method> = config
120 .methods
121 .iter()
122 .filter_map(|m| m.parse().ok())
123 .collect();
124
125 let headers: Vec<HeaderName> = config
126 .headers
127 .iter()
128 .filter_map(|h| h.parse().ok())
129 .collect();
130
131 let mut layer = CorsLayer::new()
132 .allow_origin(AllowOrigin::predicate(predicate))
133 .allow_methods(methods)
134 .allow_headers(headers)
135 .max_age(std::time::Duration::from_secs(config.max_age_secs));
136
137 if config.allow_credentials {
138 layer = layer.allow_credentials(true);
139 }
140
141 layer
142}
143
144/// Returns a predicate that matches origins against an exact list of URLs.
145///
146/// # Example
147///
148/// ```rust,no_run
149/// use modo::middleware::{cors_with, urls, CorsConfig};
150///
151/// let config = CorsConfig::default();
152/// let layer = cors_with(&config, urls(&["https://example.com".to_string()]));
153/// ```
154pub fn urls(
155 origins: &[String],
156) -> impl Fn(&HeaderValue, &http::request::Parts) -> bool + Clone + use<> {
157 let allowed: Vec<String> = origins.to_vec();
158 move |origin: &HeaderValue, _parts: &http::request::Parts| {
159 origin
160 .to_str()
161 .map(|o| allowed.iter().any(|a| a == o))
162 .unwrap_or(false)
163 }
164}
165
166/// Returns a predicate that matches any subdomain of `domain` (including the
167/// domain itself). Both `http://` and `https://` schemes are accepted.
168///
169/// # Example
170///
171/// ```rust,no_run
172/// use modo::middleware::{cors_with, subdomains, CorsConfig};
173///
174/// let config = CorsConfig::default();
175/// // Matches https://example.com, https://api.example.com, etc.
176/// let layer = cors_with(&config, subdomains("example.com"));
177/// ```
178pub fn subdomains(
179 domain: &str,
180) -> impl Fn(&HeaderValue, &http::request::Parts) -> bool + Clone + use<> {
181 let suffix = format!(".{domain}");
182 let exact = domain.to_string();
183 move |origin: &HeaderValue, _parts: &http::request::Parts| {
184 origin
185 .to_str()
186 .map(|o| {
187 if let Some(host) = o
188 .strip_prefix("https://")
189 .or_else(|| o.strip_prefix("http://"))
190 {
191 host == exact || host.ends_with(&suffix)
192 } else {
193 false
194 }
195 })
196 .unwrap_or(false)
197 }
198}