acton_htmx/extractors/
csrf.rs1use crate::agents::{CsrfToken, GetOrCreateToken};
6use crate::auth::session::SessionId;
7use crate::state::ActonHtmxState;
8use acton_reactive::prelude::AgentHandleInterface;
9use axum::{
10 extract::{FromRef, FromRequestParts},
11 http::{request::Parts, StatusCode},
12};
13use std::time::Duration;
14
15#[derive(Debug, Clone)]
37pub struct CsrfTokenExtractor {
38 token: CsrfToken,
39}
40
41impl CsrfTokenExtractor {
42 #[must_use]
44 pub fn token(&self) -> &str {
45 self.token.as_str()
46 }
47
48 #[must_use]
50 pub const fn value(&self) -> &CsrfToken {
51 &self.token
52 }
53}
54
55impl<S> FromRequestParts<S> for CsrfTokenExtractor
56where
57 S: Send + Sync,
58 ActonHtmxState: axum::extract::FromRef<S>,
59{
60 type Rejection = (StatusCode, String);
61
62 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
63 let state = ActonHtmxState::from_ref(state);
65
66 let session_id = parts
68 .extensions
69 .get::<SessionId>()
70 .cloned()
71 .ok_or_else(|| {
72 (
73 StatusCode::INTERNAL_SERVER_ERROR,
74 "Session not found - ensure SessionMiddleware is applied".to_string(),
75 )
76 })?;
77
78 let (request, rx) = GetOrCreateToken::new(session_id);
80 state.csrf_manager().send(request).await;
81
82 let timeout = Duration::from_millis(100);
84 let token = tokio::time::timeout(timeout, rx)
85 .await
86 .map_err(|_| {
87 (
88 StatusCode::INTERNAL_SERVER_ERROR,
89 "CSRF token retrieval timeout".to_string(),
90 )
91 })?
92 .map_err(|_| {
93 (
94 StatusCode::INTERNAL_SERVER_ERROR,
95 "CSRF token retrieval error".to_string(),
96 )
97 })?;
98
99 Ok(Self { token })
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn test_csrf_token_extractor_creation() {
109 let token = CsrfToken::generate();
110 let extractor = CsrfTokenExtractor {
111 token: token.clone(),
112 };
113
114 assert_eq!(extractor.token(), token.as_str());
115 assert_eq!(extractor.value(), &token);
116 }
117
118 #[test]
119 fn test_csrf_token_extractor_debug() {
120 let token = CsrfToken::generate();
121 let extractor = CsrfTokenExtractor { token };
122
123 let debug_str = format!("{extractor:?}");
124 assert!(debug_str.contains("CsrfTokenExtractor"));
125 }
126
127 #[test]
128 fn test_csrf_token_extractor_clone() {
129 let token = CsrfToken::generate();
130 let extractor = CsrfTokenExtractor { token };
131 let cloned = extractor.clone();
132
133 assert_eq!(extractor.token(), cloned.token());
134 }
135}