arbiter_session/
any_store.rs1use crate::error::SessionError;
4use crate::model::{SessionId, TaskSession};
5use crate::storage_store::StorageBackedSessionStore;
6use crate::store::{CreateSessionRequest, SessionStore};
7
8#[derive(Clone)]
10pub enum AnySessionStore {
11 InMemory(SessionStore),
12 StorageBacked(StorageBackedSessionStore),
13}
14
15impl AnySessionStore {
16 pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
18 match self {
19 AnySessionStore::InMemory(s) => s.create(req).await,
20 AnySessionStore::StorageBacked(s) => s.create(req).await,
21 }
22 }
23
24 pub async fn create_if_under_cap(
26 &self,
27 req: CreateSessionRequest,
28 max_sessions: u64,
29 ) -> Result<TaskSession, crate::error::SessionError> {
30 match self {
31 AnySessionStore::InMemory(s) => s.create_if_under_cap(req, max_sessions).await,
32 AnySessionStore::StorageBacked(s) => s.create_if_under_cap(req, max_sessions).await,
33 }
34 }
35
36 pub async fn use_session(
38 &self,
39 session_id: SessionId,
40 tool_name: &str,
41 requesting_agent_id: Option<uuid::Uuid>,
42 ) -> Result<TaskSession, SessionError> {
43 match self {
44 AnySessionStore::InMemory(s) => {
45 s.use_session(session_id, tool_name, requesting_agent_id)
46 .await
47 }
48 AnySessionStore::StorageBacked(s) => {
49 s.use_session(session_id, tool_name, requesting_agent_id)
50 .await
51 }
52 }
53 }
54
55 pub async fn use_session_batch(
61 &self,
62 session_id: SessionId,
63 tool_names: &[&str],
64 requesting_agent_id: Option<uuid::Uuid>,
65 ) -> Result<TaskSession, SessionError> {
66 match self {
67 AnySessionStore::InMemory(s) => {
68 s.use_session_batch(session_id, tool_names, requesting_agent_id)
69 .await
70 }
71 AnySessionStore::StorageBacked(s) => {
72 s.use_session_batch(session_id, tool_names, requesting_agent_id)
73 .await
74 }
75 }
76 }
77
78 pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
80 match self {
81 AnySessionStore::InMemory(s) => s.close(session_id).await,
82 AnySessionStore::StorageBacked(s) => s.close(session_id).await,
83 }
84 }
85
86 pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
88 match self {
89 AnySessionStore::InMemory(s) => s.get(session_id).await,
90 AnySessionStore::StorageBacked(s) => s.get(session_id).await,
91 }
92 }
93
94 pub async fn list_all(&self) -> Vec<TaskSession> {
96 match self {
97 AnySessionStore::InMemory(s) => s.list_all().await,
98 AnySessionStore::StorageBacked(s) => s.list_all().await,
99 }
100 }
101
102 pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
106 match self {
107 AnySessionStore::InMemory(s) => s.count_active_for_agent(agent_id).await,
108 AnySessionStore::StorageBacked(s) => s.count_active_for_agent(agent_id).await,
109 }
110 }
111
112 pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
116 match self {
117 AnySessionStore::InMemory(s) => s.close_sessions_for_agent(agent_id).await,
118 AnySessionStore::StorageBacked(s) => s.close_sessions_for_agent(agent_id).await,
119 }
120 }
121
122 pub async fn cleanup_expired(&self) -> usize {
124 match self {
125 AnySessionStore::InMemory(s) => s.cleanup_expired().await,
126 AnySessionStore::StorageBacked(s) => s.cleanup_expired().await,
127 }
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use crate::model::DataSensitivity;
135
136 #[tokio::test]
137 async fn any_store_in_memory_dispatch() {
138 let store = AnySessionStore::InMemory(SessionStore::new());
139
140 let req = CreateSessionRequest {
141 agent_id: uuid::Uuid::new_v4(),
142 delegation_chain_snapshot: vec![],
143 declared_intent: "test intent".into(),
144 authorized_tools: vec!["read_file".into()],
145 authorized_credentials: vec![],
146 time_limit: chrono::Duration::hours(1),
147 call_budget: 10,
148 rate_limit_per_minute: None,
149 rate_limit_window_secs: 60,
150 data_sensitivity_ceiling: DataSensitivity::Internal,
151 };
152
153 let session = store.create(req).await;
155 assert_eq!(session.calls_made, 0);
156 assert!(session.is_active());
157
158 let updated = store
160 .use_session(session.session_id, "read_file", None)
161 .await
162 .unwrap();
163 assert_eq!(updated.calls_made, 1);
164
165 let fetched = store.get(session.session_id).await.unwrap();
167 assert_eq!(fetched.calls_made, 1);
168 assert_eq!(fetched.declared_intent, "test intent");
169
170 let all = store.list_all().await;
172 assert_eq!(all.len(), 1);
173
174 let count = store.count_active_for_agent(session.agent_id).await;
176 assert_eq!(count, 1);
177
178 let closed = store.close(session.session_id).await.unwrap();
180 assert_eq!(closed.status, crate::model::SessionStatus::Closed);
181
182 let err = store
184 .use_session(session.session_id, "read_file", None)
185 .await;
186 assert!(matches!(err, Err(SessionError::AlreadyClosed(_))));
187 }
188}