Skip to main content

arbiter_session/
storage_store.rs

1//! Storage-backed session store with in-memory write-through cache.
2//!
3//! Design decision: In-memory cache for hot-path reads (request latency),
4//! write-through to persistent storage for durability (persistence depth).
5//! On startup, the cache is populated from storage.
6//!
7//! REQ-001: Session state survives process restart.
8//! REQ-007: Storage behind async trait; swappable backends.
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use uuid::Uuid;
14
15use arbiter_storage::{
16    SessionStore as StorageSessionStore, StorageError, StoredDataSensitivity, StoredSession,
17    StoredSessionStatus,
18};
19
20use crate::error::SessionError;
21use crate::model::{DataSensitivity, SessionId, SessionStatus, TaskSession};
22use crate::store::CreateSessionRequest;
23
24/// A session store backed by persistent storage with an in-memory cache.
25///
26/// All reads hit the cache first. All writes go to both the cache and
27/// the underlying storage backend. On construction, the cache is warmed
28/// from storage to handle process restarts.
29#[derive(Clone)]
30pub struct StorageBackedSessionStore {
31    cache: Arc<RwLock<HashMap<SessionId, TaskSession>>>,
32    storage: Arc<dyn StorageSessionStore>,
33}
34
35impl StorageBackedSessionStore {
36    /// Create a new storage-backed session store.
37    ///
38    /// Loads all existing sessions from storage into the in-memory cache.
39    pub async fn new(storage: Arc<dyn StorageSessionStore>) -> Result<Self, StorageError> {
40        let store = Self {
41            cache: Arc::new(RwLock::new(HashMap::new())),
42            storage,
43        };
44
45        // Warm cache from storage.
46        store.reload_from_storage().await?;
47
48        Ok(store)
49    }
50
51    /// Reload the in-memory cache from storage.
52    async fn reload_from_storage(&self) -> Result<(), StorageError> {
53        let stored_sessions = self.storage.list_sessions().await?;
54        let mut cache = self.cache.write().await;
55        cache.clear();
56        for stored in stored_sessions {
57            if let Ok(session) = stored_to_domain(stored) {
58                cache.insert(session.session_id, session);
59            }
60        }
61        tracing::info!(sessions = cache.len(), "session cache warmed from storage");
62        Ok(())
63    }
64
65    /// Create a new task session and return it.
66    pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
67        let session = TaskSession {
68            session_id: Uuid::new_v4(),
69            agent_id: req.agent_id,
70            delegation_chain_snapshot: req.delegation_chain_snapshot,
71            declared_intent: req.declared_intent,
72            authorized_tools: req.authorized_tools,
73            time_limit: req.time_limit,
74            call_budget: req.call_budget,
75            calls_made: 0,
76            rate_limit_per_minute: req.rate_limit_per_minute,
77            rate_window_start: chrono::Utc::now(),
78            rate_window_calls: 0,
79            rate_limit_window_secs: req.rate_limit_window_secs,
80            data_sensitivity_ceiling: req.data_sensitivity_ceiling,
81            created_at: chrono::Utc::now(),
82            status: SessionStatus::Active,
83        };
84
85        tracing::info!(
86            session_id = %session.session_id,
87            agent_id = %session.agent_id,
88            intent = %session.declared_intent,
89            budget = session.call_budget,
90            "created task session (storage-backed)"
91        );
92
93        // Write to storage first, then cache.
94        let stored = domain_to_stored(&session);
95        if let Err(e) = self.storage.insert_session(&stored).await {
96            tracing::error!(error = %e, "failed to persist session to storage");
97        }
98
99        let mut cache = self.cache.write().await;
100        cache.insert(session.session_id, session.clone());
101        session
102    }
103
104    /// Record a tool call against the session, checking all constraints.
105    pub async fn use_session(
106        &self,
107        session_id: SessionId,
108        tool_name: &str,
109    ) -> Result<TaskSession, SessionError> {
110        let mut cache = self.cache.write().await;
111        let session = cache
112            .get_mut(&session_id)
113            .ok_or(SessionError::NotFound(session_id))?;
114
115        if session.status == SessionStatus::Closed {
116            return Err(SessionError::AlreadyClosed(session_id));
117        }
118
119        if session.is_expired() {
120            session.status = SessionStatus::Expired;
121            // Write-through: update storage.
122            let stored = domain_to_stored(session);
123            if let Err(e) = self.storage.update_session(&stored).await {
124                tracing::error!(error = %e, "failed to persist expired session status");
125            }
126            return Err(SessionError::Expired(session_id));
127        }
128
129        if session.is_budget_exceeded() {
130            return Err(SessionError::BudgetExceeded {
131                session_id,
132                limit: session.call_budget,
133                used: session.calls_made,
134            });
135        }
136
137        if !session.is_tool_authorized(tool_name) {
138            return Err(SessionError::ToolNotAuthorized {
139                session_id,
140                tool: tool_name.into(),
141            });
142        }
143
144        if session.check_rate_limit() {
145            return Err(SessionError::RateLimited {
146                session_id,
147                limit_per_minute: session.rate_limit_per_minute.unwrap_or(0),
148            });
149        }
150
151        session.calls_made += 1;
152
153        tracing::debug!(
154            session_id = %session_id,
155            tool = tool_name,
156            calls = session.calls_made,
157            budget = session.call_budget,
158            "session tool call recorded (storage-backed)"
159        );
160
161        let result = session.clone();
162
163        // Write-through: update storage.
164        let stored = domain_to_stored(&result);
165        if let Err(e) = self.storage.update_session(&stored).await {
166            tracing::error!(error = %e, "failed to persist session update");
167        }
168
169        Ok(result)
170    }
171
172    /// Atomically validate and record a batch of tool calls against the session.
173    ///
174    /// Acquires the write lock once, validates ALL tools against
175    /// the whitelist and budget, and only increments `calls_made` by the full
176    /// batch count if every tool passes. If any tool fails validation, no
177    /// budget is consumed for any of them.
178    pub async fn use_session_batch(
179        &self,
180        session_id: SessionId,
181        tool_names: &[&str],
182    ) -> Result<TaskSession, SessionError> {
183        let mut cache = self.cache.write().await;
184        let session = cache
185            .get_mut(&session_id)
186            .ok_or(SessionError::NotFound(session_id))?;
187
188        if session.status == SessionStatus::Closed {
189            return Err(SessionError::AlreadyClosed(session_id));
190        }
191
192        if session.is_expired() {
193            session.status = SessionStatus::Expired;
194            // Write-through: update storage.
195            let stored = domain_to_stored(session);
196            if let Err(e) = self.storage.update_session(&stored).await {
197                tracing::error!(error = %e, "failed to persist expired session status");
198            }
199            return Err(SessionError::Expired(session_id));
200        }
201
202        let batch_size = tool_names.len() as u64;
203
204        // Check budget for the entire batch.
205        if session.calls_made + batch_size > session.call_budget {
206            return Err(SessionError::BudgetExceeded {
207                session_id,
208                limit: session.call_budget,
209                used: session.calls_made,
210            });
211        }
212
213        // Check tool authorization for every tool before consuming any budget.
214        for tool_name in tool_names {
215            if !session.is_tool_authorized(tool_name) {
216                return Err(SessionError::ToolNotAuthorized {
217                    session_id,
218                    tool: (*tool_name).into(),
219                });
220            }
221        }
222
223        // Check rate limit for the entire batch.
224        if let Some(limit) = session.rate_limit_per_minute {
225            let now = chrono::Utc::now();
226            let elapsed = now - session.rate_window_start;
227            if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
228                // New window; will be reset below after all checks pass.
229            } else if session.rate_window_calls + batch_size > limit {
230                return Err(SessionError::RateLimited {
231                    session_id,
232                    limit_per_minute: limit,
233                });
234            }
235        }
236
237        // All checks passed. Atomically increment counters.
238        if let Some(_limit) = session.rate_limit_per_minute {
239            let now = chrono::Utc::now();
240            let elapsed = now - session.rate_window_start;
241            if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
242                session.rate_window_start = now;
243                session.rate_window_calls = batch_size;
244            } else {
245                session.rate_window_calls += batch_size;
246            }
247        }
248
249        session.calls_made += batch_size;
250
251        tracing::debug!(
252            session_id = %session_id,
253            batch_size = batch_size,
254            calls = session.calls_made,
255            budget = session.call_budget,
256            "session batch tool calls recorded (storage-backed)"
257        );
258
259        let result = session.clone();
260
261        // Write-through: update storage.
262        let stored = domain_to_stored(&result);
263        if let Err(e) = self.storage.update_session(&stored).await {
264            tracing::error!(error = %e, "failed to persist session batch update");
265        }
266
267        Ok(result)
268    }
269
270    /// Close a session, preventing further use.
271    pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
272        let mut cache = self.cache.write().await;
273        let session = cache
274            .get_mut(&session_id)
275            .ok_or(SessionError::NotFound(session_id))?;
276
277        if session.status == SessionStatus::Closed {
278            return Err(SessionError::AlreadyClosed(session_id));
279        }
280
281        session.status = SessionStatus::Closed;
282        tracing::info!(session_id = %session_id, "session closed (storage-backed)");
283
284        let result = session.clone();
285
286        // Write-through: update storage.
287        let stored = domain_to_stored(&result);
288        if let Err(e) = self.storage.update_session(&stored).await {
289            tracing::error!(error = %e, "failed to persist session close");
290        }
291
292        Ok(result)
293    }
294
295    /// Get a session by ID without modifying it.
296    pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
297        let cache = self.cache.read().await;
298        cache
299            .get(&session_id)
300            .cloned()
301            .ok_or(SessionError::NotFound(session_id))
302    }
303
304    /// List all sessions currently in the store.
305    pub async fn list_all(&self) -> Vec<TaskSession> {
306        let cache = self.cache.read().await;
307        cache.values().cloned().collect()
308    }
309
310    /// Count the number of active sessions for a given agent.
311    ///
312    /// P0: Used to enforce per-agent concurrent session caps.
313    pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
314        let cache = self.cache.read().await;
315        cache
316            .values()
317            .filter(|s| s.agent_id == agent_id && s.status == SessionStatus::Active)
318            .count() as u64
319    }
320
321    /// Close all active sessions belonging to a specific agent.
322    ///
323    /// Called during agent deactivation.
324    pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
325        let mut cache = self.cache.write().await;
326        let mut closed = 0usize;
327        for session in cache.values_mut() {
328            if session.agent_id == agent_id && session.status == SessionStatus::Active {
329                session.status = SessionStatus::Closed;
330                let stored = domain_to_stored(session);
331                if let Err(e) = self.storage.update_session(&stored).await {
332                    tracing::error!(error = %e, "failed to persist session closure during agent deactivation");
333                }
334                closed += 1;
335            }
336        }
337        closed
338    }
339
340    /// Remove expired sessions from cache and storage. Returns the number removed.
341    pub async fn cleanup_expired(&self) -> usize {
342        let mut cache = self.cache.write().await;
343        let before = cache.len();
344        cache.retain(|_, s| {
345            if s.is_expired() {
346                tracing::debug!(session_id = %s.session_id, "cleaning up expired session (storage-backed)");
347                false
348            } else {
349                true
350            }
351        });
352        let removed = before - cache.len();
353
354        // Also clean up storage.
355        if let Err(e) = self.storage.delete_expired_sessions().await {
356            tracing::error!(error = %e, "failed to clean up expired sessions in storage");
357        }
358
359        if removed > 0 {
360            tracing::info!(removed, "cleaned up expired sessions (storage-backed)");
361        }
362        removed
363    }
364}
365
366// ── Conversion helpers ──────────────────────────────────────────────
367
368fn domain_to_stored(session: &TaskSession) -> StoredSession {
369    StoredSession {
370        session_id: session.session_id,
371        agent_id: session.agent_id,
372        delegation_chain_snapshot: session.delegation_chain_snapshot.clone(),
373        declared_intent: session.declared_intent.clone(),
374        authorized_tools: session.authorized_tools.clone(),
375        time_limit_secs: session.time_limit.num_seconds(),
376        call_budget: session.call_budget,
377        calls_made: session.calls_made,
378        rate_limit_per_minute: session.rate_limit_per_minute,
379        rate_window_start: session.rate_window_start,
380        rate_window_calls: session.rate_window_calls,
381        rate_limit_window_secs: session.rate_limit_window_secs,
382        data_sensitivity_ceiling: sensitivity_to_stored(session.data_sensitivity_ceiling),
383        created_at: session.created_at,
384        status: status_to_stored(session.status),
385    }
386}
387
388fn stored_to_domain(stored: StoredSession) -> Result<TaskSession, String> {
389    Ok(TaskSession {
390        session_id: stored.session_id,
391        agent_id: stored.agent_id,
392        delegation_chain_snapshot: stored.delegation_chain_snapshot,
393        declared_intent: stored.declared_intent,
394        authorized_tools: stored.authorized_tools,
395        time_limit: chrono::Duration::seconds(stored.time_limit_secs),
396        call_budget: stored.call_budget,
397        calls_made: stored.calls_made,
398        rate_limit_per_minute: stored.rate_limit_per_minute,
399        rate_window_start: stored.rate_window_start,
400        rate_window_calls: stored.rate_window_calls,
401        rate_limit_window_secs: stored.rate_limit_window_secs,
402        data_sensitivity_ceiling: stored_to_sensitivity(stored.data_sensitivity_ceiling),
403        created_at: stored.created_at,
404        status: stored_to_status(stored.status),
405    })
406}
407
408fn status_to_stored(status: SessionStatus) -> StoredSessionStatus {
409    match status {
410        SessionStatus::Active => StoredSessionStatus::Active,
411        SessionStatus::Closed => StoredSessionStatus::Closed,
412        SessionStatus::Expired => StoredSessionStatus::Expired,
413    }
414}
415
416fn stored_to_status(status: StoredSessionStatus) -> SessionStatus {
417    match status {
418        StoredSessionStatus::Active => SessionStatus::Active,
419        StoredSessionStatus::Closed => SessionStatus::Closed,
420        StoredSessionStatus::Expired => SessionStatus::Expired,
421    }
422}
423
424fn sensitivity_to_stored(sensitivity: DataSensitivity) -> StoredDataSensitivity {
425    match sensitivity {
426        DataSensitivity::Public => StoredDataSensitivity::Public,
427        DataSensitivity::Internal => StoredDataSensitivity::Internal,
428        DataSensitivity::Confidential => StoredDataSensitivity::Confidential,
429        DataSensitivity::Restricted => StoredDataSensitivity::Restricted,
430    }
431}
432
433fn stored_to_sensitivity(sensitivity: StoredDataSensitivity) -> DataSensitivity {
434    match sensitivity {
435        StoredDataSensitivity::Public => DataSensitivity::Public,
436        StoredDataSensitivity::Internal => DataSensitivity::Internal,
437        StoredDataSensitivity::Confidential => DataSensitivity::Confidential,
438        StoredDataSensitivity::Restricted => DataSensitivity::Restricted,
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use async_trait::async_trait;
446    use std::collections::HashMap;
447    use std::sync::Arc;
448    use tokio::sync::RwLock;
449
450    /// In-memory mock implementing `StorageSessionStore` (arbiter_storage::SessionStore).
451    #[derive(Clone)]
452    struct MockStorage {
453        sessions: Arc<RwLock<HashMap<Uuid, StoredSession>>>,
454    }
455
456    impl MockStorage {
457        fn new() -> Self {
458            Self {
459                sessions: Arc::new(RwLock::new(HashMap::new())),
460            }
461        }
462    }
463
464    #[async_trait]
465    impl StorageSessionStore for MockStorage {
466        async fn insert_session(&self, session: &StoredSession) -> Result<(), StorageError> {
467            let mut map = self.sessions.write().await;
468            map.insert(session.session_id, session.clone());
469            Ok(())
470        }
471
472        async fn get_session(&self, session_id: Uuid) -> Result<StoredSession, StorageError> {
473            let map = self.sessions.read().await;
474            map.get(&session_id)
475                .cloned()
476                .ok_or(StorageError::SessionNotFound(session_id))
477        }
478
479        async fn update_session(&self, session: &StoredSession) -> Result<(), StorageError> {
480            let mut map = self.sessions.write().await;
481            map.insert(session.session_id, session.clone());
482            Ok(())
483        }
484
485        async fn delete_expired_sessions(&self) -> Result<usize, StorageError> {
486            let mut map = self.sessions.write().await;
487            let before = map.len();
488            let now = chrono::Utc::now();
489            map.retain(|_, s| {
490                let created = s.created_at;
491                let limit = chrono::Duration::seconds(s.time_limit_secs);
492                let elapsed = now - created;
493                elapsed <= limit && s.status != StoredSessionStatus::Expired
494            });
495            Ok(before - map.len())
496        }
497
498        async fn list_sessions(&self) -> Result<Vec<StoredSession>, StorageError> {
499            let map = self.sessions.read().await;
500            Ok(map.values().cloned().collect())
501        }
502    }
503
504    fn test_create_request() -> CreateSessionRequest {
505        CreateSessionRequest {
506            agent_id: Uuid::new_v4(),
507            delegation_chain_snapshot: vec![],
508            declared_intent: "read and analyze files".into(),
509            authorized_tools: vec!["read_file".into(), "list_dir".into()],
510            time_limit: chrono::Duration::hours(1),
511            call_budget: 5,
512            rate_limit_per_minute: None,
513            rate_limit_window_secs: 60,
514            data_sensitivity_ceiling: DataSensitivity::Internal,
515        }
516    }
517
518    async fn make_store() -> (StorageBackedSessionStore, MockStorage) {
519        let mock = MockStorage::new();
520        let store = StorageBackedSessionStore::new(Arc::new(mock.clone()))
521            .await
522            .expect("failed to create storage-backed store");
523        (store, mock)
524    }
525
526    #[tokio::test]
527    async fn create_and_use_session() {
528        let (store, _mock) = make_store().await;
529        let session = store.create(test_create_request()).await;
530
531        assert_eq!(session.calls_made, 0);
532        assert!(session.is_active());
533
534        let updated = store
535            .use_session(session.session_id, "read_file")
536            .await
537            .unwrap();
538        assert_eq!(updated.calls_made, 1);
539
540        // Verify get returns same data.
541        let fetched = store.get(session.session_id).await.unwrap();
542        assert_eq!(fetched.calls_made, 1);
543    }
544
545    #[tokio::test]
546    async fn budget_enforcement() {
547        let (store, _mock) = make_store().await;
548        let mut req = test_create_request();
549        req.call_budget = 2;
550        let session = store.create(req).await;
551
552        store
553            .use_session(session.session_id, "read_file")
554            .await
555            .unwrap();
556        store
557            .use_session(session.session_id, "read_file")
558            .await
559            .unwrap();
560
561        let result = store.use_session(session.session_id, "read_file").await;
562        assert!(matches!(result, Err(SessionError::BudgetExceeded { .. })));
563    }
564
565    #[tokio::test]
566    async fn tool_whitelist_enforcement() {
567        let (store, _mock) = make_store().await;
568        let session = store.create(test_create_request()).await;
569
570        store
571            .use_session(session.session_id, "read_file")
572            .await
573            .unwrap();
574
575        let result = store.use_session(session.session_id, "delete_file").await;
576        assert!(matches!(
577            result,
578            Err(SessionError::ToolNotAuthorized { .. })
579        ));
580    }
581
582    #[tokio::test]
583    async fn session_expiry() {
584        let (store, _mock) = make_store().await;
585        let mut req = test_create_request();
586        req.time_limit = chrono::Duration::zero();
587        let session = store.create(req).await;
588
589        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
590
591        let result = store.use_session(session.session_id, "read_file").await;
592        assert!(matches!(result, Err(SessionError::Expired(_))));
593    }
594
595    #[tokio::test]
596    async fn close_and_reuse() {
597        let (store, _mock) = make_store().await;
598        let session = store.create(test_create_request()).await;
599
600        store.close(session.session_id).await.unwrap();
601
602        let result = store.use_session(session.session_id, "read_file").await;
603        assert!(matches!(result, Err(SessionError::AlreadyClosed(_))));
604    }
605
606    #[tokio::test]
607    async fn session_not_found() {
608        let (store, _mock) = make_store().await;
609        let fake_id = Uuid::new_v4();
610        let result = store.use_session(fake_id, "anything").await;
611        assert!(matches!(result, Err(SessionError::NotFound(_))));
612    }
613
614    /// batch with one bad tool must consume zero budget.
615    #[tokio::test]
616    async fn batch_validation_atomicity() {
617        let (store, _mock) = make_store().await;
618        let mut req = test_create_request();
619        req.call_budget = 10;
620        req.authorized_tools = vec!["read_file".into(), "list_dir".into()];
621        let session = store.create(req).await;
622
623        // Batch contains one unauthorized tool ("delete_file").
624        let result = store
625            .use_session_batch(session.session_id, &["read_file", "delete_file"])
626            .await;
627        assert!(
628            matches!(result, Err(SessionError::ToolNotAuthorized { .. })),
629            "expected ToolNotAuthorized, got {result:?}"
630        );
631
632        // Budget must remain untouched.
633        let s = store.get(session.session_id).await.unwrap();
634        assert_eq!(
635            s.calls_made, 0,
636            "no budget should be consumed on batch failure"
637        );
638    }
639
640    #[tokio::test]
641    async fn batch_budget_enforcement() {
642        let (store, _mock) = make_store().await;
643        let mut req = test_create_request();
644        req.call_budget = 3;
645        req.authorized_tools = vec!["read_file".into()];
646        let session = store.create(req).await;
647
648        // Batch of 4 exceeds budget of 3.
649        let result = store
650            .use_session_batch(
651                session.session_id,
652                &["read_file", "read_file", "read_file", "read_file"],
653            )
654            .await;
655        assert!(
656            matches!(result, Err(SessionError::BudgetExceeded { .. })),
657            "expected BudgetExceeded, got {result:?}"
658        );
659
660        // Budget must remain at 0 (no partial consumption).
661        let s = store.get(session.session_id).await.unwrap();
662        assert_eq!(
663            s.calls_made, 0,
664            "no budget should be consumed on batch failure"
665        );
666    }
667
668    #[tokio::test]
669    async fn batch_rate_limit_enforcement() {
670        let (store, _mock) = make_store().await;
671        let mut req = test_create_request();
672        req.call_budget = 100;
673        req.rate_limit_per_minute = Some(3);
674        req.authorized_tools = vec!["read_file".into()];
675        let session = store.create(req).await;
676
677        // Batch of 4 exceeds rate limit of 3.
678        let result = store
679            .use_session_batch(
680                session.session_id,
681                &["read_file", "read_file", "read_file", "read_file"],
682            )
683            .await;
684        assert!(
685            matches!(result, Err(SessionError::RateLimited { .. })),
686            "expected RateLimited, got {result:?}"
687        );
688    }
689
690    #[tokio::test]
691    async fn cleanup_expired_sessions() {
692        let (store, _mock) = make_store().await;
693
694        // Create an already-expired session.
695        let mut req = test_create_request();
696        req.time_limit = chrono::Duration::zero();
697        store.create(req).await;
698
699        // Create a valid session.
700        store.create(test_create_request()).await;
701
702        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
703
704        let removed = store.cleanup_expired().await;
705        assert_eq!(removed, 1);
706    }
707
708    #[tokio::test]
709    async fn count_active_for_agent() {
710        let (store, _mock) = make_store().await;
711        let agent_id = Uuid::new_v4();
712
713        // Create 3 sessions for the same agent.
714        for _ in 0..3 {
715            let mut req = test_create_request();
716            req.agent_id = agent_id;
717            store.create(req).await;
718        }
719
720        // Create 1 session for a different agent.
721        store.create(test_create_request()).await;
722
723        let count = store.count_active_for_agent(agent_id).await;
724        assert_eq!(count, 3);
725    }
726
727    /// Verify that mutations are written through to the mock storage backend.
728    #[tokio::test]
729    async fn storage_write_through() {
730        let (store, mock) = make_store().await;
731        let session = store.create(test_create_request()).await;
732
733        // After create, session must exist in storage.
734        let stored = mock
735            .get_session(session.session_id)
736            .await
737            .expect("session should exist in storage after create");
738        assert_eq!(stored.calls_made, 0);
739
740        // After use_session, storage must reflect the increment.
741        store
742            .use_session(session.session_id, "read_file")
743            .await
744            .unwrap();
745        let stored = mock.get_session(session.session_id).await.unwrap();
746        assert_eq!(stored.calls_made, 1, "storage must reflect the tool call");
747
748        // After close, storage must reflect the new status.
749        store.close(session.session_id).await.unwrap();
750        let stored = mock.get_session(session.session_id).await.unwrap();
751        assert_eq!(
752            stored.status,
753            StoredSessionStatus::Closed,
754            "storage must reflect the closed status"
755        );
756    }
757}