use axum::{
extract::Request,
http::{HeaderMap, Method, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
use std::sync::OnceLock;
const DEFAULT_ALLOWED_ORIGINS: &[&str] = &[
"http://localhost:3000",
"http://localhost:5173",
"http://127.0.0.1:3000",
"http://127.0.0.1:5173",
"https://app.mockforge.dev",
"https://mockforge.dev",
];
static ALLOWED_ORIGINS: OnceLock<Vec<String>> = OnceLock::new();
fn is_csrf_enabled() -> bool {
std::env::var("CSRF_ENABLED")
.map(|v| v.to_lowercase() != "false")
.unwrap_or(true)
}
fn get_allowed_origins() -> &'static Vec<String> {
ALLOWED_ORIGINS.get_or_init(|| {
std::env::var("ALLOWED_ORIGINS")
.map(|s| s.split(',').map(|o| o.trim().to_string()).collect())
.unwrap_or_else(|_| DEFAULT_ALLOWED_ORIGINS.iter().map(|s| s.to_string()).collect())
})
}
fn is_origin_allowed(origin: &str) -> bool {
let allowed = get_allowed_origins();
if allowed.iter().any(|o| o == origin) {
return true;
}
for allowed_origin in allowed {
if allowed_origin.starts_with("*.") {
let suffix = &allowed_origin[1..]; if let Some(prefix) = origin.strip_suffix(suffix) {
if prefix.starts_with("https://") || prefix.starts_with("http://") {
return true;
}
}
}
}
false
}
fn extract_origin(headers: &HeaderMap) -> Option<String> {
if let Some(origin) = headers.get("Origin") {
if let Ok(value) = origin.to_str() {
if !value.is_empty() && value != "null" {
return Some(value.to_string());
}
}
}
if let Some(referer) = headers.get("Referer") {
if let Ok(value) = referer.to_str() {
if let Some(scheme_end) = value.find("://") {
let after_scheme = &value[scheme_end + 3..];
if let Some(path_start) = after_scheme.find('/') {
return Some(value[..scheme_end + 3 + path_start].to_string());
} else {
return Some(value.to_string());
}
}
}
}
None
}
fn is_state_changing_method(method: &Method) -> bool {
matches!(method, &Method::POST | &Method::PUT | &Method::PATCH | &Method::DELETE)
}
pub async fn csrf_middleware(
headers: HeaderMap,
request: Request,
next: Next,
) -> Result<Response, Response> {
if !is_csrf_enabled() {
return Ok(next.run(request).await);
}
if !is_state_changing_method(request.method()) {
return Ok(next.run(request).await);
}
if headers.contains_key("Authorization") {
return Ok(next.run(request).await);
}
let origin = extract_origin(&headers);
match origin {
Some(ref o) if is_origin_allowed(o) => {
Ok(next.run(request).await)
}
Some(ref o) => {
tracing::warn!(
origin = %o,
path = %request.uri().path(),
"CSRF check failed: origin not allowed"
);
Err(csrf_error_response().into_response())
}
None => {
tracing::debug!(
path = %request.uri().path(),
"Request without Origin/Referer header"
);
Ok(next.run(request).await)
}
}
}
fn csrf_error_response() -> impl IntoResponse {
(
StatusCode::FORBIDDEN,
Json(json!({
"error": {
"code": "CSRF_VALIDATION_FAILED",
"message": "Cross-site request forgery validation failed. Please try again."
}
})),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_state_changing_method() {
assert!(is_state_changing_method(&Method::POST));
assert!(is_state_changing_method(&Method::PUT));
assert!(is_state_changing_method(&Method::PATCH));
assert!(is_state_changing_method(&Method::DELETE));
assert!(!is_state_changing_method(&Method::GET));
assert!(!is_state_changing_method(&Method::HEAD));
assert!(!is_state_changing_method(&Method::OPTIONS));
}
#[test]
fn test_is_origin_allowed_exact() {
assert!(is_origin_allowed("http://localhost:3000"));
assert!(is_origin_allowed("https://app.mockforge.dev"));
assert!(!is_origin_allowed("https://evil.com"));
assert!(!is_origin_allowed("http://localhost:9999"));
}
#[test]
fn test_extract_origin_from_header() {
let mut headers = HeaderMap::new();
headers.insert("Origin", "https://app.mockforge.dev".parse().unwrap());
let origin = extract_origin(&headers);
assert_eq!(origin, Some("https://app.mockforge.dev".to_string()));
}
#[test]
fn test_extract_origin_from_referer() {
let mut headers = HeaderMap::new();
headers.insert("Referer", "https://app.mockforge.dev/some/path".parse().unwrap());
let origin = extract_origin(&headers);
assert_eq!(origin, Some("https://app.mockforge.dev".to_string()));
}
#[test]
fn test_extract_origin_prefers_origin_header() {
let mut headers = HeaderMap::new();
headers.insert("Origin", "https://origin.example.com".parse().unwrap());
headers.insert("Referer", "https://referer.example.com/path".parse().unwrap());
let origin = extract_origin(&headers);
assert_eq!(origin, Some("https://origin.example.com".to_string()));
}
#[test]
fn test_extract_origin_empty() {
let headers = HeaderMap::new();
let origin = extract_origin(&headers);
assert!(origin.is_none());
}
#[test]
fn test_extract_origin_null() {
let mut headers = HeaderMap::new();
headers.insert("Origin", "null".parse().unwrap());
let origin = extract_origin(&headers);
assert!(origin.is_none());
}
}