1use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use crate::core::Result;
12use crate::core::ResumaError;
13use axum::http::{header, HeaderMap, HeaderValue, Request};
14use axum::response::Response;
15use once_cell::sync::Lazy;
16use parking_lot::RwLock;
17
18#[derive(Clone, Debug)]
20pub struct CspNonce(pub String);
21
22pub const CSRF_COOKIE: &str = "__resuma-csrf";
24pub const CSRF_HEADER: &str = "x-resuma-csrf";
26pub const CSRF_FIELD: &str = "_csrf";
28
29static CONFIG: Lazy<RwLock<SecurityConfig>> = Lazy::new(|| RwLock::new(SecurityConfig::from_env()));
30
31static RATE_BUCKETS: Lazy<RwLock<HashMap<String, Vec<Instant>>>> =
32 Lazy::new(|| RwLock::new(HashMap::new()));
33
34#[derive(Debug, Clone)]
36pub struct SecurityConfig {
37 pub csrf: bool,
39 pub origin_check: bool,
41 pub trust_proxy: bool,
43 pub body_limit_bytes: usize,
45 pub actions_per_minute: u32,
47 pub submits_per_minute: u32,
49 pub hide_benchmark: bool,
51 pub production: bool,
53}
54
55impl Default for SecurityConfig {
56 fn default() -> Self {
57 Self::from_env()
58 }
59}
60
61impl SecurityConfig {
62 pub fn from_env() -> Self {
63 let production = matches!(
64 std::env::var("RESUMA_ENV").as_deref(),
65 Ok("production") | Ok("prod")
66 );
67 let trust_proxy = matches!(
68 std::env::var("RESUMA_TRUST_PROXY").as_deref(),
69 Ok("1") | Ok("true") | Ok("TRUE")
70 );
71 Self {
72 csrf: !env_flag_off("RESUMA_CSRF"),
73 origin_check: !env_flag_off("RESUMA_ORIGIN_CHECK"),
74 trust_proxy,
75 body_limit_bytes: std::env::var("RESUMA_BODY_LIMIT")
76 .ok()
77 .and_then(|v| v.parse().ok())
78 .unwrap_or(1024 * 1024),
79 actions_per_minute: std::env::var("RESUMA_RATE_ACTIONS")
80 .ok()
81 .and_then(|v| v.parse().ok())
82 .unwrap_or(120),
83 submits_per_minute: std::env::var("RESUMA_RATE_SUBMITS")
84 .ok()
85 .and_then(|v| v.parse().ok())
86 .unwrap_or(60),
87 hide_benchmark: production,
88 production,
89 }
90 }
91}
92
93fn env_flag_off(name: &str) -> bool {
94 matches!(
95 std::env::var(name).as_deref(),
96 Ok("0") | Ok("false") | Ok("FALSE") | Ok("off")
97 )
98}
99
100pub fn configure(config: SecurityConfig) {
102 *CONFIG.write() = config;
103}
104
105pub fn config() -> SecurityConfig {
106 CONFIG.read().clone()
107}
108
109pub fn random_token() -> String {
111 let mut bytes = [0u8; 16];
112 getrandom::getrandom(&mut bytes).expect("OS random number generator");
113 bytes.iter().map(|b| format!("{b:02x}")).collect()
114}
115
116pub fn csrf_token() -> String {
117 random_token()
118}
119
120pub fn request_is_https<B>(req: &Request<B>) -> bool {
122 let cfg = config();
123 if cfg.trust_proxy {
124 if let Some(proto) = req
125 .headers()
126 .get("x-forwarded-proto")
127 .and_then(|v| v.to_str().ok())
128 {
129 if proto.eq_ignore_ascii_case("https") {
130 return true;
131 }
132 }
133 }
134 req.uri().scheme_str() == Some("https")
135}
136
137pub fn client_ip<B>(req: &Request<B>) -> String {
139 client_ip_from_parts(req.headers(), connect_addr(req))
140}
141
142pub fn client_ip_from_parts(headers: &HeaderMap, connect: Option<SocketAddr>) -> String {
143 let cfg = config();
144 if cfg.trust_proxy {
145 if let Some(xff) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
146 if let Some(first) = xff.split(',').next() {
147 let ip = first.trim();
148 if !ip.is_empty() {
149 return ip.to_string();
150 }
151 }
152 }
153 if let Some(xri) = headers.get("x-real-ip").and_then(|v| v.to_str().ok()) {
154 if !xri.is_empty() {
155 return xri.to_string();
156 }
157 }
158 }
159 connect
160 .map(|a| a.ip().to_string())
161 .unwrap_or_else(|| "unknown".to_string())
162}
163
164fn connect_addr<B>(req: &Request<B>) -> Option<SocketAddr> {
165 req.extensions()
166 .get::<axum::extract::ConnectInfo<SocketAddr>>()
167 .map(|ci| ci.0)
168}
169
170pub fn check_rate_limit(ip: &str, bucket: &str, limit_per_minute: u32) -> Result<()> {
172 if limit_per_minute == 0 {
173 return Ok(());
174 }
175 let key = format!("{bucket}:{ip}");
176 let now = Instant::now();
177 let window = Duration::from_secs(60);
178 let mut map = RATE_BUCKETS.write();
179 let entries = map.entry(key).or_default();
180 entries.retain(|t| now.duration_since(*t) < window);
181 if entries.len() as u32 >= limit_per_minute {
182 return Err(ResumaError::RateLimited);
183 }
184 entries.push(now);
185 Ok(())
186}
187
188fn header_str(headers: &HeaderMap, name: &str) -> Option<String> {
189 headers
190 .get(name)
191 .and_then(|v| v.to_str().ok())
192 .map(|s| s.to_string())
193}
194
195fn cookie_value(headers: &HeaderMap, name: &str) -> Option<String> {
196 let cookie = header_str(headers, header::COOKIE.as_str())?;
197 for part in cookie.split(';') {
198 let part = part.trim();
199 if let Some((k, v)) = part.split_once('=') {
200 if k.trim() == name {
201 return Some(v.trim().to_string());
202 }
203 }
204 }
205 None
206}
207
208pub fn validate_csrf(headers: &HeaderMap, form_csrf: Option<&str>) -> Result<()> {
210 let cfg = config();
211 if !cfg.csrf {
212 return Ok(());
213 }
214 let cookie = cookie_value(headers, CSRF_COOKIE).ok_or(ResumaError::InvalidCsrf)?;
215 let header = header_str(headers, CSRF_HEADER);
216 let token = header
217 .as_deref()
218 .or(form_csrf)
219 .ok_or(ResumaError::InvalidCsrf)?;
220 if token != cookie || token.len() < 16 {
221 return Err(ResumaError::InvalidCsrf);
222 }
223 Ok(())
224}
225
226pub fn validate_origin(headers: &HeaderMap, host: &str) -> Result<()> {
228 let cfg = config();
229 if !cfg.origin_check {
230 return Ok(());
231 }
232 let host = host.split(':').next().unwrap_or(host).to_lowercase();
233
234 if let Some(origin) = header_str(headers, header::ORIGIN.as_str()) {
235 if !origin_matches_host(&origin, &host) {
236 return Err(ResumaError::Forbidden("cross-origin request".into()));
237 }
238 return Ok(());
239 }
240
241 if let Some(referer) = header_str(headers, header::REFERER.as_str()) {
242 if !referer_host_matches(&referer, &host) {
243 return Err(ResumaError::Forbidden("invalid referer".into()));
244 }
245 }
246 Ok(())
247}
248
249fn origin_matches_host(origin: &str, host: &str) -> bool {
250 origin
251 .strip_prefix("http://")
252 .or_else(|| origin.strip_prefix("https://"))
253 .and_then(|rest| rest.split('/').next())
254 .map(|authority| authority.split(':').next().unwrap_or(authority))
257 .map(|h| {
258 h.eq_ignore_ascii_case(host)
259 || h.strip_prefix("www.").unwrap_or(h) == host.strip_prefix("www.").unwrap_or(host)
260 })
261 .unwrap_or(false)
262}
263
264fn referer_host_matches(referer: &str, host: &str) -> bool {
265 referer
266 .strip_prefix("http://")
267 .or_else(|| referer.strip_prefix("https://"))
268 .and_then(|rest| rest.split('/').next())
269 .map(|authority| authority.split(':').next().unwrap_or(authority))
270 .map(|h| h.eq_ignore_ascii_case(host))
271 .unwrap_or(false)
272}
273
274pub fn csrf_set_cookie(token: &str, https: bool) -> HeaderValue {
276 let secure = if https { "; Secure" } else { "" };
277 HeaderValue::from_str(&format!(
278 "{CSRF_COOKIE}={token}; Path=/; SameSite=Strict; HttpOnly{secure}"
279 ))
280 .unwrap_or_else(|_| HeaderValue::from_static("invalid"))
281}
282
283#[derive(Debug, Clone, Default)]
285pub struct SecurityHeaderOptions {
286 pub csp_nonce: Option<String>,
287 pub https: bool,
288}
289
290pub fn apply_security_headers(mut response: Response, opts: &SecurityHeaderOptions) -> Response {
292 let headers = response.headers_mut();
293 if opts.https {
294 insert_header(
295 headers,
296 header::STRICT_TRANSPORT_SECURITY,
297 "max-age=63072000; includeSubDomains; preload",
298 );
299 }
300 insert_header(headers, header::X_FRAME_OPTIONS, "DENY");
301 insert_header(headers, header::X_CONTENT_TYPE_OPTIONS, "nosniff");
302 insert_header(
303 headers,
304 header::HeaderName::from_static("x-xss-protection"),
305 "0",
306 );
307 insert_header(
308 headers,
309 header::REFERRER_POLICY,
310 "strict-origin-when-cross-origin",
311 );
312 insert_header(
313 headers,
314 header::HeaderName::from_static("permissions-policy"),
315 "camera=(), microphone=(), geolocation=()",
316 );
317 insert_header(
318 headers,
319 header::HeaderName::from_static("cross-origin-opener-policy"),
320 "same-origin",
321 );
322 insert_header(
323 headers,
324 header::HeaderName::from_static("cross-origin-resource-policy"),
325 "same-origin",
326 );
327 insert_header(
328 headers,
329 header::HeaderName::from_static("x-dns-prefetch-control"),
330 "off",
331 );
332
333 let csp = if let Some(nonce) = &opts.csp_nonce {
334 let mut policy = format!(
338 "default-src 'self'; script-src 'self' 'nonce-{nonce}' 'unsafe-eval'; style-src 'self' 'nonce-{nonce}'; img-src 'self' data:; font-src 'self'; connect-src 'self'; object-src 'none'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
339 );
340 if opts.https {
341 policy.push_str("; upgrade-insecure-requests");
342 }
343 policy
344 } else {
345 let mut policy = "default-src 'self'; script-src 'self' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self'; connect-src 'self'; object-src 'none'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'".to_string();
346 if opts.https {
347 policy.push_str("; upgrade-insecure-requests");
348 }
349 policy
350 };
351 insert_header(headers, header::CONTENT_SECURITY_POLICY, &csp);
352 response
353}
354
355fn insert_header(headers: &mut axum::http::HeaderMap, name: header::HeaderName, value: &str) {
356 if let Ok(v) = HeaderValue::from_str(value) {
357 headers.insert(name, v);
358 }
359}
360
361pub fn guard_mutation(
363 headers: &HeaderMap,
364 host: &str,
365 ip: &str,
366 bucket: &str,
367 limit: u32,
368 form_csrf: Option<&str>,
369) -> Result<()> {
370 check_rate_limit(ip, bucket, limit)?;
371 validate_origin(headers, host)?;
372 validate_csrf(headers, form_csrf)?;
373 Ok(())
374}
375
376pub fn http_status(err: &ResumaError) -> axum::http::StatusCode {
378 axum::http::StatusCode::from_u16(err.status_code())
379 .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
380}
381
382#[derive(Clone, Default)]
384pub struct SecurityState {
385 pub config: Arc<SecurityConfig>,
386}
387
388impl SecurityState {
389 pub fn new(config: SecurityConfig) -> Self {
390 Self {
391 config: Arc::new(config),
392 }
393 }
394
395 pub fn current() -> Self {
396 Self::new(config())
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn origin_matches_ignoring_port() {
406 assert!(origin_matches_host("http://localhost:3000", "localhost"));
408 assert!(origin_matches_host("http://127.0.0.1:3939", "127.0.0.1"));
409 assert!(origin_matches_host("https://example.com", "example.com"));
410 assert!(origin_matches_host(
411 "https://example.com:8443",
412 "example.com"
413 ));
414 assert!(origin_matches_host(
415 "https://www.example.com:443",
416 "example.com"
417 ));
418 }
419
420 #[test]
421 fn origin_rejects_other_hosts() {
422 assert!(!origin_matches_host("http://evil.test:3000", "localhost"));
423 assert!(!origin_matches_host(
424 "https://attacker.example",
425 "example.com"
426 ));
427 }
428
429 #[test]
430 fn referer_matches_ignoring_port() {
431 assert!(referer_host_matches(
432 "http://localhost:3000/items",
433 "localhost"
434 ));
435 assert!(referer_host_matches(
436 "https://example.com:8443/x",
437 "example.com"
438 ));
439 assert!(!referer_host_matches(
440 "http://evil.test:3000/x",
441 "localhost"
442 ));
443 }
444
445 #[test]
446 fn validate_origin_allows_same_host_with_port() {
447 let mut headers = HeaderMap::new();
448 headers.insert(header::ORIGIN, "http://localhost:3000".parse().unwrap());
449 assert!(validate_origin(&headers, "localhost:3000").is_ok());
451 }
452
453 #[test]
454 fn csp_allows_runtime_compiled_handlers() {
455 let res = Response::new(axum::body::Body::empty());
456 let res = apply_security_headers(
457 res,
458 &SecurityHeaderOptions {
459 csp_nonce: Some("abc123".into()),
460 https: false,
461 },
462 );
463 let csp = res
464 .headers()
465 .get(header::CONTENT_SECURITY_POLICY)
466 .and_then(|v| v.to_str().ok())
467 .unwrap();
468
469 assert!(csp.contains("'nonce-abc123'"));
470 assert!(csp.contains("'unsafe-eval'"));
471 }
472}