Skip to main content

modo/middleware/
cors.rs

1use http::{HeaderName, HeaderValue, Method};
2use serde::Deserialize;
3use tower_http::cors::{AllowOrigin, CorsLayer};
4
5/// Configuration for CORS middleware.
6///
7/// When `origins` is empty (the default), the layer permits any origin
8/// (`Access-Control-Allow-Origin: *`) and forces `allow_credentials` to
9/// `false` — the CORS spec forbids `*` with credentials.
10///
11/// When one or more origins are specified, only those exact values are
12/// reflected back.
13#[non_exhaustive]
14#[derive(Debug, Clone, Deserialize)]
15#[serde(default)]
16pub struct CorsConfig {
17    /// Allowed origin URLs (e.g. `["https://example.com"]`).
18    /// Empty means allow any origin.
19    pub origins: Vec<String>,
20    /// Allowed HTTP methods.
21    pub methods: Vec<String>,
22    /// Allowed request headers.
23    pub headers: Vec<String>,
24    /// Value for `Access-Control-Max-Age` in seconds.
25    pub max_age_secs: u64,
26    /// Whether to set `Access-Control-Allow-Credentials: true`.
27    /// Ignored when `origins` is empty (forced to `false`).
28    pub allow_credentials: bool,
29}
30
31impl Default for CorsConfig {
32    fn default() -> Self {
33        Self {
34            origins: vec![],
35            methods: vec!["GET", "POST", "PUT", "DELETE", "PATCH"]
36                .into_iter()
37                .map(String::from)
38                .collect(),
39            headers: vec!["Content-Type", "Authorization"]
40                .into_iter()
41                .map(String::from)
42                .collect(),
43            max_age_secs: 86400,
44            allow_credentials: true,
45        }
46    }
47}
48
49/// Returns a [`CorsLayer`] configured from static origin values.
50///
51/// When `config.origins` is empty, any origin is allowed and credentials
52/// are disabled. Otherwise only the listed origins are reflected.
53///
54/// # Example
55///
56/// ```rust,no_run
57/// use modo::middleware::{cors, CorsConfig};
58///
59/// let mut config = CorsConfig::default();
60/// config.origins = vec!["https://example.com".to_string()];
61/// let layer = cors(&config);
62/// ```
63pub fn cors(config: &CorsConfig) -> CorsLayer {
64    let origins: Vec<HeaderValue> = config
65        .origins
66        .iter()
67        .filter_map(|o| o.parse().ok())
68        .collect();
69
70    let methods: Vec<Method> = config
71        .methods
72        .iter()
73        .filter_map(|m| m.parse().ok())
74        .collect();
75
76    let headers: Vec<HeaderName> = config
77        .headers
78        .iter()
79        .filter_map(|h| h.parse().ok())
80        .collect();
81
82    let mut layer = CorsLayer::new()
83        .allow_methods(methods)
84        .allow_headers(headers)
85        .max_age(std::time::Duration::from_secs(config.max_age_secs));
86
87    if origins.is_empty() {
88        // CORS spec forbids Access-Control-Allow-Origin: * with credentials
89        layer = layer
90            .allow_origin(tower_http::cors::Any)
91            .allow_credentials(false);
92    } else {
93        layer = layer.allow_origin(origins);
94        if config.allow_credentials {
95            layer = layer.allow_credentials(true);
96        }
97    }
98
99    layer
100}
101
102/// Returns a [`CorsLayer`] that delegates origin decisions to `predicate`.
103///
104/// Use this when the set of allowed origins is dynamic (e.g. loaded from a
105/// database) or when you need pattern matching such as subdomain wildcards.
106///
107/// # Example
108///
109/// ```rust,no_run
110/// use modo::middleware::{cors_with, subdomains, CorsConfig};
111///
112/// let config = CorsConfig::default();
113/// let layer = cors_with(&config, subdomains("example.com"));
114/// ```
115pub fn cors_with<F>(config: &CorsConfig, predicate: F) -> CorsLayer
116where
117    F: Fn(&HeaderValue, &http::request::Parts) -> bool + Clone + Send + Sync + 'static,
118{
119    let methods: Vec<Method> = config
120        .methods
121        .iter()
122        .filter_map(|m| m.parse().ok())
123        .collect();
124
125    let headers: Vec<HeaderName> = config
126        .headers
127        .iter()
128        .filter_map(|h| h.parse().ok())
129        .collect();
130
131    let mut layer = CorsLayer::new()
132        .allow_origin(AllowOrigin::predicate(predicate))
133        .allow_methods(methods)
134        .allow_headers(headers)
135        .max_age(std::time::Duration::from_secs(config.max_age_secs));
136
137    if config.allow_credentials {
138        layer = layer.allow_credentials(true);
139    }
140
141    layer
142}
143
144/// Returns a predicate that matches origins against an exact list of URLs.
145///
146/// # Example
147///
148/// ```rust,no_run
149/// use modo::middleware::{cors_with, urls, CorsConfig};
150///
151/// let config = CorsConfig::default();
152/// let layer = cors_with(&config, urls(&["https://example.com".to_string()]));
153/// ```
154pub fn urls(
155    origins: &[String],
156) -> impl Fn(&HeaderValue, &http::request::Parts) -> bool + Clone + use<> {
157    let allowed: Vec<String> = origins.to_vec();
158    move |origin: &HeaderValue, _parts: &http::request::Parts| {
159        origin
160            .to_str()
161            .map(|o| allowed.iter().any(|a| a == o))
162            .unwrap_or(false)
163    }
164}
165
166/// Returns a predicate that matches any subdomain of `domain` (including the
167/// domain itself). Both `http://` and `https://` schemes are accepted.
168///
169/// # Example
170///
171/// ```rust,no_run
172/// use modo::middleware::{cors_with, subdomains, CorsConfig};
173///
174/// let config = CorsConfig::default();
175/// // Matches https://example.com, https://api.example.com, etc.
176/// let layer = cors_with(&config, subdomains("example.com"));
177/// ```
178pub fn subdomains(
179    domain: &str,
180) -> impl Fn(&HeaderValue, &http::request::Parts) -> bool + Clone + use<> {
181    let suffix = format!(".{domain}");
182    let exact = domain.to_string();
183    move |origin: &HeaderValue, _parts: &http::request::Parts| {
184        origin
185            .to_str()
186            .map(|o| {
187                if let Some(host) = o
188                    .strip_prefix("https://")
189                    .or_else(|| o.strip_prefix("http://"))
190                {
191                    host == exact || host.ends_with(&suffix)
192                } else {
193                    false
194                }
195            })
196            .unwrap_or(false)
197    }
198}