arbiter_session/
any_store.rs1use crate::error::SessionError;
4use crate::model::{SessionId, TaskSession};
5use crate::storage_store::StorageBackedSessionStore;
6use crate::store::{CreateSessionRequest, SessionStore};
7
8#[derive(Clone)]
10pub enum AnySessionStore {
11 InMemory(SessionStore),
12 StorageBacked(StorageBackedSessionStore),
13}
14
15impl AnySessionStore {
16 pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
18 match self {
19 AnySessionStore::InMemory(s) => s.create(req).await,
20 AnySessionStore::StorageBacked(s) => s.create(req).await,
21 }
22 }
23
24 pub async fn use_session(
26 &self,
27 session_id: SessionId,
28 tool_name: &str,
29 ) -> Result<TaskSession, SessionError> {
30 match self {
31 AnySessionStore::InMemory(s) => s.use_session(session_id, tool_name).await,
32 AnySessionStore::StorageBacked(s) => s.use_session(session_id, tool_name).await,
33 }
34 }
35
36 pub async fn use_session_batch(
42 &self,
43 session_id: SessionId,
44 tool_names: &[&str],
45 ) -> Result<TaskSession, SessionError> {
46 match self {
47 AnySessionStore::InMemory(s) => s.use_session_batch(session_id, tool_names).await,
48 AnySessionStore::StorageBacked(s) => s.use_session_batch(session_id, tool_names).await,
49 }
50 }
51
52 pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
54 match self {
55 AnySessionStore::InMemory(s) => s.close(session_id).await,
56 AnySessionStore::StorageBacked(s) => s.close(session_id).await,
57 }
58 }
59
60 pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
62 match self {
63 AnySessionStore::InMemory(s) => s.get(session_id).await,
64 AnySessionStore::StorageBacked(s) => s.get(session_id).await,
65 }
66 }
67
68 pub async fn list_all(&self) -> Vec<TaskSession> {
70 match self {
71 AnySessionStore::InMemory(s) => s.list_all().await,
72 AnySessionStore::StorageBacked(s) => s.list_all().await,
73 }
74 }
75
76 pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
80 match self {
81 AnySessionStore::InMemory(s) => s.count_active_for_agent(agent_id).await,
82 AnySessionStore::StorageBacked(s) => s.count_active_for_agent(agent_id).await,
83 }
84 }
85
86 pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
90 match self {
91 AnySessionStore::InMemory(s) => s.close_sessions_for_agent(agent_id).await,
92 AnySessionStore::StorageBacked(s) => s.close_sessions_for_agent(agent_id).await,
93 }
94 }
95
96 pub async fn cleanup_expired(&self) -> usize {
98 match self {
99 AnySessionStore::InMemory(s) => s.cleanup_expired().await,
100 AnySessionStore::StorageBacked(s) => s.cleanup_expired().await,
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use crate::model::DataSensitivity;
109
110 #[tokio::test]
111 async fn any_store_in_memory_dispatch() {
112 let store = AnySessionStore::InMemory(SessionStore::new());
113
114 let req = CreateSessionRequest {
115 agent_id: uuid::Uuid::new_v4(),
116 delegation_chain_snapshot: vec![],
117 declared_intent: "test intent".into(),
118 authorized_tools: vec!["read_file".into()],
119 time_limit: chrono::Duration::hours(1),
120 call_budget: 10,
121 rate_limit_per_minute: None,
122 rate_limit_window_secs: 60,
123 data_sensitivity_ceiling: DataSensitivity::Internal,
124 };
125
126 let session = store.create(req).await;
128 assert_eq!(session.calls_made, 0);
129 assert!(session.is_active());
130
131 let updated = store
133 .use_session(session.session_id, "read_file")
134 .await
135 .unwrap();
136 assert_eq!(updated.calls_made, 1);
137
138 let fetched = store.get(session.session_id).await.unwrap();
140 assert_eq!(fetched.calls_made, 1);
141 assert_eq!(fetched.declared_intent, "test intent");
142
143 let all = store.list_all().await;
145 assert_eq!(all.len(), 1);
146
147 let count = store.count_active_for_agent(session.agent_id).await;
149 assert_eq!(count, 1);
150
151 let closed = store.close(session.session_id).await.unwrap();
153 assert_eq!(closed.status, crate::model::SessionStatus::Closed);
154
155 let err = store.use_session(session.session_id, "read_file").await;
157 assert!(matches!(err, Err(SessionError::AlreadyClosed(_))));
158 }
159}