1use crate::agents::request_reply::{create_request_reply, send_response, ResponseChannel};
17use crate::agents::default_agent_config;
18use crate::auth::session::SessionId;
19use acton_reactive::prelude::*;
20use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
21use chrono::{DateTime, Duration, Utc};
22use rand::Rng;
23use std::collections::HashMap;
24use tokio::sync::oneshot;
25
26type CsrfAgentBuilder = ManagedAgent<Idle, CsrfManagerAgent>;
28
29#[derive(Clone, Debug, PartialEq, Eq, Hash)]
31pub struct CsrfToken(String);
32
33impl CsrfToken {
34 #[must_use]
36 pub fn generate() -> Self {
37 let mut rng = rand::rng();
38 let mut bytes = [0u8; 32];
39 rng.fill(&mut bytes);
40 Self(URL_SAFE_NO_PAD.encode(bytes))
41 }
42
43 #[must_use]
45 pub fn as_str(&self) -> &str {
46 &self.0
47 }
48
49 #[must_use]
51 pub const fn from_string(s: String) -> Self {
52 Self(s)
53 }
54}
55
56impl std::fmt::Display for CsrfToken {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "{}", self.0)
59 }
60}
61
62#[derive(Clone, Debug)]
64struct CsrfTokenData {
65 token: CsrfToken,
67 expires_at: DateTime<Utc>,
69}
70
71impl CsrfTokenData {
72 #[must_use]
74 fn new(token: CsrfToken) -> Self {
75 let expires_at = Utc::now() + Duration::hours(24);
76 Self { token, expires_at }
77 }
78
79 #[must_use]
81 fn is_expired(&self) -> bool {
82 Utc::now() > self.expires_at
83 }
84}
85
86#[derive(Debug, Default, Clone)]
88pub struct CsrfManagerAgent {
89 tokens: HashMap<SessionId, CsrfTokenData>,
91}
92
93#[derive(Clone, Debug)]
102pub struct GetOrCreateToken {
103 pub session_id: SessionId,
105 pub response_tx: Option<ResponseChannel<CsrfToken>>,
107}
108
109impl GetOrCreateToken {
110 #[must_use]
112 pub fn new(session_id: SessionId) -> (Self, oneshot::Receiver<CsrfToken>) {
113 let (response_tx, rx) = create_request_reply();
114 let request = Self {
115 session_id,
116 response_tx: Some(response_tx),
117 };
118 (request, rx)
119 }
120
121 #[must_use]
123 pub const fn agent_message(session_id: SessionId) -> Self {
124 Self {
125 session_id,
126 response_tx: None,
127 }
128 }
129}
130
131#[derive(Clone, Debug)]
136pub struct ValidateToken {
137 pub session_id: SessionId,
139 pub token: CsrfToken,
141 pub response_tx: Option<ResponseChannel<bool>>,
143}
144
145impl ValidateToken {
146 #[must_use]
148 pub fn new(session_id: SessionId, token: CsrfToken) -> (Self, oneshot::Receiver<bool>) {
149 let (response_tx, rx) = create_request_reply();
150 let request = Self {
151 session_id,
152 token,
153 response_tx: Some(response_tx),
154 };
155 (request, rx)
156 }
157
158 #[must_use]
160 pub const fn agent_message(session_id: SessionId, token: CsrfToken) -> Self {
161 Self {
162 session_id,
163 token,
164 response_tx: None,
165 }
166 }
167}
168
169#[derive(Clone, Debug)]
171pub struct DeleteToken {
172 pub session_id: SessionId,
174}
175
176impl DeleteToken {
177 #[must_use]
179 pub const fn new(session_id: SessionId) -> Self {
180 Self { session_id }
181 }
182}
183
184#[derive(Clone, Debug)]
186pub struct CleanupExpired;
187
188impl CsrfManagerAgent {
189 pub async fn spawn(runtime: &mut AgentRuntime) -> anyhow::Result<AgentHandle> {
195 let config = default_agent_config("csrf_manager")?;
196 let builder = runtime.new_agent_with_config::<Self>(config).await;
197 Self::configure_handlers(builder).await
198 }
199
200 async fn configure_handlers(mut builder: CsrfAgentBuilder) -> anyhow::Result<AgentHandle> {
202 builder
203 .mutate_on::<GetOrCreateToken>(|agent, envelope| {
205 let session_id = envelope.message().session_id.clone();
206 let response_tx = envelope.message().response_tx.clone();
207 let reply_envelope = envelope.reply_envelope();
208
209 let token = Self::get_or_create_token_internal(&mut agent.model, &session_id);
210
211 AgentReply::from_async(async move {
212 if let Some(tx) = response_tx {
214 let _ = send_response(tx, token.clone()).await;
215 }
216 let _: () = reply_envelope.send(token).await;
218 })
219 })
220 .mutate_on::<ValidateToken>(|agent, envelope| {
222 let session_id = envelope.message().session_id.clone();
223 let token = envelope.message().token.clone();
224 let response_tx = envelope.message().response_tx.clone();
225 let reply_envelope = envelope.reply_envelope();
226
227 let valid = Self::validate_and_rotate_token(&mut agent.model, &session_id, &token);
228
229 AgentReply::from_async(async move {
230 if let Some(tx) = response_tx {
232 let _ = send_response(tx, valid).await;
233 }
234 let _: () = reply_envelope.send(valid).await;
236 })
237 })
238 .mutate_on::<DeleteToken>(|agent, envelope| {
240 let session_id = envelope.message().session_id.clone();
241 agent.model.tokens.remove(&session_id);
242 AgentReply::immediate()
243 })
244 .mutate_on::<CleanupExpired>(|agent, _envelope| {
246 agent.model.tokens.retain(|_session_id, data| !data.is_expired());
247 tracing::debug!(
248 "Cleaned up expired CSRF tokens, {} tokens remaining",
249 agent.model.tokens.len()
250 );
251 AgentReply::immediate()
252 });
253
254 Ok(builder.start().await)
255 }
256
257 fn get_or_create_token_internal(model: &mut Self, session_id: &SessionId) -> CsrfToken {
259 if let Some(data) = model.tokens.get(session_id) {
260 if !data.is_expired() {
261 return data.token.clone();
262 }
263 }
264
265 let new_token = CsrfToken::generate();
267 model
268 .tokens
269 .insert(session_id.clone(), CsrfTokenData::new(new_token.clone()));
270 new_token
271 }
272
273 fn validate_and_rotate_token(
275 model: &mut Self,
276 session_id: &SessionId,
277 token: &CsrfToken,
278 ) -> bool {
279 let valid = model
280 .tokens
281 .get(session_id)
282 .filter(|data| !data.is_expired() && &data.token == token)
283 .is_some();
284
285 if valid {
286 let new_token = CsrfToken::generate();
287 model
288 .tokens
289 .insert(session_id.clone(), CsrfTokenData::new(new_token));
290 }
291
292 valid
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_csrf_token_generation() {
302 let token1 = CsrfToken::generate();
303 let token2 = CsrfToken::generate();
304
305 assert_ne!(token1, token2);
307
308 assert_eq!(token1.as_str().len(), 43); }
311
312 #[test]
313 fn test_csrf_token_display() {
314 let token = CsrfToken::generate();
315 let as_string = format!("{token}");
316 assert_eq!(as_string, token.as_str());
317 }
318
319 #[test]
320 fn test_csrf_token_from_string() {
321 let original = "test_token_value";
322 let token = CsrfToken::from_string(original.to_string());
323 assert_eq!(token.as_str(), original);
324 }
325
326 #[test]
327 fn test_csrf_token_data_creation() {
328 let token = CsrfToken::generate();
329 let data = CsrfTokenData::new(token.clone());
330
331 assert_eq!(data.token, token);
332 assert!(!data.is_expired());
333 assert!(data.expires_at > Utc::now());
334 }
335
336 #[test]
337 fn test_csrf_token_data_expiration() {
338 let token = CsrfToken::generate();
339 let mut data = CsrfTokenData::new(token);
340
341 data.expires_at = Utc::now() - Duration::hours(1);
343
344 assert!(data.is_expired());
345 }
346
347 #[tokio::test(flavor = "multi_thread")]
348 async fn test_csrf_manager_spawn() {
349 let mut runtime = ActonApp::launch();
350 let result = CsrfManagerAgent::spawn(&mut runtime).await;
351 assert!(result.is_ok());
352 }
353
354 #[tokio::test(flavor = "multi_thread")]
355 async fn test_get_or_create_token() {
356 let mut runtime = ActonApp::launch();
357 let handle = CsrfManagerAgent::spawn(&mut runtime).await.unwrap();
358
359 let session_id = SessionId::generate();
360 let (request, rx) = GetOrCreateToken::new(session_id.clone());
361
362 handle.send(request).await;
363
364 let token1 = rx.await.expect("Failed to receive token");
365
366 let (request2, rx2) = GetOrCreateToken::new(session_id);
368 handle.send(request2).await;
369
370 let token2 = rx2.await.expect("Failed to receive token");
371
372 assert_eq!(token1, token2);
373 }
374
375 #[tokio::test(flavor = "multi_thread")]
376 async fn test_validate_token_success() {
377 let mut runtime = ActonApp::launch();
378 let handle = CsrfManagerAgent::spawn(&mut runtime).await.unwrap();
379
380 let session_id = SessionId::generate();
381
382 let (request, rx) = GetOrCreateToken::new(session_id.clone());
384 handle.send(request).await;
385 let token = rx.await.expect("Failed to receive token");
386
387 let (validate_request, validate_rx) =
389 ValidateToken::new(session_id.clone(), token.clone());
390 handle.send(validate_request).await;
391 let valid = validate_rx.await.expect("Failed to receive validation result");
392
393 assert!(valid);
394
395 let (validate_request2, validate_rx2) = ValidateToken::new(session_id, token);
397 handle.send(validate_request2).await;
398 let valid2 = validate_rx2
399 .await
400 .expect("Failed to receive validation result");
401
402 assert!(!valid2);
403 }
404
405 #[tokio::test(flavor = "multi_thread")]
406 async fn test_validate_token_failure() {
407 let mut runtime = ActonApp::launch();
408 let handle = CsrfManagerAgent::spawn(&mut runtime).await.unwrap();
409
410 let session_id = SessionId::generate();
411
412 let (request, rx) = GetOrCreateToken::new(session_id.clone());
414 handle.send(request).await;
415 let _token = rx.await.expect("Failed to receive token");
416
417 let wrong_token = CsrfToken::generate();
419 let (validate_request, validate_rx) = ValidateToken::new(session_id, wrong_token);
420 handle.send(validate_request).await;
421 let valid = validate_rx.await.expect("Failed to receive validation result");
422
423 assert!(!valid);
424 }
425
426 #[tokio::test(flavor = "multi_thread")]
427 async fn test_delete_token() {
428 let mut runtime = ActonApp::launch();
429 let handle = CsrfManagerAgent::spawn(&mut runtime).await.unwrap();
430
431 let session_id = SessionId::generate();
432
433 let (request, rx) = GetOrCreateToken::new(session_id.clone());
435 handle.send(request).await;
436 let token = rx.await.expect("Failed to receive token");
437
438 let delete_request = DeleteToken::new(session_id.clone());
440 handle.send(delete_request).await;
441
442 let (validate_request, validate_rx) = ValidateToken::new(session_id, token);
444 handle.send(validate_request).await;
445 let valid = validate_rx.await.expect("Failed to receive validation result");
446
447 assert!(!valid);
448 }
449}