use super::backends::SessionBackend;
use super::session::Session;
use serde::{Deserialize, Serialize};
use std::time::SystemTime;
use subtle::ConstantTimeEq;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CsrfTokenData {
pub token: String,
pub created_at: SystemTime,
}
pub struct CsrfSessionManager {
session_key: String,
}
impl CsrfSessionManager {
pub fn new() -> Self {
Self {
session_key: "_csrf_token".to_string(),
}
}
pub fn with_key(session_key: String) -> Self {
Self { session_key }
}
pub fn generate_token<B: SessionBackend>(
&self,
session: &mut Session<B>,
) -> Result<String, serde_json::Error> {
let token = Uuid::new_v4().to_string();
let token_data = CsrfTokenData {
token: token.clone(),
created_at: SystemTime::now(),
};
session.set(&self.session_key, token_data)?;
Ok(token)
}
pub fn get_token<B: SessionBackend>(
&self,
session: &mut Session<B>,
) -> Result<Option<String>, serde_json::Error> {
let token_data: Option<CsrfTokenData> = session.get(&self.session_key)?;
Ok(token_data.map(|data| data.token))
}
pub fn validate_token<B: SessionBackend>(
&self,
session: &mut Session<B>,
submitted_token: &str,
) -> Result<bool, serde_json::Error> {
let stored_token = self.get_token(session)?;
match stored_token {
Some(token) => {
Ok(token.as_bytes().ct_eq(submitted_token.as_bytes()).into())
}
None => Ok(false),
}
}
pub fn rotate_token<B: SessionBackend>(
&self,
session: &mut Session<B>,
) -> Result<String, serde_json::Error> {
self.generate_token(session)
}
pub fn clear_token<B: SessionBackend>(&self, session: &mut Session<B>) {
session.delete(&self.session_key);
}
pub fn get_or_create_token<B: SessionBackend>(
&self,
session: &mut Session<B>,
) -> Result<String, serde_json::Error> {
if let Some(token) = self.get_token(session)? {
Ok(token)
} else {
self.generate_token(session)
}
}
}
impl Default for CsrfSessionManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sessions::InMemorySessionBackend;
#[tokio::test]
async fn test_csrf_manager_new() {
let _csrf = CsrfSessionManager::new();
}
#[tokio::test]
async fn test_csrf_manager_with_key() {
let csrf = CsrfSessionManager::with_key("custom_key".to_string());
assert_eq!(csrf.session_key, "custom_key");
}
#[tokio::test]
async fn test_generate_token() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
let csrf = CsrfSessionManager::new();
let token = csrf.generate_token(&mut session).unwrap();
assert!(!token.is_empty());
}
#[tokio::test]
async fn test_get_token() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
let csrf = CsrfSessionManager::new();
assert!(csrf.get_token(&mut session).unwrap().is_none());
let generated = csrf.generate_token(&mut session).unwrap();
let stored = csrf.get_token(&mut session).unwrap();
assert_eq!(stored, Some(generated));
}
#[tokio::test]
async fn test_validate_token() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
let csrf = CsrfSessionManager::new();
let token = csrf.generate_token(&mut session).unwrap();
assert!(csrf.validate_token(&mut session, &token).unwrap());
assert!(!csrf.validate_token(&mut session, "wrong_token").unwrap());
}
#[tokio::test]
async fn test_validate_token_no_token_in_session() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
let csrf = CsrfSessionManager::new();
assert!(!csrf.validate_token(&mut session, "any_token").unwrap());
}
#[tokio::test]
async fn test_rotate_token() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
let csrf = CsrfSessionManager::new();
let old_token = csrf.generate_token(&mut session).unwrap();
let new_token = csrf.rotate_token(&mut session).unwrap();
assert_ne!(old_token, new_token);
assert!(!csrf.validate_token(&mut session, &old_token).unwrap());
assert!(csrf.validate_token(&mut session, &new_token).unwrap());
}
#[tokio::test]
async fn test_clear_token() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
let csrf = CsrfSessionManager::new();
csrf.generate_token(&mut session).unwrap();
csrf.clear_token(&mut session);
assert!(csrf.get_token(&mut session).unwrap().is_none());
}
#[tokio::test]
async fn test_get_or_create_token() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
let csrf = CsrfSessionManager::new();
let token1 = csrf.get_or_create_token(&mut session).unwrap();
let token2 = csrf.get_or_create_token(&mut session).unwrap();
assert_eq!(token1, token2);
}
#[tokio::test]
async fn test_get_or_create_token_creates_if_missing() {
let backend = InMemorySessionBackend::new();
let mut session = Session::new(backend);
let csrf = CsrfSessionManager::new();
assert!(csrf.get_token(&mut session).unwrap().is_none());
let token = csrf.get_or_create_token(&mut session).unwrap();
assert!(!token.is_empty());
assert!(csrf.get_token(&mut session).unwrap().is_some());
}
}