use rand::{distributions::Alphanumeric, Rng};
use axum::{
extract::Request,
middleware::Next,
response::{Response, IntoResponse},
http::{StatusCode, header},
};
pub fn generate_csrf_token() -> String {
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect()
}
fn extract_token_from_body(bytes: &[u8]) -> Option<String> {
let body_str = String::from_utf8_lossy(bytes);
for param in body_str.split('&') {
let mut parts = param.split('=');
if let (Some(key), Some(val)) = (parts.next(), parts.next()) {
if key == "_token" {
return Some(val.to_string());
}
}
}
None
}
pub async fn csrf_middleware(req: Request, next: Next) -> Response {
let method = req.method();
if method == axum::http::Method::GET {
let has_cookie = req.headers()
.get(header::COOKIE)
.and_then(|v| v.to_str().ok())
.map(|cookie_str| cookie_str.contains("rullst_csrf="))
.unwrap_or(false);
if !has_cookie {
let token = generate_csrf_token();
let mut response = next.run(req).await;
if let Ok(cookie_val) = header::HeaderValue::from_str(&format!(
"rullst_csrf={}; Path=/; SameSite=Lax; HttpOnly",
token
)) {
response.headers_mut().append(header::SET_COOKIE, cookie_val);
}
return response;
}
return next.run(req).await;
}
let csrf_cookie = req.headers()
.get(header::COOKIE)
.and_then(|v| v.to_str().ok())
.and_then(|cookie_str| {
for cookie in cookie_str.split(';') {
let trimmed = cookie.trim();
if trimmed.starts_with("rullst_csrf=") {
return Some(trimmed["rullst_csrf=".len()..].to_string());
}
}
None
});
let Some(cookie_token) = csrf_cookie else {
return (StatusCode::FORBIDDEN, "CSRF token cookie missing").into_response();
};
let header_token = req.headers()
.get("X-CSRF-Token")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if let Some(token) = header_token {
if token == cookie_token {
return next.run(req).await;
}
return (StatusCode::FORBIDDEN, "Invalid CSRF token").into_response();
}
let (parts, body) = req.into_parts();
let bytes = match axum::body::to_bytes(body, 1024 * 1024).await {
Ok(b) => b,
Err(_) => return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response(),
};
let body_token = extract_token_from_body(&bytes);
let reconstructed_req = Request::from_parts(parts, axum::body::Body::from(bytes));
if let Some(token) = body_token {
if token == cookie_token {
return next.run(reconstructed_req).await;
}
}
(StatusCode::FORBIDDEN, "Invalid or missing CSRF token").into_response()
}
pub async fn headers_middleware(req: Request, next: Next) -> Response {
let mut response = next.run(req).await;
let headers = response.headers_mut();
headers.insert("X-Frame-Options", header::HeaderValue::from_static("DENY"));
headers.insert("X-Content-Type-Options", header::HeaderValue::from_static("nosniff"));
headers.insert("X-XSS-Protection", header::HeaderValue::from_static("1; mode=block"));
headers.insert("Referrer-Policy", header::HeaderValue::from_static("strict-origin-when-cross-origin"));
headers.insert("Strict-Transport-Security", header::HeaderValue::from_static("max-age=31536000; includeSubDomains"));
response
}