1use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use uuid::Uuid;
14
15use arbiter_storage::{
16 SessionStore as StorageSessionStore, StorageError, StoredDataSensitivity, StoredSession,
17 StoredSessionStatus,
18};
19
20use crate::error::SessionError;
21use crate::model::{DataSensitivity, SessionId, SessionStatus, TaskSession};
22use crate::store::CreateSessionRequest;
23
24#[derive(Clone)]
30pub struct StorageBackedSessionStore {
31 cache: Arc<RwLock<HashMap<SessionId, TaskSession>>>,
32 storage: Arc<dyn StorageSessionStore>,
33}
34
35impl StorageBackedSessionStore {
36 pub async fn new(storage: Arc<dyn StorageSessionStore>) -> Result<Self, StorageError> {
40 let store = Self {
41 cache: Arc::new(RwLock::new(HashMap::new())),
42 storage,
43 };
44
45 store.reload_from_storage().await?;
47
48 Ok(store)
49 }
50
51 async fn reload_from_storage(&self) -> Result<(), StorageError> {
53 let stored_sessions = self.storage.list_sessions().await?;
54 let mut cache = self.cache.write().await;
55 cache.clear();
56 for stored in stored_sessions {
57 if let Ok(session) = stored_to_domain(stored) {
58 cache.insert(session.session_id, session);
59 }
60 }
61 tracing::info!(sessions = cache.len(), "session cache warmed from storage");
62 Ok(())
63 }
64
65 pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
67 let session = TaskSession {
68 session_id: Uuid::new_v4(),
69 agent_id: req.agent_id,
70 delegation_chain_snapshot: req.delegation_chain_snapshot,
71 declared_intent: req.declared_intent,
72 authorized_tools: req.authorized_tools,
73 time_limit: req.time_limit,
74 call_budget: req.call_budget,
75 calls_made: 0,
76 rate_limit_per_minute: req.rate_limit_per_minute,
77 rate_window_start: chrono::Utc::now(),
78 rate_window_calls: 0,
79 rate_limit_window_secs: req.rate_limit_window_secs,
80 data_sensitivity_ceiling: req.data_sensitivity_ceiling,
81 created_at: chrono::Utc::now(),
82 status: SessionStatus::Active,
83 };
84
85 tracing::info!(
86 session_id = %session.session_id,
87 agent_id = %session.agent_id,
88 intent = %session.declared_intent,
89 budget = session.call_budget,
90 "created task session (storage-backed)"
91 );
92
93 let stored = domain_to_stored(&session);
95 if let Err(e) = self.storage.insert_session(&stored).await {
96 tracing::error!(error = %e, "failed to persist session to storage");
97 }
98
99 let mut cache = self.cache.write().await;
100 cache.insert(session.session_id, session.clone());
101 session
102 }
103
104 pub async fn use_session(
106 &self,
107 session_id: SessionId,
108 tool_name: &str,
109 ) -> Result<TaskSession, SessionError> {
110 let mut cache = self.cache.write().await;
111 let session = cache
112 .get_mut(&session_id)
113 .ok_or(SessionError::NotFound(session_id))?;
114
115 if session.status == SessionStatus::Closed {
116 return Err(SessionError::AlreadyClosed(session_id));
117 }
118
119 if session.is_expired() {
120 session.status = SessionStatus::Expired;
121 let stored = domain_to_stored(session);
123 if let Err(e) = self.storage.update_session(&stored).await {
124 tracing::error!(error = %e, "failed to persist expired session status");
125 }
126 return Err(SessionError::Expired(session_id));
127 }
128
129 if session.is_budget_exceeded() {
130 return Err(SessionError::BudgetExceeded {
131 session_id,
132 limit: session.call_budget,
133 used: session.calls_made,
134 });
135 }
136
137 if !session.is_tool_authorized(tool_name) {
138 return Err(SessionError::ToolNotAuthorized {
139 session_id,
140 tool: tool_name.into(),
141 });
142 }
143
144 if session.check_rate_limit() {
145 return Err(SessionError::RateLimited {
146 session_id,
147 limit_per_minute: session.rate_limit_per_minute.unwrap_or(0),
148 });
149 }
150
151 session.calls_made += 1;
152
153 tracing::debug!(
154 session_id = %session_id,
155 tool = tool_name,
156 calls = session.calls_made,
157 budget = session.call_budget,
158 "session tool call recorded (storage-backed)"
159 );
160
161 let result = session.clone();
162
163 let stored = domain_to_stored(&result);
165 if let Err(e) = self.storage.update_session(&stored).await {
166 tracing::error!(error = %e, "failed to persist session update");
167 }
168
169 Ok(result)
170 }
171
172 pub async fn use_session_batch(
179 &self,
180 session_id: SessionId,
181 tool_names: &[&str],
182 ) -> Result<TaskSession, SessionError> {
183 let mut cache = self.cache.write().await;
184 let session = cache
185 .get_mut(&session_id)
186 .ok_or(SessionError::NotFound(session_id))?;
187
188 if session.status == SessionStatus::Closed {
189 return Err(SessionError::AlreadyClosed(session_id));
190 }
191
192 if session.is_expired() {
193 session.status = SessionStatus::Expired;
194 let stored = domain_to_stored(session);
196 if let Err(e) = self.storage.update_session(&stored).await {
197 tracing::error!(error = %e, "failed to persist expired session status");
198 }
199 return Err(SessionError::Expired(session_id));
200 }
201
202 let batch_size = tool_names.len() as u64;
203
204 if session.calls_made + batch_size > session.call_budget {
206 return Err(SessionError::BudgetExceeded {
207 session_id,
208 limit: session.call_budget,
209 used: session.calls_made,
210 });
211 }
212
213 for tool_name in tool_names {
215 if !session.is_tool_authorized(tool_name) {
216 return Err(SessionError::ToolNotAuthorized {
217 session_id,
218 tool: (*tool_name).into(),
219 });
220 }
221 }
222
223 if let Some(limit) = session.rate_limit_per_minute {
225 let now = chrono::Utc::now();
226 let elapsed = now - session.rate_window_start;
227 if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
228 } else if session.rate_window_calls + batch_size > limit {
230 return Err(SessionError::RateLimited {
231 session_id,
232 limit_per_minute: limit,
233 });
234 }
235 }
236
237 if let Some(_limit) = session.rate_limit_per_minute {
239 let now = chrono::Utc::now();
240 let elapsed = now - session.rate_window_start;
241 if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
242 session.rate_window_start = now;
243 session.rate_window_calls = batch_size;
244 } else {
245 session.rate_window_calls += batch_size;
246 }
247 }
248
249 session.calls_made += batch_size;
250
251 tracing::debug!(
252 session_id = %session_id,
253 batch_size = batch_size,
254 calls = session.calls_made,
255 budget = session.call_budget,
256 "session batch tool calls recorded (storage-backed)"
257 );
258
259 let result = session.clone();
260
261 let stored = domain_to_stored(&result);
263 if let Err(e) = self.storage.update_session(&stored).await {
264 tracing::error!(error = %e, "failed to persist session batch update");
265 }
266
267 Ok(result)
268 }
269
270 pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
272 let mut cache = self.cache.write().await;
273 let session = cache
274 .get_mut(&session_id)
275 .ok_or(SessionError::NotFound(session_id))?;
276
277 if session.status == SessionStatus::Closed {
278 return Err(SessionError::AlreadyClosed(session_id));
279 }
280
281 session.status = SessionStatus::Closed;
282 tracing::info!(session_id = %session_id, "session closed (storage-backed)");
283
284 let result = session.clone();
285
286 let stored = domain_to_stored(&result);
288 if let Err(e) = self.storage.update_session(&stored).await {
289 tracing::error!(error = %e, "failed to persist session close");
290 }
291
292 Ok(result)
293 }
294
295 pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
297 let cache = self.cache.read().await;
298 cache
299 .get(&session_id)
300 .cloned()
301 .ok_or(SessionError::NotFound(session_id))
302 }
303
304 pub async fn list_all(&self) -> Vec<TaskSession> {
306 let cache = self.cache.read().await;
307 cache.values().cloned().collect()
308 }
309
310 pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
314 let cache = self.cache.read().await;
315 cache
316 .values()
317 .filter(|s| s.agent_id == agent_id && s.status == SessionStatus::Active)
318 .count() as u64
319 }
320
321 pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
325 let mut cache = self.cache.write().await;
326 let mut closed = 0usize;
327 for session in cache.values_mut() {
328 if session.agent_id == agent_id && session.status == SessionStatus::Active {
329 session.status = SessionStatus::Closed;
330 let stored = domain_to_stored(session);
331 if let Err(e) = self.storage.update_session(&stored).await {
332 tracing::error!(error = %e, "failed to persist session closure during agent deactivation");
333 }
334 closed += 1;
335 }
336 }
337 closed
338 }
339
340 pub async fn cleanup_expired(&self) -> usize {
342 let mut cache = self.cache.write().await;
343 let before = cache.len();
344 cache.retain(|_, s| {
345 if s.is_expired() {
346 tracing::debug!(session_id = %s.session_id, "cleaning up expired session (storage-backed)");
347 false
348 } else {
349 true
350 }
351 });
352 let removed = before - cache.len();
353
354 if let Err(e) = self.storage.delete_expired_sessions().await {
356 tracing::error!(error = %e, "failed to clean up expired sessions in storage");
357 }
358
359 if removed > 0 {
360 tracing::info!(removed, "cleaned up expired sessions (storage-backed)");
361 }
362 removed
363 }
364}
365
366fn domain_to_stored(session: &TaskSession) -> StoredSession {
369 StoredSession {
370 session_id: session.session_id,
371 agent_id: session.agent_id,
372 delegation_chain_snapshot: session.delegation_chain_snapshot.clone(),
373 declared_intent: session.declared_intent.clone(),
374 authorized_tools: session.authorized_tools.clone(),
375 time_limit_secs: session.time_limit.num_seconds(),
376 call_budget: session.call_budget,
377 calls_made: session.calls_made,
378 rate_limit_per_minute: session.rate_limit_per_minute,
379 rate_window_start: session.rate_window_start,
380 rate_window_calls: session.rate_window_calls,
381 rate_limit_window_secs: session.rate_limit_window_secs,
382 data_sensitivity_ceiling: sensitivity_to_stored(session.data_sensitivity_ceiling),
383 created_at: session.created_at,
384 status: status_to_stored(session.status),
385 }
386}
387
388fn stored_to_domain(stored: StoredSession) -> Result<TaskSession, String> {
389 Ok(TaskSession {
390 session_id: stored.session_id,
391 agent_id: stored.agent_id,
392 delegation_chain_snapshot: stored.delegation_chain_snapshot,
393 declared_intent: stored.declared_intent,
394 authorized_tools: stored.authorized_tools,
395 time_limit: chrono::Duration::seconds(stored.time_limit_secs),
396 call_budget: stored.call_budget,
397 calls_made: stored.calls_made,
398 rate_limit_per_minute: stored.rate_limit_per_minute,
399 rate_window_start: stored.rate_window_start,
400 rate_window_calls: stored.rate_window_calls,
401 rate_limit_window_secs: stored.rate_limit_window_secs,
402 data_sensitivity_ceiling: stored_to_sensitivity(stored.data_sensitivity_ceiling),
403 created_at: stored.created_at,
404 status: stored_to_status(stored.status),
405 })
406}
407
408fn status_to_stored(status: SessionStatus) -> StoredSessionStatus {
409 match status {
410 SessionStatus::Active => StoredSessionStatus::Active,
411 SessionStatus::Closed => StoredSessionStatus::Closed,
412 SessionStatus::Expired => StoredSessionStatus::Expired,
413 }
414}
415
416fn stored_to_status(status: StoredSessionStatus) -> SessionStatus {
417 match status {
418 StoredSessionStatus::Active => SessionStatus::Active,
419 StoredSessionStatus::Closed => SessionStatus::Closed,
420 StoredSessionStatus::Expired => SessionStatus::Expired,
421 }
422}
423
424fn sensitivity_to_stored(sensitivity: DataSensitivity) -> StoredDataSensitivity {
425 match sensitivity {
426 DataSensitivity::Public => StoredDataSensitivity::Public,
427 DataSensitivity::Internal => StoredDataSensitivity::Internal,
428 DataSensitivity::Confidential => StoredDataSensitivity::Confidential,
429 DataSensitivity::Restricted => StoredDataSensitivity::Restricted,
430 }
431}
432
433fn stored_to_sensitivity(sensitivity: StoredDataSensitivity) -> DataSensitivity {
434 match sensitivity {
435 StoredDataSensitivity::Public => DataSensitivity::Public,
436 StoredDataSensitivity::Internal => DataSensitivity::Internal,
437 StoredDataSensitivity::Confidential => DataSensitivity::Confidential,
438 StoredDataSensitivity::Restricted => DataSensitivity::Restricted,
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use async_trait::async_trait;
446 use std::collections::HashMap;
447 use std::sync::Arc;
448 use tokio::sync::RwLock;
449
450 #[derive(Clone)]
452 struct MockStorage {
453 sessions: Arc<RwLock<HashMap<Uuid, StoredSession>>>,
454 }
455
456 impl MockStorage {
457 fn new() -> Self {
458 Self {
459 sessions: Arc::new(RwLock::new(HashMap::new())),
460 }
461 }
462 }
463
464 #[async_trait]
465 impl StorageSessionStore for MockStorage {
466 async fn insert_session(&self, session: &StoredSession) -> Result<(), StorageError> {
467 let mut map = self.sessions.write().await;
468 map.insert(session.session_id, session.clone());
469 Ok(())
470 }
471
472 async fn get_session(&self, session_id: Uuid) -> Result<StoredSession, StorageError> {
473 let map = self.sessions.read().await;
474 map.get(&session_id)
475 .cloned()
476 .ok_or(StorageError::SessionNotFound(session_id))
477 }
478
479 async fn update_session(&self, session: &StoredSession) -> Result<(), StorageError> {
480 let mut map = self.sessions.write().await;
481 map.insert(session.session_id, session.clone());
482 Ok(())
483 }
484
485 async fn delete_expired_sessions(&self) -> Result<usize, StorageError> {
486 let mut map = self.sessions.write().await;
487 let before = map.len();
488 let now = chrono::Utc::now();
489 map.retain(|_, s| {
490 let created = s.created_at;
491 let limit = chrono::Duration::seconds(s.time_limit_secs);
492 let elapsed = now - created;
493 elapsed <= limit && s.status != StoredSessionStatus::Expired
494 });
495 Ok(before - map.len())
496 }
497
498 async fn list_sessions(&self) -> Result<Vec<StoredSession>, StorageError> {
499 let map = self.sessions.read().await;
500 Ok(map.values().cloned().collect())
501 }
502 }
503
504 fn test_create_request() -> CreateSessionRequest {
505 CreateSessionRequest {
506 agent_id: Uuid::new_v4(),
507 delegation_chain_snapshot: vec![],
508 declared_intent: "read and analyze files".into(),
509 authorized_tools: vec!["read_file".into(), "list_dir".into()],
510 time_limit: chrono::Duration::hours(1),
511 call_budget: 5,
512 rate_limit_per_minute: None,
513 rate_limit_window_secs: 60,
514 data_sensitivity_ceiling: DataSensitivity::Internal,
515 }
516 }
517
518 async fn make_store() -> (StorageBackedSessionStore, MockStorage) {
519 let mock = MockStorage::new();
520 let store = StorageBackedSessionStore::new(Arc::new(mock.clone()))
521 .await
522 .expect("failed to create storage-backed store");
523 (store, mock)
524 }
525
526 #[tokio::test]
527 async fn create_and_use_session() {
528 let (store, _mock) = make_store().await;
529 let session = store.create(test_create_request()).await;
530
531 assert_eq!(session.calls_made, 0);
532 assert!(session.is_active());
533
534 let updated = store
535 .use_session(session.session_id, "read_file")
536 .await
537 .unwrap();
538 assert_eq!(updated.calls_made, 1);
539
540 let fetched = store.get(session.session_id).await.unwrap();
542 assert_eq!(fetched.calls_made, 1);
543 }
544
545 #[tokio::test]
546 async fn budget_enforcement() {
547 let (store, _mock) = make_store().await;
548 let mut req = test_create_request();
549 req.call_budget = 2;
550 let session = store.create(req).await;
551
552 store
553 .use_session(session.session_id, "read_file")
554 .await
555 .unwrap();
556 store
557 .use_session(session.session_id, "read_file")
558 .await
559 .unwrap();
560
561 let result = store.use_session(session.session_id, "read_file").await;
562 assert!(matches!(result, Err(SessionError::BudgetExceeded { .. })));
563 }
564
565 #[tokio::test]
566 async fn tool_whitelist_enforcement() {
567 let (store, _mock) = make_store().await;
568 let session = store.create(test_create_request()).await;
569
570 store
571 .use_session(session.session_id, "read_file")
572 .await
573 .unwrap();
574
575 let result = store.use_session(session.session_id, "delete_file").await;
576 assert!(matches!(
577 result,
578 Err(SessionError::ToolNotAuthorized { .. })
579 ));
580 }
581
582 #[tokio::test]
583 async fn session_expiry() {
584 let (store, _mock) = make_store().await;
585 let mut req = test_create_request();
586 req.time_limit = chrono::Duration::zero();
587 let session = store.create(req).await;
588
589 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
590
591 let result = store.use_session(session.session_id, "read_file").await;
592 assert!(matches!(result, Err(SessionError::Expired(_))));
593 }
594
595 #[tokio::test]
596 async fn close_and_reuse() {
597 let (store, _mock) = make_store().await;
598 let session = store.create(test_create_request()).await;
599
600 store.close(session.session_id).await.unwrap();
601
602 let result = store.use_session(session.session_id, "read_file").await;
603 assert!(matches!(result, Err(SessionError::AlreadyClosed(_))));
604 }
605
606 #[tokio::test]
607 async fn session_not_found() {
608 let (store, _mock) = make_store().await;
609 let fake_id = Uuid::new_v4();
610 let result = store.use_session(fake_id, "anything").await;
611 assert!(matches!(result, Err(SessionError::NotFound(_))));
612 }
613
614 #[tokio::test]
616 async fn batch_validation_atomicity() {
617 let (store, _mock) = make_store().await;
618 let mut req = test_create_request();
619 req.call_budget = 10;
620 req.authorized_tools = vec!["read_file".into(), "list_dir".into()];
621 let session = store.create(req).await;
622
623 let result = store
625 .use_session_batch(session.session_id, &["read_file", "delete_file"])
626 .await;
627 assert!(
628 matches!(result, Err(SessionError::ToolNotAuthorized { .. })),
629 "expected ToolNotAuthorized, got {result:?}"
630 );
631
632 let s = store.get(session.session_id).await.unwrap();
634 assert_eq!(
635 s.calls_made, 0,
636 "no budget should be consumed on batch failure"
637 );
638 }
639
640 #[tokio::test]
641 async fn batch_budget_enforcement() {
642 let (store, _mock) = make_store().await;
643 let mut req = test_create_request();
644 req.call_budget = 3;
645 req.authorized_tools = vec!["read_file".into()];
646 let session = store.create(req).await;
647
648 let result = store
650 .use_session_batch(
651 session.session_id,
652 &["read_file", "read_file", "read_file", "read_file"],
653 )
654 .await;
655 assert!(
656 matches!(result, Err(SessionError::BudgetExceeded { .. })),
657 "expected BudgetExceeded, got {result:?}"
658 );
659
660 let s = store.get(session.session_id).await.unwrap();
662 assert_eq!(
663 s.calls_made, 0,
664 "no budget should be consumed on batch failure"
665 );
666 }
667
668 #[tokio::test]
669 async fn batch_rate_limit_enforcement() {
670 let (store, _mock) = make_store().await;
671 let mut req = test_create_request();
672 req.call_budget = 100;
673 req.rate_limit_per_minute = Some(3);
674 req.authorized_tools = vec!["read_file".into()];
675 let session = store.create(req).await;
676
677 let result = store
679 .use_session_batch(
680 session.session_id,
681 &["read_file", "read_file", "read_file", "read_file"],
682 )
683 .await;
684 assert!(
685 matches!(result, Err(SessionError::RateLimited { .. })),
686 "expected RateLimited, got {result:?}"
687 );
688 }
689
690 #[tokio::test]
691 async fn cleanup_expired_sessions() {
692 let (store, _mock) = make_store().await;
693
694 let mut req = test_create_request();
696 req.time_limit = chrono::Duration::zero();
697 store.create(req).await;
698
699 store.create(test_create_request()).await;
701
702 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
703
704 let removed = store.cleanup_expired().await;
705 assert_eq!(removed, 1);
706 }
707
708 #[tokio::test]
709 async fn count_active_for_agent() {
710 let (store, _mock) = make_store().await;
711 let agent_id = Uuid::new_v4();
712
713 for _ in 0..3 {
715 let mut req = test_create_request();
716 req.agent_id = agent_id;
717 store.create(req).await;
718 }
719
720 store.create(test_create_request()).await;
722
723 let count = store.count_active_for_agent(agent_id).await;
724 assert_eq!(count, 3);
725 }
726
727 #[tokio::test]
729 async fn storage_write_through() {
730 let (store, mock) = make_store().await;
731 let session = store.create(test_create_request()).await;
732
733 let stored = mock
735 .get_session(session.session_id)
736 .await
737 .expect("session should exist in storage after create");
738 assert_eq!(stored.calls_made, 0);
739
740 store
742 .use_session(session.session_id, "read_file")
743 .await
744 .unwrap();
745 let stored = mock.get_session(session.session_id).await.unwrap();
746 assert_eq!(stored.calls_made, 1, "storage must reflect the tool call");
747
748 store.close(session.session_id).await.unwrap();
750 let stored = mock.get_session(session.session_id).await.unwrap();
751 assert_eq!(
752 stored.status,
753 StoredSessionStatus::Closed,
754 "storage must reflect the closed status"
755 );
756 }
757}