Skip to main content

arbiter_session/
store.rs

1//! In-memory session store with TTL-based cleanup.
2
3use chrono::Utc;
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use crate::error::SessionError;
10use crate::model::{DataSensitivity, SessionId, SessionStatus, TaskSession};
11
12/// Request to create a new task session.
13pub struct CreateSessionRequest {
14    /// The agent ID for this session.
15    pub agent_id: Uuid,
16    /// Delegation chain snapshot (serialized).
17    pub delegation_chain_snapshot: Vec<String>,
18    /// Declared intent for the session.
19    pub declared_intent: String,
20    /// Tools authorized by policy evaluation.
21    pub authorized_tools: Vec<String>,
22    /// Credential references this session may resolve.
23    /// Empty means no credentials (deny-by-default for credential injection).
24    #[allow(dead_code)]
25    pub authorized_credentials: Vec<String>,
26    /// Session time limit.
27    pub time_limit: chrono::Duration,
28    /// Maximum number of tool calls.
29    pub call_budget: u64,
30    /// Per-minute rate limit. `None` means no rate limit.
31    pub rate_limit_per_minute: Option<u64>,
32    /// Duration of the rate-limit window in seconds. Defaults to 60.
33    pub rate_limit_window_secs: u64,
34    /// Data sensitivity ceiling.
35    pub data_sensitivity_ceiling: DataSensitivity,
36}
37
38/// In-memory session store with TTL-based cleanup.
39#[derive(Clone)]
40pub struct SessionStore {
41    sessions: Arc<RwLock<HashMap<SessionId, TaskSession>>>,
42}
43
44impl SessionStore {
45    /// Create a new empty session store.
46    pub fn new() -> Self {
47        Self {
48            sessions: Arc::new(RwLock::new(HashMap::new())),
49        }
50    }
51
52    /// Create a new task session and return it.
53    pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
54        // Enforce minimum time limit to prevent zero-duration sessions.
55        let time_limit = if req.time_limit < chrono::Duration::seconds(1) {
56            tracing::warn!(
57                requested = ?req.time_limit,
58                "session time_limit below minimum, clamping to 1 second"
59            );
60            chrono::Duration::seconds(1)
61        } else {
62            req.time_limit
63        };
64        let session = TaskSession {
65            session_id: Uuid::new_v4(),
66            agent_id: req.agent_id,
67            delegation_chain_snapshot: req.delegation_chain_snapshot,
68            declared_intent: req.declared_intent,
69            authorized_tools: req.authorized_tools,
70            authorized_credentials: req.authorized_credentials,
71            time_limit,
72            call_budget: req.call_budget,
73            calls_made: 0,
74            rate_limit_per_minute: req.rate_limit_per_minute,
75            rate_window_start: Utc::now(),
76            rate_window_calls: 0,
77            rate_limit_window_secs: req.rate_limit_window_secs,
78            data_sensitivity_ceiling: req.data_sensitivity_ceiling,
79            created_at: Utc::now(),
80            status: SessionStatus::Active,
81        };
82
83        tracing::info!(
84            session_id = %session.session_id,
85            agent_id = %session.agent_id,
86            intent = %session.declared_intent,
87            budget = session.call_budget,
88            "created task session"
89        );
90
91        let mut sessions = self.sessions.write().await;
92        sessions.insert(session.session_id, session.clone());
93        session
94    }
95
96    /// Atomically check per-agent session cap and create if under the limit.
97    /// Prevents the TOCTOU race where two concurrent requests both pass the
98    /// count check before either creates a session.
99    pub async fn create_if_under_cap(
100        &self,
101        req: CreateSessionRequest,
102        max_sessions: u64,
103    ) -> Result<TaskSession, SessionError> {
104        let mut sessions = self.sessions.write().await;
105
106        let active_count = sessions
107            .values()
108            .filter(|s| s.agent_id == req.agent_id && s.status == SessionStatus::Active)
109            .count() as u64;
110
111        if active_count >= max_sessions {
112            return Err(SessionError::TooManySessions {
113                agent_id: req.agent_id.to_string(),
114                max: max_sessions,
115                current: active_count,
116            });
117        }
118
119        let session = TaskSession {
120            session_id: Uuid::new_v4(),
121            agent_id: req.agent_id,
122            delegation_chain_snapshot: req.delegation_chain_snapshot,
123            declared_intent: req.declared_intent,
124            authorized_tools: req.authorized_tools,
125            authorized_credentials: req.authorized_credentials,
126            time_limit: req.time_limit,
127            call_budget: req.call_budget,
128            calls_made: 0,
129            rate_limit_per_minute: req.rate_limit_per_minute,
130            rate_window_start: Utc::now(),
131            rate_window_calls: 0,
132            rate_limit_window_secs: req.rate_limit_window_secs,
133            data_sensitivity_ceiling: req.data_sensitivity_ceiling,
134            created_at: Utc::now(),
135            status: SessionStatus::Active,
136        };
137
138        sessions.insert(session.session_id, session.clone());
139        Ok(session)
140    }
141
142    /// Record a tool call against the session, checking all constraints.
143    ///
144    /// Returns the updated session on success, or an error if:
145    /// - Session not found
146    /// - Session expired (would return 408)
147    /// - Budget exceeded (would return 429)
148    /// - Tool not authorized (would return 403)
149    pub async fn use_session(
150        &self,
151        session_id: SessionId,
152        tool_name: &str,
153        requesting_agent_id: Option<Uuid>,
154    ) -> Result<TaskSession, SessionError> {
155        let mut sessions = self.sessions.write().await;
156        let session = sessions
157            .get_mut(&session_id)
158            .ok_or(SessionError::NotFound(session_id))?;
159
160        // Verify agent binding to prevent session fixation.
161        if let Some(agent_id) = requesting_agent_id
162            && agent_id != session.agent_id
163        {
164            return Err(SessionError::AgentMismatch {
165                session_id,
166                expected: session.agent_id,
167                actual: agent_id,
168            });
169        }
170
171        if session.status == SessionStatus::Closed {
172            return Err(SessionError::AlreadyClosed(session_id));
173        }
174
175        // Check expiry.
176        if session.is_expired() {
177            session.status = SessionStatus::Expired;
178            return Err(SessionError::Expired(session_id));
179        }
180
181        // Check budget.
182        if session.is_budget_exceeded() {
183            return Err(SessionError::BudgetExceeded {
184                session_id,
185                limit: session.call_budget,
186                used: session.calls_made,
187            });
188        }
189
190        // Check tool authorization.
191        if !session.is_tool_authorized(tool_name) {
192            return Err(SessionError::ToolNotAuthorized {
193                session_id,
194                tool: tool_name.into(),
195            });
196        }
197
198        // Check rate limit.
199        if session.check_rate_limit() {
200            return Err(SessionError::RateLimited {
201                session_id,
202                limit_per_minute: session.rate_limit_per_minute.unwrap_or(0),
203            });
204        }
205
206        // All checks passed. Increment counter.
207        session.calls_made += 1;
208
209        tracing::debug!(
210            session_id = %session_id,
211            tool = tool_name,
212            calls = session.calls_made,
213            budget = session.call_budget,
214            "session tool call recorded"
215        );
216
217        Ok(session.clone())
218    }
219
220    /// Atomically validate and record a batch of tool calls against the session.
221    ///
222    /// Acquires the write lock once, validates ALL tools against
223    /// the whitelist and budget, and only increments `calls_made` by the full
224    /// batch count if every tool passes. If any tool fails validation, no
225    /// budget is consumed for any of them.
226    pub async fn use_session_batch(
227        &self,
228        session_id: SessionId,
229        tool_names: &[&str],
230        requesting_agent_id: Option<Uuid>,
231    ) -> Result<TaskSession, SessionError> {
232        let mut sessions = self.sessions.write().await;
233        let session = sessions
234            .get_mut(&session_id)
235            .ok_or(SessionError::NotFound(session_id))?;
236
237        // Verify agent binding to prevent session fixation.
238        if let Some(agent_id) = requesting_agent_id
239            && agent_id != session.agent_id
240        {
241            return Err(SessionError::AgentMismatch {
242                session_id,
243                expected: session.agent_id,
244                actual: agent_id,
245            });
246        }
247
248        if session.status == SessionStatus::Closed {
249            return Err(SessionError::AlreadyClosed(session_id));
250        }
251
252        // Check expiry.
253        if session.is_expired() {
254            session.status = SessionStatus::Expired;
255            return Err(SessionError::Expired(session_id));
256        }
257
258        let batch_size = tool_names.len() as u64;
259
260        // Check budget for the entire batch.
261        if session.calls_made + batch_size > session.call_budget {
262            return Err(SessionError::BudgetExceeded {
263                session_id,
264                limit: session.call_budget,
265                used: session.calls_made,
266            });
267        }
268
269        // Check tool authorization for every tool before consuming any budget.
270        for tool_name in tool_names {
271            if !session.is_tool_authorized(tool_name) {
272                return Err(SessionError::ToolNotAuthorized {
273                    session_id,
274                    tool: (*tool_name).into(),
275                });
276            }
277        }
278
279        // Check rate limit for the entire batch.
280        // We check whether adding batch_size calls would exceed the limit,
281        // without mutating state until we know it's safe.
282        if let Some(limit) = session.rate_limit_per_minute {
283            let now = chrono::Utc::now();
284            let elapsed = now - session.rate_window_start;
285            if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
286                // New window; will be reset below after all checks pass.
287            } else if session.rate_window_calls + batch_size > limit {
288                return Err(SessionError::RateLimited {
289                    session_id,
290                    limit_per_minute: limit,
291                });
292            }
293        }
294
295        // All checks passed. Atomically increment counters.
296        // Update rate limit window.
297        if let Some(_limit) = session.rate_limit_per_minute {
298            let now = chrono::Utc::now();
299            let elapsed = now - session.rate_window_start;
300            if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
301                session.rate_window_start = now;
302                session.rate_window_calls = batch_size;
303            } else {
304                session.rate_window_calls += batch_size;
305            }
306        }
307
308        session.calls_made += batch_size;
309
310        tracing::debug!(
311            session_id = %session_id,
312            batch_size = batch_size,
313            calls = session.calls_made,
314            budget = session.call_budget,
315            "session batch tool calls recorded"
316        );
317
318        Ok(session.clone())
319    }
320
321    /// Close a session, preventing further use.
322    pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
323        let mut sessions = self.sessions.write().await;
324        let session = sessions
325            .get_mut(&session_id)
326            .ok_or(SessionError::NotFound(session_id))?;
327
328        if session.status == SessionStatus::Closed {
329            return Err(SessionError::AlreadyClosed(session_id));
330        }
331
332        session.status = SessionStatus::Closed;
333        tracing::info!(session_id = %session_id, "session closed");
334        Ok(session.clone())
335    }
336
337    /// Get a session by ID without modifying it.
338    pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
339        let sessions = self.sessions.read().await;
340        sessions
341            .get(&session_id)
342            .cloned()
343            .ok_or(SessionError::NotFound(session_id))
344    }
345
346    /// List all sessions currently in the store (active, expired, and closed).
347    pub async fn list_all(&self) -> Vec<TaskSession> {
348        let sessions = self.sessions.read().await;
349        sessions.values().cloned().collect()
350    }
351
352    /// List only sessions belonging to a specific agent.
353    /// Use this instead of list_all() when agent-scoped access is needed
354    /// to prevent cross-agent session data exposure.
355    pub async fn list_for_agent(&self, agent_id: Uuid) -> Vec<TaskSession> {
356        let sessions = self.sessions.read().await;
357        sessions
358            .values()
359            .filter(|s| s.agent_id == agent_id)
360            .cloned()
361            .collect()
362    }
363
364    /// Count the number of active sessions for a given agent.
365    ///
366    /// P0: Used to enforce per-agent concurrent session caps.
367    pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
368        let sessions = self.sessions.read().await;
369        sessions
370            .values()
371            .filter(|s| s.agent_id == agent_id && s.status == SessionStatus::Active)
372            .count() as u64
373    }
374
375    /// Close all active sessions belonging to a specific agent.
376    ///
377    /// When an agent is deactivated via cascade_deactivate,
378    /// all its sessions must be immediately closed.
379    pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
380        let mut sessions = self.sessions.write().await;
381        let mut closed = 0usize;
382        for session in sessions.values_mut() {
383            if session.agent_id == agent_id && session.status == SessionStatus::Active {
384                session.status = SessionStatus::Closed;
385                closed += 1;
386                tracing::info!(
387                    session_id = %session.session_id,
388                    agent_id = %agent_id,
389                    "closed session due to agent deactivation"
390                );
391            }
392        }
393        closed
394    }
395
396    /// Remove expired sessions from the store. Returns the number removed.
397    pub async fn cleanup_expired(&self) -> usize {
398        let mut sessions = self.sessions.write().await;
399        let before = sessions.len();
400        // Also clean up closed sessions, not just expired ones.
401        // Previously, closed sessions accumulated indefinitely, growing the store without bound.
402        sessions.retain(|_, s| {
403            if s.is_expired() {
404                tracing::debug!(session_id = %s.session_id, "cleaning up expired session");
405                false
406            } else if s.status == SessionStatus::Closed {
407                tracing::debug!(session_id = %s.session_id, "cleaning up closed session");
408                false
409            } else {
410                true
411            }
412        });
413        let removed = before - sessions.len();
414        if removed > 0 {
415            tracing::info!(removed, "cleaned up expired/closed sessions");
416        }
417        removed
418    }
419}
420
421impl Default for SessionStore {
422    fn default() -> Self {
423        Self::new()
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    fn test_create_request() -> CreateSessionRequest {
432        CreateSessionRequest {
433            agent_id: Uuid::new_v4(),
434            delegation_chain_snapshot: vec![],
435            declared_intent: "read and analyze files".into(),
436            authorized_tools: vec!["read_file".into(), "list_dir".into()],
437            authorized_credentials: vec![],
438            time_limit: chrono::Duration::hours(1),
439            call_budget: 5,
440            rate_limit_per_minute: None,
441            rate_limit_window_secs: 60,
442            data_sensitivity_ceiling: DataSensitivity::Internal,
443        }
444    }
445
446    #[tokio::test]
447    async fn create_and_use_session() {
448        let store = SessionStore::new();
449        let session = store.create(test_create_request()).await;
450
451        assert_eq!(session.calls_made, 0);
452        assert!(session.is_active());
453
454        let updated = store
455            .use_session(session.session_id, "read_file", None)
456            .await
457            .unwrap();
458        assert_eq!(updated.calls_made, 1);
459    }
460
461    #[tokio::test]
462    async fn budget_enforcement() {
463        let store = SessionStore::new();
464        let mut req = test_create_request();
465        req.call_budget = 2;
466        let session = store.create(req).await;
467
468        // Use up the budget.
469        store
470            .use_session(session.session_id, "read_file", None)
471            .await
472            .unwrap();
473        store
474            .use_session(session.session_id, "read_file", None)
475            .await
476            .unwrap();
477
478        // Third call should fail.
479        let result = store
480            .use_session(session.session_id, "read_file", None)
481            .await;
482        assert!(matches!(result, Err(SessionError::BudgetExceeded { .. })));
483    }
484
485    #[tokio::test]
486    async fn tool_whitelist_enforcement() {
487        let store = SessionStore::new();
488        let session = store.create(test_create_request()).await;
489
490        // Authorized tool works.
491        store
492            .use_session(session.session_id, "read_file", None)
493            .await
494            .unwrap();
495
496        // Unauthorized tool is rejected.
497        let result = store
498            .use_session(session.session_id, "delete_file", None)
499            .await;
500        assert!(matches!(
501            result,
502            Err(SessionError::ToolNotAuthorized { .. })
503        ));
504    }
505
506    #[tokio::test]
507    async fn session_expiry() {
508        let store = SessionStore::new();
509        let mut req = test_create_request();
510        // Set a 1-second time limit (minimum enforced by create()).
511        // Previously used zero, but minimum is now clamped to 1s.
512        req.time_limit = chrono::Duration::seconds(1);
513        let session = store.create(req).await;
514
515        // Wait for the session to expire.
516        tokio::time::sleep(std::time::Duration::from_millis(1100)).await;
517
518        let result = store
519            .use_session(session.session_id, "read_file", None)
520            .await;
521        assert!(matches!(result, Err(SessionError::Expired(_))));
522    }
523
524    #[tokio::test]
525    async fn close_and_reuse() {
526        let store = SessionStore::new();
527        let session = store.create(test_create_request()).await;
528
529        store.close(session.session_id).await.unwrap();
530
531        let result = store
532            .use_session(session.session_id, "read_file", None)
533            .await;
534        assert!(matches!(result, Err(SessionError::AlreadyClosed(_))));
535    }
536
537    #[tokio::test]
538    async fn cleanup_expired_sessions() {
539        let store = SessionStore::new();
540
541        // Create a short-lived session (1s minimum).
542        let mut req = test_create_request();
543        req.time_limit = chrono::Duration::seconds(1);
544        store.create(req).await;
545
546        // Create a valid session with longer limit.
547        let valid_req = test_create_request();
548        store.create(valid_req).await;
549
550        // Wait for the short session to expire.
551        tokio::time::sleep(std::time::Duration::from_millis(1100)).await;
552
553        let removed = store.cleanup_expired().await;
554        assert_eq!(removed, 1);
555    }
556
557    #[tokio::test]
558    async fn session_not_found() {
559        let store = SessionStore::new();
560        let fake_id = Uuid::new_v4();
561        let result = store.use_session(fake_id, "anything", None).await;
562        assert!(matches!(result, Err(SessionError::NotFound(_))));
563    }
564
565    #[tokio::test]
566    async fn rate_limit_enforcement() {
567        let store = SessionStore::new();
568        let mut req = test_create_request();
569        req.rate_limit_per_minute = Some(3);
570        req.call_budget = 100; // high budget, rate limit should trigger first
571        let session = store.create(req).await;
572
573        // First 3 calls succeed (within rate limit).
574        store
575            .use_session(session.session_id, "read_file", None)
576            .await
577            .unwrap();
578        store
579            .use_session(session.session_id, "read_file", None)
580            .await
581            .unwrap();
582        store
583            .use_session(session.session_id, "read_file", None)
584            .await
585            .unwrap();
586
587        // 4th call hits rate limit.
588        let result = store
589            .use_session(session.session_id, "read_file", None)
590            .await;
591        assert!(
592            matches!(result, Err(SessionError::RateLimited { .. })),
593            "expected RateLimited, got {result:?}"
594        );
595    }
596
597    #[tokio::test]
598    async fn no_rate_limit_when_unset() {
599        let store = SessionStore::new();
600        let mut req = test_create_request();
601        req.rate_limit_per_minute = None;
602        req.call_budget = 100;
603        let session = store.create(req).await;
604
605        // All calls succeed without rate limiting.
606        for _ in 0..10 {
607            store
608                .use_session(session.session_id, "read_file", None)
609                .await
610                .unwrap();
611        }
612    }
613
614    /// batch with one unauthorized tool must consume zero budget.
615    #[tokio::test]
616    async fn batch_validation_atomicity() {
617        let store = SessionStore::new();
618        let mut req = test_create_request();
619        req.call_budget = 10;
620        req.authorized_tools = vec!["read_file".into(), "list_dir".into()];
621        let session = store.create(req).await;
622
623        // Batch contains one unauthorized tool ("delete_file").
624        let result = store
625            .use_session_batch(session.session_id, &["read_file", "delete_file"], None)
626            .await;
627        assert!(
628            matches!(result, Err(SessionError::ToolNotAuthorized { .. })),
629            "expected ToolNotAuthorized, got {result:?}"
630        );
631
632        // Budget must remain untouched.
633        let s = store.get(session.session_id).await.unwrap();
634        assert_eq!(
635            s.calls_made, 0,
636            "no budget should be consumed on batch failure"
637        );
638    }
639
640    #[tokio::test]
641    async fn batch_budget_enforcement() {
642        let store = SessionStore::new();
643        let mut req = test_create_request();
644        req.call_budget = 3;
645        req.authorized_tools = vec!["read_file".into()];
646        let session = store.create(req).await;
647
648        // Batch of 4 exceeds budget of 3.
649        let result = store
650            .use_session_batch(
651                session.session_id,
652                &["read_file", "read_file", "read_file", "read_file"],
653                None,
654            )
655            .await;
656        assert!(
657            matches!(result, Err(SessionError::BudgetExceeded { .. })),
658            "expected BudgetExceeded, got {result:?}"
659        );
660
661        // Budget must remain at 0.
662        let s = store.get(session.session_id).await.unwrap();
663        assert_eq!(
664            s.calls_made, 0,
665            "no budget should be consumed on batch failure"
666        );
667    }
668
669    #[tokio::test]
670    async fn batch_rate_limit_enforcement() {
671        let store = SessionStore::new();
672        let mut req = test_create_request();
673        req.call_budget = 100;
674        req.rate_limit_per_minute = Some(3);
675        req.authorized_tools = vec!["read_file".into()];
676        let session = store.create(req).await;
677
678        // Batch of 4 exceeds rate limit of 3.
679        let result = store
680            .use_session_batch(
681                session.session_id,
682                &["read_file", "read_file", "read_file", "read_file"],
683                None,
684            )
685            .await;
686        assert!(
687            matches!(result, Err(SessionError::RateLimited { .. })),
688            "expected RateLimited, got {result:?}"
689        );
690    }
691
692    #[tokio::test]
693    async fn empty_batch_succeeds() {
694        let store = SessionStore::new();
695        let session = store.create(test_create_request()).await;
696
697        // Empty batch should succeed without consuming budget.
698        let result = store
699            .use_session_batch(session.session_id, &[], None)
700            .await
701            .unwrap();
702        assert_eq!(result.calls_made, 0, "empty batch must not consume budget");
703    }
704
705    /// cleanup should also remove closed sessions.
706    #[tokio::test]
707    async fn cleanup_also_removes_closed() {
708        let store = SessionStore::new();
709        let session = store.create(test_create_request()).await;
710
711        // Close it.
712        store.close(session.session_id).await.unwrap();
713
714        // Cleanup should remove the closed session.
715        let removed = store.cleanup_expired().await;
716        assert_eq!(removed, 1, "closed session should be cleaned up");
717
718        // It should be gone.
719        let result = store.get(session.session_id).await;
720        assert!(
721            matches!(result, Err(SessionError::NotFound(_))),
722            "closed session should be removed after cleanup"
723        );
724    }
725
726    /// A session created with call_budget=0 should immediately fail on use.
727    #[tokio::test]
728    async fn zero_budget_session() {
729        let store = SessionStore::new();
730        let mut req = test_create_request();
731        req.call_budget = 0;
732        let session = store.create(req).await;
733
734        let result = store
735            .use_session(session.session_id, "read_file", None)
736            .await;
737        assert!(
738            matches!(result, Err(SessionError::BudgetExceeded { .. })),
739            "zero-budget session must reject the first call, got {result:?}"
740        );
741    }
742
743    /// Agent deactivation must close all agent sessions.
744    #[tokio::test]
745    async fn deactivation_closes_agent_sessions() {
746        let store = SessionStore::new();
747        let agent_id = Uuid::new_v4();
748        let other_agent = Uuid::new_v4();
749
750        for _ in 0..3 {
751            let mut req = test_create_request();
752            req.agent_id = agent_id;
753            store.create(req).await;
754        }
755        let mut other_req = test_create_request();
756        other_req.agent_id = other_agent;
757        let other_session = store.create(other_req).await;
758
759        let closed = store.close_sessions_for_agent(agent_id).await;
760        assert_eq!(closed, 3);
761
762        let all = store.list_all().await;
763        for s in &all {
764            if s.agent_id == agent_id {
765                assert_eq!(s.status, SessionStatus::Closed);
766            }
767        }
768        let other = store.get(other_session.session_id).await.unwrap();
769        assert_eq!(other.status, SessionStatus::Active);
770    }
771
772    /// Concurrent budget enforcement.
773    /// Spawn 10 tasks each calling use_session once on a session with budget=5.
774    /// Exactly 5 must succeed and 5 must fail with BudgetExceeded.
775    #[tokio::test]
776    async fn concurrent_budget_enforcement() {
777        let store = SessionStore::new();
778        let mut req = test_create_request();
779        req.call_budget = 5;
780        req.authorized_tools = vec!["read_file".into()];
781        let session = store.create(req).await;
782
783        let successes = Arc::new(std::sync::atomic::AtomicU64::new(0));
784        let failures = Arc::new(std::sync::atomic::AtomicU64::new(0));
785
786        let mut handles = Vec::new();
787        for _ in 0..10 {
788            let store = store.clone();
789            let sid = session.session_id;
790            let s = successes.clone();
791            let f = failures.clone();
792            handles.push(tokio::spawn(async move {
793                match store.use_session(sid, "read_file", None).await {
794                    Ok(_) => {
795                        s.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
796                    }
797                    Err(SessionError::BudgetExceeded { .. }) => {
798                        f.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
799                    }
800                    Err(e) => panic!("unexpected error: {e:?}"),
801                }
802            }));
803        }
804
805        for h in handles {
806            h.await.unwrap();
807        }
808
809        assert_eq!(
810            successes.load(std::sync::atomic::Ordering::Relaxed),
811            5,
812            "exactly 5 calls should succeed"
813        );
814        assert_eq!(
815            failures.load(std::sync::atomic::Ordering::Relaxed),
816            5,
817            "exactly 5 calls should fail with BudgetExceeded"
818        );
819    }
820
821    /// Session fixation prevention: a different agent must not be able to use
822    /// another agent's session by presenting its session ID.
823    #[tokio::test]
824    async fn agent_mismatch_rejected() {
825        let store = SessionStore::new();
826        let session = store.create(test_create_request()).await;
827        let attacker_id = Uuid::new_v4();
828
829        // Attacker presents a different agent_id than the session owner.
830        let result = store
831            .use_session(session.session_id, "read_file", Some(attacker_id))
832            .await;
833        assert!(
834            matches!(result, Err(SessionError::AgentMismatch { .. })),
835            "different agent must be rejected, got {result:?}"
836        );
837
838        // Legitimate agent succeeds.
839        let result = store
840            .use_session(session.session_id, "read_file", Some(session.agent_id))
841            .await;
842        assert!(result.is_ok(), "session owner should succeed");
843    }
844
845    /// Batch variant of agent mismatch check.
846    #[tokio::test]
847    async fn batch_agent_mismatch_rejected() {
848        let store = SessionStore::new();
849        let session = store.create(test_create_request()).await;
850        let attacker_id = Uuid::new_v4();
851
852        let result = store
853            .use_session_batch(session.session_id, &["read_file"], Some(attacker_id))
854            .await;
855        assert!(
856            matches!(result, Err(SessionError::AgentMismatch { .. })),
857            "batch with wrong agent must be rejected, got {result:?}"
858        );
859    }
860}