Skip to main content

arbiter_session/
storage_store.rs

1//! Storage-backed session store with in-memory write-through cache.
2//!
3//! Design decision: In-memory cache for hot-path reads (request latency),
4//! write-through to persistent storage for durability (persistence depth).
5//! On startup, the cache is populated from storage.
6//!
7//! REQ-001: Session state survives process restart.
8//! REQ-007: Storage behind async trait; swappable backends.
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use uuid::Uuid;
14
15use arbiter_storage::{
16    SessionStore as StorageSessionStore, StorageError, StoredDataSensitivity, StoredSession,
17    StoredSessionStatus,
18};
19
20use crate::error::SessionError;
21use crate::model::{DataSensitivity, SessionId, SessionStatus, TaskSession};
22use crate::store::CreateSessionRequest;
23
24/// A session store backed by persistent storage with an in-memory cache.
25///
26/// All reads hit the cache first. All writes go to both the cache and
27/// the underlying storage backend. On construction, the cache is warmed
28/// from storage to handle process restarts.
29#[derive(Clone)]
30pub struct StorageBackedSessionStore {
31    cache: Arc<RwLock<HashMap<SessionId, TaskSession>>>,
32    storage: Arc<dyn StorageSessionStore>,
33}
34
35impl StorageBackedSessionStore {
36    /// Create a new storage-backed session store.
37    ///
38    /// Loads all existing sessions from storage into the in-memory cache.
39    pub async fn new(storage: Arc<dyn StorageSessionStore>) -> Result<Self, StorageError> {
40        let store = Self {
41            cache: Arc::new(RwLock::new(HashMap::new())),
42            storage,
43        };
44
45        // Warm cache from storage.
46        store.reload_from_storage().await?;
47
48        Ok(store)
49    }
50
51    /// Reload the in-memory cache from storage.
52    async fn reload_from_storage(&self) -> Result<(), StorageError> {
53        let stored_sessions = self.storage.list_sessions().await?;
54        let mut cache = self.cache.write().await;
55        cache.clear();
56        for stored in stored_sessions {
57            if let Ok(session) = stored_to_domain(stored) {
58                cache.insert(session.session_id, session);
59            }
60        }
61        tracing::info!(sessions = cache.len(), "session cache warmed from storage");
62        Ok(())
63    }
64
65    /// Create a new task session and return it.
66    pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
67        let session = TaskSession {
68            session_id: Uuid::new_v4(),
69            agent_id: req.agent_id,
70            delegation_chain_snapshot: req.delegation_chain_snapshot,
71            declared_intent: req.declared_intent,
72            authorized_tools: req.authorized_tools,
73            authorized_credentials: req.authorized_credentials,
74            time_limit: req.time_limit,
75            call_budget: req.call_budget,
76            calls_made: 0,
77            rate_limit_per_minute: req.rate_limit_per_minute,
78            rate_window_start: chrono::Utc::now(),
79            rate_window_calls: 0,
80            rate_limit_window_secs: req.rate_limit_window_secs,
81            data_sensitivity_ceiling: req.data_sensitivity_ceiling,
82            created_at: chrono::Utc::now(),
83            status: SessionStatus::Active,
84        };
85
86        tracing::info!(
87            session_id = %session.session_id,
88            agent_id = %session.agent_id,
89            intent = %session.declared_intent,
90            budget = session.call_budget,
91            "created task session (storage-backed)"
92        );
93
94        // Write to storage first, then cache.
95        let stored = domain_to_stored(&session);
96        if let Err(e) = self.storage.insert_session(&stored).await {
97            tracing::error!(error = %e, "failed to persist session to storage");
98        }
99
100        let mut cache = self.cache.write().await;
101        cache.insert(session.session_id, session.clone());
102        session
103    }
104
105    /// Atomically check per-agent session cap and create if under the limit.
106    pub async fn create_if_under_cap(
107        &self,
108        req: CreateSessionRequest,
109        max_sessions: u64,
110    ) -> Result<TaskSession, SessionError> {
111        let mut cache = self.cache.write().await;
112
113        let active_count = cache
114            .values()
115            .filter(|s| s.agent_id == req.agent_id && s.status == SessionStatus::Active)
116            .count() as u64;
117
118        if active_count >= max_sessions {
119            return Err(SessionError::TooManySessions {
120                agent_id: req.agent_id.to_string(),
121                max: max_sessions,
122                current: active_count,
123            });
124        }
125
126        let session = TaskSession {
127            session_id: uuid::Uuid::new_v4(),
128            agent_id: req.agent_id,
129            delegation_chain_snapshot: req.delegation_chain_snapshot,
130            declared_intent: req.declared_intent,
131            authorized_tools: req.authorized_tools,
132            authorized_credentials: req.authorized_credentials,
133            time_limit: req.time_limit,
134            call_budget: req.call_budget,
135            calls_made: 0,
136            rate_limit_per_minute: req.rate_limit_per_minute,
137            rate_window_start: chrono::Utc::now(),
138            rate_window_calls: 0,
139            rate_limit_window_secs: req.rate_limit_window_secs,
140            data_sensitivity_ceiling: req.data_sensitivity_ceiling,
141            created_at: chrono::Utc::now(),
142            status: SessionStatus::Active,
143        };
144
145        // Write-through to storage.
146        let stored = domain_to_stored(&session);
147        if let Err(e) = self.storage.insert_session(&stored).await {
148            tracing::error!(error = %e, "failed to persist session to storage");
149        }
150
151        cache.insert(session.session_id, session.clone());
152        Ok(session)
153    }
154
155    /// Record a tool call against the session, checking all constraints.
156    pub async fn use_session(
157        &self,
158        session_id: SessionId,
159        tool_name: &str,
160        requesting_agent_id: Option<uuid::Uuid>,
161    ) -> Result<TaskSession, SessionError> {
162        let mut cache = self.cache.write().await;
163        let session = cache
164            .get_mut(&session_id)
165            .ok_or(SessionError::NotFound(session_id))?;
166
167        // Verify agent binding to prevent session fixation.
168        if let Some(agent_id) = requesting_agent_id
169            && agent_id != session.agent_id
170        {
171            return Err(SessionError::AgentMismatch {
172                session_id,
173                expected: session.agent_id,
174                actual: agent_id,
175            });
176        }
177
178        if session.status == SessionStatus::Closed {
179            return Err(SessionError::AlreadyClosed(session_id));
180        }
181
182        if session.is_expired() {
183            session.status = SessionStatus::Expired;
184            // Write-through: update storage.
185            let stored = domain_to_stored(session);
186            if let Err(e) = self.storage.update_session(&stored).await {
187                tracing::error!(error = %e, "failed to persist expired session status");
188            }
189            return Err(SessionError::Expired(session_id));
190        }
191
192        if session.is_budget_exceeded() {
193            return Err(SessionError::BudgetExceeded {
194                session_id,
195                limit: session.call_budget,
196                used: session.calls_made,
197            });
198        }
199
200        if !session.is_tool_authorized(tool_name) {
201            return Err(SessionError::ToolNotAuthorized {
202                session_id,
203                tool: tool_name.into(),
204            });
205        }
206
207        if session.check_rate_limit() {
208            return Err(SessionError::RateLimited {
209                session_id,
210                limit_per_minute: session.rate_limit_per_minute.unwrap_or(0),
211            });
212        }
213
214        session.calls_made += 1;
215
216        tracing::debug!(
217            session_id = %session_id,
218            tool = tool_name,
219            calls = session.calls_made,
220            budget = session.call_budget,
221            "session tool call recorded (storage-backed)"
222        );
223
224        let result = session.clone();
225
226        // Write-through: update storage. Propagate failure so callers know
227        // the budget increment is not durably committed.
228        let stored = domain_to_stored(&result);
229        if let Err(e) = self.storage.update_session(&stored).await {
230            tracing::error!(error = %e, "failed to persist session update");
231            return Err(SessionError::StorageWriteThrough {
232                session_id,
233                detail: e.to_string(),
234            });
235        }
236
237        Ok(result)
238    }
239
240    /// Atomically validate and record a batch of tool calls against the session.
241    ///
242    /// Acquires the write lock once, validates ALL tools against
243    /// the whitelist and budget, and only increments `calls_made` by the full
244    /// batch count if every tool passes. If any tool fails validation, no
245    /// budget is consumed for any of them.
246    pub async fn use_session_batch(
247        &self,
248        session_id: SessionId,
249        tool_names: &[&str],
250        requesting_agent_id: Option<uuid::Uuid>,
251    ) -> Result<TaskSession, SessionError> {
252        let mut cache = self.cache.write().await;
253        let session = cache
254            .get_mut(&session_id)
255            .ok_or(SessionError::NotFound(session_id))?;
256
257        // Verify agent binding to prevent session fixation.
258        if let Some(agent_id) = requesting_agent_id
259            && agent_id != session.agent_id
260        {
261            return Err(SessionError::AgentMismatch {
262                session_id,
263                expected: session.agent_id,
264                actual: agent_id,
265            });
266        }
267
268        if session.status == SessionStatus::Closed {
269            return Err(SessionError::AlreadyClosed(session_id));
270        }
271
272        if session.is_expired() {
273            session.status = SessionStatus::Expired;
274            // Write-through: update storage.
275            let stored = domain_to_stored(session);
276            if let Err(e) = self.storage.update_session(&stored).await {
277                tracing::error!(error = %e, "failed to persist expired session status");
278            }
279            return Err(SessionError::Expired(session_id));
280        }
281
282        let batch_size = tool_names.len() as u64;
283
284        // Check budget for the entire batch.
285        if session.calls_made + batch_size > session.call_budget {
286            return Err(SessionError::BudgetExceeded {
287                session_id,
288                limit: session.call_budget,
289                used: session.calls_made,
290            });
291        }
292
293        // Check tool authorization for every tool before consuming any budget.
294        for tool_name in tool_names {
295            if !session.is_tool_authorized(tool_name) {
296                return Err(SessionError::ToolNotAuthorized {
297                    session_id,
298                    tool: (*tool_name).into(),
299                });
300            }
301        }
302
303        // Check rate limit for the entire batch.
304        if let Some(limit) = session.rate_limit_per_minute {
305            let now = chrono::Utc::now();
306            let elapsed = now - session.rate_window_start;
307            if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
308                // New window; will be reset below after all checks pass.
309            } else if session.rate_window_calls + batch_size > limit {
310                return Err(SessionError::RateLimited {
311                    session_id,
312                    limit_per_minute: limit,
313                });
314            }
315        }
316
317        // All checks passed. Atomically increment counters.
318        if let Some(_limit) = session.rate_limit_per_minute {
319            let now = chrono::Utc::now();
320            let elapsed = now - session.rate_window_start;
321            if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
322                session.rate_window_start = now;
323                session.rate_window_calls = batch_size;
324            } else {
325                session.rate_window_calls += batch_size;
326            }
327        }
328
329        session.calls_made += batch_size;
330
331        tracing::debug!(
332            session_id = %session_id,
333            batch_size = batch_size,
334            calls = session.calls_made,
335            budget = session.call_budget,
336            "session batch tool calls recorded (storage-backed)"
337        );
338
339        let result = session.clone();
340
341        // Write-through: update storage. Propagate failure.
342        let stored = domain_to_stored(&result);
343        if let Err(e) = self.storage.update_session(&stored).await {
344            tracing::error!(error = %e, "failed to persist session batch update");
345            return Err(SessionError::StorageWriteThrough {
346                session_id,
347                detail: e.to_string(),
348            });
349        }
350
351        Ok(result)
352    }
353
354    /// Close a session, preventing further use.
355    pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
356        let mut cache = self.cache.write().await;
357        let session = cache
358            .get_mut(&session_id)
359            .ok_or(SessionError::NotFound(session_id))?;
360
361        if session.status == SessionStatus::Closed {
362            return Err(SessionError::AlreadyClosed(session_id));
363        }
364
365        session.status = SessionStatus::Closed;
366        tracing::info!(session_id = %session_id, "session closed (storage-backed)");
367
368        let result = session.clone();
369
370        // Write-through: update storage.
371        let stored = domain_to_stored(&result);
372        if let Err(e) = self.storage.update_session(&stored).await {
373            tracing::error!(error = %e, "failed to persist session close");
374        }
375
376        Ok(result)
377    }
378
379    /// Get a session by ID without modifying it.
380    pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
381        let cache = self.cache.read().await;
382        cache
383            .get(&session_id)
384            .cloned()
385            .ok_or(SessionError::NotFound(session_id))
386    }
387
388    /// List all sessions currently in the store.
389    pub async fn list_all(&self) -> Vec<TaskSession> {
390        let cache = self.cache.read().await;
391        cache.values().cloned().collect()
392    }
393
394    /// Count the number of active sessions for a given agent.
395    ///
396    /// P0: Used to enforce per-agent concurrent session caps.
397    pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
398        let cache = self.cache.read().await;
399        cache
400            .values()
401            .filter(|s| s.agent_id == agent_id && s.status == SessionStatus::Active)
402            .count() as u64
403    }
404
405    /// Close all active sessions belonging to a specific agent.
406    ///
407    /// Called during agent deactivation.
408    pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
409        // Collect sessions to close while holding the write lock, then release
410        // the lock before performing storage writes. This prevents blocking all
411        // other session operations during sequential async storage writes.
412        let to_persist: Vec<StoredSession>;
413        let closed: usize;
414        {
415            let mut cache = self.cache.write().await;
416            let mut count = 0usize;
417            let mut stored_sessions = Vec::new();
418            for session in cache.values_mut() {
419                if session.agent_id == agent_id && session.status == SessionStatus::Active {
420                    session.status = SessionStatus::Closed;
421                    stored_sessions.push(domain_to_stored(session));
422                    count += 1;
423                }
424            }
425            to_persist = stored_sessions;
426            closed = count;
427        } // write lock released here
428
429        // Persist closures outside the critical section.
430        for stored in &to_persist {
431            if let Err(e) = self.storage.update_session(stored).await {
432                tracing::error!(
433                    error = %e,
434                    session_id = %stored.session_id,
435                    "failed to persist session closure during agent deactivation"
436                );
437            }
438        }
439        closed
440    }
441
442    /// Remove expired sessions from cache and storage. Returns the number removed.
443    pub async fn cleanup_expired(&self) -> usize {
444        let mut cache = self.cache.write().await;
445        let before = cache.len();
446        cache.retain(|_, s| {
447            if s.is_expired() {
448                tracing::debug!(session_id = %s.session_id, "cleaning up expired session (storage-backed)");
449                false
450            } else {
451                true
452            }
453        });
454        let removed = before - cache.len();
455
456        // Also clean up storage.
457        if let Err(e) = self.storage.delete_expired_sessions().await {
458            tracing::error!(error = %e, "failed to clean up expired sessions in storage");
459        }
460
461        if removed > 0 {
462            tracing::info!(removed, "cleaned up expired sessions (storage-backed)");
463        }
464        removed
465    }
466}
467
468// ── Conversion helpers ──────────────────────────────────────────────
469
470fn domain_to_stored(session: &TaskSession) -> StoredSession {
471    StoredSession {
472        session_id: session.session_id,
473        agent_id: session.agent_id,
474        delegation_chain_snapshot: session.delegation_chain_snapshot.clone(),
475        declared_intent: session.declared_intent.clone(),
476        authorized_tools: session.authorized_tools.clone(),
477        time_limit_secs: session.time_limit.num_seconds(),
478        call_budget: session.call_budget,
479        calls_made: session.calls_made,
480        rate_limit_per_minute: session.rate_limit_per_minute,
481        rate_window_start: session.rate_window_start,
482        rate_window_calls: session.rate_window_calls,
483        rate_limit_window_secs: session.rate_limit_window_secs,
484        data_sensitivity_ceiling: sensitivity_to_stored(session.data_sensitivity_ceiling),
485        created_at: session.created_at,
486        status: status_to_stored(session.status),
487    }
488}
489
490/// Maximum session duration (24 hours). Re-validated on reload to prevent
491/// a compromised storage backend from extending sessions indefinitely.
492const MAX_SESSION_TIME_LIMIT_SECS: i64 = 86400;
493
494fn stored_to_domain(stored: StoredSession) -> Result<TaskSession, String> {
495    // Re-validate time_limit_secs upper bound on reload.
496    let clamped_time_limit = stored.time_limit_secs.min(MAX_SESSION_TIME_LIMIT_SECS);
497    if stored.time_limit_secs > MAX_SESSION_TIME_LIMIT_SECS {
498        tracing::warn!(
499            session_id = %stored.session_id,
500            stored = stored.time_limit_secs,
501            clamped = clamped_time_limit,
502            "session time_limit_secs exceeded maximum on reload, clamping"
503        );
504    }
505
506    Ok(TaskSession {
507        session_id: stored.session_id,
508        agent_id: stored.agent_id,
509        delegation_chain_snapshot: stored.delegation_chain_snapshot,
510        declared_intent: stored.declared_intent,
511        authorized_tools: stored.authorized_tools,
512        authorized_credentials: vec![], // TODO: persist in StoredSession once storage schema is updated
513        time_limit: chrono::Duration::seconds(clamped_time_limit),
514        call_budget: stored.call_budget,
515        calls_made: stored.calls_made,
516        rate_limit_per_minute: stored.rate_limit_per_minute,
517        rate_window_start: stored.rate_window_start,
518        rate_window_calls: stored.rate_window_calls,
519        rate_limit_window_secs: stored.rate_limit_window_secs,
520        data_sensitivity_ceiling: stored_to_sensitivity(stored.data_sensitivity_ceiling),
521        created_at: stored.created_at,
522        status: stored_to_status(stored.status),
523    })
524}
525
526fn status_to_stored(status: SessionStatus) -> StoredSessionStatus {
527    match status {
528        SessionStatus::Active => StoredSessionStatus::Active,
529        SessionStatus::Closed => StoredSessionStatus::Closed,
530        SessionStatus::Expired => StoredSessionStatus::Expired,
531    }
532}
533
534fn stored_to_status(status: StoredSessionStatus) -> SessionStatus {
535    match status {
536        StoredSessionStatus::Active => SessionStatus::Active,
537        StoredSessionStatus::Closed => SessionStatus::Closed,
538        StoredSessionStatus::Expired => SessionStatus::Expired,
539    }
540}
541
542fn sensitivity_to_stored(sensitivity: DataSensitivity) -> StoredDataSensitivity {
543    match sensitivity {
544        DataSensitivity::Public => StoredDataSensitivity::Public,
545        DataSensitivity::Internal => StoredDataSensitivity::Internal,
546        DataSensitivity::Confidential => StoredDataSensitivity::Confidential,
547        DataSensitivity::Restricted => StoredDataSensitivity::Restricted,
548    }
549}
550
551fn stored_to_sensitivity(sensitivity: StoredDataSensitivity) -> DataSensitivity {
552    match sensitivity {
553        StoredDataSensitivity::Public => DataSensitivity::Public,
554        StoredDataSensitivity::Internal => DataSensitivity::Internal,
555        StoredDataSensitivity::Confidential => DataSensitivity::Confidential,
556        StoredDataSensitivity::Restricted => DataSensitivity::Restricted,
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use async_trait::async_trait;
564    use std::collections::HashMap;
565    use std::sync::Arc;
566    use tokio::sync::RwLock;
567
568    /// In-memory mock implementing `StorageSessionStore` (arbiter_storage::SessionStore).
569    #[derive(Clone)]
570    struct MockStorage {
571        sessions: Arc<RwLock<HashMap<Uuid, StoredSession>>>,
572    }
573
574    impl MockStorage {
575        fn new() -> Self {
576            Self {
577                sessions: Arc::new(RwLock::new(HashMap::new())),
578            }
579        }
580    }
581
582    #[async_trait]
583    impl StorageSessionStore for MockStorage {
584        async fn insert_session(&self, session: &StoredSession) -> Result<(), StorageError> {
585            let mut map = self.sessions.write().await;
586            map.insert(session.session_id, session.clone());
587            Ok(())
588        }
589
590        async fn get_session(&self, session_id: Uuid) -> Result<StoredSession, StorageError> {
591            let map = self.sessions.read().await;
592            map.get(&session_id)
593                .cloned()
594                .ok_or(StorageError::SessionNotFound(session_id))
595        }
596
597        async fn update_session(&self, session: &StoredSession) -> Result<(), StorageError> {
598            let mut map = self.sessions.write().await;
599            map.insert(session.session_id, session.clone());
600            Ok(())
601        }
602
603        async fn delete_expired_sessions(&self) -> Result<usize, StorageError> {
604            let mut map = self.sessions.write().await;
605            let before = map.len();
606            let now = chrono::Utc::now();
607            map.retain(|_, s| {
608                let created = s.created_at;
609                let limit = chrono::Duration::seconds(s.time_limit_secs);
610                let elapsed = now - created;
611                elapsed <= limit && s.status != StoredSessionStatus::Expired
612            });
613            Ok(before - map.len())
614        }
615
616        async fn list_sessions(&self) -> Result<Vec<StoredSession>, StorageError> {
617            let map = self.sessions.read().await;
618            Ok(map.values().cloned().collect())
619        }
620    }
621
622    fn test_create_request() -> CreateSessionRequest {
623        CreateSessionRequest {
624            agent_id: Uuid::new_v4(),
625            delegation_chain_snapshot: vec![],
626            declared_intent: "read and analyze files".into(),
627            authorized_tools: vec!["read_file".into(), "list_dir".into()],
628            authorized_credentials: vec![],
629            time_limit: chrono::Duration::hours(1),
630            call_budget: 5,
631            rate_limit_per_minute: None,
632            rate_limit_window_secs: 60,
633            data_sensitivity_ceiling: DataSensitivity::Internal,
634        }
635    }
636
637    async fn make_store() -> (StorageBackedSessionStore, MockStorage) {
638        let mock = MockStorage::new();
639        let store = StorageBackedSessionStore::new(Arc::new(mock.clone()))
640            .await
641            .expect("failed to create storage-backed store");
642        (store, mock)
643    }
644
645    #[tokio::test]
646    async fn create_and_use_session() {
647        let (store, _mock) = make_store().await;
648        let session = store.create(test_create_request()).await;
649
650        assert_eq!(session.calls_made, 0);
651        assert!(session.is_active());
652
653        let updated = store
654            .use_session(session.session_id, "read_file", None)
655            .await
656            .unwrap();
657        assert_eq!(updated.calls_made, 1);
658
659        // Verify get returns same data.
660        let fetched = store.get(session.session_id).await.unwrap();
661        assert_eq!(fetched.calls_made, 1);
662    }
663
664    #[tokio::test]
665    async fn budget_enforcement() {
666        let (store, _mock) = make_store().await;
667        let mut req = test_create_request();
668        req.call_budget = 2;
669        let session = store.create(req).await;
670
671        store
672            .use_session(session.session_id, "read_file", None)
673            .await
674            .unwrap();
675        store
676            .use_session(session.session_id, "read_file", None)
677            .await
678            .unwrap();
679
680        let result = store
681            .use_session(session.session_id, "read_file", None)
682            .await;
683        assert!(matches!(result, Err(SessionError::BudgetExceeded { .. })));
684    }
685
686    #[tokio::test]
687    async fn tool_whitelist_enforcement() {
688        let (store, _mock) = make_store().await;
689        let session = store.create(test_create_request()).await;
690
691        store
692            .use_session(session.session_id, "read_file", None)
693            .await
694            .unwrap();
695
696        let result = store
697            .use_session(session.session_id, "delete_file", None)
698            .await;
699        assert!(matches!(
700            result,
701            Err(SessionError::ToolNotAuthorized { .. })
702        ));
703    }
704
705    #[tokio::test]
706    async fn session_expiry() {
707        let (store, _mock) = make_store().await;
708        let mut req = test_create_request();
709        req.time_limit = chrono::Duration::zero();
710        let session = store.create(req).await;
711
712        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
713
714        let result = store
715            .use_session(session.session_id, "read_file", None)
716            .await;
717        assert!(matches!(result, Err(SessionError::Expired(_))));
718    }
719
720    #[tokio::test]
721    async fn close_and_reuse() {
722        let (store, _mock) = make_store().await;
723        let session = store.create(test_create_request()).await;
724
725        store.close(session.session_id).await.unwrap();
726
727        let result = store
728            .use_session(session.session_id, "read_file", None)
729            .await;
730        assert!(matches!(result, Err(SessionError::AlreadyClosed(_))));
731    }
732
733    #[tokio::test]
734    async fn session_not_found() {
735        let (store, _mock) = make_store().await;
736        let fake_id = Uuid::new_v4();
737        let result = store.use_session(fake_id, "anything", None).await;
738        assert!(matches!(result, Err(SessionError::NotFound(_))));
739    }
740
741    /// batch with one bad tool must consume zero budget.
742    #[tokio::test]
743    async fn batch_validation_atomicity() {
744        let (store, _mock) = make_store().await;
745        let mut req = test_create_request();
746        req.call_budget = 10;
747        req.authorized_tools = vec!["read_file".into(), "list_dir".into()];
748        let session = store.create(req).await;
749
750        // Batch contains one unauthorized tool ("delete_file").
751        let result = store
752            .use_session_batch(session.session_id, &["read_file", "delete_file"], None)
753            .await;
754        assert!(
755            matches!(result, Err(SessionError::ToolNotAuthorized { .. })),
756            "expected ToolNotAuthorized, got {result:?}"
757        );
758
759        // Budget must remain untouched.
760        let s = store.get(session.session_id).await.unwrap();
761        assert_eq!(
762            s.calls_made, 0,
763            "no budget should be consumed on batch failure"
764        );
765    }
766
767    #[tokio::test]
768    async fn batch_budget_enforcement() {
769        let (store, _mock) = make_store().await;
770        let mut req = test_create_request();
771        req.call_budget = 3;
772        req.authorized_tools = vec!["read_file".into()];
773        let session = store.create(req).await;
774
775        // Batch of 4 exceeds budget of 3.
776        let result = store
777            .use_session_batch(
778                session.session_id,
779                &["read_file", "read_file", "read_file", "read_file"],
780                None,
781            )
782            .await;
783        assert!(
784            matches!(result, Err(SessionError::BudgetExceeded { .. })),
785            "expected BudgetExceeded, got {result:?}"
786        );
787
788        // Budget must remain at 0 (no partial consumption).
789        let s = store.get(session.session_id).await.unwrap();
790        assert_eq!(
791            s.calls_made, 0,
792            "no budget should be consumed on batch failure"
793        );
794    }
795
796    #[tokio::test]
797    async fn batch_rate_limit_enforcement() {
798        let (store, _mock) = make_store().await;
799        let mut req = test_create_request();
800        req.call_budget = 100;
801        req.rate_limit_per_minute = Some(3);
802        req.authorized_tools = vec!["read_file".into()];
803        let session = store.create(req).await;
804
805        // Batch of 4 exceeds rate limit of 3.
806        let result = store
807            .use_session_batch(
808                session.session_id,
809                &["read_file", "read_file", "read_file", "read_file"],
810                None,
811            )
812            .await;
813        assert!(
814            matches!(result, Err(SessionError::RateLimited { .. })),
815            "expected RateLimited, got {result:?}"
816        );
817    }
818
819    #[tokio::test]
820    async fn cleanup_expired_sessions() {
821        let (store, _mock) = make_store().await;
822
823        // Create an already-expired session.
824        let mut req = test_create_request();
825        req.time_limit = chrono::Duration::zero();
826        store.create(req).await;
827
828        // Create a valid session.
829        store.create(test_create_request()).await;
830
831        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
832
833        let removed = store.cleanup_expired().await;
834        assert_eq!(removed, 1);
835    }
836
837    #[tokio::test]
838    async fn count_active_for_agent() {
839        let (store, _mock) = make_store().await;
840        let agent_id = Uuid::new_v4();
841
842        // Create 3 sessions for the same agent.
843        for _ in 0..3 {
844            let mut req = test_create_request();
845            req.agent_id = agent_id;
846            store.create(req).await;
847        }
848
849        // Create 1 session for a different agent.
850        store.create(test_create_request()).await;
851
852        let count = store.count_active_for_agent(agent_id).await;
853        assert_eq!(count, 3);
854    }
855
856    /// Verify that mutations are written through to the mock storage backend.
857    #[tokio::test]
858    async fn storage_write_through() {
859        let (store, mock) = make_store().await;
860        let session = store.create(test_create_request()).await;
861
862        // After create, session must exist in storage.
863        let stored = mock
864            .get_session(session.session_id)
865            .await
866            .expect("session should exist in storage after create");
867        assert_eq!(stored.calls_made, 0);
868
869        // After use_session, storage must reflect the increment.
870        store
871            .use_session(session.session_id, "read_file", None)
872            .await
873            .unwrap();
874        let stored = mock.get_session(session.session_id).await.unwrap();
875        assert_eq!(stored.calls_made, 1, "storage must reflect the tool call");
876
877        // After close, storage must reflect the new status.
878        store.close(session.session_id).await.unwrap();
879        let stored = mock.get_session(session.session_id).await.unwrap();
880        assert_eq!(
881            stored.status,
882            StoredSessionStatus::Closed,
883            "storage must reflect the closed status"
884        );
885    }
886}