acton_htmx/extractors/
csrf.rs

1//! CSRF token extractors for Axum handlers
2//!
3//! Provides extractors that allow handlers to access CSRF tokens for rendering in templates.
4
5use crate::agents::{CsrfToken, GetOrCreateToken};
6use crate::auth::session::SessionId;
7use crate::state::ActonHtmxState;
8use acton_reactive::prelude::AgentHandleInterface;
9use axum::{
10    extract::{FromRef, FromRequestParts},
11    http::{request::Parts, StatusCode},
12};
13use std::time::Duration;
14
15/// Extractor for CSRF token
16///
17/// Retrieves or creates a CSRF token for the current session.
18/// Requires SessionMiddleware to be applied first.
19///
20/// # Example
21///
22/// ```rust,ignore
23/// use acton_htmx::extractors::CsrfTokenExtractor;
24/// use axum::{response::Html, extract::State};
25///
26/// async fn render_form(csrf: CsrfTokenExtractor) -> Html<String> {
27///     let token = csrf.token();
28///     Html(format!(
29///         r#"<form method="post">
30///             <input type="hidden" name="_csrf_token" value="{token}">
31///             <button type="submit">Submit</button>
32///         </form>"#
33///     ))
34/// }
35/// ```
36#[derive(Debug, Clone)]
37pub struct CsrfTokenExtractor {
38    token: CsrfToken,
39}
40
41impl CsrfTokenExtractor {
42    /// Get the CSRF token as a string
43    #[must_use]
44    pub fn token(&self) -> &str {
45        self.token.as_str()
46    }
47
48    /// Get the CSRF token value
49    #[must_use]
50    pub const fn value(&self) -> &CsrfToken {
51        &self.token
52    }
53}
54
55impl<S> FromRequestParts<S> for CsrfTokenExtractor
56where
57    S: Send + Sync,
58    ActonHtmxState: axum::extract::FromRef<S>,
59{
60    type Rejection = (StatusCode, String);
61
62    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
63        // Extract state
64        let state = ActonHtmxState::from_ref(state);
65
66        // Get session ID from extensions (set by SessionMiddleware)
67        let session_id = parts
68            .extensions
69            .get::<SessionId>()
70            .cloned()
71            .ok_or_else(|| {
72                (
73                    StatusCode::INTERNAL_SERVER_ERROR,
74                    "Session not found - ensure SessionMiddleware is applied".to_string(),
75                )
76            })?;
77
78        // Get or create CSRF token from CSRF manager
79        let (request, rx) = GetOrCreateToken::new(session_id);
80        state.csrf_manager().send(request).await;
81
82        // Wait for response with timeout
83        let timeout = Duration::from_millis(100);
84        let token = tokio::time::timeout(timeout, rx)
85            .await
86            .map_err(|_| {
87                (
88                    StatusCode::INTERNAL_SERVER_ERROR,
89                    "CSRF token retrieval timeout".to_string(),
90                )
91            })?
92            .map_err(|_| {
93                (
94                    StatusCode::INTERNAL_SERVER_ERROR,
95                    "CSRF token retrieval error".to_string(),
96                )
97            })?;
98
99        Ok(Self { token })
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_csrf_token_extractor_creation() {
109        let token = CsrfToken::generate();
110        let extractor = CsrfTokenExtractor {
111            token: token.clone(),
112        };
113
114        assert_eq!(extractor.token(), token.as_str());
115        assert_eq!(extractor.value(), &token);
116    }
117
118    #[test]
119    fn test_csrf_token_extractor_debug() {
120        let token = CsrfToken::generate();
121        let extractor = CsrfTokenExtractor { token };
122
123        let debug_str = format!("{extractor:?}");
124        assert!(debug_str.contains("CsrfTokenExtractor"));
125    }
126
127    #[test]
128    fn test_csrf_token_extractor_clone() {
129        let token = CsrfToken::generate();
130        let extractor = CsrfTokenExtractor { token };
131        let cloned = extractor.clone();
132
133        assert_eq!(extractor.token(), cloned.token());
134    }
135}