use crate::cookie::{Cookie, SameSite};
use crate::error::Result;
use crate::middleware::{BoxedMiddleware, Next};
use crate::response::{Response, ResponseBuilder};
use crate::Context;
use base64::Engine;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct Csrf {
cookie_name: String,
header_name: String,
secure: bool,
same_site: SameSite,
}
impl Default for Csrf {
fn default() -> Self {
Self {
cookie_name: "csrf_token".to_string(),
header_name: "x-csrf-token".to_string(),
secure: true,
same_site: SameSite::Lax,
}
}
}
impl Csrf {
pub fn new() -> Self {
Self::default()
}
pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
self.cookie_name = name.into();
self
}
pub fn header_name(mut self, name: impl Into<String>) -> Self {
self.header_name = name.into();
self
}
pub fn secure(mut self, v: bool) -> Self {
self.secure = v;
self
}
pub fn same_site(mut self, v: SameSite) -> Self {
self.same_site = v;
self
}
pub fn build(self) -> BoxedMiddleware {
let cfg = Arc::new(self);
Arc::new(move |ctx: Context, next: Next| {
let cfg = cfg.clone();
Box::pin(async move {
let existing = ctx.cookie(&cfg.cookie_name);
if is_unsafe(ctx.req.method()) {
let valid = match (&existing, ctx.req.header(&cfg.header_name)) {
(Some(c), Some(h)) => ct_eq(c.as_bytes(), h.as_bytes()),
_ => false,
};
if !valid {
return Ok(forbidden());
}
}
if existing.is_none() {
let token = generate_token();
let cookie = Cookie::new(cfg.cookie_name.clone(), token)
.secure(cfg.secure)
.same_site(cfg.same_site)
.path("/");
let _ = ctx.set_cookie(cookie).await;
}
next(ctx).await
}) as Pin<Box<dyn Future<Output = Result<Response>> + Send>>
})
}
}
pub fn csrf() -> BoxedMiddleware {
Csrf::new().build()
}
fn is_unsafe(method: &hyper::Method) -> bool {
!matches!(
*method,
hyper::Method::GET | hyper::Method::HEAD | hyper::Method::OPTIONS | hyper::Method::TRACE
)
}
fn ct_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
fn generate_token() -> String {
let mut bytes = [0u8; 32];
getrandom::fill(&mut bytes).expect("OS RNG failure generating CSRF token");
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
fn forbidden() -> Response {
ResponseBuilder::new()
.status(403)
.text("CSRF token missing or invalid")
.build()
.unwrap_or_else(|_| crate::response::helpers::text("Forbidden").unwrap())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ct_eq_basic() {
assert!(ct_eq(b"abc", b"abc"));
assert!(!ct_eq(b"abc", b"abd"));
assert!(!ct_eq(b"abc", b"abcd"));
}
#[test]
fn unsafe_methods() {
assert!(is_unsafe(&hyper::Method::POST));
assert!(is_unsafe(&hyper::Method::DELETE));
assert!(!is_unsafe(&hyper::Method::GET));
assert!(!is_unsafe(&hyper::Method::OPTIONS));
}
}