acton_htmx/agents/
session_manager.rs

1//! Session Manager Agent
2//!
3//! Actor-based session management using acton-reactive.
4//! Implements hybrid in-memory + Redis storage strategy.
5//!
6//! This module uses unified message patterns that support both:
7//! 1. **Agent-to-Agent**: Using `reply_envelope` for inter-agent communication
8//! 2. **Web Handler**: Using optional oneshot channels for request-reply from Axum handlers
9//!
10//! Messages with optional `response_tx` fields can be used from both contexts.
11
12use crate::agents::request_reply::{create_request_reply, send_response, ResponseChannel};
13use crate::agents::default_agent_config;
14use crate::auth::session::{FlashMessage, SessionData, SessionId};
15use acton_reactive::prelude::*;
16use chrono::{DateTime, Duration, Utc};
17use std::cmp::Reverse;
18use std::collections::{BinaryHeap, HashMap};
19use tokio::sync::oneshot;
20
21// Type alias for the ManagedAgent builder type
22type SessionAgentBuilder = ManagedAgent<Idle, SessionManagerAgent>;
23
24#[cfg(feature = "redis")]
25use deadpool_redis::Pool as RedisPool;
26
27/// Session manager agent model
28#[derive(Debug, Default, Clone)]
29pub struct SessionManagerAgent {
30    /// In-memory session storage
31    sessions: HashMap<SessionId, SessionData>,
32    /// Expiry queue for cleanup (min-heap by expiration time)
33    expiry_queue: BinaryHeap<Reverse<(DateTime<Utc>, SessionId)>>,
34    /// Optional Redis backend for distributed sessions
35    #[cfg(feature = "redis")]
36    redis: Option<RedisPool>,
37}
38
39// ============================================================================
40// Unified Messages (support both web handlers and agent-to-agent)
41// ============================================================================
42
43/// Load a session by ID
44///
45/// Supports both web handler (with response_tx) and agent-to-agent (reply_envelope) patterns.
46#[derive(Clone, Debug)]
47pub struct LoadSession {
48    /// The session ID to load
49    pub session_id: SessionId,
50    /// Optional response channel for web handlers
51    pub response_tx: Option<ResponseChannel<Option<SessionData>>>,
52}
53
54impl LoadSession {
55    /// Create a new load session message for agent-to-agent communication
56    #[must_use]
57    pub const fn new(session_id: SessionId) -> Self {
58        Self {
59            session_id,
60            response_tx: None,
61        }
62    }
63
64    /// Create a new load session request with response channel for web handlers
65    #[must_use]
66    pub fn with_response(session_id: SessionId) -> (Self, oneshot::Receiver<Option<SessionData>>) {
67        let (response_tx, rx) = create_request_reply();
68        let request = Self {
69            session_id,
70            response_tx: Some(response_tx),
71        };
72        (request, rx)
73    }
74}
75
76/// Save session data
77///
78/// Supports both web handler (with response_tx) and agent-to-agent patterns.
79#[derive(Clone, Debug)]
80pub struct SaveSession {
81    /// The session ID to save
82    pub session_id: SessionId,
83    /// The session data to persist
84    pub data: SessionData,
85    /// Optional response channel for confirmation
86    pub response_tx: Option<ResponseChannel<bool>>,
87}
88
89impl SaveSession {
90    /// Create a new save session message (fire-and-forget)
91    #[must_use]
92    pub const fn new(session_id: SessionId, data: SessionData) -> Self {
93        Self {
94            session_id,
95            data,
96            response_tx: None,
97        }
98    }
99
100    /// Create a new save session request with confirmation
101    #[must_use]
102    pub fn with_confirmation(
103        session_id: SessionId,
104        data: SessionData,
105    ) -> (Self, oneshot::Receiver<bool>) {
106        let (response_tx, rx) = create_request_reply();
107        let request = Self {
108            session_id,
109            data,
110            response_tx: Some(response_tx),
111        };
112        (request, rx)
113    }
114}
115
116/// Get and clear flash messages from a session
117///
118/// Supports both web handler (with response_tx) and agent-to-agent (reply_envelope) patterns.
119#[derive(Clone, Debug)]
120pub struct TakeFlashes {
121    /// The session ID to retrieve flashes from
122    pub session_id: SessionId,
123    /// Optional response channel for web handlers
124    pub response_tx: Option<ResponseChannel<Vec<FlashMessage>>>,
125}
126
127impl TakeFlashes {
128    /// Create a new take flashes message for agent-to-agent communication
129    #[must_use]
130    pub const fn new(session_id: SessionId) -> Self {
131        Self {
132            session_id,
133            response_tx: None,
134        }
135    }
136
137    /// Create a new take flashes request with response channel for web handlers
138    #[must_use]
139    pub fn with_response(session_id: SessionId) -> (Self, oneshot::Receiver<Vec<FlashMessage>>) {
140        let (response_tx, rx) = create_request_reply();
141        let request = Self {
142            session_id,
143            response_tx: Some(response_tx),
144        };
145        (request, rx)
146    }
147}
148
149/// Message to delete a session by ID
150#[derive(Clone, Debug)]
151pub struct DeleteSession {
152    /// The session ID to delete
153    pub session_id: SessionId,
154}
155
156/// Message to trigger cleanup of expired sessions
157#[derive(Clone, Debug)]
158pub struct CleanupExpired;
159
160/// Message to add a flash message to a session
161#[derive(Clone, Debug)]
162pub struct AddFlash {
163    /// The session ID to add the flash to
164    pub session_id: SessionId,
165    /// The flash message to add
166    pub message: FlashMessage,
167}
168
169impl SessionManagerAgent {
170    /// Spawn session manager agent without Redis backend
171    ///
172    /// Uses in-memory storage only. Suitable for development or single-instance deployments.
173    ///
174    /// # Errors
175    ///
176    /// Returns error if agent initialization fails
177    pub async fn spawn(runtime: &mut AgentRuntime) -> anyhow::Result<AgentHandle> {
178        let config = default_agent_config("session_manager")?;
179        let builder = runtime.new_agent_with_config::<Self>(config).await;
180        Self::configure_handlers(builder).await
181    }
182
183    /// Spawn session manager with Redis backend
184    ///
185    /// Uses Redis for distributed session storage with in-memory caching.
186    ///
187    /// # Errors
188    ///
189    /// Returns error if agent initialization fails
190    #[cfg(feature = "redis")]
191    pub async fn spawn_with_redis(
192        runtime: &mut AgentRuntime,
193        redis_pool: RedisPool,
194    ) -> anyhow::Result<AgentHandle> {
195        let config = default_agent_config("session_manager")?;
196        let mut builder = runtime.new_agent_with_config::<Self>(config).await;
197        builder.model.redis = Some(redis_pool);
198        Self::configure_handlers(builder).await
199    }
200
201    /// Configure all message handlers for the session manager
202    async fn configure_handlers(mut builder: SessionAgentBuilder) -> anyhow::Result<AgentHandle> {
203        builder
204            // ================================================================
205            // Unified Handlers (support both web and agent-to-agent patterns)
206            // ================================================================
207            .act_on::<LoadSession>(|agent, envelope| {
208                let session_id = envelope.message().session_id.clone();
209                let response_tx = envelope.message().response_tx.clone();
210                let session = agent.model.sessions.get(&session_id).cloned();
211                let reply_envelope = envelope.reply_envelope();
212
213                Box::pin(async move {
214                    // Use validate_and_touch to combine expiry check and touch
215                    let result = session.and_then(|mut data| {
216                        if data.validate_and_touch(Duration::hours(24)) {
217                            Some(data)
218                        } else {
219                            None
220                        }
221                    });
222
223                    // Send response to web handler if channel provided
224                    if let Some(tx) = response_tx {
225                        let _ = send_response(tx, result.clone()).await;
226                    }
227
228                    // Always send reply envelope for agent-to-agent
229                    let _: () = reply_envelope.send(result).await;
230                })
231            })
232            .mutate_on::<SaveSession>(|agent, envelope| {
233                let session_id = envelope.message().session_id.clone();
234                let data = envelope.message().data.clone();
235                let response_tx = envelope.message().response_tx.clone();
236
237                agent
238                    .model
239                    .sessions
240                    .insert(session_id.clone(), data.clone());
241                agent
242                    .model
243                    .expiry_queue
244                    .push(Reverse((data.expires_at, session_id)));
245
246                AgentReply::from_async(async move {
247                    // Send confirmation to web handler if channel provided
248                    if let Some(tx) = response_tx {
249                        let _ = send_response(tx, true).await;
250                    }
251                })
252            })
253            .mutate_on::<TakeFlashes>(|agent, envelope| {
254                let session_id = envelope.message().session_id.clone();
255                let response_tx = envelope.message().response_tx.clone();
256                let reply_envelope = envelope.reply_envelope();
257
258                // Take and clear flash messages atomically
259                let messages = agent
260                    .model
261                    .sessions
262                    .get_mut(&session_id)
263                    .map(|session| std::mem::take(&mut session.flash_messages))
264                    .unwrap_or_default();
265
266                AgentReply::from_async(async move {
267                    // Send response to web handler if channel provided
268                    if let Some(tx) = response_tx {
269                        let _ = send_response(tx, messages.clone()).await;
270                    }
271
272                    // Always send reply envelope for agent-to-agent
273                    let _: () = reply_envelope.send(messages).await;
274                })
275            })
276            .mutate_on::<DeleteSession>(|agent, envelope| {
277                agent.model.sessions.remove(&envelope.message().session_id);
278                AgentReply::immediate()
279            })
280            .mutate_on::<CleanupExpired>(|agent, _envelope| {
281                let now = Utc::now();
282                let mut expired = Vec::new();
283
284                loop {
285                    let should_pop = agent
286                        .model
287                        .expiry_queue
288                        .peek()
289                        .is_some_and(|Reverse((expiry, _))| *expiry <= now);
290
291                    if should_pop {
292                        if let Some(Reverse((_, session_id))) = agent.model.expiry_queue.pop() {
293                            expired.push(session_id);
294                        }
295                    } else {
296                        break;
297                    }
298                }
299
300                for session_id in expired {
301                    agent.model.sessions.remove(&session_id);
302                }
303
304                AgentReply::immediate()
305            })
306            .mutate_on::<AddFlash>(|agent, envelope| {
307                let session_id = envelope.message().session_id.clone();
308                let message = envelope.message().message.clone();
309
310                if let Some(session) = agent.model.sessions.get_mut(&session_id) {
311                    session.flash_messages.push(message);
312                }
313
314                AgentReply::immediate()
315            });
316
317        Ok(builder.start().await)
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[tokio::test(flavor = "multi_thread")]
326    async fn test_session_manager_creation() {
327        let mut runtime = ActonApp::launch();
328        let result = SessionManagerAgent::spawn(&mut runtime).await;
329        assert!(result.is_ok());
330        runtime.shutdown_all().await.expect("Failed to shutdown");
331    }
332
333    #[tokio::test(flavor = "multi_thread")]
334    async fn test_session_save_and_load_with_verification() {
335        let mut runtime = ActonApp::launch();
336        let session_manager = SessionManagerAgent::spawn(&mut runtime).await.unwrap();
337
338        let session_id = SessionId::generate();
339        let mut data = SessionData::new();
340        data.set("test_key".to_string(), "test_value".to_string())
341            .unwrap();
342
343        // Save session
344        session_manager
345            .send(SaveSession::new(session_id.clone(), data.clone()))
346            .await;
347
348        // Allow message processing
349        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
350
351        // Load session using web handler style (with oneshot channel for verification)
352        let (request, rx) = LoadSession::with_response(session_id.clone());
353        session_manager.send(request).await;
354
355        // Verify response
356        let loaded_data = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx)
357            .await
358            .expect("Timeout waiting for response")
359            .expect("Channel closed");
360
361        assert!(loaded_data.is_some(), "Session should exist");
362        let loaded = loaded_data.unwrap();
363        let loaded_value: Option<String> = loaded.get("test_key").unwrap();
364        assert_eq!(loaded_value, Some("test_value".to_string()));
365
366        runtime.shutdown_all().await.expect("Failed to shutdown");
367    }
368
369    #[tokio::test(flavor = "multi_thread")]
370    async fn test_session_not_found() {
371        let mut runtime = ActonApp::launch();
372        let session_manager = SessionManagerAgent::spawn(&mut runtime).await.unwrap();
373
374        let session_id = SessionId::generate();
375
376        // Try to load non-existent session
377        let (request, rx) = LoadSession::with_response(session_id);
378        session_manager.send(request).await;
379
380        // Verify response
381        let result = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx)
382            .await
383            .expect("Timeout waiting for response")
384            .expect("Channel closed");
385
386        assert!(result.is_none(), "Session should not exist");
387
388        runtime.shutdown_all().await.expect("Failed to shutdown");
389    }
390
391    #[tokio::test(flavor = "multi_thread")]
392    async fn test_session_delete_with_verification() {
393        let mut runtime = ActonApp::launch();
394        let session_manager = SessionManagerAgent::spawn(&mut runtime).await.unwrap();
395
396        let session_id = SessionId::generate();
397        let data = SessionData::new();
398
399        // Save session
400        session_manager
401            .send(SaveSession::new(session_id.clone(), data))
402            .await;
403
404        // Allow message processing
405        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
406
407        // Verify session exists
408        let (request, rx) = LoadSession::with_response(session_id.clone());
409        session_manager.send(request).await;
410        let result = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx)
411            .await
412            .expect("Timeout")
413            .expect("Channel closed");
414        assert!(result.is_some(), "Session should exist before deletion");
415
416        // Delete session
417        session_manager
418            .send(DeleteSession {
419                session_id: session_id.clone(),
420            })
421            .await;
422
423        // Allow message processing
424        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
425
426        // Verify session is deleted
427        let (request, rx) = LoadSession::with_response(session_id);
428        session_manager.send(request).await;
429        let result = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx)
430            .await
431            .expect("Timeout")
432            .expect("Channel closed");
433        assert!(result.is_none(), "Session should not exist after deletion");
434
435        runtime.shutdown_all().await.expect("Failed to shutdown");
436    }
437
438    #[tokio::test(flavor = "multi_thread")]
439    async fn test_flash_messages_with_verification() {
440        let mut runtime = ActonApp::launch();
441        let session_manager = SessionManagerAgent::spawn(&mut runtime).await.unwrap();
442
443        let session_id = SessionId::generate();
444        let data = SessionData::new();
445
446        // Save session first
447        session_manager
448            .send(SaveSession::new(session_id.clone(), data))
449            .await;
450
451        // Allow message processing
452        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
453
454        // Add flash messages
455        session_manager
456            .send(AddFlash {
457                session_id: session_id.clone(),
458                message: FlashMessage::success("Success message"),
459            })
460            .await;
461
462        session_manager
463            .send(AddFlash {
464                session_id: session_id.clone(),
465                message: FlashMessage::error("Error message"),
466            })
467            .await;
468
469        // Allow message processing
470        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
471
472        // Get and verify flashes (using TakeFlashesRequest which clears them)
473        let (request, rx) = TakeFlashes::with_response(session_id.clone());
474        session_manager.send(request).await;
475
476        let flashes = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx)
477            .await
478            .expect("Timeout waiting for response")
479            .expect("Channel closed");
480
481        assert_eq!(flashes.len(), 2, "Should have 2 flash messages");
482        assert_eq!(flashes[0].message, "Success message");
483        assert_eq!(flashes[1].message, "Error message");
484
485        // Verify flashes are cleared after taking
486        let (request, rx) = TakeFlashes::with_response(session_id);
487        session_manager.send(request).await;
488
489        let flashes = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx)
490            .await
491            .expect("Timeout")
492            .expect("Channel closed");
493
494        assert_eq!(flashes.len(), 0, "Flashes should be cleared after taking");
495
496        runtime.shutdown_all().await.expect("Failed to shutdown");
497    }
498
499    #[tokio::test(flavor = "multi_thread")]
500    async fn test_session_expiry_cleanup() {
501        let mut runtime = ActonApp::launch();
502        let session_manager = SessionManagerAgent::spawn(&mut runtime).await.unwrap();
503
504        let session_id = SessionId::generate();
505        let mut data = SessionData::new();
506        // Set expiry to the past
507        data.expires_at = Utc::now() - Duration::hours(1);
508
509        // Save expired session
510        session_manager
511            .send(SaveSession::new(session_id.clone(), data))
512            .await;
513
514        // Allow message processing
515        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
516
517        // Trigger cleanup
518        session_manager.send(CleanupExpired).await;
519
520        // Allow message processing
521        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
522
523        // Verify expired session is not returned
524        let (request, rx) = LoadSession::with_response(session_id);
525        session_manager.send(request).await;
526
527        let result = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx)
528            .await
529            .expect("Timeout")
530            .expect("Channel closed");
531
532        assert!(result.is_none(), "Expired session should not be returned");
533
534        runtime.shutdown_all().await.expect("Failed to shutdown");
535    }
536
537    #[tokio::test(flavor = "multi_thread")]
538    async fn test_session_touch_extends_expiry() {
539        let mut runtime = ActonApp::launch();
540        let session_manager = SessionManagerAgent::spawn(&mut runtime).await.unwrap();
541
542        let session_id = SessionId::generate();
543        let mut data = SessionData::new();
544        let original_expiry = Utc::now() + Duration::hours(1);
545        data.expires_at = original_expiry;
546
547        // Save session
548        session_manager
549            .send(SaveSession::new(session_id.clone(), data))
550            .await;
551
552        // Allow message processing
553        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
554
555        // Load session (which should touch and extend expiry)
556        let (request, rx) = LoadSession::with_response(session_id);
557        session_manager.send(request).await;
558
559        let loaded = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx)
560            .await
561            .expect("Timeout")
562            .expect("Channel closed");
563
564        assert!(loaded.is_some(), "Session should exist");
565        let loaded_data = loaded.unwrap();
566        assert!(
567            loaded_data.expires_at > original_expiry,
568            "Expiry should be extended after touch"
569        );
570
571        runtime.shutdown_all().await.expect("Failed to shutdown");
572    }
573
574    #[tokio::test(flavor = "multi_thread")]
575    async fn test_save_with_confirmation() {
576        let mut runtime = ActonApp::launch();
577        let session_manager = SessionManagerAgent::spawn(&mut runtime).await.unwrap();
578
579        let session_id = SessionId::generate();
580        let data = SessionData::new();
581
582        // Save with confirmation
583        let (request, rx) = SaveSession::with_confirmation(session_id, data);
584        session_manager.send(request).await;
585
586        // Verify confirmation
587        let confirmed = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx)
588            .await
589            .expect("Timeout waiting for confirmation")
590            .expect("Channel closed");
591
592        assert!(confirmed, "Save should be confirmed");
593
594        runtime.shutdown_all().await.expect("Failed to shutdown");
595    }
596
597    #[tokio::test(flavor = "multi_thread")]
598    async fn test_concurrent_flash_messages() {
599        let mut runtime = ActonApp::launch();
600        let session_manager = SessionManagerAgent::spawn(&mut runtime).await.unwrap();
601
602        let session_id = SessionId::generate();
603        let data = SessionData::new();
604
605        // Save session
606        session_manager
607            .send(SaveSession::new(session_id.clone(), data))
608            .await;
609
610        // Allow message processing
611        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
612
613        // Add multiple flash messages concurrently
614        let handles: Vec<_> = (0..10)
615            .map(|i| {
616                let sm = session_manager.clone();
617                let sid = session_id.clone();
618                tokio::spawn(async move {
619                    sm.send(AddFlash {
620                        session_id: sid,
621                        message: FlashMessage::info(format!("Message {i}")),
622                    })
623                    .await;
624                })
625            })
626            .collect();
627
628        // Wait for all sends to complete
629        for handle in handles {
630            handle.await.unwrap();
631        }
632
633        // Allow message processing
634        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
635
636        // Retrieve flashes
637        let (request, rx) = TakeFlashes::with_response(session_id);
638        session_manager.send(request).await;
639
640        let flashes = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx)
641            .await
642            .expect("Timeout")
643            .expect("Channel closed");
644
645        assert_eq!(
646            flashes.len(),
647            10,
648            "Should have all 10 flash messages despite concurrent adds"
649        );
650
651        runtime.shutdown_all().await.expect("Failed to shutdown");
652    }
653}