Skip to main content

arbiter_session/
any_store.rs

1//! Unified session store enum dispatching to either in-memory or storage-backed.
2
3use crate::error::SessionError;
4use crate::model::{SessionId, TaskSession};
5use crate::storage_store::StorageBackedSessionStore;
6use crate::store::{CreateSessionRequest, SessionStore};
7
8/// A session store that dispatches to either in-memory or storage-backed.
9#[derive(Clone)]
10pub enum AnySessionStore {
11    InMemory(SessionStore),
12    StorageBacked(StorageBackedSessionStore),
13}
14
15impl AnySessionStore {
16    /// Create a new task session and return it.
17    pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
18        match self {
19            AnySessionStore::InMemory(s) => s.create(req).await,
20            AnySessionStore::StorageBacked(s) => s.create(req).await,
21        }
22    }
23
24    /// Atomically check per-agent session cap and create if under the limit.
25    pub async fn create_if_under_cap(
26        &self,
27        req: CreateSessionRequest,
28        max_sessions: u64,
29    ) -> Result<TaskSession, crate::error::SessionError> {
30        match self {
31            AnySessionStore::InMemory(s) => s.create_if_under_cap(req, max_sessions).await,
32            AnySessionStore::StorageBacked(s) => s.create_if_under_cap(req, max_sessions).await,
33        }
34    }
35
36    /// Record a tool call against the session.
37    pub async fn use_session(
38        &self,
39        session_id: SessionId,
40        tool_name: &str,
41        requesting_agent_id: Option<uuid::Uuid>,
42    ) -> Result<TaskSession, SessionError> {
43        match self {
44            AnySessionStore::InMemory(s) => {
45                s.use_session(session_id, tool_name, requesting_agent_id)
46                    .await
47            }
48            AnySessionStore::StorageBacked(s) => {
49                s.use_session(session_id, tool_name, requesting_agent_id)
50                    .await
51            }
52        }
53    }
54
55    /// Atomically validate and record a batch of tool calls against the session.
56    ///
57    /// Validates ALL tools and budget atomically under a single
58    /// lock acquisition. No budget is consumed unless every tool in the batch
59    /// passes validation.
60    pub async fn use_session_batch(
61        &self,
62        session_id: SessionId,
63        tool_names: &[&str],
64        requesting_agent_id: Option<uuid::Uuid>,
65    ) -> Result<TaskSession, SessionError> {
66        match self {
67            AnySessionStore::InMemory(s) => {
68                s.use_session_batch(session_id, tool_names, requesting_agent_id)
69                    .await
70            }
71            AnySessionStore::StorageBacked(s) => {
72                s.use_session_batch(session_id, tool_names, requesting_agent_id)
73                    .await
74            }
75        }
76    }
77
78    /// Close a session.
79    pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
80        match self {
81            AnySessionStore::InMemory(s) => s.close(session_id).await,
82            AnySessionStore::StorageBacked(s) => s.close(session_id).await,
83        }
84    }
85
86    /// Get a session by ID.
87    pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
88        match self {
89            AnySessionStore::InMemory(s) => s.get(session_id).await,
90            AnySessionStore::StorageBacked(s) => s.get(session_id).await,
91        }
92    }
93
94    /// List all sessions.
95    pub async fn list_all(&self) -> Vec<TaskSession> {
96        match self {
97            AnySessionStore::InMemory(s) => s.list_all().await,
98            AnySessionStore::StorageBacked(s) => s.list_all().await,
99        }
100    }
101
102    /// Count the number of active sessions for a given agent.
103    ///
104    /// P0: Used to enforce per-agent concurrent session caps.
105    pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
106        match self {
107            AnySessionStore::InMemory(s) => s.count_active_for_agent(agent_id).await,
108            AnySessionStore::StorageBacked(s) => s.count_active_for_agent(agent_id).await,
109        }
110    }
111
112    /// Close all active sessions belonging to a specific agent.
113    ///
114    /// Called during agent deactivation.
115    pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
116        match self {
117            AnySessionStore::InMemory(s) => s.close_sessions_for_agent(agent_id).await,
118            AnySessionStore::StorageBacked(s) => s.close_sessions_for_agent(agent_id).await,
119        }
120    }
121
122    /// Remove expired sessions. Returns the number removed.
123    pub async fn cleanup_expired(&self) -> usize {
124        match self {
125            AnySessionStore::InMemory(s) => s.cleanup_expired().await,
126            AnySessionStore::StorageBacked(s) => s.cleanup_expired().await,
127        }
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::model::DataSensitivity;
135
136    #[tokio::test]
137    async fn any_store_in_memory_dispatch() {
138        let store = AnySessionStore::InMemory(SessionStore::new());
139
140        let req = CreateSessionRequest {
141            agent_id: uuid::Uuid::new_v4(),
142            delegation_chain_snapshot: vec![],
143            declared_intent: "test intent".into(),
144            authorized_tools: vec!["read_file".into()],
145            authorized_credentials: vec![],
146            time_limit: chrono::Duration::hours(1),
147            call_budget: 10,
148            rate_limit_per_minute: None,
149            rate_limit_window_secs: 60,
150            data_sensitivity_ceiling: DataSensitivity::Internal,
151        };
152
153        // Create.
154        let session = store.create(req).await;
155        assert_eq!(session.calls_made, 0);
156        assert!(session.is_active());
157
158        // Use.
159        let updated = store
160            .use_session(session.session_id, "read_file", None)
161            .await
162            .unwrap();
163        assert_eq!(updated.calls_made, 1);
164
165        // Get.
166        let fetched = store.get(session.session_id).await.unwrap();
167        assert_eq!(fetched.calls_made, 1);
168        assert_eq!(fetched.declared_intent, "test intent");
169
170        // List.
171        let all = store.list_all().await;
172        assert_eq!(all.len(), 1);
173
174        // Count active for agent.
175        let count = store.count_active_for_agent(session.agent_id).await;
176        assert_eq!(count, 1);
177
178        // Close.
179        let closed = store.close(session.session_id).await.unwrap();
180        assert_eq!(closed.status, crate::model::SessionStatus::Closed);
181
182        // Use after close should fail.
183        let err = store
184            .use_session(session.session_id, "read_file", None)
185            .await;
186        assert!(matches!(err, Err(SessionError::AlreadyClosed(_))));
187    }
188}