use crate::agents::{CsrfToken, GetOrCreateToken};
use crate::auth::session::SessionId;
use crate::state::ActonHtmxState;
use acton_reactive::prelude::AgentHandleInterface;
use axum::{
extract::{FromRef, FromRequestParts},
http::{request::Parts, StatusCode},
};
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct CsrfTokenExtractor {
token: CsrfToken,
}
impl CsrfTokenExtractor {
#[must_use]
pub fn token(&self) -> &str {
self.token.as_str()
}
#[must_use]
pub const fn value(&self) -> &CsrfToken {
&self.token
}
}
impl<S> FromRequestParts<S> for CsrfTokenExtractor
where
S: Send + Sync,
ActonHtmxState: axum::extract::FromRef<S>,
{
type Rejection = (StatusCode, String);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let state = ActonHtmxState::from_ref(state);
let session_id = parts
.extensions
.get::<SessionId>()
.cloned()
.ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Session not found - ensure SessionMiddleware is applied".to_string(),
)
})?;
let (request, rx) = GetOrCreateToken::new(session_id);
state.csrf_manager().send(request).await;
let timeout = Duration::from_millis(100);
let token = tokio::time::timeout(timeout, rx)
.await
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"CSRF token retrieval timeout".to_string(),
)
})?
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"CSRF token retrieval error".to_string(),
)
})?;
Ok(Self { token })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_csrf_token_extractor_creation() {
let token = CsrfToken::generate();
let extractor = CsrfTokenExtractor {
token: token.clone(),
};
assert_eq!(extractor.token(), token.as_str());
assert_eq!(extractor.value(), &token);
}
#[test]
fn test_csrf_token_extractor_debug() {
let token = CsrfToken::generate();
let extractor = CsrfTokenExtractor { token };
let debug_str = format!("{extractor:?}");
assert!(debug_str.contains("CsrfTokenExtractor"));
}
#[test]
fn test_csrf_token_extractor_clone() {
let token = CsrfToken::generate();
let extractor = CsrfTokenExtractor { token };
let cloned = extractor.clone();
assert_eq!(extractor.token(), cloned.token());
}
}