1use chrono::Utc;
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use crate::error::SessionError;
10use crate::model::{DataSensitivity, SessionId, SessionStatus, TaskSession};
11
12pub struct CreateSessionRequest {
14 pub agent_id: Uuid,
16 pub delegation_chain_snapshot: Vec<String>,
18 pub declared_intent: String,
20 pub authorized_tools: Vec<String>,
22 pub time_limit: chrono::Duration,
24 pub call_budget: u64,
26 pub rate_limit_per_minute: Option<u64>,
28 pub rate_limit_window_secs: u64,
30 pub data_sensitivity_ceiling: DataSensitivity,
32}
33
34#[derive(Clone)]
36pub struct SessionStore {
37 sessions: Arc<RwLock<HashMap<SessionId, TaskSession>>>,
38}
39
40impl SessionStore {
41 pub fn new() -> Self {
43 Self {
44 sessions: Arc::new(RwLock::new(HashMap::new())),
45 }
46 }
47
48 pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
50 let session = TaskSession {
51 session_id: Uuid::new_v4(),
52 agent_id: req.agent_id,
53 delegation_chain_snapshot: req.delegation_chain_snapshot,
54 declared_intent: req.declared_intent,
55 authorized_tools: req.authorized_tools,
56 time_limit: req.time_limit,
57 call_budget: req.call_budget,
58 calls_made: 0,
59 rate_limit_per_minute: req.rate_limit_per_minute,
60 rate_window_start: Utc::now(),
61 rate_window_calls: 0,
62 rate_limit_window_secs: req.rate_limit_window_secs,
63 data_sensitivity_ceiling: req.data_sensitivity_ceiling,
64 created_at: Utc::now(),
65 status: SessionStatus::Active,
66 };
67
68 tracing::info!(
69 session_id = %session.session_id,
70 agent_id = %session.agent_id,
71 intent = %session.declared_intent,
72 budget = session.call_budget,
73 "created task session"
74 );
75
76 let mut sessions = self.sessions.write().await;
77 sessions.insert(session.session_id, session.clone());
78 session
79 }
80
81 pub async fn use_session(
89 &self,
90 session_id: SessionId,
91 tool_name: &str,
92 ) -> Result<TaskSession, SessionError> {
93 let mut sessions = self.sessions.write().await;
94 let session = sessions
95 .get_mut(&session_id)
96 .ok_or(SessionError::NotFound(session_id))?;
97
98 if session.status == SessionStatus::Closed {
99 return Err(SessionError::AlreadyClosed(session_id));
100 }
101
102 if session.is_expired() {
104 session.status = SessionStatus::Expired;
105 return Err(SessionError::Expired(session_id));
106 }
107
108 if session.is_budget_exceeded() {
110 return Err(SessionError::BudgetExceeded {
111 session_id,
112 limit: session.call_budget,
113 used: session.calls_made,
114 });
115 }
116
117 if !session.is_tool_authorized(tool_name) {
119 return Err(SessionError::ToolNotAuthorized {
120 session_id,
121 tool: tool_name.into(),
122 });
123 }
124
125 if session.check_rate_limit() {
127 return Err(SessionError::RateLimited {
128 session_id,
129 limit_per_minute: session.rate_limit_per_minute.unwrap_or(0),
130 });
131 }
132
133 session.calls_made += 1;
135
136 tracing::debug!(
137 session_id = %session_id,
138 tool = tool_name,
139 calls = session.calls_made,
140 budget = session.call_budget,
141 "session tool call recorded"
142 );
143
144 Ok(session.clone())
145 }
146
147 pub async fn use_session_batch(
154 &self,
155 session_id: SessionId,
156 tool_names: &[&str],
157 ) -> Result<TaskSession, SessionError> {
158 let mut sessions = self.sessions.write().await;
159 let session = sessions
160 .get_mut(&session_id)
161 .ok_or(SessionError::NotFound(session_id))?;
162
163 if session.status == SessionStatus::Closed {
164 return Err(SessionError::AlreadyClosed(session_id));
165 }
166
167 if session.is_expired() {
169 session.status = SessionStatus::Expired;
170 return Err(SessionError::Expired(session_id));
171 }
172
173 let batch_size = tool_names.len() as u64;
174
175 if session.calls_made + batch_size > session.call_budget {
177 return Err(SessionError::BudgetExceeded {
178 session_id,
179 limit: session.call_budget,
180 used: session.calls_made,
181 });
182 }
183
184 for tool_name in tool_names {
186 if !session.is_tool_authorized(tool_name) {
187 return Err(SessionError::ToolNotAuthorized {
188 session_id,
189 tool: (*tool_name).into(),
190 });
191 }
192 }
193
194 if let Some(limit) = session.rate_limit_per_minute {
198 let now = chrono::Utc::now();
199 let elapsed = now - session.rate_window_start;
200 if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
201 } else if session.rate_window_calls + batch_size > limit {
203 return Err(SessionError::RateLimited {
204 session_id,
205 limit_per_minute: limit,
206 });
207 }
208 }
209
210 if let Some(_limit) = session.rate_limit_per_minute {
213 let now = chrono::Utc::now();
214 let elapsed = now - session.rate_window_start;
215 if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
216 session.rate_window_start = now;
217 session.rate_window_calls = batch_size;
218 } else {
219 session.rate_window_calls += batch_size;
220 }
221 }
222
223 session.calls_made += batch_size;
224
225 tracing::debug!(
226 session_id = %session_id,
227 batch_size = batch_size,
228 calls = session.calls_made,
229 budget = session.call_budget,
230 "session batch tool calls recorded"
231 );
232
233 Ok(session.clone())
234 }
235
236 pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
238 let mut sessions = self.sessions.write().await;
239 let session = sessions
240 .get_mut(&session_id)
241 .ok_or(SessionError::NotFound(session_id))?;
242
243 if session.status == SessionStatus::Closed {
244 return Err(SessionError::AlreadyClosed(session_id));
245 }
246
247 session.status = SessionStatus::Closed;
248 tracing::info!(session_id = %session_id, "session closed");
249 Ok(session.clone())
250 }
251
252 pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
254 let sessions = self.sessions.read().await;
255 sessions
256 .get(&session_id)
257 .cloned()
258 .ok_or(SessionError::NotFound(session_id))
259 }
260
261 pub async fn list_all(&self) -> Vec<TaskSession> {
263 let sessions = self.sessions.read().await;
264 sessions.values().cloned().collect()
265 }
266
267 pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
271 let sessions = self.sessions.read().await;
272 sessions
273 .values()
274 .filter(|s| s.agent_id == agent_id && s.status == SessionStatus::Active)
275 .count() as u64
276 }
277
278 pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
283 let mut sessions = self.sessions.write().await;
284 let mut closed = 0usize;
285 for session in sessions.values_mut() {
286 if session.agent_id == agent_id && session.status == SessionStatus::Active {
287 session.status = SessionStatus::Closed;
288 closed += 1;
289 tracing::info!(
290 session_id = %session.session_id,
291 agent_id = %agent_id,
292 "closed session due to agent deactivation"
293 );
294 }
295 }
296 closed
297 }
298
299 pub async fn cleanup_expired(&self) -> usize {
301 let mut sessions = self.sessions.write().await;
302 let before = sessions.len();
303 sessions.retain(|_, s| {
306 if s.is_expired() {
307 tracing::debug!(session_id = %s.session_id, "cleaning up expired session");
308 false
309 } else if s.status == SessionStatus::Closed {
310 tracing::debug!(session_id = %s.session_id, "cleaning up closed session");
311 false
312 } else {
313 true
314 }
315 });
316 let removed = before - sessions.len();
317 if removed > 0 {
318 tracing::info!(removed, "cleaned up expired/closed sessions");
319 }
320 removed
321 }
322}
323
324impl Default for SessionStore {
325 fn default() -> Self {
326 Self::new()
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 fn test_create_request() -> CreateSessionRequest {
335 CreateSessionRequest {
336 agent_id: Uuid::new_v4(),
337 delegation_chain_snapshot: vec![],
338 declared_intent: "read and analyze files".into(),
339 authorized_tools: vec!["read_file".into(), "list_dir".into()],
340 time_limit: chrono::Duration::hours(1),
341 call_budget: 5,
342 rate_limit_per_minute: None,
343 rate_limit_window_secs: 60,
344 data_sensitivity_ceiling: DataSensitivity::Internal,
345 }
346 }
347
348 #[tokio::test]
349 async fn create_and_use_session() {
350 let store = SessionStore::new();
351 let session = store.create(test_create_request()).await;
352
353 assert_eq!(session.calls_made, 0);
354 assert!(session.is_active());
355
356 let updated = store
357 .use_session(session.session_id, "read_file")
358 .await
359 .unwrap();
360 assert_eq!(updated.calls_made, 1);
361 }
362
363 #[tokio::test]
364 async fn budget_enforcement() {
365 let store = SessionStore::new();
366 let mut req = test_create_request();
367 req.call_budget = 2;
368 let session = store.create(req).await;
369
370 store
372 .use_session(session.session_id, "read_file")
373 .await
374 .unwrap();
375 store
376 .use_session(session.session_id, "read_file")
377 .await
378 .unwrap();
379
380 let result = store.use_session(session.session_id, "read_file").await;
382 assert!(matches!(result, Err(SessionError::BudgetExceeded { .. })));
383 }
384
385 #[tokio::test]
386 async fn tool_whitelist_enforcement() {
387 let store = SessionStore::new();
388 let session = store.create(test_create_request()).await;
389
390 store
392 .use_session(session.session_id, "read_file")
393 .await
394 .unwrap();
395
396 let result = store.use_session(session.session_id, "delete_file").await;
398 assert!(matches!(
399 result,
400 Err(SessionError::ToolNotAuthorized { .. })
401 ));
402 }
403
404 #[tokio::test]
405 async fn session_expiry() {
406 let store = SessionStore::new();
407 let mut req = test_create_request();
408 req.time_limit = chrono::Duration::zero();
410 let session = store.create(req).await;
411
412 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
416
417 let result = store.use_session(session.session_id, "read_file").await;
418 assert!(matches!(result, Err(SessionError::Expired(_))));
419 }
420
421 #[tokio::test]
422 async fn close_and_reuse() {
423 let store = SessionStore::new();
424 let session = store.create(test_create_request()).await;
425
426 store.close(session.session_id).await.unwrap();
427
428 let result = store.use_session(session.session_id, "read_file").await;
429 assert!(matches!(result, Err(SessionError::AlreadyClosed(_))));
430 }
431
432 #[tokio::test]
433 async fn cleanup_expired_sessions() {
434 let store = SessionStore::new();
435
436 let mut req = test_create_request();
438 req.time_limit = chrono::Duration::zero();
439 store.create(req).await;
440
441 let valid_req = test_create_request();
443 store.create(valid_req).await;
444
445 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
446
447 let removed = store.cleanup_expired().await;
448 assert_eq!(removed, 1);
449 }
450
451 #[tokio::test]
452 async fn session_not_found() {
453 let store = SessionStore::new();
454 let fake_id = Uuid::new_v4();
455 let result = store.use_session(fake_id, "anything").await;
456 assert!(matches!(result, Err(SessionError::NotFound(_))));
457 }
458
459 #[tokio::test]
460 async fn rate_limit_enforcement() {
461 let store = SessionStore::new();
462 let mut req = test_create_request();
463 req.rate_limit_per_minute = Some(3);
464 req.call_budget = 100; let session = store.create(req).await;
466
467 store
469 .use_session(session.session_id, "read_file")
470 .await
471 .unwrap();
472 store
473 .use_session(session.session_id, "read_file")
474 .await
475 .unwrap();
476 store
477 .use_session(session.session_id, "read_file")
478 .await
479 .unwrap();
480
481 let result = store.use_session(session.session_id, "read_file").await;
483 assert!(
484 matches!(result, Err(SessionError::RateLimited { .. })),
485 "expected RateLimited, got {result:?}"
486 );
487 }
488
489 #[tokio::test]
490 async fn no_rate_limit_when_unset() {
491 let store = SessionStore::new();
492 let mut req = test_create_request();
493 req.rate_limit_per_minute = None;
494 req.call_budget = 100;
495 let session = store.create(req).await;
496
497 for _ in 0..10 {
499 store
500 .use_session(session.session_id, "read_file")
501 .await
502 .unwrap();
503 }
504 }
505
506 #[tokio::test]
508 async fn batch_validation_atomicity() {
509 let store = SessionStore::new();
510 let mut req = test_create_request();
511 req.call_budget = 10;
512 req.authorized_tools = vec!["read_file".into(), "list_dir".into()];
513 let session = store.create(req).await;
514
515 let result = store
517 .use_session_batch(session.session_id, &["read_file", "delete_file"])
518 .await;
519 assert!(
520 matches!(result, Err(SessionError::ToolNotAuthorized { .. })),
521 "expected ToolNotAuthorized, got {result:?}"
522 );
523
524 let s = store.get(session.session_id).await.unwrap();
526 assert_eq!(
527 s.calls_made, 0,
528 "no budget should be consumed on batch failure"
529 );
530 }
531
532 #[tokio::test]
533 async fn batch_budget_enforcement() {
534 let store = SessionStore::new();
535 let mut req = test_create_request();
536 req.call_budget = 3;
537 req.authorized_tools = vec!["read_file".into()];
538 let session = store.create(req).await;
539
540 let result = store
542 .use_session_batch(
543 session.session_id,
544 &["read_file", "read_file", "read_file", "read_file"],
545 )
546 .await;
547 assert!(
548 matches!(result, Err(SessionError::BudgetExceeded { .. })),
549 "expected BudgetExceeded, got {result:?}"
550 );
551
552 let s = store.get(session.session_id).await.unwrap();
554 assert_eq!(
555 s.calls_made, 0,
556 "no budget should be consumed on batch failure"
557 );
558 }
559
560 #[tokio::test]
561 async fn batch_rate_limit_enforcement() {
562 let store = SessionStore::new();
563 let mut req = test_create_request();
564 req.call_budget = 100;
565 req.rate_limit_per_minute = Some(3);
566 req.authorized_tools = vec!["read_file".into()];
567 let session = store.create(req).await;
568
569 let result = store
571 .use_session_batch(
572 session.session_id,
573 &["read_file", "read_file", "read_file", "read_file"],
574 )
575 .await;
576 assert!(
577 matches!(result, Err(SessionError::RateLimited { .. })),
578 "expected RateLimited, got {result:?}"
579 );
580 }
581
582 #[tokio::test]
583 async fn empty_batch_succeeds() {
584 let store = SessionStore::new();
585 let session = store.create(test_create_request()).await;
586
587 let result = store
589 .use_session_batch(session.session_id, &[])
590 .await
591 .unwrap();
592 assert_eq!(result.calls_made, 0, "empty batch must not consume budget");
593 }
594
595 #[tokio::test]
597 async fn cleanup_also_removes_closed() {
598 let store = SessionStore::new();
599 let session = store.create(test_create_request()).await;
600
601 store.close(session.session_id).await.unwrap();
603
604 let removed = store.cleanup_expired().await;
606 assert_eq!(removed, 1, "closed session should be cleaned up");
607
608 let result = store.get(session.session_id).await;
610 assert!(
611 matches!(result, Err(SessionError::NotFound(_))),
612 "closed session should be removed after cleanup"
613 );
614 }
615
616 #[tokio::test]
618 async fn zero_budget_session() {
619 let store = SessionStore::new();
620 let mut req = test_create_request();
621 req.call_budget = 0;
622 let session = store.create(req).await;
623
624 let result = store.use_session(session.session_id, "read_file").await;
625 assert!(
626 matches!(result, Err(SessionError::BudgetExceeded { .. })),
627 "zero-budget session must reject the first call, got {result:?}"
628 );
629 }
630
631 #[tokio::test]
633 async fn deactivation_closes_agent_sessions() {
634 let store = SessionStore::new();
635 let agent_id = Uuid::new_v4();
636 let other_agent = Uuid::new_v4();
637
638 for _ in 0..3 {
639 let mut req = test_create_request();
640 req.agent_id = agent_id;
641 store.create(req).await;
642 }
643 let mut other_req = test_create_request();
644 other_req.agent_id = other_agent;
645 let other_session = store.create(other_req).await;
646
647 let closed = store.close_sessions_for_agent(agent_id).await;
648 assert_eq!(closed, 3);
649
650 let all = store.list_all().await;
651 for s in &all {
652 if s.agent_id == agent_id {
653 assert_eq!(s.status, SessionStatus::Closed);
654 }
655 }
656 let other = store.get(other_session.session_id).await.unwrap();
657 assert_eq!(other.status, SessionStatus::Active);
658 }
659
660 #[tokio::test]
664 async fn concurrent_budget_enforcement() {
665 let store = SessionStore::new();
666 let mut req = test_create_request();
667 req.call_budget = 5;
668 req.authorized_tools = vec!["read_file".into()];
669 let session = store.create(req).await;
670
671 let successes = Arc::new(std::sync::atomic::AtomicU64::new(0));
672 let failures = Arc::new(std::sync::atomic::AtomicU64::new(0));
673
674 let mut handles = Vec::new();
675 for _ in 0..10 {
676 let store = store.clone();
677 let sid = session.session_id;
678 let s = successes.clone();
679 let f = failures.clone();
680 handles.push(tokio::spawn(async move {
681 match store.use_session(sid, "read_file").await {
682 Ok(_) => {
683 s.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
684 }
685 Err(SessionError::BudgetExceeded { .. }) => {
686 f.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
687 }
688 Err(e) => panic!("unexpected error: {e:?}"),
689 }
690 }));
691 }
692
693 for h in handles {
694 h.await.unwrap();
695 }
696
697 assert_eq!(
698 successes.load(std::sync::atomic::Ordering::Relaxed),
699 5,
700 "exactly 5 calls should succeed"
701 );
702 assert_eq!(
703 failures.load(std::sync::atomic::Ordering::Relaxed),
704 5,
705 "exactly 5 calls should fail with BudgetExceeded"
706 );
707 }
708}