use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::core::Result;
use crate::core::ResumaError;
use axum::http::{header, HeaderMap, HeaderValue, Request};
use axum::response::Response;
use once_cell::sync::Lazy;
use parking_lot::RwLock;
#[derive(Clone, Debug)]
pub struct CspNonce(pub String);
pub const CSRF_COOKIE: &str = "__resuma-csrf";
pub const CSRF_HEADER: &str = "x-resuma-csrf";
pub const CSRF_FIELD: &str = "_csrf";
static CONFIG: Lazy<RwLock<SecurityConfig>> = Lazy::new(|| RwLock::new(SecurityConfig::from_env()));
static RATE_BUCKETS: Lazy<RwLock<HashMap<String, Vec<Instant>>>> =
Lazy::new(|| RwLock::new(HashMap::new()));
#[derive(Debug, Clone)]
pub struct CspConfig {
pub enabled: bool,
pub report_only: bool,
pub strict_dynamic: bool,
pub unsafe_eval: bool,
pub img_src: Vec<String>,
pub script_src: Vec<String>,
pub style_src: Vec<String>,
pub connect_src: Vec<String>,
pub font_src: Vec<String>,
}
impl Default for CspConfig {
fn default() -> Self {
Self::from_env()
}
}
impl CspConfig {
pub fn from_env() -> Self {
let enabled = !env_flag_off("RESUMA_CSP")
&& (env_flag_on("RESUMA_CSP_DEV") || !crate::server::dev::dev_mode_enabled());
Self {
enabled,
report_only: env_flag_on("RESUMA_CSP_REPORT_ONLY"),
strict_dynamic: !env_flag_off("RESUMA_CSP_STRICT_DYNAMIC"),
unsafe_eval: !env_flag_off("RESUMA_CSP_UNSAFE_EVAL"),
img_src: parse_csp_list_env("RESUMA_CSP_IMG_SRC"),
script_src: parse_csp_list_env("RESUMA_CSP_SCRIPT_SRC"),
style_src: parse_csp_list_env("RESUMA_CSP_STYLE_SRC"),
connect_src: parse_csp_list_env("RESUMA_CSP_CONNECT_SRC"),
font_src: parse_csp_list_env("RESUMA_CSP_FONT_SRC"),
}
}
pub fn disabled() -> Self {
let mut c = Self::from_env();
c.enabled = false;
c
}
pub fn production(extra_img_src: impl IntoIterator<Item = impl Into<String>>) -> Self {
let mut c = Self::from_env();
c.enabled = true;
c.img_src = extra_img_src.into_iter().map(Into::into).collect();
c
}
}
#[derive(Debug, Clone)]
pub struct SecurityConfig {
pub csrf: bool,
pub origin_check: bool,
pub trust_proxy: bool,
pub body_limit_bytes: usize,
pub actions_per_minute: u32,
pub submits_per_minute: u32,
pub hide_benchmark: bool,
pub production: bool,
pub csp: CspConfig,
}
impl Default for SecurityConfig {
fn default() -> Self {
Self::from_env()
}
}
impl SecurityConfig {
pub fn from_env() -> Self {
let production = matches!(
std::env::var("RESUMA_ENV").as_deref(),
Ok("production") | Ok("prod")
);
let trust_proxy = matches!(
std::env::var("RESUMA_TRUST_PROXY").as_deref(),
Ok("1") | Ok("true") | Ok("TRUE")
);
Self {
csrf: !env_flag_off("RESUMA_CSRF"),
origin_check: !env_flag_off("RESUMA_ORIGIN_CHECK"),
trust_proxy,
body_limit_bytes: std::env::var("RESUMA_BODY_LIMIT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(1024 * 1024),
actions_per_minute: std::env::var("RESUMA_RATE_ACTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(120),
submits_per_minute: std::env::var("RESUMA_RATE_SUBMITS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(60),
hide_benchmark: production,
production,
csp: CspConfig::from_env(),
}
}
}
fn env_flag_on(name: &str) -> bool {
matches!(
std::env::var(name).as_deref(),
Ok("1") | Ok("true") | Ok("TRUE") | Ok("on")
)
}
fn parse_csp_list_env(name: &str) -> Vec<String> {
std::env::var(name)
.ok()
.map(|raw| {
raw.split(|c: char| c.is_whitespace() || c == ',')
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect()
})
.unwrap_or_default()
}
fn env_flag_off(name: &str) -> bool {
matches!(
std::env::var(name).as_deref(),
Ok("0") | Ok("false") | Ok("FALSE") | Ok("off")
)
}
pub fn configure(config: SecurityConfig) {
*CONFIG.write() = config;
}
pub fn config() -> SecurityConfig {
CONFIG.read().clone()
}
pub fn random_token() -> String {
let mut bytes = [0u8; 16];
getrandom::getrandom(&mut bytes).expect("OS random number generator");
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
pub fn csrf_token() -> String {
random_token()
}
pub fn request_is_https<B>(req: &Request<B>) -> bool {
let cfg = config();
if cfg.trust_proxy {
if let Some(proto) = req
.headers()
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
{
if proto.eq_ignore_ascii_case("https") {
return true;
}
}
}
req.uri().scheme_str() == Some("https")
}
pub fn client_ip<B>(req: &Request<B>) -> String {
client_ip_from_parts(req.headers(), connect_addr(req))
}
pub fn client_ip_from_parts(headers: &HeaderMap, connect: Option<SocketAddr>) -> String {
let cfg = config();
if cfg.trust_proxy {
if let Some(xff) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
if let Some(first) = xff.split(',').next() {
let ip = first.trim();
if !ip.is_empty() {
return ip.to_string();
}
}
}
if let Some(xri) = headers.get("x-real-ip").and_then(|v| v.to_str().ok()) {
if !xri.is_empty() {
return xri.to_string();
}
}
}
connect
.map(|a| a.ip().to_string())
.unwrap_or_else(|| "unknown".to_string())
}
fn connect_addr<B>(req: &Request<B>) -> Option<SocketAddr> {
req.extensions()
.get::<axum::extract::ConnectInfo<SocketAddr>>()
.map(|ci| ci.0)
}
pub fn check_rate_limit(ip: &str, bucket: &str, limit_per_minute: u32) -> Result<()> {
if limit_per_minute == 0 {
return Ok(());
}
let key = format!("{bucket}:{ip}");
let now = Instant::now();
let window = Duration::from_secs(60);
let mut map = RATE_BUCKETS.write();
let entries = map.entry(key).or_default();
entries.retain(|t| now.duration_since(*t) < window);
if entries.len() as u32 >= limit_per_minute {
return Err(ResumaError::RateLimited);
}
entries.push(now);
Ok(())
}
fn header_str(headers: &HeaderMap, name: &str) -> Option<String> {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
fn cookie_value(headers: &HeaderMap, name: &str) -> Option<String> {
let cookie = header_str(headers, header::COOKIE.as_str())?;
for part in cookie.split(';') {
let part = part.trim();
if let Some((k, v)) = part.split_once('=') {
if k.trim() == name {
return Some(v.trim().to_string());
}
}
}
None
}
pub fn validate_csrf(headers: &HeaderMap, form_csrf: Option<&str>) -> Result<()> {
let cfg = config();
if !cfg.csrf {
return Ok(());
}
let cookie = cookie_value(headers, CSRF_COOKIE).ok_or(ResumaError::InvalidCsrf)?;
let header = header_str(headers, CSRF_HEADER);
let token = header
.as_deref()
.or(form_csrf)
.ok_or(ResumaError::InvalidCsrf)?;
if token != cookie || token.len() < 16 {
return Err(ResumaError::InvalidCsrf);
}
Ok(())
}
pub fn validate_origin(headers: &HeaderMap, host: &str) -> Result<()> {
let cfg = config();
if !cfg.origin_check {
return Ok(());
}
let host = host.split(':').next().unwrap_or(host).to_lowercase();
if let Some(origin) = header_str(headers, header::ORIGIN.as_str()) {
if !origin_matches_host(&origin, &host) {
return Err(ResumaError::Forbidden("cross-origin request".into()));
}
return Ok(());
}
if let Some(referer) = header_str(headers, header::REFERER.as_str()) {
if !referer_host_matches(&referer, &host) {
return Err(ResumaError::Forbidden("invalid referer".into()));
}
}
Ok(())
}
fn origin_matches_host(origin: &str, host: &str) -> bool {
origin
.strip_prefix("http://")
.or_else(|| origin.strip_prefix("https://"))
.and_then(|rest| rest.split('/').next())
.map(|authority| authority.split(':').next().unwrap_or(authority))
.map(|h| {
h.eq_ignore_ascii_case(host)
|| h.strip_prefix("www.").unwrap_or(h) == host.strip_prefix("www.").unwrap_or(host)
})
.unwrap_or(false)
}
fn referer_host_matches(referer: &str, host: &str) -> bool {
referer
.strip_prefix("http://")
.or_else(|| referer.strip_prefix("https://"))
.and_then(|rest| rest.split('/').next())
.map(|authority| authority.split(':').next().unwrap_or(authority))
.map(|h| h.eq_ignore_ascii_case(host))
.unwrap_or(false)
}
pub fn csrf_set_cookie(token: &str, https: bool) -> HeaderValue {
let secure = if https { "; Secure" } else { "" };
HeaderValue::from_str(&format!(
"{CSRF_COOKIE}={token}; Path=/; SameSite=Strict; HttpOnly{secure}"
))
.unwrap_or_else(|_| HeaderValue::from_static("invalid"))
}
#[derive(Debug, Clone, Default)]
pub struct SecurityHeaderOptions {
pub csp_nonce: Option<String>,
pub https: bool,
}
pub fn build_content_security_policy(nonce: Option<&str>, https: bool, csp: &CspConfig) -> String {
let mut directives: Vec<String> = vec![
"default-src 'self'".into(),
"base-uri 'self'".into(),
"object-src 'none'".into(),
"frame-ancestors 'none'".into(),
"form-action 'self'".into(),
];
let mut script_src = vec!["'self'".to_string()];
if let Some(nonce) = nonce {
script_src.push(format!("'nonce-{nonce}'"));
if csp.strict_dynamic {
script_src.push("'strict-dynamic'".into());
}
}
if csp.unsafe_eval {
script_src.push("'unsafe-eval'".into());
}
script_src.extend(csp.script_src.iter().cloned());
directives.push(format!("script-src {}", script_src.join(" ")));
let mut style_src = vec!["'self'".to_string()];
if let Some(nonce) = nonce {
style_src.push(format!("'nonce-{nonce}'"));
}
style_src.push("'unsafe-inline'".into());
style_src.push("https://fonts.googleapis.com".into());
style_src.extend(csp.style_src.iter().cloned());
let style_joined = style_src.join(" ");
directives.push(format!("style-src {style_joined}"));
directives.push(format!("style-src-elem {style_joined}"));
directives.push("style-src-attr 'unsafe-inline'".into());
let mut img_src = vec!["'self'", "data:", "blob:"];
img_src.extend(csp.img_src.iter().map(String::as_str));
directives.push(format!("img-src {}", img_src.join(" ")));
let mut font_src = vec!["'self'", "https://fonts.gstatic.com", "data:"];
font_src.extend(csp.font_src.iter().map(String::as_str));
directives.push(format!("font-src {}", font_src.join(" ")));
let mut connect_src = vec!["'self'"];
connect_src.extend(csp.connect_src.iter().map(String::as_str));
directives.push(format!("connect-src {}", connect_src.join(" ")));
if https {
directives.push("upgrade-insecure-requests".into());
}
directives.join("; ")
}
pub fn apply_security_headers(mut response: Response, opts: &SecurityHeaderOptions) -> Response {
let headers = response.headers_mut();
if opts.https {
insert_header(
headers,
header::STRICT_TRANSPORT_SECURITY,
"max-age=63072000; includeSubDomains; preload",
);
}
insert_header(headers, header::X_FRAME_OPTIONS, "DENY");
insert_header(headers, header::X_CONTENT_TYPE_OPTIONS, "nosniff");
insert_header(
headers,
header::HeaderName::from_static("x-xss-protection"),
"0",
);
insert_header(
headers,
header::REFERRER_POLICY,
"strict-origin-when-cross-origin",
);
insert_header(
headers,
header::HeaderName::from_static("permissions-policy"),
"camera=(), microphone=(), geolocation=()",
);
insert_header(
headers,
header::HeaderName::from_static("cross-origin-opener-policy"),
"same-origin",
);
insert_header(
headers,
header::HeaderName::from_static("cross-origin-resource-policy"),
"same-origin",
);
insert_header(
headers,
header::HeaderName::from_static("x-dns-prefetch-control"),
"off",
);
let sec = config();
if sec.csp.enabled {
let policy = build_content_security_policy(opts.csp_nonce.as_deref(), opts.https, &sec.csp);
let header_name = if sec.csp.report_only {
header::HeaderName::from_static("content-security-policy-report-only")
} else {
header::CONTENT_SECURITY_POLICY
};
insert_header(headers, header_name, &policy);
}
response
}
fn insert_header(headers: &mut axum::http::HeaderMap, name: header::HeaderName, value: &str) {
if let Ok(v) = HeaderValue::from_str(value) {
headers.insert(name, v);
}
}
pub fn guard_mutation(
headers: &HeaderMap,
host: &str,
ip: &str,
bucket: &str,
limit: u32,
form_csrf: Option<&str>,
) -> Result<()> {
check_rate_limit(ip, bucket, limit)?;
validate_origin(headers, host)?;
validate_csrf(headers, form_csrf)?;
Ok(())
}
pub fn http_status(err: &ResumaError) -> axum::http::StatusCode {
axum::http::StatusCode::from_u16(err.status_code())
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
}
#[derive(Clone, Default)]
pub struct SecurityState {
pub config: Arc<SecurityConfig>,
}
impl SecurityState {
pub fn new(config: SecurityConfig) -> Self {
Self {
config: Arc::new(config),
}
}
pub fn current() -> Self {
Self::new(config())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn origin_matches_ignoring_port() {
assert!(origin_matches_host("http://localhost:3000", "localhost"));
assert!(origin_matches_host("http://127.0.0.1:3939", "127.0.0.1"));
assert!(origin_matches_host("https://example.com", "example.com"));
assert!(origin_matches_host(
"https://example.com:8443",
"example.com"
));
assert!(origin_matches_host(
"https://www.example.com:443",
"example.com"
));
}
#[test]
fn origin_rejects_other_hosts() {
assert!(!origin_matches_host("http://evil.test:3000", "localhost"));
assert!(!origin_matches_host(
"https://attacker.example",
"example.com"
));
}
#[test]
fn referer_matches_ignoring_port() {
assert!(referer_host_matches(
"http://localhost:3000/items",
"localhost"
));
assert!(referer_host_matches(
"https://example.com:8443/x",
"example.com"
));
assert!(!referer_host_matches(
"http://evil.test:3000/x",
"localhost"
));
}
#[test]
fn validate_origin_allows_same_host_with_port() {
let mut headers = HeaderMap::new();
headers.insert(header::ORIGIN, "http://localhost:3000".parse().unwrap());
assert!(validate_origin(&headers, "localhost:3000").is_ok());
}
#[test]
fn csp_allows_runtime_compiled_handlers() {
let csp = build_content_security_policy(
Some("abc123"),
false,
&CspConfig {
enabled: true,
strict_dynamic: true,
unsafe_eval: true,
..CspConfig::from_env()
},
);
assert!(csp.contains("'nonce-abc123'"));
assert!(csp.contains("'strict-dynamic'"));
assert!(csp.contains("'unsafe-eval'"));
assert!(csp.contains("style-src 'self' 'nonce-abc123' 'unsafe-inline'"));
assert!(csp.contains("style-src-elem 'self' 'nonce-abc123' 'unsafe-inline'"));
assert!(csp.contains("style-src-attr 'unsafe-inline'"));
assert!(csp.contains("img-src 'self' data: blob:"));
}
#[test]
fn csp_extra_img_src() {
let policy = build_content_security_policy(
Some("n1"),
false,
&CspConfig {
enabled: true,
img_src: vec!["https://images.pexels.com".into()],
..CspConfig::from_env()
},
);
assert!(policy.contains("img-src 'self' data: blob: https://images.pexels.com"));
}
#[test]
fn csp_omitted_when_disabled() {
configure(SecurityConfig {
csp: CspConfig::disabled(),
..SecurityConfig::from_env()
});
let res = Response::new(axum::body::Body::empty());
let res = apply_security_headers(
res,
&SecurityHeaderOptions {
csp_nonce: Some("abc".into()),
https: false,
},
);
assert!(res.headers().get(header::CONTENT_SECURITY_POLICY).is_none());
}
}