1use chrono::Utc;
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use crate::error::SessionError;
10use crate::model::{DataSensitivity, SessionId, SessionStatus, TaskSession};
11
12pub struct CreateSessionRequest {
14 pub agent_id: Uuid,
16 pub delegation_chain_snapshot: Vec<String>,
18 pub declared_intent: String,
20 pub authorized_tools: Vec<String>,
22 #[allow(dead_code)]
25 pub authorized_credentials: Vec<String>,
26 pub time_limit: chrono::Duration,
28 pub call_budget: u64,
30 pub rate_limit_per_minute: Option<u64>,
32 pub rate_limit_window_secs: u64,
34 pub data_sensitivity_ceiling: DataSensitivity,
36}
37
38#[derive(Clone)]
40pub struct SessionStore {
41 sessions: Arc<RwLock<HashMap<SessionId, TaskSession>>>,
42}
43
44impl SessionStore {
45 pub fn new() -> Self {
47 Self {
48 sessions: Arc::new(RwLock::new(HashMap::new())),
49 }
50 }
51
52 pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
54 let time_limit = if req.time_limit < chrono::Duration::seconds(1) {
56 tracing::warn!(
57 requested = ?req.time_limit,
58 "session time_limit below minimum, clamping to 1 second"
59 );
60 chrono::Duration::seconds(1)
61 } else {
62 req.time_limit
63 };
64 let session = TaskSession {
65 session_id: Uuid::new_v4(),
66 agent_id: req.agent_id,
67 delegation_chain_snapshot: req.delegation_chain_snapshot,
68 declared_intent: req.declared_intent,
69 authorized_tools: req.authorized_tools,
70 authorized_credentials: req.authorized_credentials,
71 time_limit,
72 call_budget: req.call_budget,
73 calls_made: 0,
74 rate_limit_per_minute: req.rate_limit_per_minute,
75 rate_window_start: Utc::now(),
76 rate_window_calls: 0,
77 rate_limit_window_secs: req.rate_limit_window_secs,
78 data_sensitivity_ceiling: req.data_sensitivity_ceiling,
79 created_at: Utc::now(),
80 status: SessionStatus::Active,
81 };
82
83 tracing::info!(
84 session_id = %session.session_id,
85 agent_id = %session.agent_id,
86 intent = %session.declared_intent,
87 budget = session.call_budget,
88 "created task session"
89 );
90
91 let mut sessions = self.sessions.write().await;
92 sessions.insert(session.session_id, session.clone());
93 session
94 }
95
96 pub async fn create_if_under_cap(
100 &self,
101 req: CreateSessionRequest,
102 max_sessions: u64,
103 ) -> Result<TaskSession, SessionError> {
104 let mut sessions = self.sessions.write().await;
105
106 let active_count = sessions
107 .values()
108 .filter(|s| s.agent_id == req.agent_id && s.status == SessionStatus::Active)
109 .count() as u64;
110
111 if active_count >= max_sessions {
112 return Err(SessionError::TooManySessions {
113 agent_id: req.agent_id.to_string(),
114 max: max_sessions,
115 current: active_count,
116 });
117 }
118
119 let session = TaskSession {
120 session_id: Uuid::new_v4(),
121 agent_id: req.agent_id,
122 delegation_chain_snapshot: req.delegation_chain_snapshot,
123 declared_intent: req.declared_intent,
124 authorized_tools: req.authorized_tools,
125 authorized_credentials: req.authorized_credentials,
126 time_limit: req.time_limit,
127 call_budget: req.call_budget,
128 calls_made: 0,
129 rate_limit_per_minute: req.rate_limit_per_minute,
130 rate_window_start: Utc::now(),
131 rate_window_calls: 0,
132 rate_limit_window_secs: req.rate_limit_window_secs,
133 data_sensitivity_ceiling: req.data_sensitivity_ceiling,
134 created_at: Utc::now(),
135 status: SessionStatus::Active,
136 };
137
138 sessions.insert(session.session_id, session.clone());
139 Ok(session)
140 }
141
142 pub async fn use_session(
150 &self,
151 session_id: SessionId,
152 tool_name: &str,
153 requesting_agent_id: Option<Uuid>,
154 ) -> Result<TaskSession, SessionError> {
155 let mut sessions = self.sessions.write().await;
156 let session = sessions
157 .get_mut(&session_id)
158 .ok_or(SessionError::NotFound(session_id))?;
159
160 if let Some(agent_id) = requesting_agent_id
162 && agent_id != session.agent_id
163 {
164 return Err(SessionError::AgentMismatch {
165 session_id,
166 expected: session.agent_id,
167 actual: agent_id,
168 });
169 }
170
171 if session.status == SessionStatus::Closed {
172 return Err(SessionError::AlreadyClosed(session_id));
173 }
174
175 if session.is_expired() {
177 session.status = SessionStatus::Expired;
178 return Err(SessionError::Expired(session_id));
179 }
180
181 if session.is_budget_exceeded() {
183 return Err(SessionError::BudgetExceeded {
184 session_id,
185 limit: session.call_budget,
186 used: session.calls_made,
187 });
188 }
189
190 if !session.is_tool_authorized(tool_name) {
192 return Err(SessionError::ToolNotAuthorized {
193 session_id,
194 tool: tool_name.into(),
195 });
196 }
197
198 if session.check_rate_limit() {
200 return Err(SessionError::RateLimited {
201 session_id,
202 limit_per_minute: session.rate_limit_per_minute.unwrap_or(0),
203 });
204 }
205
206 session.calls_made += 1;
208
209 tracing::debug!(
210 session_id = %session_id,
211 tool = tool_name,
212 calls = session.calls_made,
213 budget = session.call_budget,
214 "session tool call recorded"
215 );
216
217 Ok(session.clone())
218 }
219
220 pub async fn use_session_batch(
227 &self,
228 session_id: SessionId,
229 tool_names: &[&str],
230 requesting_agent_id: Option<Uuid>,
231 ) -> Result<TaskSession, SessionError> {
232 let mut sessions = self.sessions.write().await;
233 let session = sessions
234 .get_mut(&session_id)
235 .ok_or(SessionError::NotFound(session_id))?;
236
237 if let Some(agent_id) = requesting_agent_id
239 && agent_id != session.agent_id
240 {
241 return Err(SessionError::AgentMismatch {
242 session_id,
243 expected: session.agent_id,
244 actual: agent_id,
245 });
246 }
247
248 if session.status == SessionStatus::Closed {
249 return Err(SessionError::AlreadyClosed(session_id));
250 }
251
252 if session.is_expired() {
254 session.status = SessionStatus::Expired;
255 return Err(SessionError::Expired(session_id));
256 }
257
258 let batch_size = tool_names.len() as u64;
259
260 if session.calls_made + batch_size > session.call_budget {
262 return Err(SessionError::BudgetExceeded {
263 session_id,
264 limit: session.call_budget,
265 used: session.calls_made,
266 });
267 }
268
269 for tool_name in tool_names {
271 if !session.is_tool_authorized(tool_name) {
272 return Err(SessionError::ToolNotAuthorized {
273 session_id,
274 tool: (*tool_name).into(),
275 });
276 }
277 }
278
279 if let Some(limit) = session.rate_limit_per_minute {
283 let now = chrono::Utc::now();
284 let elapsed = now - session.rate_window_start;
285 if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
286 } else if session.rate_window_calls + batch_size > limit {
288 return Err(SessionError::RateLimited {
289 session_id,
290 limit_per_minute: limit,
291 });
292 }
293 }
294
295 if let Some(_limit) = session.rate_limit_per_minute {
298 let now = chrono::Utc::now();
299 let elapsed = now - session.rate_window_start;
300 if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
301 session.rate_window_start = now;
302 session.rate_window_calls = batch_size;
303 } else {
304 session.rate_window_calls += batch_size;
305 }
306 }
307
308 session.calls_made += batch_size;
309
310 tracing::debug!(
311 session_id = %session_id,
312 batch_size = batch_size,
313 calls = session.calls_made,
314 budget = session.call_budget,
315 "session batch tool calls recorded"
316 );
317
318 Ok(session.clone())
319 }
320
321 pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
323 let mut sessions = self.sessions.write().await;
324 let session = sessions
325 .get_mut(&session_id)
326 .ok_or(SessionError::NotFound(session_id))?;
327
328 if session.status == SessionStatus::Closed {
329 return Err(SessionError::AlreadyClosed(session_id));
330 }
331
332 session.status = SessionStatus::Closed;
333 tracing::info!(session_id = %session_id, "session closed");
334 Ok(session.clone())
335 }
336
337 pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
339 let sessions = self.sessions.read().await;
340 sessions
341 .get(&session_id)
342 .cloned()
343 .ok_or(SessionError::NotFound(session_id))
344 }
345
346 pub async fn list_all(&self) -> Vec<TaskSession> {
348 let sessions = self.sessions.read().await;
349 sessions.values().cloned().collect()
350 }
351
352 pub async fn list_for_agent(&self, agent_id: Uuid) -> Vec<TaskSession> {
356 let sessions = self.sessions.read().await;
357 sessions
358 .values()
359 .filter(|s| s.agent_id == agent_id)
360 .cloned()
361 .collect()
362 }
363
364 pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
368 let sessions = self.sessions.read().await;
369 sessions
370 .values()
371 .filter(|s| s.agent_id == agent_id && s.status == SessionStatus::Active)
372 .count() as u64
373 }
374
375 pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
380 let mut sessions = self.sessions.write().await;
381 let mut closed = 0usize;
382 for session in sessions.values_mut() {
383 if session.agent_id == agent_id && session.status == SessionStatus::Active {
384 session.status = SessionStatus::Closed;
385 closed += 1;
386 tracing::info!(
387 session_id = %session.session_id,
388 agent_id = %agent_id,
389 "closed session due to agent deactivation"
390 );
391 }
392 }
393 closed
394 }
395
396 pub async fn cleanup_expired(&self) -> usize {
398 let mut sessions = self.sessions.write().await;
399 let before = sessions.len();
400 sessions.retain(|_, s| {
403 if s.is_expired() {
404 tracing::debug!(session_id = %s.session_id, "cleaning up expired session");
405 false
406 } else if s.status == SessionStatus::Closed {
407 tracing::debug!(session_id = %s.session_id, "cleaning up closed session");
408 false
409 } else {
410 true
411 }
412 });
413 let removed = before - sessions.len();
414 if removed > 0 {
415 tracing::info!(removed, "cleaned up expired/closed sessions");
416 }
417 removed
418 }
419}
420
421impl Default for SessionStore {
422 fn default() -> Self {
423 Self::new()
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 fn test_create_request() -> CreateSessionRequest {
432 CreateSessionRequest {
433 agent_id: Uuid::new_v4(),
434 delegation_chain_snapshot: vec![],
435 declared_intent: "read and analyze files".into(),
436 authorized_tools: vec!["read_file".into(), "list_dir".into()],
437 authorized_credentials: vec![],
438 time_limit: chrono::Duration::hours(1),
439 call_budget: 5,
440 rate_limit_per_minute: None,
441 rate_limit_window_secs: 60,
442 data_sensitivity_ceiling: DataSensitivity::Internal,
443 }
444 }
445
446 #[tokio::test]
447 async fn create_and_use_session() {
448 let store = SessionStore::new();
449 let session = store.create(test_create_request()).await;
450
451 assert_eq!(session.calls_made, 0);
452 assert!(session.is_active());
453
454 let updated = store
455 .use_session(session.session_id, "read_file", None)
456 .await
457 .unwrap();
458 assert_eq!(updated.calls_made, 1);
459 }
460
461 #[tokio::test]
462 async fn budget_enforcement() {
463 let store = SessionStore::new();
464 let mut req = test_create_request();
465 req.call_budget = 2;
466 let session = store.create(req).await;
467
468 store
470 .use_session(session.session_id, "read_file", None)
471 .await
472 .unwrap();
473 store
474 .use_session(session.session_id, "read_file", None)
475 .await
476 .unwrap();
477
478 let result = store
480 .use_session(session.session_id, "read_file", None)
481 .await;
482 assert!(matches!(result, Err(SessionError::BudgetExceeded { .. })));
483 }
484
485 #[tokio::test]
486 async fn tool_whitelist_enforcement() {
487 let store = SessionStore::new();
488 let session = store.create(test_create_request()).await;
489
490 store
492 .use_session(session.session_id, "read_file", None)
493 .await
494 .unwrap();
495
496 let result = store
498 .use_session(session.session_id, "delete_file", None)
499 .await;
500 assert!(matches!(
501 result,
502 Err(SessionError::ToolNotAuthorized { .. })
503 ));
504 }
505
506 #[tokio::test]
507 async fn session_expiry() {
508 let store = SessionStore::new();
509 let mut req = test_create_request();
510 req.time_limit = chrono::Duration::seconds(1);
513 let session = store.create(req).await;
514
515 tokio::time::sleep(std::time::Duration::from_millis(1100)).await;
517
518 let result = store
519 .use_session(session.session_id, "read_file", None)
520 .await;
521 assert!(matches!(result, Err(SessionError::Expired(_))));
522 }
523
524 #[tokio::test]
525 async fn close_and_reuse() {
526 let store = SessionStore::new();
527 let session = store.create(test_create_request()).await;
528
529 store.close(session.session_id).await.unwrap();
530
531 let result = store
532 .use_session(session.session_id, "read_file", None)
533 .await;
534 assert!(matches!(result, Err(SessionError::AlreadyClosed(_))));
535 }
536
537 #[tokio::test]
538 async fn cleanup_expired_sessions() {
539 let store = SessionStore::new();
540
541 let mut req = test_create_request();
543 req.time_limit = chrono::Duration::seconds(1);
544 store.create(req).await;
545
546 let valid_req = test_create_request();
548 store.create(valid_req).await;
549
550 tokio::time::sleep(std::time::Duration::from_millis(1100)).await;
552
553 let removed = store.cleanup_expired().await;
554 assert_eq!(removed, 1);
555 }
556
557 #[tokio::test]
558 async fn session_not_found() {
559 let store = SessionStore::new();
560 let fake_id = Uuid::new_v4();
561 let result = store.use_session(fake_id, "anything", None).await;
562 assert!(matches!(result, Err(SessionError::NotFound(_))));
563 }
564
565 #[tokio::test]
566 async fn rate_limit_enforcement() {
567 let store = SessionStore::new();
568 let mut req = test_create_request();
569 req.rate_limit_per_minute = Some(3);
570 req.call_budget = 100; let session = store.create(req).await;
572
573 store
575 .use_session(session.session_id, "read_file", None)
576 .await
577 .unwrap();
578 store
579 .use_session(session.session_id, "read_file", None)
580 .await
581 .unwrap();
582 store
583 .use_session(session.session_id, "read_file", None)
584 .await
585 .unwrap();
586
587 let result = store
589 .use_session(session.session_id, "read_file", None)
590 .await;
591 assert!(
592 matches!(result, Err(SessionError::RateLimited { .. })),
593 "expected RateLimited, got {result:?}"
594 );
595 }
596
597 #[tokio::test]
598 async fn no_rate_limit_when_unset() {
599 let store = SessionStore::new();
600 let mut req = test_create_request();
601 req.rate_limit_per_minute = None;
602 req.call_budget = 100;
603 let session = store.create(req).await;
604
605 for _ in 0..10 {
607 store
608 .use_session(session.session_id, "read_file", None)
609 .await
610 .unwrap();
611 }
612 }
613
614 #[tokio::test]
616 async fn batch_validation_atomicity() {
617 let store = SessionStore::new();
618 let mut req = test_create_request();
619 req.call_budget = 10;
620 req.authorized_tools = vec!["read_file".into(), "list_dir".into()];
621 let session = store.create(req).await;
622
623 let result = store
625 .use_session_batch(session.session_id, &["read_file", "delete_file"], None)
626 .await;
627 assert!(
628 matches!(result, Err(SessionError::ToolNotAuthorized { .. })),
629 "expected ToolNotAuthorized, got {result:?}"
630 );
631
632 let s = store.get(session.session_id).await.unwrap();
634 assert_eq!(
635 s.calls_made, 0,
636 "no budget should be consumed on batch failure"
637 );
638 }
639
640 #[tokio::test]
641 async fn batch_budget_enforcement() {
642 let store = SessionStore::new();
643 let mut req = test_create_request();
644 req.call_budget = 3;
645 req.authorized_tools = vec!["read_file".into()];
646 let session = store.create(req).await;
647
648 let result = store
650 .use_session_batch(
651 session.session_id,
652 &["read_file", "read_file", "read_file", "read_file"],
653 None,
654 )
655 .await;
656 assert!(
657 matches!(result, Err(SessionError::BudgetExceeded { .. })),
658 "expected BudgetExceeded, got {result:?}"
659 );
660
661 let s = store.get(session.session_id).await.unwrap();
663 assert_eq!(
664 s.calls_made, 0,
665 "no budget should be consumed on batch failure"
666 );
667 }
668
669 #[tokio::test]
670 async fn batch_rate_limit_enforcement() {
671 let store = SessionStore::new();
672 let mut req = test_create_request();
673 req.call_budget = 100;
674 req.rate_limit_per_minute = Some(3);
675 req.authorized_tools = vec!["read_file".into()];
676 let session = store.create(req).await;
677
678 let result = store
680 .use_session_batch(
681 session.session_id,
682 &["read_file", "read_file", "read_file", "read_file"],
683 None,
684 )
685 .await;
686 assert!(
687 matches!(result, Err(SessionError::RateLimited { .. })),
688 "expected RateLimited, got {result:?}"
689 );
690 }
691
692 #[tokio::test]
693 async fn empty_batch_succeeds() {
694 let store = SessionStore::new();
695 let session = store.create(test_create_request()).await;
696
697 let result = store
699 .use_session_batch(session.session_id, &[], None)
700 .await
701 .unwrap();
702 assert_eq!(result.calls_made, 0, "empty batch must not consume budget");
703 }
704
705 #[tokio::test]
707 async fn cleanup_also_removes_closed() {
708 let store = SessionStore::new();
709 let session = store.create(test_create_request()).await;
710
711 store.close(session.session_id).await.unwrap();
713
714 let removed = store.cleanup_expired().await;
716 assert_eq!(removed, 1, "closed session should be cleaned up");
717
718 let result = store.get(session.session_id).await;
720 assert!(
721 matches!(result, Err(SessionError::NotFound(_))),
722 "closed session should be removed after cleanup"
723 );
724 }
725
726 #[tokio::test]
728 async fn zero_budget_session() {
729 let store = SessionStore::new();
730 let mut req = test_create_request();
731 req.call_budget = 0;
732 let session = store.create(req).await;
733
734 let result = store
735 .use_session(session.session_id, "read_file", None)
736 .await;
737 assert!(
738 matches!(result, Err(SessionError::BudgetExceeded { .. })),
739 "zero-budget session must reject the first call, got {result:?}"
740 );
741 }
742
743 #[tokio::test]
745 async fn deactivation_closes_agent_sessions() {
746 let store = SessionStore::new();
747 let agent_id = Uuid::new_v4();
748 let other_agent = Uuid::new_v4();
749
750 for _ in 0..3 {
751 let mut req = test_create_request();
752 req.agent_id = agent_id;
753 store.create(req).await;
754 }
755 let mut other_req = test_create_request();
756 other_req.agent_id = other_agent;
757 let other_session = store.create(other_req).await;
758
759 let closed = store.close_sessions_for_agent(agent_id).await;
760 assert_eq!(closed, 3);
761
762 let all = store.list_all().await;
763 for s in &all {
764 if s.agent_id == agent_id {
765 assert_eq!(s.status, SessionStatus::Closed);
766 }
767 }
768 let other = store.get(other_session.session_id).await.unwrap();
769 assert_eq!(other.status, SessionStatus::Active);
770 }
771
772 #[tokio::test]
776 async fn concurrent_budget_enforcement() {
777 let store = SessionStore::new();
778 let mut req = test_create_request();
779 req.call_budget = 5;
780 req.authorized_tools = vec!["read_file".into()];
781 let session = store.create(req).await;
782
783 let successes = Arc::new(std::sync::atomic::AtomicU64::new(0));
784 let failures = Arc::new(std::sync::atomic::AtomicU64::new(0));
785
786 let mut handles = Vec::new();
787 for _ in 0..10 {
788 let store = store.clone();
789 let sid = session.session_id;
790 let s = successes.clone();
791 let f = failures.clone();
792 handles.push(tokio::spawn(async move {
793 match store.use_session(sid, "read_file", None).await {
794 Ok(_) => {
795 s.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
796 }
797 Err(SessionError::BudgetExceeded { .. }) => {
798 f.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
799 }
800 Err(e) => panic!("unexpected error: {e:?}"),
801 }
802 }));
803 }
804
805 for h in handles {
806 h.await.unwrap();
807 }
808
809 assert_eq!(
810 successes.load(std::sync::atomic::Ordering::Relaxed),
811 5,
812 "exactly 5 calls should succeed"
813 );
814 assert_eq!(
815 failures.load(std::sync::atomic::Ordering::Relaxed),
816 5,
817 "exactly 5 calls should fail with BudgetExceeded"
818 );
819 }
820
821 #[tokio::test]
824 async fn agent_mismatch_rejected() {
825 let store = SessionStore::new();
826 let session = store.create(test_create_request()).await;
827 let attacker_id = Uuid::new_v4();
828
829 let result = store
831 .use_session(session.session_id, "read_file", Some(attacker_id))
832 .await;
833 assert!(
834 matches!(result, Err(SessionError::AgentMismatch { .. })),
835 "different agent must be rejected, got {result:?}"
836 );
837
838 let result = store
840 .use_session(session.session_id, "read_file", Some(session.agent_id))
841 .await;
842 assert!(result.is_ok(), "session owner should succeed");
843 }
844
845 #[tokio::test]
847 async fn batch_agent_mismatch_rejected() {
848 let store = SessionStore::new();
849 let session = store.create(test_create_request()).await;
850 let attacker_id = Uuid::new_v4();
851
852 let result = store
853 .use_session_batch(session.session_id, &["read_file"], Some(attacker_id))
854 .await;
855 assert!(
856 matches!(result, Err(SessionError::AgentMismatch { .. })),
857 "batch with wrong agent must be rejected, got {result:?}"
858 );
859 }
860}