Skip to main content

resuma/server/
security.rs

1//! Security primitives for Resuma HTTP servers — CSRF, rate limiting, headers, origin checks.
2//!
3//! Enabled by default on `ResumaApp::serve()` and `FlowApp::serve()`. Configure via
4//! [`SecurityConfig`] or environment variables (see `docs/SECURITY.md`).
5
6use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use crate::core::Result;
12use crate::core::ResumaError;
13use axum::http::{header, HeaderMap, HeaderValue, Request};
14use axum::response::Response;
15use once_cell::sync::Lazy;
16use parking_lot::RwLock;
17
18/// Per-request CSP nonce stored in response extensions after HTML render.
19#[derive(Clone, Debug)]
20pub struct CspNonce(pub String);
21
22/// Cookie name for double-submit CSRF protection.
23pub const CSRF_COOKIE: &str = "__resuma-csrf";
24/// Header clients must send on POST (actions + submits).
25pub const CSRF_HEADER: &str = "x-resuma-csrf";
26/// Form field name for progressive-enhancement submits.
27pub const CSRF_FIELD: &str = "_csrf";
28
29static CONFIG: Lazy<RwLock<SecurityConfig>> = Lazy::new(|| RwLock::new(SecurityConfig::from_env()));
30
31static RATE_BUCKETS: Lazy<RwLock<HashMap<String, Vec<Instant>>>> =
32    Lazy::new(|| RwLock::new(HashMap::new()));
33
34/// Global security configuration (shared by ResumaApp and FlowApp).
35#[derive(Debug, Clone)]
36pub struct SecurityConfig {
37    /// Require CSRF token on `POST /_resuma/action/*` and `POST /_resuma/submit/*`.
38    pub csrf: bool,
39    /// Validate `Origin` / `Referer` on mutating requests (same-origin).
40    pub origin_check: bool,
41    /// Trust `X-Forwarded-For` / `X-Forwarded-Proto` (set `RESUMA_TRUST_PROXY=1` behind Fly/nginx).
42    pub trust_proxy: bool,
43    /// Max POST body size in bytes.
44    pub body_limit_bytes: usize,
45    /// Max action RPC calls per client IP per minute.
46    pub actions_per_minute: u32,
47    /// Max form submits per client IP per minute.
48    pub submits_per_minute: u32,
49    /// Hide `/_resuma/benchmark.json` in production.
50    pub hide_benchmark: bool,
51    /// Sanitize error messages returned to clients.
52    pub production: bool,
53}
54
55impl Default for SecurityConfig {
56    fn default() -> Self {
57        Self::from_env()
58    }
59}
60
61impl SecurityConfig {
62    pub fn from_env() -> Self {
63        let production = matches!(
64            std::env::var("RESUMA_ENV").as_deref(),
65            Ok("production") | Ok("prod")
66        );
67        let trust_proxy = matches!(
68            std::env::var("RESUMA_TRUST_PROXY").as_deref(),
69            Ok("1") | Ok("true") | Ok("TRUE")
70        );
71        Self {
72            csrf: !env_flag_off("RESUMA_CSRF"),
73            origin_check: !env_flag_off("RESUMA_ORIGIN_CHECK"),
74            trust_proxy,
75            body_limit_bytes: std::env::var("RESUMA_BODY_LIMIT")
76                .ok()
77                .and_then(|v| v.parse().ok())
78                .unwrap_or(1024 * 1024),
79            actions_per_minute: std::env::var("RESUMA_RATE_ACTIONS")
80                .ok()
81                .and_then(|v| v.parse().ok())
82                .unwrap_or(120),
83            submits_per_minute: std::env::var("RESUMA_RATE_SUBMITS")
84                .ok()
85                .and_then(|v| v.parse().ok())
86                .unwrap_or(60),
87            hide_benchmark: production,
88            production,
89        }
90    }
91}
92
93fn env_flag_off(name: &str) -> bool {
94    matches!(
95        std::env::var(name).as_deref(),
96        Ok("0") | Ok("false") | Ok("FALSE") | Ok("off")
97    )
98}
99
100/// Install global security config (call before `serve()` to override env defaults).
101pub fn configure(config: SecurityConfig) {
102    *CONFIG.write() = config;
103}
104
105pub fn config() -> SecurityConfig {
106    CONFIG.read().clone()
107}
108
109/// Cryptographically random token (32 hex chars).
110pub fn random_token() -> String {
111    let mut bytes = [0u8; 16];
112    getrandom::getrandom(&mut bytes).expect("OS random number generator");
113    bytes.iter().map(|b| format!("{b:02x}")).collect()
114}
115
116pub fn csrf_token() -> String {
117    random_token()
118}
119
120/// True when the request arrived over HTTPS (direct TLS or `X-Forwarded-Proto`).
121pub fn request_is_https<B>(req: &Request<B>) -> bool {
122    let cfg = config();
123    if cfg.trust_proxy {
124        if let Some(proto) = req
125            .headers()
126            .get("x-forwarded-proto")
127            .and_then(|v| v.to_str().ok())
128        {
129            if proto.eq_ignore_ascii_case("https") {
130                return true;
131            }
132        }
133    }
134    req.uri().scheme_str() == Some("https")
135}
136
137/// Best-effort client IP for rate limiting.
138pub fn client_ip<B>(req: &Request<B>) -> String {
139    client_ip_from_parts(req.headers(), connect_addr(req))
140}
141
142pub fn client_ip_from_parts(headers: &HeaderMap, connect: Option<SocketAddr>) -> String {
143    let cfg = config();
144    if cfg.trust_proxy {
145        if let Some(xff) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
146            if let Some(first) = xff.split(',').next() {
147                let ip = first.trim();
148                if !ip.is_empty() {
149                    return ip.to_string();
150                }
151            }
152        }
153        if let Some(xri) = headers.get("x-real-ip").and_then(|v| v.to_str().ok()) {
154            if !xri.is_empty() {
155                return xri.to_string();
156            }
157        }
158    }
159    connect
160        .map(|a| a.ip().to_string())
161        .unwrap_or_else(|| "unknown".to_string())
162}
163
164fn connect_addr<B>(req: &Request<B>) -> Option<SocketAddr> {
165    req.extensions()
166        .get::<axum::extract::ConnectInfo<SocketAddr>>()
167        .map(|ci| ci.0)
168}
169
170/// Sliding-window rate limit. Returns `Err(RateLimited)` when exceeded.
171pub fn check_rate_limit(ip: &str, bucket: &str, limit_per_minute: u32) -> Result<()> {
172    if limit_per_minute == 0 {
173        return Ok(());
174    }
175    let key = format!("{bucket}:{ip}");
176    let now = Instant::now();
177    let window = Duration::from_secs(60);
178    let mut map = RATE_BUCKETS.write();
179    let entries = map.entry(key).or_default();
180    entries.retain(|t| now.duration_since(*t) < window);
181    if entries.len() as u32 >= limit_per_minute {
182        return Err(ResumaError::RateLimited);
183    }
184    entries.push(now);
185    Ok(())
186}
187
188fn header_str(headers: &HeaderMap, name: &str) -> Option<String> {
189    headers
190        .get(name)
191        .and_then(|v| v.to_str().ok())
192        .map(|s| s.to_string())
193}
194
195fn cookie_value(headers: &HeaderMap, name: &str) -> Option<String> {
196    let cookie = header_str(headers, header::COOKIE.as_str())?;
197    for part in cookie.split(';') {
198        let part = part.trim();
199        if let Some((k, v)) = part.split_once('=') {
200            if k.trim() == name {
201                return Some(v.trim().to_string());
202            }
203        }
204    }
205    None
206}
207
208/// Validate double-submit CSRF: header (or form field) must match cookie.
209pub fn validate_csrf(headers: &HeaderMap, form_csrf: Option<&str>) -> Result<()> {
210    let cfg = config();
211    if !cfg.csrf {
212        return Ok(());
213    }
214    let cookie = cookie_value(headers, CSRF_COOKIE).ok_or(ResumaError::InvalidCsrf)?;
215    let header = header_str(headers, CSRF_HEADER);
216    let token = header
217        .as_deref()
218        .or(form_csrf)
219        .ok_or(ResumaError::InvalidCsrf)?;
220    if token != cookie || token.len() < 16 {
221        return Err(ResumaError::InvalidCsrf);
222    }
223    Ok(())
224}
225
226/// Reject cross-origin POST when `Origin`/`Referer` do not match the host.
227pub fn validate_origin(headers: &HeaderMap, host: &str) -> Result<()> {
228    let cfg = config();
229    if !cfg.origin_check {
230        return Ok(());
231    }
232    let host = host.split(':').next().unwrap_or(host).to_lowercase();
233
234    if let Some(origin) = header_str(headers, header::ORIGIN.as_str()) {
235        if !origin_matches_host(&origin, &host) {
236            return Err(ResumaError::Forbidden("cross-origin request".into()));
237        }
238        return Ok(());
239    }
240
241    if let Some(referer) = header_str(headers, header::REFERER.as_str()) {
242        if !referer_host_matches(&referer, &host) {
243            return Err(ResumaError::Forbidden("invalid referer".into()));
244        }
245    }
246    Ok(())
247}
248
249fn origin_matches_host(origin: &str, host: &str) -> bool {
250    origin
251        .strip_prefix("http://")
252        .or_else(|| origin.strip_prefix("https://"))
253        .and_then(|rest| rest.split('/').next())
254        // Browsers include the port in `Origin` (e.g. `http://localhost:3000`);
255        // `host` arrives without it, so compare hostnames only.
256        .map(|authority| authority.split(':').next().unwrap_or(authority))
257        .map(|h| {
258            h.eq_ignore_ascii_case(host)
259                || h.strip_prefix("www.").unwrap_or(h) == host.strip_prefix("www.").unwrap_or(host)
260        })
261        .unwrap_or(false)
262}
263
264fn referer_host_matches(referer: &str, host: &str) -> bool {
265    referer
266        .strip_prefix("http://")
267        .or_else(|| referer.strip_prefix("https://"))
268        .and_then(|rest| rest.split('/').next())
269        .map(|authority| authority.split(':').next().unwrap_or(authority))
270        .map(|h| h.eq_ignore_ascii_case(host))
271        .unwrap_or(false)
272}
273
274/// Build `Set-Cookie` for CSRF double-submit.
275pub fn csrf_set_cookie(token: &str, https: bool) -> HeaderValue {
276    let secure = if https { "; Secure" } else { "" };
277    HeaderValue::from_str(&format!(
278        "{CSRF_COOKIE}={token}; Path=/; SameSite=Strict; HttpOnly{secure}"
279    ))
280    .unwrap_or_else(|_| HeaderValue::from_static("invalid"))
281}
282
283/// Options passed to [`apply_security_headers`].
284#[derive(Debug, Clone, Default)]
285pub struct SecurityHeaderOptions {
286    pub csp_nonce: Option<String>,
287    pub https: bool,
288}
289
290/// Apply standard security headers (Helmet-style baseline).
291pub fn apply_security_headers(mut response: Response, opts: &SecurityHeaderOptions) -> Response {
292    let headers = response.headers_mut();
293    if opts.https {
294        insert_header(
295            headers,
296            header::STRICT_TRANSPORT_SECURITY,
297            "max-age=63072000; includeSubDomains; preload",
298        );
299    }
300    insert_header(headers, header::X_FRAME_OPTIONS, "DENY");
301    insert_header(headers, header::X_CONTENT_TYPE_OPTIONS, "nosniff");
302    insert_header(
303        headers,
304        header::HeaderName::from_static("x-xss-protection"),
305        "0",
306    );
307    insert_header(
308        headers,
309        header::REFERRER_POLICY,
310        "strict-origin-when-cross-origin",
311    );
312    insert_header(
313        headers,
314        header::HeaderName::from_static("permissions-policy"),
315        "camera=(), microphone=(), geolocation=()",
316    );
317    insert_header(
318        headers,
319        header::HeaderName::from_static("cross-origin-opener-policy"),
320        "same-origin",
321    );
322    insert_header(
323        headers,
324        header::HeaderName::from_static("cross-origin-resource-policy"),
325        "same-origin",
326    );
327    insert_header(
328        headers,
329        header::HeaderName::from_static("x-dns-prefetch-control"),
330        "off",
331    );
332
333    let csp = if let Some(nonce) = &opts.csp_nonce {
334        // The current resumability runtime compiles small inline handlers,
335        // effects, and visible tasks with `new Function`. Keep CSP honest so
336        // enabled Resuma features work under the default security headers.
337        let mut policy = format!(
338            "default-src 'self'; script-src 'self' 'nonce-{nonce}' 'unsafe-eval'; style-src 'self' 'nonce-{nonce}'; img-src 'self' data:; font-src 'self'; connect-src 'self'; object-src 'none'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
339        );
340        if opts.https {
341            policy.push_str("; upgrade-insecure-requests");
342        }
343        policy
344    } else {
345        let mut policy = "default-src 'self'; script-src 'self' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self'; connect-src 'self'; object-src 'none'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'".to_string();
346        if opts.https {
347            policy.push_str("; upgrade-insecure-requests");
348        }
349        policy
350    };
351    insert_header(headers, header::CONTENT_SECURITY_POLICY, &csp);
352    response
353}
354
355fn insert_header(headers: &mut axum::http::HeaderMap, name: header::HeaderName, value: &str) {
356    if let Ok(v) = HeaderValue::from_str(value) {
357        headers.insert(name, v);
358    }
359}
360
361/// Guard mutating API requests (CSRF + origin + rate limit).
362pub fn guard_mutation(
363    headers: &HeaderMap,
364    host: &str,
365    ip: &str,
366    bucket: &str,
367    limit: u32,
368    form_csrf: Option<&str>,
369) -> Result<()> {
370    check_rate_limit(ip, bucket, limit)?;
371    validate_origin(headers, host)?;
372    validate_csrf(headers, form_csrf)?;
373    Ok(())
374}
375
376/// Map [`ResumaError`] to an HTTP status code.
377pub fn http_status(err: &ResumaError) -> axum::http::StatusCode {
378    axum::http::StatusCode::from_u16(err.status_code())
379        .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
380}
381
382/// Shared state for security-aware routers.
383#[derive(Clone, Default)]
384pub struct SecurityState {
385    pub config: Arc<SecurityConfig>,
386}
387
388impl SecurityState {
389    pub fn new(config: SecurityConfig) -> Self {
390        Self {
391            config: Arc::new(config),
392        }
393    }
394
395    pub fn current() -> Self {
396        Self::new(config())
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn origin_matches_ignoring_port() {
406        // Browsers send the port in `Origin`; `host` arrives without it.
407        assert!(origin_matches_host("http://localhost:3000", "localhost"));
408        assert!(origin_matches_host("http://127.0.0.1:3939", "127.0.0.1"));
409        assert!(origin_matches_host("https://example.com", "example.com"));
410        assert!(origin_matches_host(
411            "https://example.com:8443",
412            "example.com"
413        ));
414        assert!(origin_matches_host(
415            "https://www.example.com:443",
416            "example.com"
417        ));
418    }
419
420    #[test]
421    fn origin_rejects_other_hosts() {
422        assert!(!origin_matches_host("http://evil.test:3000", "localhost"));
423        assert!(!origin_matches_host(
424            "https://attacker.example",
425            "example.com"
426        ));
427    }
428
429    #[test]
430    fn referer_matches_ignoring_port() {
431        assert!(referer_host_matches(
432            "http://localhost:3000/items",
433            "localhost"
434        ));
435        assert!(referer_host_matches(
436            "https://example.com:8443/x",
437            "example.com"
438        ));
439        assert!(!referer_host_matches(
440            "http://evil.test:3000/x",
441            "localhost"
442        ));
443    }
444
445    #[test]
446    fn validate_origin_allows_same_host_with_port() {
447        let mut headers = HeaderMap::new();
448        headers.insert(header::ORIGIN, "http://localhost:3000".parse().unwrap());
449        // host carries the port as it would from the HTTP `Host` header.
450        assert!(validate_origin(&headers, "localhost:3000").is_ok());
451    }
452
453    #[test]
454    fn csp_allows_runtime_compiled_handlers() {
455        let res = Response::new(axum::body::Body::empty());
456        let res = apply_security_headers(
457            res,
458            &SecurityHeaderOptions {
459                csp_nonce: Some("abc123".into()),
460                https: false,
461            },
462        );
463        let csp = res
464            .headers()
465            .get(header::CONTENT_SECURITY_POLICY)
466            .and_then(|v| v.to_str().ok())
467            .unwrap();
468
469        assert!(csp.contains("'nonce-abc123'"));
470        assert!(csp.contains("'unsafe-eval'"));
471    }
472}