Skip to main content

arbiter_session/
any_store.rs

1//! Unified session store enum dispatching to either in-memory or storage-backed.
2
3use crate::error::SessionError;
4use crate::model::{SessionId, TaskSession};
5use crate::storage_store::StorageBackedSessionStore;
6use crate::store::{CreateSessionRequest, SessionStore};
7
8/// A session store that dispatches to either in-memory or storage-backed.
9#[derive(Clone)]
10pub enum AnySessionStore {
11    InMemory(SessionStore),
12    StorageBacked(StorageBackedSessionStore),
13}
14
15impl AnySessionStore {
16    /// Create a new task session and return it.
17    pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
18        match self {
19            AnySessionStore::InMemory(s) => s.create(req).await,
20            AnySessionStore::StorageBacked(s) => s.create(req).await,
21        }
22    }
23
24    /// Record a tool call against the session.
25    pub async fn use_session(
26        &self,
27        session_id: SessionId,
28        tool_name: &str,
29    ) -> Result<TaskSession, SessionError> {
30        match self {
31            AnySessionStore::InMemory(s) => s.use_session(session_id, tool_name).await,
32            AnySessionStore::StorageBacked(s) => s.use_session(session_id, tool_name).await,
33        }
34    }
35
36    /// Atomically validate and record a batch of tool calls against the session.
37    ///
38    /// Validates ALL tools and budget atomically under a single
39    /// lock acquisition. No budget is consumed unless every tool in the batch
40    /// passes validation.
41    pub async fn use_session_batch(
42        &self,
43        session_id: SessionId,
44        tool_names: &[&str],
45    ) -> Result<TaskSession, SessionError> {
46        match self {
47            AnySessionStore::InMemory(s) => s.use_session_batch(session_id, tool_names).await,
48            AnySessionStore::StorageBacked(s) => s.use_session_batch(session_id, tool_names).await,
49        }
50    }
51
52    /// Close a session.
53    pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
54        match self {
55            AnySessionStore::InMemory(s) => s.close(session_id).await,
56            AnySessionStore::StorageBacked(s) => s.close(session_id).await,
57        }
58    }
59
60    /// Get a session by ID.
61    pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
62        match self {
63            AnySessionStore::InMemory(s) => s.get(session_id).await,
64            AnySessionStore::StorageBacked(s) => s.get(session_id).await,
65        }
66    }
67
68    /// List all sessions.
69    pub async fn list_all(&self) -> Vec<TaskSession> {
70        match self {
71            AnySessionStore::InMemory(s) => s.list_all().await,
72            AnySessionStore::StorageBacked(s) => s.list_all().await,
73        }
74    }
75
76    /// Count the number of active sessions for a given agent.
77    ///
78    /// P0: Used to enforce per-agent concurrent session caps.
79    pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
80        match self {
81            AnySessionStore::InMemory(s) => s.count_active_for_agent(agent_id).await,
82            AnySessionStore::StorageBacked(s) => s.count_active_for_agent(agent_id).await,
83        }
84    }
85
86    /// Close all active sessions belonging to a specific agent.
87    ///
88    /// Called during agent deactivation.
89    pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
90        match self {
91            AnySessionStore::InMemory(s) => s.close_sessions_for_agent(agent_id).await,
92            AnySessionStore::StorageBacked(s) => s.close_sessions_for_agent(agent_id).await,
93        }
94    }
95
96    /// Remove expired sessions. Returns the number removed.
97    pub async fn cleanup_expired(&self) -> usize {
98        match self {
99            AnySessionStore::InMemory(s) => s.cleanup_expired().await,
100            AnySessionStore::StorageBacked(s) => s.cleanup_expired().await,
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::model::DataSensitivity;
109
110    #[tokio::test]
111    async fn any_store_in_memory_dispatch() {
112        let store = AnySessionStore::InMemory(SessionStore::new());
113
114        let req = CreateSessionRequest {
115            agent_id: uuid::Uuid::new_v4(),
116            delegation_chain_snapshot: vec![],
117            declared_intent: "test intent".into(),
118            authorized_tools: vec!["read_file".into()],
119            time_limit: chrono::Duration::hours(1),
120            call_budget: 10,
121            rate_limit_per_minute: None,
122            rate_limit_window_secs: 60,
123            data_sensitivity_ceiling: DataSensitivity::Internal,
124        };
125
126        // Create.
127        let session = store.create(req).await;
128        assert_eq!(session.calls_made, 0);
129        assert!(session.is_active());
130
131        // Use.
132        let updated = store
133            .use_session(session.session_id, "read_file")
134            .await
135            .unwrap();
136        assert_eq!(updated.calls_made, 1);
137
138        // Get.
139        let fetched = store.get(session.session_id).await.unwrap();
140        assert_eq!(fetched.calls_made, 1);
141        assert_eq!(fetched.declared_intent, "test intent");
142
143        // List.
144        let all = store.list_all().await;
145        assert_eq!(all.len(), 1);
146
147        // Count active for agent.
148        let count = store.count_active_for_agent(session.agent_id).await;
149        assert_eq!(count, 1);
150
151        // Close.
152        let closed = store.close(session.session_id).await.unwrap();
153        assert_eq!(closed.status, crate::model::SessionStatus::Closed);
154
155        // Use after close should fail.
156        let err = store.use_session(session.session_id, "read_file").await;
157        assert!(matches!(err, Err(SessionError::AlreadyClosed(_))));
158    }
159}