Skip to main content

arbiter_session/
store.rs

1//! In-memory session store with TTL-based cleanup.
2
3use chrono::Utc;
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use crate::error::SessionError;
10use crate::model::{DataSensitivity, SessionId, SessionStatus, TaskSession};
11
12/// Request to create a new task session.
13pub struct CreateSessionRequest {
14    /// The agent ID for this session.
15    pub agent_id: Uuid,
16    /// Delegation chain snapshot (serialized).
17    pub delegation_chain_snapshot: Vec<String>,
18    /// Declared intent for the session.
19    pub declared_intent: String,
20    /// Tools authorized by policy evaluation.
21    pub authorized_tools: Vec<String>,
22    /// Session time limit.
23    pub time_limit: chrono::Duration,
24    /// Maximum number of tool calls.
25    pub call_budget: u64,
26    /// Per-minute rate limit. `None` means no rate limit.
27    pub rate_limit_per_minute: Option<u64>,
28    /// Duration of the rate-limit window in seconds. Defaults to 60.
29    pub rate_limit_window_secs: u64,
30    /// Data sensitivity ceiling.
31    pub data_sensitivity_ceiling: DataSensitivity,
32}
33
34/// In-memory session store with TTL-based cleanup.
35#[derive(Clone)]
36pub struct SessionStore {
37    sessions: Arc<RwLock<HashMap<SessionId, TaskSession>>>,
38}
39
40impl SessionStore {
41    /// Create a new empty session store.
42    pub fn new() -> Self {
43        Self {
44            sessions: Arc::new(RwLock::new(HashMap::new())),
45        }
46    }
47
48    /// Create a new task session and return it.
49    pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
50        let session = TaskSession {
51            session_id: Uuid::new_v4(),
52            agent_id: req.agent_id,
53            delegation_chain_snapshot: req.delegation_chain_snapshot,
54            declared_intent: req.declared_intent,
55            authorized_tools: req.authorized_tools,
56            time_limit: req.time_limit,
57            call_budget: req.call_budget,
58            calls_made: 0,
59            rate_limit_per_minute: req.rate_limit_per_minute,
60            rate_window_start: Utc::now(),
61            rate_window_calls: 0,
62            rate_limit_window_secs: req.rate_limit_window_secs,
63            data_sensitivity_ceiling: req.data_sensitivity_ceiling,
64            created_at: Utc::now(),
65            status: SessionStatus::Active,
66        };
67
68        tracing::info!(
69            session_id = %session.session_id,
70            agent_id = %session.agent_id,
71            intent = %session.declared_intent,
72            budget = session.call_budget,
73            "created task session"
74        );
75
76        let mut sessions = self.sessions.write().await;
77        sessions.insert(session.session_id, session.clone());
78        session
79    }
80
81    /// Record a tool call against the session, checking all constraints.
82    ///
83    /// Returns the updated session on success, or an error if:
84    /// - Session not found
85    /// - Session expired (would return 408)
86    /// - Budget exceeded (would return 429)
87    /// - Tool not authorized (would return 403)
88    pub async fn use_session(
89        &self,
90        session_id: SessionId,
91        tool_name: &str,
92    ) -> Result<TaskSession, SessionError> {
93        let mut sessions = self.sessions.write().await;
94        let session = sessions
95            .get_mut(&session_id)
96            .ok_or(SessionError::NotFound(session_id))?;
97
98        if session.status == SessionStatus::Closed {
99            return Err(SessionError::AlreadyClosed(session_id));
100        }
101
102        // Check expiry.
103        if session.is_expired() {
104            session.status = SessionStatus::Expired;
105            return Err(SessionError::Expired(session_id));
106        }
107
108        // Check budget.
109        if session.is_budget_exceeded() {
110            return Err(SessionError::BudgetExceeded {
111                session_id,
112                limit: session.call_budget,
113                used: session.calls_made,
114            });
115        }
116
117        // Check tool authorization.
118        if !session.is_tool_authorized(tool_name) {
119            return Err(SessionError::ToolNotAuthorized {
120                session_id,
121                tool: tool_name.into(),
122            });
123        }
124
125        // Check rate limit.
126        if session.check_rate_limit() {
127            return Err(SessionError::RateLimited {
128                session_id,
129                limit_per_minute: session.rate_limit_per_minute.unwrap_or(0),
130            });
131        }
132
133        // All checks passed. Increment counter.
134        session.calls_made += 1;
135
136        tracing::debug!(
137            session_id = %session_id,
138            tool = tool_name,
139            calls = session.calls_made,
140            budget = session.call_budget,
141            "session tool call recorded"
142        );
143
144        Ok(session.clone())
145    }
146
147    /// Atomically validate and record a batch of tool calls against the session.
148    ///
149    /// Acquires the write lock once, validates ALL tools against
150    /// the whitelist and budget, and only increments `calls_made` by the full
151    /// batch count if every tool passes. If any tool fails validation, no
152    /// budget is consumed for any of them.
153    pub async fn use_session_batch(
154        &self,
155        session_id: SessionId,
156        tool_names: &[&str],
157    ) -> Result<TaskSession, SessionError> {
158        let mut sessions = self.sessions.write().await;
159        let session = sessions
160            .get_mut(&session_id)
161            .ok_or(SessionError::NotFound(session_id))?;
162
163        if session.status == SessionStatus::Closed {
164            return Err(SessionError::AlreadyClosed(session_id));
165        }
166
167        // Check expiry.
168        if session.is_expired() {
169            session.status = SessionStatus::Expired;
170            return Err(SessionError::Expired(session_id));
171        }
172
173        let batch_size = tool_names.len() as u64;
174
175        // Check budget for the entire batch.
176        if session.calls_made + batch_size > session.call_budget {
177            return Err(SessionError::BudgetExceeded {
178                session_id,
179                limit: session.call_budget,
180                used: session.calls_made,
181            });
182        }
183
184        // Check tool authorization for every tool before consuming any budget.
185        for tool_name in tool_names {
186            if !session.is_tool_authorized(tool_name) {
187                return Err(SessionError::ToolNotAuthorized {
188                    session_id,
189                    tool: (*tool_name).into(),
190                });
191            }
192        }
193
194        // Check rate limit for the entire batch.
195        // We check whether adding batch_size calls would exceed the limit,
196        // without mutating state until we know it's safe.
197        if let Some(limit) = session.rate_limit_per_minute {
198            let now = chrono::Utc::now();
199            let elapsed = now - session.rate_window_start;
200            if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
201                // New window; will be reset below after all checks pass.
202            } else if session.rate_window_calls + batch_size > limit {
203                return Err(SessionError::RateLimited {
204                    session_id,
205                    limit_per_minute: limit,
206                });
207            }
208        }
209
210        // All checks passed. Atomically increment counters.
211        // Update rate limit window.
212        if let Some(_limit) = session.rate_limit_per_minute {
213            let now = chrono::Utc::now();
214            let elapsed = now - session.rate_window_start;
215            if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
216                session.rate_window_start = now;
217                session.rate_window_calls = batch_size;
218            } else {
219                session.rate_window_calls += batch_size;
220            }
221        }
222
223        session.calls_made += batch_size;
224
225        tracing::debug!(
226            session_id = %session_id,
227            batch_size = batch_size,
228            calls = session.calls_made,
229            budget = session.call_budget,
230            "session batch tool calls recorded"
231        );
232
233        Ok(session.clone())
234    }
235
236    /// Close a session, preventing further use.
237    pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
238        let mut sessions = self.sessions.write().await;
239        let session = sessions
240            .get_mut(&session_id)
241            .ok_or(SessionError::NotFound(session_id))?;
242
243        if session.status == SessionStatus::Closed {
244            return Err(SessionError::AlreadyClosed(session_id));
245        }
246
247        session.status = SessionStatus::Closed;
248        tracing::info!(session_id = %session_id, "session closed");
249        Ok(session.clone())
250    }
251
252    /// Get a session by ID without modifying it.
253    pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
254        let sessions = self.sessions.read().await;
255        sessions
256            .get(&session_id)
257            .cloned()
258            .ok_or(SessionError::NotFound(session_id))
259    }
260
261    /// List all sessions currently in the store (active, expired, and closed).
262    pub async fn list_all(&self) -> Vec<TaskSession> {
263        let sessions = self.sessions.read().await;
264        sessions.values().cloned().collect()
265    }
266
267    /// Count the number of active sessions for a given agent.
268    ///
269    /// P0: Used to enforce per-agent concurrent session caps.
270    pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
271        let sessions = self.sessions.read().await;
272        sessions
273            .values()
274            .filter(|s| s.agent_id == agent_id && s.status == SessionStatus::Active)
275            .count() as u64
276    }
277
278    /// Close all active sessions belonging to a specific agent.
279    ///
280    /// When an agent is deactivated via cascade_deactivate,
281    /// all its sessions must be immediately closed.
282    pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
283        let mut sessions = self.sessions.write().await;
284        let mut closed = 0usize;
285        for session in sessions.values_mut() {
286            if session.agent_id == agent_id && session.status == SessionStatus::Active {
287                session.status = SessionStatus::Closed;
288                closed += 1;
289                tracing::info!(
290                    session_id = %session.session_id,
291                    agent_id = %agent_id,
292                    "closed session due to agent deactivation"
293                );
294            }
295        }
296        closed
297    }
298
299    /// Remove expired sessions from the store. Returns the number removed.
300    pub async fn cleanup_expired(&self) -> usize {
301        let mut sessions = self.sessions.write().await;
302        let before = sessions.len();
303        // Also clean up closed sessions, not just expired ones.
304        // Previously, closed sessions accumulated indefinitely, growing the store without bound.
305        sessions.retain(|_, s| {
306            if s.is_expired() {
307                tracing::debug!(session_id = %s.session_id, "cleaning up expired session");
308                false
309            } else if s.status == SessionStatus::Closed {
310                tracing::debug!(session_id = %s.session_id, "cleaning up closed session");
311                false
312            } else {
313                true
314            }
315        });
316        let removed = before - sessions.len();
317        if removed > 0 {
318            tracing::info!(removed, "cleaned up expired/closed sessions");
319        }
320        removed
321    }
322}
323
324impl Default for SessionStore {
325    fn default() -> Self {
326        Self::new()
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    fn test_create_request() -> CreateSessionRequest {
335        CreateSessionRequest {
336            agent_id: Uuid::new_v4(),
337            delegation_chain_snapshot: vec![],
338            declared_intent: "read and analyze files".into(),
339            authorized_tools: vec!["read_file".into(), "list_dir".into()],
340            time_limit: chrono::Duration::hours(1),
341            call_budget: 5,
342            rate_limit_per_minute: None,
343            rate_limit_window_secs: 60,
344            data_sensitivity_ceiling: DataSensitivity::Internal,
345        }
346    }
347
348    #[tokio::test]
349    async fn create_and_use_session() {
350        let store = SessionStore::new();
351        let session = store.create(test_create_request()).await;
352
353        assert_eq!(session.calls_made, 0);
354        assert!(session.is_active());
355
356        let updated = store
357            .use_session(session.session_id, "read_file")
358            .await
359            .unwrap();
360        assert_eq!(updated.calls_made, 1);
361    }
362
363    #[tokio::test]
364    async fn budget_enforcement() {
365        let store = SessionStore::new();
366        let mut req = test_create_request();
367        req.call_budget = 2;
368        let session = store.create(req).await;
369
370        // Use up the budget.
371        store
372            .use_session(session.session_id, "read_file")
373            .await
374            .unwrap();
375        store
376            .use_session(session.session_id, "read_file")
377            .await
378            .unwrap();
379
380        // Third call should fail.
381        let result = store.use_session(session.session_id, "read_file").await;
382        assert!(matches!(result, Err(SessionError::BudgetExceeded { .. })));
383    }
384
385    #[tokio::test]
386    async fn tool_whitelist_enforcement() {
387        let store = SessionStore::new();
388        let session = store.create(test_create_request()).await;
389
390        // Authorized tool works.
391        store
392            .use_session(session.session_id, "read_file")
393            .await
394            .unwrap();
395
396        // Unauthorized tool is rejected.
397        let result = store.use_session(session.session_id, "delete_file").await;
398        assert!(matches!(
399            result,
400            Err(SessionError::ToolNotAuthorized { .. })
401        ));
402    }
403
404    #[tokio::test]
405    async fn session_expiry() {
406        let store = SessionStore::new();
407        let mut req = test_create_request();
408        // Set a very short time limit (zero duration = immediately expired on next check).
409        req.time_limit = chrono::Duration::zero();
410        let session = store.create(req).await;
411
412        // The session was just created, but with zero duration it expires immediately.
413        // We need the clock to advance at least slightly, which it does between create and use.
414        // Use tokio::time::sleep to guarantee advancement.
415        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
416
417        let result = store.use_session(session.session_id, "read_file").await;
418        assert!(matches!(result, Err(SessionError::Expired(_))));
419    }
420
421    #[tokio::test]
422    async fn close_and_reuse() {
423        let store = SessionStore::new();
424        let session = store.create(test_create_request()).await;
425
426        store.close(session.session_id).await.unwrap();
427
428        let result = store.use_session(session.session_id, "read_file").await;
429        assert!(matches!(result, Err(SessionError::AlreadyClosed(_))));
430    }
431
432    #[tokio::test]
433    async fn cleanup_expired_sessions() {
434        let store = SessionStore::new();
435
436        // Create an already-expired session.
437        let mut req = test_create_request();
438        req.time_limit = chrono::Duration::zero();
439        store.create(req).await;
440
441        // Create a valid session.
442        let valid_req = test_create_request();
443        store.create(valid_req).await;
444
445        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
446
447        let removed = store.cleanup_expired().await;
448        assert_eq!(removed, 1);
449    }
450
451    #[tokio::test]
452    async fn session_not_found() {
453        let store = SessionStore::new();
454        let fake_id = Uuid::new_v4();
455        let result = store.use_session(fake_id, "anything").await;
456        assert!(matches!(result, Err(SessionError::NotFound(_))));
457    }
458
459    #[tokio::test]
460    async fn rate_limit_enforcement() {
461        let store = SessionStore::new();
462        let mut req = test_create_request();
463        req.rate_limit_per_minute = Some(3);
464        req.call_budget = 100; // high budget, rate limit should trigger first
465        let session = store.create(req).await;
466
467        // First 3 calls succeed (within rate limit).
468        store
469            .use_session(session.session_id, "read_file")
470            .await
471            .unwrap();
472        store
473            .use_session(session.session_id, "read_file")
474            .await
475            .unwrap();
476        store
477            .use_session(session.session_id, "read_file")
478            .await
479            .unwrap();
480
481        // 4th call hits rate limit.
482        let result = store.use_session(session.session_id, "read_file").await;
483        assert!(
484            matches!(result, Err(SessionError::RateLimited { .. })),
485            "expected RateLimited, got {result:?}"
486        );
487    }
488
489    #[tokio::test]
490    async fn no_rate_limit_when_unset() {
491        let store = SessionStore::new();
492        let mut req = test_create_request();
493        req.rate_limit_per_minute = None;
494        req.call_budget = 100;
495        let session = store.create(req).await;
496
497        // All calls succeed without rate limiting.
498        for _ in 0..10 {
499            store
500                .use_session(session.session_id, "read_file")
501                .await
502                .unwrap();
503        }
504    }
505
506    /// batch with one unauthorized tool must consume zero budget.
507    #[tokio::test]
508    async fn batch_validation_atomicity() {
509        let store = SessionStore::new();
510        let mut req = test_create_request();
511        req.call_budget = 10;
512        req.authorized_tools = vec!["read_file".into(), "list_dir".into()];
513        let session = store.create(req).await;
514
515        // Batch contains one unauthorized tool ("delete_file").
516        let result = store
517            .use_session_batch(session.session_id, &["read_file", "delete_file"])
518            .await;
519        assert!(
520            matches!(result, Err(SessionError::ToolNotAuthorized { .. })),
521            "expected ToolNotAuthorized, got {result:?}"
522        );
523
524        // Budget must remain untouched.
525        let s = store.get(session.session_id).await.unwrap();
526        assert_eq!(
527            s.calls_made, 0,
528            "no budget should be consumed on batch failure"
529        );
530    }
531
532    #[tokio::test]
533    async fn batch_budget_enforcement() {
534        let store = SessionStore::new();
535        let mut req = test_create_request();
536        req.call_budget = 3;
537        req.authorized_tools = vec!["read_file".into()];
538        let session = store.create(req).await;
539
540        // Batch of 4 exceeds budget of 3.
541        let result = store
542            .use_session_batch(
543                session.session_id,
544                &["read_file", "read_file", "read_file", "read_file"],
545            )
546            .await;
547        assert!(
548            matches!(result, Err(SessionError::BudgetExceeded { .. })),
549            "expected BudgetExceeded, got {result:?}"
550        );
551
552        // Budget must remain at 0.
553        let s = store.get(session.session_id).await.unwrap();
554        assert_eq!(
555            s.calls_made, 0,
556            "no budget should be consumed on batch failure"
557        );
558    }
559
560    #[tokio::test]
561    async fn batch_rate_limit_enforcement() {
562        let store = SessionStore::new();
563        let mut req = test_create_request();
564        req.call_budget = 100;
565        req.rate_limit_per_minute = Some(3);
566        req.authorized_tools = vec!["read_file".into()];
567        let session = store.create(req).await;
568
569        // Batch of 4 exceeds rate limit of 3.
570        let result = store
571            .use_session_batch(
572                session.session_id,
573                &["read_file", "read_file", "read_file", "read_file"],
574            )
575            .await;
576        assert!(
577            matches!(result, Err(SessionError::RateLimited { .. })),
578            "expected RateLimited, got {result:?}"
579        );
580    }
581
582    #[tokio::test]
583    async fn empty_batch_succeeds() {
584        let store = SessionStore::new();
585        let session = store.create(test_create_request()).await;
586
587        // Empty batch should succeed without consuming budget.
588        let result = store
589            .use_session_batch(session.session_id, &[])
590            .await
591            .unwrap();
592        assert_eq!(result.calls_made, 0, "empty batch must not consume budget");
593    }
594
595    /// cleanup should also remove closed sessions.
596    #[tokio::test]
597    async fn cleanup_also_removes_closed() {
598        let store = SessionStore::new();
599        let session = store.create(test_create_request()).await;
600
601        // Close it.
602        store.close(session.session_id).await.unwrap();
603
604        // Cleanup should remove the closed session.
605        let removed = store.cleanup_expired().await;
606        assert_eq!(removed, 1, "closed session should be cleaned up");
607
608        // It should be gone.
609        let result = store.get(session.session_id).await;
610        assert!(
611            matches!(result, Err(SessionError::NotFound(_))),
612            "closed session should be removed after cleanup"
613        );
614    }
615
616    /// A session created with call_budget=0 should immediately fail on use.
617    #[tokio::test]
618    async fn zero_budget_session() {
619        let store = SessionStore::new();
620        let mut req = test_create_request();
621        req.call_budget = 0;
622        let session = store.create(req).await;
623
624        let result = store.use_session(session.session_id, "read_file").await;
625        assert!(
626            matches!(result, Err(SessionError::BudgetExceeded { .. })),
627            "zero-budget session must reject the first call, got {result:?}"
628        );
629    }
630
631    /// Agent deactivation must close all agent sessions.
632    #[tokio::test]
633    async fn deactivation_closes_agent_sessions() {
634        let store = SessionStore::new();
635        let agent_id = Uuid::new_v4();
636        let other_agent = Uuid::new_v4();
637
638        for _ in 0..3 {
639            let mut req = test_create_request();
640            req.agent_id = agent_id;
641            store.create(req).await;
642        }
643        let mut other_req = test_create_request();
644        other_req.agent_id = other_agent;
645        let other_session = store.create(other_req).await;
646
647        let closed = store.close_sessions_for_agent(agent_id).await;
648        assert_eq!(closed, 3);
649
650        let all = store.list_all().await;
651        for s in &all {
652            if s.agent_id == agent_id {
653                assert_eq!(s.status, SessionStatus::Closed);
654            }
655        }
656        let other = store.get(other_session.session_id).await.unwrap();
657        assert_eq!(other.status, SessionStatus::Active);
658    }
659
660    /// Concurrent budget enforcement.
661    /// Spawn 10 tasks each calling use_session once on a session with budget=5.
662    /// Exactly 5 must succeed and 5 must fail with BudgetExceeded.
663    #[tokio::test]
664    async fn concurrent_budget_enforcement() {
665        let store = SessionStore::new();
666        let mut req = test_create_request();
667        req.call_budget = 5;
668        req.authorized_tools = vec!["read_file".into()];
669        let session = store.create(req).await;
670
671        let successes = Arc::new(std::sync::atomic::AtomicU64::new(0));
672        let failures = Arc::new(std::sync::atomic::AtomicU64::new(0));
673
674        let mut handles = Vec::new();
675        for _ in 0..10 {
676            let store = store.clone();
677            let sid = session.session_id;
678            let s = successes.clone();
679            let f = failures.clone();
680            handles.push(tokio::spawn(async move {
681                match store.use_session(sid, "read_file").await {
682                    Ok(_) => {
683                        s.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
684                    }
685                    Err(SessionError::BudgetExceeded { .. }) => {
686                        f.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
687                    }
688                    Err(e) => panic!("unexpected error: {e:?}"),
689                }
690            }));
691        }
692
693        for h in handles {
694            h.await.unwrap();
695        }
696
697        assert_eq!(
698            successes.load(std::sync::atomic::Ordering::Relaxed),
699            5,
700            "exactly 5 calls should succeed"
701        );
702        assert_eq!(
703            failures.load(std::sync::atomic::Ordering::Relaxed),
704            5,
705            "exactly 5 calls should fail with BudgetExceeded"
706        );
707    }
708}