use axum::{
extract::Request,
http::{StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use rand::{Rng, distributions::Alphanumeric};
pub fn generate_csrf_token() -> String {
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect()
}
fn extract_token_from_body(bytes: &[u8]) -> Option<String> {
let body_str = String::from_utf8_lossy(bytes);
for param in body_str.split('&') {
let mut parts = param.split('=');
if let (Some(key), Some(val)) = (parts.next(), parts.next())
&& key == "_token"
{
return Some(val.to_string());
}
}
None
}
pub async fn csrf_middleware(req: Request, next: Next) -> Response {
let method = req.method();
if method == axum::http::Method::GET {
let has_cookie = req
.headers()
.get(header::COOKIE)
.and_then(|v| v.to_str().ok())
.map(|cookie_str| cookie_str.contains("rullst_csrf="))
.unwrap_or(false);
if !has_cookie {
let token = generate_csrf_token();
let mut response = next.run(req).await;
if let Ok(cookie_val) = header::HeaderValue::from_str(&format!(
"rullst_csrf={}; Path=/; SameSite=Lax; HttpOnly",
token
)) {
response
.headers_mut()
.append(header::SET_COOKIE, cookie_val);
}
return response;
}
return next.run(req).await;
}
let csrf_cookie = req
.headers()
.get(header::COOKIE)
.and_then(|v| v.to_str().ok())
.and_then(|cookie_str| {
for cookie in cookie_str.split(';') {
let trimmed = cookie.trim();
if let Some(stripped) = trimmed.strip_prefix("rullst_csrf=") {
return Some(stripped.to_string());
}
}
None
});
let Some(cookie_token) = csrf_cookie else {
return (StatusCode::FORBIDDEN, "CSRF token cookie missing").into_response();
};
let header_token = req
.headers()
.get("X-CSRF-Token")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if let Some(token) = header_token {
if token == cookie_token {
return next.run(req).await;
}
return (StatusCode::FORBIDDEN, "Invalid CSRF token").into_response();
}
let (parts, body) = req.into_parts();
let bytes = match axum::body::to_bytes(body, 1024 * 1024).await {
Ok(b) => b,
Err(_) => return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response(),
};
let body_token = extract_token_from_body(&bytes);
let reconstructed_req = Request::from_parts(parts, axum::body::Body::from(bytes));
if let Some(token) = body_token
&& token == cookie_token
{
return next.run(reconstructed_req).await;
}
(StatusCode::FORBIDDEN, "Invalid or missing CSRF token").into_response()
}
pub async fn headers_middleware(req: Request, next: Next) -> Response {
let mut response = next.run(req).await;
let headers = response.headers_mut();
headers.insert("X-Frame-Options", header::HeaderValue::from_static("DENY"));
headers.insert(
"X-Content-Type-Options",
header::HeaderValue::from_static("nosniff"),
);
headers.insert(
"X-XSS-Protection",
header::HeaderValue::from_static("1; mode=block"),
);
headers.insert(
"Referrer-Policy",
header::HeaderValue::from_static("strict-origin-when-cross-origin"),
);
headers.insert(
"Strict-Transport-Security",
header::HeaderValue::from_static("max-age=31536000; includeSubDomains"),
);
response
}