use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use http::HeaderValue;
use http::Method;
use http::StatusCode;
use crate::middleware::IntoMiddleware;
use crate::middleware::Next;
use crate::responder::Responder;
use crate::types::Request;
use crate::types::Response;
pub struct Csrf {
cookie_name: String,
header_name: String,
exempt_paths: Vec<String>,
secure: bool,
}
impl Default for Csrf {
fn default() -> Self {
Self::new()
}
}
impl Csrf {
pub fn new() -> Self {
Self {
cookie_name: "csrf_token".to_string(),
header_name: "x-csrf-token".to_string(),
exempt_paths: Vec::new(),
secure: false,
}
}
pub fn cookie_name(mut self, name: &str) -> Self {
self.cookie_name = name.to_string();
self
}
pub fn header_name(mut self, name: &str) -> Self {
self.header_name = name.to_string();
self
}
pub fn exempt(mut self, path: &str) -> Self {
self.exempt_paths.push(path.to_string());
self
}
pub fn secure(mut self, secure: bool) -> Self {
self.secure = secure;
self
}
}
fn generate_csrf_token() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let a = now.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let b = a.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
format!("{:016x}{:016x}", a & 0xFFFFFFFFFFFFFFFF, b & 0xFFFFFFFFFFFFFFFF)
}
fn is_unsafe_method(method: &Method) -> bool {
matches!(
*method,
Method::POST | Method::PUT | Method::DELETE | Method::PATCH
)
}
impl IntoMiddleware for Csrf {
fn into_middleware(
self,
) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
+ Clone
+ Send
+ Sync
+ 'static {
let cookie_name = Arc::new(self.cookie_name);
let header_name = Arc::new(self.header_name);
let exempt_paths = Arc::new(self.exempt_paths);
let secure = self.secure;
move |req: Request, next: Next| {
let cookie_name = cookie_name.clone();
let header_name = header_name.clone();
let exempt_paths = exempt_paths.clone();
Box::pin(async move {
let path = req.uri().path().to_string();
if !is_unsafe_method(req.method()) {
let mut resp = next.run(req).await;
ensure_csrf_cookie(&mut resp, &cookie_name, secure);
return resp;
}
if exempt_paths.iter().any(|p| path.starts_with(p.as_str())) {
let mut resp = next.run(req).await;
ensure_csrf_cookie(&mut resp, &cookie_name, secure);
return resp;
}
let cookie_token = req
.headers()
.get(http::header::COOKIE)
.and_then(|v| v.to_str().ok())
.and_then(|cookies| {
cookies.split(';').find_map(|pair| {
let pair = pair.trim();
let (name, value) = pair.split_once('=')?;
if name.trim() == cookie_name.as_str() {
Some(value.trim().to_string())
} else {
None
}
})
});
let header_token = req
.headers()
.get(header_name.as_str())
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
match (cookie_token, header_token) {
(Some(ct), Some(ht)) if ct == ht && !ct.is_empty() => {
let mut resp = next.run(req).await;
ensure_csrf_cookie(&mut resp, &cookie_name, secure);
resp
}
_ => {
(StatusCode::FORBIDDEN, "CSRF token mismatch").into_response()
}
}
})
}
}
}
fn ensure_csrf_cookie(resp: &mut Response, cookie_name: &str, secure: bool) {
let has_csrf = resp
.headers()
.get_all(http::header::SET_COOKIE)
.iter()
.any(|v| v.to_str().unwrap_or("").starts_with(cookie_name));
if !has_csrf {
let token = generate_csrf_token();
let mut cookie = format!(
"{}={}; Path=/; SameSite=Strict",
cookie_name, token
);
if secure {
cookie.push_str("; Secure");
}
if let Ok(val) = HeaderValue::from_str(&cookie) {
resp.headers_mut().append(http::header::SET_COOKIE, val);
}
}
}