use crate::agents::request_reply::{create_request_reply, send_response, ResponseChannel};
use crate::agents::default_agent_config;
use crate::auth::session::SessionId;
use acton_reactive::prelude::*;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use chrono::{DateTime, Duration, Utc};
use rand::Rng;
use std::collections::HashMap;
use tokio::sync::oneshot;
type CsrfAgentBuilder = ManagedAgent<Idle, CsrfManagerAgent>;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct CsrfToken(String);
impl CsrfToken {
#[must_use]
pub fn generate() -> Self {
let mut rng = rand::rng();
let mut bytes = [0u8; 32];
rng.fill(&mut bytes);
Self(URL_SAFE_NO_PAD.encode(bytes))
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub const fn from_string(s: String) -> Self {
Self(s)
}
}
impl std::fmt::Display for CsrfToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone, Debug)]
struct CsrfTokenData {
token: CsrfToken,
expires_at: DateTime<Utc>,
}
impl CsrfTokenData {
#[must_use]
fn new(token: CsrfToken) -> Self {
let expires_at = Utc::now() + Duration::hours(24);
Self { token, expires_at }
}
#[must_use]
fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
}
#[derive(Debug, Default, Clone)]
pub struct CsrfManagerAgent {
tokens: HashMap<SessionId, CsrfTokenData>,
}
#[derive(Clone, Debug)]
pub struct GetOrCreateToken {
pub session_id: SessionId,
pub response_tx: Option<ResponseChannel<CsrfToken>>,
}
impl GetOrCreateToken {
#[must_use]
pub fn new(session_id: SessionId) -> (Self, oneshot::Receiver<CsrfToken>) {
let (response_tx, rx) = create_request_reply();
let request = Self {
session_id,
response_tx: Some(response_tx),
};
(request, rx)
}
#[must_use]
pub const fn agent_message(session_id: SessionId) -> Self {
Self {
session_id,
response_tx: None,
}
}
}
#[derive(Clone, Debug)]
pub struct ValidateToken {
pub session_id: SessionId,
pub token: CsrfToken,
pub response_tx: Option<ResponseChannel<bool>>,
}
impl ValidateToken {
#[must_use]
pub fn new(session_id: SessionId, token: CsrfToken) -> (Self, oneshot::Receiver<bool>) {
let (response_tx, rx) = create_request_reply();
let request = Self {
session_id,
token,
response_tx: Some(response_tx),
};
(request, rx)
}
#[must_use]
pub const fn agent_message(session_id: SessionId, token: CsrfToken) -> Self {
Self {
session_id,
token,
response_tx: None,
}
}
}
#[derive(Clone, Debug)]
pub struct DeleteToken {
pub session_id: SessionId,
}
impl DeleteToken {
#[must_use]
pub const fn new(session_id: SessionId) -> Self {
Self { session_id }
}
}
#[derive(Clone, Debug)]
pub struct CleanupExpired;
impl CsrfManagerAgent {
pub async fn spawn(runtime: &mut AgentRuntime) -> anyhow::Result<AgentHandle> {
let config = default_agent_config("csrf_manager")?;
let builder = runtime.new_agent_with_config::<Self>(config).await;
Self::configure_handlers(builder).await
}
async fn configure_handlers(mut builder: CsrfAgentBuilder) -> anyhow::Result<AgentHandle> {
builder
.mutate_on::<GetOrCreateToken>(|agent, envelope| {
let session_id = envelope.message().session_id.clone();
let response_tx = envelope.message().response_tx.clone();
let reply_envelope = envelope.reply_envelope();
let token = Self::get_or_create_token_internal(&mut agent.model, &session_id);
AgentReply::from_async(async move {
if let Some(tx) = response_tx {
let _ = send_response(tx, token.clone()).await;
}
let _: () = reply_envelope.send(token).await;
})
})
.mutate_on::<ValidateToken>(|agent, envelope| {
let session_id = envelope.message().session_id.clone();
let token = envelope.message().token.clone();
let response_tx = envelope.message().response_tx.clone();
let reply_envelope = envelope.reply_envelope();
let valid = Self::validate_and_rotate_token(&mut agent.model, &session_id, &token);
AgentReply::from_async(async move {
if let Some(tx) = response_tx {
let _ = send_response(tx, valid).await;
}
let _: () = reply_envelope.send(valid).await;
})
})
.mutate_on::<DeleteToken>(|agent, envelope| {
let session_id = envelope.message().session_id.clone();
agent.model.tokens.remove(&session_id);
AgentReply::immediate()
})
.mutate_on::<CleanupExpired>(|agent, _envelope| {
agent.model.tokens.retain(|_session_id, data| !data.is_expired());
tracing::debug!(
"Cleaned up expired CSRF tokens, {} tokens remaining",
agent.model.tokens.len()
);
AgentReply::immediate()
});
Ok(builder.start().await)
}
fn get_or_create_token_internal(model: &mut Self, session_id: &SessionId) -> CsrfToken {
if let Some(data) = model.tokens.get(session_id) {
if !data.is_expired() {
return data.token.clone();
}
}
let new_token = CsrfToken::generate();
model
.tokens
.insert(session_id.clone(), CsrfTokenData::new(new_token.clone()));
new_token
}
fn validate_and_rotate_token(
model: &mut Self,
session_id: &SessionId,
token: &CsrfToken,
) -> bool {
let valid = model
.tokens
.get(session_id)
.filter(|data| !data.is_expired() && &data.token == token)
.is_some();
if valid {
let new_token = CsrfToken::generate();
model
.tokens
.insert(session_id.clone(), CsrfTokenData::new(new_token));
}
valid
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_csrf_token_generation() {
let token1 = CsrfToken::generate();
let token2 = CsrfToken::generate();
assert_ne!(token1, token2);
assert_eq!(token1.as_str().len(), 43); }
#[test]
fn test_csrf_token_display() {
let token = CsrfToken::generate();
let as_string = format!("{token}");
assert_eq!(as_string, token.as_str());
}
#[test]
fn test_csrf_token_from_string() {
let original = "test_token_value";
let token = CsrfToken::from_string(original.to_string());
assert_eq!(token.as_str(), original);
}
#[test]
fn test_csrf_token_data_creation() {
let token = CsrfToken::generate();
let data = CsrfTokenData::new(token.clone());
assert_eq!(data.token, token);
assert!(!data.is_expired());
assert!(data.expires_at > Utc::now());
}
#[test]
fn test_csrf_token_data_expiration() {
let token = CsrfToken::generate();
let mut data = CsrfTokenData::new(token);
data.expires_at = Utc::now() - Duration::hours(1);
assert!(data.is_expired());
}
#[tokio::test(flavor = "multi_thread")]
async fn test_csrf_manager_spawn() {
let mut runtime = ActonApp::launch();
let result = CsrfManagerAgent::spawn(&mut runtime).await;
assert!(result.is_ok());
}
#[tokio::test(flavor = "multi_thread")]
async fn test_get_or_create_token() {
let mut runtime = ActonApp::launch();
let handle = CsrfManagerAgent::spawn(&mut runtime).await.unwrap();
let session_id = SessionId::generate();
let (request, rx) = GetOrCreateToken::new(session_id.clone());
handle.send(request).await;
let token1 = rx.await.expect("Failed to receive token");
let (request2, rx2) = GetOrCreateToken::new(session_id);
handle.send(request2).await;
let token2 = rx2.await.expect("Failed to receive token");
assert_eq!(token1, token2);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_validate_token_success() {
let mut runtime = ActonApp::launch();
let handle = CsrfManagerAgent::spawn(&mut runtime).await.unwrap();
let session_id = SessionId::generate();
let (request, rx) = GetOrCreateToken::new(session_id.clone());
handle.send(request).await;
let token = rx.await.expect("Failed to receive token");
let (validate_request, validate_rx) =
ValidateToken::new(session_id.clone(), token.clone());
handle.send(validate_request).await;
let valid = validate_rx.await.expect("Failed to receive validation result");
assert!(valid);
let (validate_request2, validate_rx2) = ValidateToken::new(session_id, token);
handle.send(validate_request2).await;
let valid2 = validate_rx2
.await
.expect("Failed to receive validation result");
assert!(!valid2);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_validate_token_failure() {
let mut runtime = ActonApp::launch();
let handle = CsrfManagerAgent::spawn(&mut runtime).await.unwrap();
let session_id = SessionId::generate();
let (request, rx) = GetOrCreateToken::new(session_id.clone());
handle.send(request).await;
let _token = rx.await.expect("Failed to receive token");
let wrong_token = CsrfToken::generate();
let (validate_request, validate_rx) = ValidateToken::new(session_id, wrong_token);
handle.send(validate_request).await;
let valid = validate_rx.await.expect("Failed to receive validation result");
assert!(!valid);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_delete_token() {
let mut runtime = ActonApp::launch();
let handle = CsrfManagerAgent::spawn(&mut runtime).await.unwrap();
let session_id = SessionId::generate();
let (request, rx) = GetOrCreateToken::new(session_id.clone());
handle.send(request).await;
let token = rx.await.expect("Failed to receive token");
let delete_request = DeleteToken::new(session_id.clone());
handle.send(delete_request).await;
let (validate_request, validate_rx) = ValidateToken::new(session_id, token);
handle.send(validate_request).await;
let valid = validate_rx.await.expect("Failed to receive validation result");
assert!(!valid);
}
}