1use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use uuid::Uuid;
14
15use arbiter_storage::{
16 SessionStore as StorageSessionStore, StorageError, StoredDataSensitivity, StoredSession,
17 StoredSessionStatus,
18};
19
20use crate::error::SessionError;
21use crate::model::{DataSensitivity, SessionId, SessionStatus, TaskSession};
22use crate::store::CreateSessionRequest;
23
24#[derive(Clone)]
30pub struct StorageBackedSessionStore {
31 cache: Arc<RwLock<HashMap<SessionId, TaskSession>>>,
32 storage: Arc<dyn StorageSessionStore>,
33}
34
35impl StorageBackedSessionStore {
36 pub async fn new(storage: Arc<dyn StorageSessionStore>) -> Result<Self, StorageError> {
40 let store = Self {
41 cache: Arc::new(RwLock::new(HashMap::new())),
42 storage,
43 };
44
45 store.reload_from_storage().await?;
47
48 Ok(store)
49 }
50
51 async fn reload_from_storage(&self) -> Result<(), StorageError> {
53 let stored_sessions = self.storage.list_sessions().await?;
54 let mut cache = self.cache.write().await;
55 cache.clear();
56 for stored in stored_sessions {
57 if let Ok(session) = stored_to_domain(stored) {
58 cache.insert(session.session_id, session);
59 }
60 }
61 tracing::info!(sessions = cache.len(), "session cache warmed from storage");
62 Ok(())
63 }
64
65 pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
67 let session = TaskSession {
68 session_id: Uuid::new_v4(),
69 agent_id: req.agent_id,
70 delegation_chain_snapshot: req.delegation_chain_snapshot,
71 declared_intent: req.declared_intent,
72 authorized_tools: req.authorized_tools,
73 authorized_credentials: req.authorized_credentials,
74 time_limit: req.time_limit,
75 call_budget: req.call_budget,
76 calls_made: 0,
77 rate_limit_per_minute: req.rate_limit_per_minute,
78 rate_window_start: chrono::Utc::now(),
79 rate_window_calls: 0,
80 rate_limit_window_secs: req.rate_limit_window_secs,
81 data_sensitivity_ceiling: req.data_sensitivity_ceiling,
82 created_at: chrono::Utc::now(),
83 status: SessionStatus::Active,
84 };
85
86 tracing::info!(
87 session_id = %session.session_id,
88 agent_id = %session.agent_id,
89 intent = %session.declared_intent,
90 budget = session.call_budget,
91 "created task session (storage-backed)"
92 );
93
94 let stored = domain_to_stored(&session);
96 if let Err(e) = self.storage.insert_session(&stored).await {
97 tracing::error!(error = %e, "failed to persist session to storage");
98 }
99
100 let mut cache = self.cache.write().await;
101 cache.insert(session.session_id, session.clone());
102 session
103 }
104
105 pub async fn create_if_under_cap(
107 &self,
108 req: CreateSessionRequest,
109 max_sessions: u64,
110 ) -> Result<TaskSession, SessionError> {
111 let mut cache = self.cache.write().await;
112
113 let active_count = cache
114 .values()
115 .filter(|s| s.agent_id == req.agent_id && s.status == SessionStatus::Active)
116 .count() as u64;
117
118 if active_count >= max_sessions {
119 return Err(SessionError::TooManySessions {
120 agent_id: req.agent_id.to_string(),
121 max: max_sessions,
122 current: active_count,
123 });
124 }
125
126 let session = TaskSession {
127 session_id: uuid::Uuid::new_v4(),
128 agent_id: req.agent_id,
129 delegation_chain_snapshot: req.delegation_chain_snapshot,
130 declared_intent: req.declared_intent,
131 authorized_tools: req.authorized_tools,
132 authorized_credentials: req.authorized_credentials,
133 time_limit: req.time_limit,
134 call_budget: req.call_budget,
135 calls_made: 0,
136 rate_limit_per_minute: req.rate_limit_per_minute,
137 rate_window_start: chrono::Utc::now(),
138 rate_window_calls: 0,
139 rate_limit_window_secs: req.rate_limit_window_secs,
140 data_sensitivity_ceiling: req.data_sensitivity_ceiling,
141 created_at: chrono::Utc::now(),
142 status: SessionStatus::Active,
143 };
144
145 let stored = domain_to_stored(&session);
147 if let Err(e) = self.storage.insert_session(&stored).await {
148 tracing::error!(error = %e, "failed to persist session to storage");
149 }
150
151 cache.insert(session.session_id, session.clone());
152 Ok(session)
153 }
154
155 pub async fn use_session(
157 &self,
158 session_id: SessionId,
159 tool_name: &str,
160 requesting_agent_id: Option<uuid::Uuid>,
161 ) -> Result<TaskSession, SessionError> {
162 let mut cache = self.cache.write().await;
163 let session = cache
164 .get_mut(&session_id)
165 .ok_or(SessionError::NotFound(session_id))?;
166
167 if let Some(agent_id) = requesting_agent_id
169 && agent_id != session.agent_id
170 {
171 return Err(SessionError::AgentMismatch {
172 session_id,
173 expected: session.agent_id,
174 actual: agent_id,
175 });
176 }
177
178 if session.status == SessionStatus::Closed {
179 return Err(SessionError::AlreadyClosed(session_id));
180 }
181
182 if session.is_expired() {
183 session.status = SessionStatus::Expired;
184 let stored = domain_to_stored(session);
186 if let Err(e) = self.storage.update_session(&stored).await {
187 tracing::error!(error = %e, "failed to persist expired session status");
188 }
189 return Err(SessionError::Expired(session_id));
190 }
191
192 if session.is_budget_exceeded() {
193 return Err(SessionError::BudgetExceeded {
194 session_id,
195 limit: session.call_budget,
196 used: session.calls_made,
197 });
198 }
199
200 if !session.is_tool_authorized(tool_name) {
201 return Err(SessionError::ToolNotAuthorized {
202 session_id,
203 tool: tool_name.into(),
204 });
205 }
206
207 if session.check_rate_limit() {
208 return Err(SessionError::RateLimited {
209 session_id,
210 limit_per_minute: session.rate_limit_per_minute.unwrap_or(0),
211 });
212 }
213
214 session.calls_made += 1;
215
216 tracing::debug!(
217 session_id = %session_id,
218 tool = tool_name,
219 calls = session.calls_made,
220 budget = session.call_budget,
221 "session tool call recorded (storage-backed)"
222 );
223
224 let result = session.clone();
225
226 let stored = domain_to_stored(&result);
229 if let Err(e) = self.storage.update_session(&stored).await {
230 tracing::error!(error = %e, "failed to persist session update");
231 return Err(SessionError::StorageWriteThrough {
232 session_id,
233 detail: e.to_string(),
234 });
235 }
236
237 Ok(result)
238 }
239
240 pub async fn use_session_batch(
247 &self,
248 session_id: SessionId,
249 tool_names: &[&str],
250 requesting_agent_id: Option<uuid::Uuid>,
251 ) -> Result<TaskSession, SessionError> {
252 let mut cache = self.cache.write().await;
253 let session = cache
254 .get_mut(&session_id)
255 .ok_or(SessionError::NotFound(session_id))?;
256
257 if let Some(agent_id) = requesting_agent_id
259 && agent_id != session.agent_id
260 {
261 return Err(SessionError::AgentMismatch {
262 session_id,
263 expected: session.agent_id,
264 actual: agent_id,
265 });
266 }
267
268 if session.status == SessionStatus::Closed {
269 return Err(SessionError::AlreadyClosed(session_id));
270 }
271
272 if session.is_expired() {
273 session.status = SessionStatus::Expired;
274 let stored = domain_to_stored(session);
276 if let Err(e) = self.storage.update_session(&stored).await {
277 tracing::error!(error = %e, "failed to persist expired session status");
278 }
279 return Err(SessionError::Expired(session_id));
280 }
281
282 let batch_size = tool_names.len() as u64;
283
284 if session.calls_made + batch_size > session.call_budget {
286 return Err(SessionError::BudgetExceeded {
287 session_id,
288 limit: session.call_budget,
289 used: session.calls_made,
290 });
291 }
292
293 for tool_name in tool_names {
295 if !session.is_tool_authorized(tool_name) {
296 return Err(SessionError::ToolNotAuthorized {
297 session_id,
298 tool: (*tool_name).into(),
299 });
300 }
301 }
302
303 if let Some(limit) = session.rate_limit_per_minute {
305 let now = chrono::Utc::now();
306 let elapsed = now - session.rate_window_start;
307 if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
308 } else if session.rate_window_calls + batch_size > limit {
310 return Err(SessionError::RateLimited {
311 session_id,
312 limit_per_minute: limit,
313 });
314 }
315 }
316
317 if let Some(_limit) = session.rate_limit_per_minute {
319 let now = chrono::Utc::now();
320 let elapsed = now - session.rate_window_start;
321 if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
322 session.rate_window_start = now;
323 session.rate_window_calls = batch_size;
324 } else {
325 session.rate_window_calls += batch_size;
326 }
327 }
328
329 session.calls_made += batch_size;
330
331 tracing::debug!(
332 session_id = %session_id,
333 batch_size = batch_size,
334 calls = session.calls_made,
335 budget = session.call_budget,
336 "session batch tool calls recorded (storage-backed)"
337 );
338
339 let result = session.clone();
340
341 let stored = domain_to_stored(&result);
343 if let Err(e) = self.storage.update_session(&stored).await {
344 tracing::error!(error = %e, "failed to persist session batch update");
345 return Err(SessionError::StorageWriteThrough {
346 session_id,
347 detail: e.to_string(),
348 });
349 }
350
351 Ok(result)
352 }
353
354 pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
356 let mut cache = self.cache.write().await;
357 let session = cache
358 .get_mut(&session_id)
359 .ok_or(SessionError::NotFound(session_id))?;
360
361 if session.status == SessionStatus::Closed {
362 return Err(SessionError::AlreadyClosed(session_id));
363 }
364
365 session.status = SessionStatus::Closed;
366 tracing::info!(session_id = %session_id, "session closed (storage-backed)");
367
368 let result = session.clone();
369
370 let stored = domain_to_stored(&result);
372 if let Err(e) = self.storage.update_session(&stored).await {
373 tracing::error!(error = %e, "failed to persist session close");
374 }
375
376 Ok(result)
377 }
378
379 pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
381 let cache = self.cache.read().await;
382 cache
383 .get(&session_id)
384 .cloned()
385 .ok_or(SessionError::NotFound(session_id))
386 }
387
388 pub async fn list_all(&self) -> Vec<TaskSession> {
390 let cache = self.cache.read().await;
391 cache.values().cloned().collect()
392 }
393
394 pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
398 let cache = self.cache.read().await;
399 cache
400 .values()
401 .filter(|s| s.agent_id == agent_id && s.status == SessionStatus::Active)
402 .count() as u64
403 }
404
405 pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
409 let to_persist: Vec<StoredSession>;
413 let closed: usize;
414 {
415 let mut cache = self.cache.write().await;
416 let mut count = 0usize;
417 let mut stored_sessions = Vec::new();
418 for session in cache.values_mut() {
419 if session.agent_id == agent_id && session.status == SessionStatus::Active {
420 session.status = SessionStatus::Closed;
421 stored_sessions.push(domain_to_stored(session));
422 count += 1;
423 }
424 }
425 to_persist = stored_sessions;
426 closed = count;
427 } for stored in &to_persist {
431 if let Err(e) = self.storage.update_session(stored).await {
432 tracing::error!(
433 error = %e,
434 session_id = %stored.session_id,
435 "failed to persist session closure during agent deactivation"
436 );
437 }
438 }
439 closed
440 }
441
442 pub async fn cleanup_expired(&self) -> usize {
444 let mut cache = self.cache.write().await;
445 let before = cache.len();
446 cache.retain(|_, s| {
447 if s.is_expired() {
448 tracing::debug!(session_id = %s.session_id, "cleaning up expired session (storage-backed)");
449 false
450 } else {
451 true
452 }
453 });
454 let removed = before - cache.len();
455
456 if let Err(e) = self.storage.delete_expired_sessions().await {
458 tracing::error!(error = %e, "failed to clean up expired sessions in storage");
459 }
460
461 if removed > 0 {
462 tracing::info!(removed, "cleaned up expired sessions (storage-backed)");
463 }
464 removed
465 }
466}
467
468fn domain_to_stored(session: &TaskSession) -> StoredSession {
471 StoredSession {
472 session_id: session.session_id,
473 agent_id: session.agent_id,
474 delegation_chain_snapshot: session.delegation_chain_snapshot.clone(),
475 declared_intent: session.declared_intent.clone(),
476 authorized_tools: session.authorized_tools.clone(),
477 time_limit_secs: session.time_limit.num_seconds(),
478 call_budget: session.call_budget,
479 calls_made: session.calls_made,
480 rate_limit_per_minute: session.rate_limit_per_minute,
481 rate_window_start: session.rate_window_start,
482 rate_window_calls: session.rate_window_calls,
483 rate_limit_window_secs: session.rate_limit_window_secs,
484 data_sensitivity_ceiling: sensitivity_to_stored(session.data_sensitivity_ceiling),
485 created_at: session.created_at,
486 status: status_to_stored(session.status),
487 }
488}
489
490const MAX_SESSION_TIME_LIMIT_SECS: i64 = 86400;
493
494fn stored_to_domain(stored: StoredSession) -> Result<TaskSession, String> {
495 let clamped_time_limit = stored.time_limit_secs.min(MAX_SESSION_TIME_LIMIT_SECS);
497 if stored.time_limit_secs > MAX_SESSION_TIME_LIMIT_SECS {
498 tracing::warn!(
499 session_id = %stored.session_id,
500 stored = stored.time_limit_secs,
501 clamped = clamped_time_limit,
502 "session time_limit_secs exceeded maximum on reload, clamping"
503 );
504 }
505
506 Ok(TaskSession {
507 session_id: stored.session_id,
508 agent_id: stored.agent_id,
509 delegation_chain_snapshot: stored.delegation_chain_snapshot,
510 declared_intent: stored.declared_intent,
511 authorized_tools: stored.authorized_tools,
512 authorized_credentials: vec![], time_limit: chrono::Duration::seconds(clamped_time_limit),
514 call_budget: stored.call_budget,
515 calls_made: stored.calls_made,
516 rate_limit_per_minute: stored.rate_limit_per_minute,
517 rate_window_start: stored.rate_window_start,
518 rate_window_calls: stored.rate_window_calls,
519 rate_limit_window_secs: stored.rate_limit_window_secs,
520 data_sensitivity_ceiling: stored_to_sensitivity(stored.data_sensitivity_ceiling),
521 created_at: stored.created_at,
522 status: stored_to_status(stored.status),
523 })
524}
525
526fn status_to_stored(status: SessionStatus) -> StoredSessionStatus {
527 match status {
528 SessionStatus::Active => StoredSessionStatus::Active,
529 SessionStatus::Closed => StoredSessionStatus::Closed,
530 SessionStatus::Expired => StoredSessionStatus::Expired,
531 }
532}
533
534fn stored_to_status(status: StoredSessionStatus) -> SessionStatus {
535 match status {
536 StoredSessionStatus::Active => SessionStatus::Active,
537 StoredSessionStatus::Closed => SessionStatus::Closed,
538 StoredSessionStatus::Expired => SessionStatus::Expired,
539 }
540}
541
542fn sensitivity_to_stored(sensitivity: DataSensitivity) -> StoredDataSensitivity {
543 match sensitivity {
544 DataSensitivity::Public => StoredDataSensitivity::Public,
545 DataSensitivity::Internal => StoredDataSensitivity::Internal,
546 DataSensitivity::Confidential => StoredDataSensitivity::Confidential,
547 DataSensitivity::Restricted => StoredDataSensitivity::Restricted,
548 }
549}
550
551fn stored_to_sensitivity(sensitivity: StoredDataSensitivity) -> DataSensitivity {
552 match sensitivity {
553 StoredDataSensitivity::Public => DataSensitivity::Public,
554 StoredDataSensitivity::Internal => DataSensitivity::Internal,
555 StoredDataSensitivity::Confidential => DataSensitivity::Confidential,
556 StoredDataSensitivity::Restricted => DataSensitivity::Restricted,
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563 use async_trait::async_trait;
564 use std::collections::HashMap;
565 use std::sync::Arc;
566 use tokio::sync::RwLock;
567
568 #[derive(Clone)]
570 struct MockStorage {
571 sessions: Arc<RwLock<HashMap<Uuid, StoredSession>>>,
572 }
573
574 impl MockStorage {
575 fn new() -> Self {
576 Self {
577 sessions: Arc::new(RwLock::new(HashMap::new())),
578 }
579 }
580 }
581
582 #[async_trait]
583 impl StorageSessionStore for MockStorage {
584 async fn insert_session(&self, session: &StoredSession) -> Result<(), StorageError> {
585 let mut map = self.sessions.write().await;
586 map.insert(session.session_id, session.clone());
587 Ok(())
588 }
589
590 async fn get_session(&self, session_id: Uuid) -> Result<StoredSession, StorageError> {
591 let map = self.sessions.read().await;
592 map.get(&session_id)
593 .cloned()
594 .ok_or(StorageError::SessionNotFound(session_id))
595 }
596
597 async fn update_session(&self, session: &StoredSession) -> Result<(), StorageError> {
598 let mut map = self.sessions.write().await;
599 map.insert(session.session_id, session.clone());
600 Ok(())
601 }
602
603 async fn delete_expired_sessions(&self) -> Result<usize, StorageError> {
604 let mut map = self.sessions.write().await;
605 let before = map.len();
606 let now = chrono::Utc::now();
607 map.retain(|_, s| {
608 let created = s.created_at;
609 let limit = chrono::Duration::seconds(s.time_limit_secs);
610 let elapsed = now - created;
611 elapsed <= limit && s.status != StoredSessionStatus::Expired
612 });
613 Ok(before - map.len())
614 }
615
616 async fn list_sessions(&self) -> Result<Vec<StoredSession>, StorageError> {
617 let map = self.sessions.read().await;
618 Ok(map.values().cloned().collect())
619 }
620 }
621
622 fn test_create_request() -> CreateSessionRequest {
623 CreateSessionRequest {
624 agent_id: Uuid::new_v4(),
625 delegation_chain_snapshot: vec![],
626 declared_intent: "read and analyze files".into(),
627 authorized_tools: vec!["read_file".into(), "list_dir".into()],
628 authorized_credentials: vec![],
629 time_limit: chrono::Duration::hours(1),
630 call_budget: 5,
631 rate_limit_per_minute: None,
632 rate_limit_window_secs: 60,
633 data_sensitivity_ceiling: DataSensitivity::Internal,
634 }
635 }
636
637 async fn make_store() -> (StorageBackedSessionStore, MockStorage) {
638 let mock = MockStorage::new();
639 let store = StorageBackedSessionStore::new(Arc::new(mock.clone()))
640 .await
641 .expect("failed to create storage-backed store");
642 (store, mock)
643 }
644
645 #[tokio::test]
646 async fn create_and_use_session() {
647 let (store, _mock) = make_store().await;
648 let session = store.create(test_create_request()).await;
649
650 assert_eq!(session.calls_made, 0);
651 assert!(session.is_active());
652
653 let updated = store
654 .use_session(session.session_id, "read_file", None)
655 .await
656 .unwrap();
657 assert_eq!(updated.calls_made, 1);
658
659 let fetched = store.get(session.session_id).await.unwrap();
661 assert_eq!(fetched.calls_made, 1);
662 }
663
664 #[tokio::test]
665 async fn budget_enforcement() {
666 let (store, _mock) = make_store().await;
667 let mut req = test_create_request();
668 req.call_budget = 2;
669 let session = store.create(req).await;
670
671 store
672 .use_session(session.session_id, "read_file", None)
673 .await
674 .unwrap();
675 store
676 .use_session(session.session_id, "read_file", None)
677 .await
678 .unwrap();
679
680 let result = store
681 .use_session(session.session_id, "read_file", None)
682 .await;
683 assert!(matches!(result, Err(SessionError::BudgetExceeded { .. })));
684 }
685
686 #[tokio::test]
687 async fn tool_whitelist_enforcement() {
688 let (store, _mock) = make_store().await;
689 let session = store.create(test_create_request()).await;
690
691 store
692 .use_session(session.session_id, "read_file", None)
693 .await
694 .unwrap();
695
696 let result = store
697 .use_session(session.session_id, "delete_file", None)
698 .await;
699 assert!(matches!(
700 result,
701 Err(SessionError::ToolNotAuthorized { .. })
702 ));
703 }
704
705 #[tokio::test]
706 async fn session_expiry() {
707 let (store, _mock) = make_store().await;
708 let mut req = test_create_request();
709 req.time_limit = chrono::Duration::zero();
710 let session = store.create(req).await;
711
712 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
713
714 let result = store
715 .use_session(session.session_id, "read_file", None)
716 .await;
717 assert!(matches!(result, Err(SessionError::Expired(_))));
718 }
719
720 #[tokio::test]
721 async fn close_and_reuse() {
722 let (store, _mock) = make_store().await;
723 let session = store.create(test_create_request()).await;
724
725 store.close(session.session_id).await.unwrap();
726
727 let result = store
728 .use_session(session.session_id, "read_file", None)
729 .await;
730 assert!(matches!(result, Err(SessionError::AlreadyClosed(_))));
731 }
732
733 #[tokio::test]
734 async fn session_not_found() {
735 let (store, _mock) = make_store().await;
736 let fake_id = Uuid::new_v4();
737 let result = store.use_session(fake_id, "anything", None).await;
738 assert!(matches!(result, Err(SessionError::NotFound(_))));
739 }
740
741 #[tokio::test]
743 async fn batch_validation_atomicity() {
744 let (store, _mock) = make_store().await;
745 let mut req = test_create_request();
746 req.call_budget = 10;
747 req.authorized_tools = vec!["read_file".into(), "list_dir".into()];
748 let session = store.create(req).await;
749
750 let result = store
752 .use_session_batch(session.session_id, &["read_file", "delete_file"], None)
753 .await;
754 assert!(
755 matches!(result, Err(SessionError::ToolNotAuthorized { .. })),
756 "expected ToolNotAuthorized, got {result:?}"
757 );
758
759 let s = store.get(session.session_id).await.unwrap();
761 assert_eq!(
762 s.calls_made, 0,
763 "no budget should be consumed on batch failure"
764 );
765 }
766
767 #[tokio::test]
768 async fn batch_budget_enforcement() {
769 let (store, _mock) = make_store().await;
770 let mut req = test_create_request();
771 req.call_budget = 3;
772 req.authorized_tools = vec!["read_file".into()];
773 let session = store.create(req).await;
774
775 let result = store
777 .use_session_batch(
778 session.session_id,
779 &["read_file", "read_file", "read_file", "read_file"],
780 None,
781 )
782 .await;
783 assert!(
784 matches!(result, Err(SessionError::BudgetExceeded { .. })),
785 "expected BudgetExceeded, got {result:?}"
786 );
787
788 let s = store.get(session.session_id).await.unwrap();
790 assert_eq!(
791 s.calls_made, 0,
792 "no budget should be consumed on batch failure"
793 );
794 }
795
796 #[tokio::test]
797 async fn batch_rate_limit_enforcement() {
798 let (store, _mock) = make_store().await;
799 let mut req = test_create_request();
800 req.call_budget = 100;
801 req.rate_limit_per_minute = Some(3);
802 req.authorized_tools = vec!["read_file".into()];
803 let session = store.create(req).await;
804
805 let result = store
807 .use_session_batch(
808 session.session_id,
809 &["read_file", "read_file", "read_file", "read_file"],
810 None,
811 )
812 .await;
813 assert!(
814 matches!(result, Err(SessionError::RateLimited { .. })),
815 "expected RateLimited, got {result:?}"
816 );
817 }
818
819 #[tokio::test]
820 async fn cleanup_expired_sessions() {
821 let (store, _mock) = make_store().await;
822
823 let mut req = test_create_request();
825 req.time_limit = chrono::Duration::zero();
826 store.create(req).await;
827
828 store.create(test_create_request()).await;
830
831 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
832
833 let removed = store.cleanup_expired().await;
834 assert_eq!(removed, 1);
835 }
836
837 #[tokio::test]
838 async fn count_active_for_agent() {
839 let (store, _mock) = make_store().await;
840 let agent_id = Uuid::new_v4();
841
842 for _ in 0..3 {
844 let mut req = test_create_request();
845 req.agent_id = agent_id;
846 store.create(req).await;
847 }
848
849 store.create(test_create_request()).await;
851
852 let count = store.count_active_for_agent(agent_id).await;
853 assert_eq!(count, 3);
854 }
855
856 #[tokio::test]
858 async fn storage_write_through() {
859 let (store, mock) = make_store().await;
860 let session = store.create(test_create_request()).await;
861
862 let stored = mock
864 .get_session(session.session_id)
865 .await
866 .expect("session should exist in storage after create");
867 assert_eq!(stored.calls_made, 0);
868
869 store
871 .use_session(session.session_id, "read_file", None)
872 .await
873 .unwrap();
874 let stored = mock.get_session(session.session_id).await.unwrap();
875 assert_eq!(stored.calls_made, 1, "storage must reflect the tool call");
876
877 store.close(session.session_id).await.unwrap();
879 let stored = mock.get_session(session.session_id).await.unwrap();
880 assert_eq!(
881 stored.status,
882 StoredSessionStatus::Closed,
883 "storage must reflect the closed status"
884 );
885 }
886}