1use dashmap::DashMap;
19use std::collections::{HashMap, HashSet, VecDeque};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use vellaveto_config::ToolManifest;
23use vellaveto_mcp::memory_tracking::MemoryTracker;
24use vellaveto_mcp::rug_pull::ToolAnnotations;
25use vellaveto_types::AgentIdentity;
26
27pub type ToolAnnotationsCompact = ToolAnnotations;
29
30#[derive(Debug)]
32pub struct SessionState {
33 pub session_id: String,
34 pub created_at: Instant,
35 pub last_activity: Instant,
36 pub protocol_version: Option<String>,
37 pub(crate) known_tools: HashMap<String, ToolAnnotations>,
40 pub request_count: u64,
41 pub tools_list_seen: bool,
44 pub oauth_subject: Option<String>,
47 pub(crate) flagged_tools: HashSet<String>,
52 pub pinned_manifest: Option<ToolManifest>,
55 pub(crate) call_counts: HashMap<String, u64>,
60 pub(crate) action_history: VecDeque<String>,
66 pub memory_tracker: MemoryTracker,
70 pub elicitation_count: u32,
73 pub sampling_count: u32,
77 pub(crate) pending_tool_calls: HashMap<String, String>,
83 pub token_expires_at: Option<u64>,
85 pub current_call_chain: Vec<vellaveto_types::CallChainEntry>,
90 pub agent_identity: Option<AgentIdentity>,
94 pub(crate) backend_sessions: HashMap<String, String>,
99 pub(crate) gateway_tools: HashMap<String, Vec<String>>,
104 pub risk_score: Option<vellaveto_types::RiskScore>,
106 pub(crate) abac_granted_policies: Vec<String>,
110 pub discovered_tools: HashMap<String, DiscoveredToolSession>,
113}
114
115const MAX_DISCOVERED_TOOLS_PER_SESSION: usize = 10_000;
118
119const MAX_BACKEND_SESSIONS: usize = 128;
121const MAX_GATEWAY_TOOLS: usize = 128;
123const MAX_TOOLS_PER_BACKEND: usize = 1000;
125
126const MAX_GRANTED_POLICIES: usize = 1024;
128
129const MAX_KNOWN_TOOLS: usize = 2048;
131
132const MAX_FLAGGED_TOOLS: usize = 2048;
134
135const MAX_GLOBAL_FLAGGED_TOOLS: usize = 10_000;
138
139const GLOBAL_FLAGGED_TOOL_TTL: Duration = Duration::from_secs(24 * 60 * 60);
143
144#[derive(Debug, Clone)]
147pub struct GlobalFlaggedToolEntry {
148 pub flagged_at: Instant,
150 pub ttl: Duration,
152}
153
154impl GlobalFlaggedToolEntry {
155 fn is_expired(&self) -> bool {
156 self.flagged_at.elapsed() > self.ttl
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct DiscoveredToolSession {
163 pub tool_id: String,
165 pub discovered_at: Instant,
167 pub ttl: Duration,
169 pub used: bool,
171}
172
173impl DiscoveredToolSession {
174 pub fn is_expired(&self) -> bool {
176 self.discovered_at.elapsed() > self.ttl
177 }
178}
179
180impl SessionState {
181 pub fn new(session_id: String) -> Self {
182 let now = Instant::now();
183 Self {
184 session_id,
185 created_at: now,
186 last_activity: now,
187 protocol_version: None,
188 known_tools: HashMap::new(),
189 request_count: 0,
190 tools_list_seen: false,
191 oauth_subject: None,
192 flagged_tools: HashSet::new(),
193 pinned_manifest: None,
194 call_counts: HashMap::new(),
195 action_history: VecDeque::new(),
196 memory_tracker: MemoryTracker::new(),
197 elicitation_count: 0,
198 sampling_count: 0,
199 pending_tool_calls: HashMap::new(),
200 token_expires_at: None,
201 current_call_chain: Vec::new(),
202 agent_identity: None,
203 backend_sessions: HashMap::new(),
204 gateway_tools: HashMap::new(),
205 risk_score: None,
206 abac_granted_policies: Vec::new(),
207 discovered_tools: HashMap::new(),
208 }
209 }
210
211 pub fn known_tools(&self) -> &HashMap<String, ToolAnnotations> {
219 &self.known_tools
220 }
221
222 pub fn flagged_tools(&self) -> &HashSet<String> {
224 &self.flagged_tools
225 }
226
227 pub fn backend_sessions(&self) -> &HashMap<String, String> {
229 &self.backend_sessions
230 }
231
232 pub fn gateway_tools(&self) -> &HashMap<String, Vec<String>> {
234 &self.gateway_tools
235 }
236
237 pub fn abac_granted_policies(&self) -> &[String] {
239 &self.abac_granted_policies
240 }
241
242 #[allow(clippy::map_entry)] pub fn insert_backend_session(
246 &mut self,
247 backend_id: String,
248 upstream_session_id: String,
249 ) -> bool {
250 if self.backend_sessions.contains_key(&backend_id) {
251 self.backend_sessions
252 .insert(backend_id, upstream_session_id);
253 return true;
254 }
255 if self.backend_sessions.len() >= MAX_BACKEND_SESSIONS {
256 tracing::warn!(
257 session_id = %self.session_id,
258 capacity = MAX_BACKEND_SESSIONS,
259 "Backend sessions capacity reached; dropping new entry"
260 );
261 return false;
262 }
263 self.backend_sessions
264 .insert(backend_id, upstream_session_id);
265 true
266 }
267
268 pub fn insert_gateway_tools(&mut self, backend_id: String, tools: Vec<String>) -> bool {
271 if !self.gateway_tools.contains_key(&backend_id)
272 && self.gateway_tools.len() >= MAX_GATEWAY_TOOLS
273 {
274 tracing::warn!(
275 session_id = %self.session_id,
276 capacity = MAX_GATEWAY_TOOLS,
277 "Gateway tools capacity reached; dropping new backend entry"
278 );
279 return false;
280 }
281 let bounded_tools: Vec<String> = tools.into_iter().take(MAX_TOOLS_PER_BACKEND).collect();
283 self.gateway_tools.insert(backend_id, bounded_tools);
284 true
285 }
286
287 pub fn insert_granted_policy(&mut self, policy_id: String) {
289 if !self.abac_granted_policies.contains(&policy_id)
290 && self.abac_granted_policies.len() < MAX_GRANTED_POLICIES
291 {
292 self.abac_granted_policies.push(policy_id);
293 }
294 }
295
296 #[allow(clippy::map_entry)] pub fn insert_known_tool(&mut self, name: String, annotations: ToolAnnotationsCompact) -> bool {
300 if self.known_tools.contains_key(&name) {
301 self.known_tools.insert(name, annotations);
302 return true;
303 }
304 if self.known_tools.len() >= MAX_KNOWN_TOOLS {
305 tracing::warn!(
306 session_id = %self.session_id,
307 capacity = MAX_KNOWN_TOOLS,
308 "Known tools capacity reached; dropping new tool"
309 );
310 return false;
311 }
312 self.known_tools.insert(name, annotations);
313 true
314 }
315
316 pub fn insert_flagged_tool(&mut self, name: String) {
318 if self.flagged_tools.len() < MAX_FLAGGED_TOOLS {
319 self.flagged_tools.insert(name);
320 }
321 }
322
323 pub fn record_discovered_tools(&mut self, tool_ids: &[String], ttl: Duration) {
329 let now = Instant::now();
330 for tool_id in tool_ids {
331 if !self.discovered_tools.contains_key(tool_id) {
333 if self.discovered_tools.len() >= MAX_DISCOVERED_TOOLS_PER_SESSION {
334 self.evict_expired_discoveries();
336 }
337 if self.discovered_tools.len() >= MAX_DISCOVERED_TOOLS_PER_SESSION {
338 tracing::warn!(
339 session_id = %self.session_id,
340 capacity = MAX_DISCOVERED_TOOLS_PER_SESSION,
341 "Discovered tools capacity reached; dropping new tool"
342 );
343 continue;
344 }
345 }
346 self.discovered_tools.insert(
347 tool_id.clone(),
348 DiscoveredToolSession {
349 tool_id: tool_id.clone(),
350 discovered_at: now,
351 ttl,
352 used: false,
353 },
354 );
355 }
356 }
357
358 pub fn is_tool_discovery_expired(&self, tool_id: &str) -> Option<bool> {
364 self.discovered_tools.get(tool_id).map(|d| d.is_expired())
365 }
366
367 pub fn mark_tool_used(&mut self, tool_id: &str) -> bool {
371 if let Some(entry) = self.discovered_tools.get_mut(tool_id) {
372 entry.used = true;
373 true
374 } else {
375 false
376 }
377 }
378
379 pub fn evict_expired_discoveries(&mut self) -> usize {
383 let before = self.discovered_tools.len();
384 self.discovered_tools.retain(|_, d| !d.is_expired());
385 before - self.discovered_tools.len()
386 }
387
388 pub fn touch(&mut self) {
390 self.last_activity = Instant::now();
391 self.request_count = self.request_count.saturating_add(1);
393 }
394
395 pub fn is_expired(&self, timeout: Duration, max_lifetime: Option<Duration>) -> bool {
401 if self.last_activity.elapsed() > timeout {
402 return true;
403 }
404 if let Some(max) = max_lifetime {
405 if self.created_at.elapsed() > max {
406 return true;
407 }
408 }
409 if let Some(exp) = self.token_expires_at {
410 let now = std::time::SystemTime::now()
411 .duration_since(std::time::UNIX_EPOCH)
412 .unwrap_or_default()
413 .as_secs();
414 if now >= exp {
415 return true;
416 }
417 }
418 false
419 }
420}
421
422use vellaveto_types::identity::RequestContext;
427
428pub struct StatefulContext<'a> {
443 session: &'a SessionState,
444 previous_actions_cache: std::sync::OnceLock<Vec<String>>,
447}
448
449impl<'a> StatefulContext<'a> {
450 pub fn new(session: &'a SessionState) -> Self {
452 Self {
453 session,
454 previous_actions_cache: std::sync::OnceLock::new(),
455 }
456 }
457}
458
459impl RequestContext for StatefulContext<'_> {
460 fn call_counts(&self) -> &HashMap<String, u64> {
461 &self.session.call_counts
462 }
463
464 fn previous_actions(&self) -> &[String] {
465 self.previous_actions_cache
466 .get_or_init(|| self.session.action_history.iter().cloned().collect())
467 }
468
469 fn call_chain(&self) -> &[vellaveto_types::CallChainEntry] {
470 &self.session.current_call_chain
471 }
472
473 fn agent_identity(&self) -> Option<&AgentIdentity> {
474 self.session.agent_identity.as_ref()
475 }
476
477 fn session_guard_state(&self) -> Option<&str> {
478 None }
480
481 fn risk_score(&self) -> Option<&vellaveto_types::RiskScore> {
482 self.session.risk_score.as_ref()
483 }
484
485 fn to_evaluation_context(&self) -> vellaveto_types::EvaluationContext {
486 vellaveto_types::EvaluationContext {
487 agent_id: self.session.oauth_subject.clone(),
488 agent_identity: self.session.agent_identity.clone(),
489 call_counts: self.session.call_counts.clone(),
490 previous_actions: self.session.action_history.iter().cloned().collect(),
491 call_chain: self.session.current_call_chain.clone(),
492 session_state: None,
493 ..Default::default()
494 }
495 }
496}
497
498const MAX_SESSION_ID_LEN: usize = 128;
502
503pub struct SessionStore {
505 sessions: Arc<DashMap<String, SessionState>>,
506 session_timeout: Duration,
507 max_sessions: usize,
508 max_lifetime: Option<Duration>,
512 global_flagged_tools: Arc<DashMap<String, GlobalFlaggedToolEntry>>,
517}
518
519impl SessionStore {
520 pub fn new(session_timeout: Duration, max_sessions: usize) -> Self {
521 Self {
522 sessions: Arc::new(DashMap::new()),
523 session_timeout,
524 max_sessions,
525 max_lifetime: None,
526 global_flagged_tools: Arc::new(DashMap::new()),
527 }
528 }
529
530 pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
533 self.max_lifetime = Some(lifetime);
534 self
535 }
536
537 pub fn get_or_create(&self, client_session_id: Option<&str>) -> String {
543 let client_session_id = client_session_id.filter(|id| id.len() <= MAX_SESSION_ID_LEN);
546
547 if let Some(id) = client_session_id {
549 if let Some(mut session) = self.sessions.get_mut(id) {
550 if !session.is_expired(self.session_timeout, self.max_lifetime) {
551 session.touch();
552 return id.to_string();
553 }
554 drop(session);
556 self.sessions.remove(id);
557 }
558 }
559
560 if self.sessions.len() >= self.max_sessions {
567 self.evict_expired();
568 if self.sessions.len() >= self.max_sessions {
570 self.evict_oldest();
571 }
572 }
573
574 let session_id = uuid::Uuid::new_v4().to_string();
576 self.sessions
577 .insert(session_id.clone(), SessionState::new(session_id.clone()));
578 session_id
579 }
580
581 pub fn get(
583 &self,
584 session_id: &str,
585 ) -> Option<dashmap::mapref::one::Ref<'_, String, SessionState>> {
586 self.sessions.get(session_id)
587 }
588
589 pub fn get_mut(
591 &self,
592 session_id: &str,
593 ) -> Option<dashmap::mapref::one::RefMut<'_, String, SessionState>> {
594 self.sessions.get_mut(session_id)
595 }
596
597 pub fn evict_expired(&self) {
599 self.sessions
600 .retain(|_, session| !session.is_expired(self.session_timeout, self.max_lifetime));
601 }
602
603 fn evict_oldest(&self) {
605 let oldest = self
606 .sessions
607 .iter()
608 .min_by_key(|entry| entry.value().last_activity)
609 .map(|entry| entry.key().clone());
610
611 if let Some(id) = oldest {
612 self.sessions.remove(&id);
613 }
614 }
615
616 pub fn len(&self) -> usize {
618 self.sessions.len()
619 }
620
621 pub fn is_empty(&self) -> bool {
623 self.sessions.is_empty()
624 }
625
626 pub fn remove(&self, session_id: &str) -> bool {
628 self.sessions.remove(session_id).is_some()
629 }
630
631 pub fn flag_tool_globally(&self, tool_name: String) {
641 if self.global_flagged_tools.len() >= MAX_GLOBAL_FLAGGED_TOOLS {
642 self.evict_expired_global_flags();
644 if self.global_flagged_tools.len() >= MAX_GLOBAL_FLAGGED_TOOLS {
645 tracing::warn!(
646 tool = %tool_name,
647 capacity = MAX_GLOBAL_FLAGGED_TOOLS,
648 "Global flagged-tools registry at capacity; dropping new entry"
649 );
650 return;
651 }
652 }
653 self.global_flagged_tools
655 .entry(tool_name)
656 .or_insert_with(|| GlobalFlaggedToolEntry {
657 flagged_at: Instant::now(),
658 ttl: GLOBAL_FLAGGED_TOOL_TTL,
659 });
660 }
661
662 pub fn is_tool_globally_flagged(&self, tool_name: &str) -> bool {
668 self.global_flagged_tools
669 .get(tool_name)
670 .map(|entry| !entry.is_expired())
671 .unwrap_or(false)
672 }
673
674 pub fn evict_expired_global_flags(&self) -> usize {
676 let before = self.global_flagged_tools.len();
677 self.global_flagged_tools
678 .retain(|_, entry| !entry.is_expired());
679 before.saturating_sub(self.global_flagged_tools.len())
680 }
681
682 pub fn global_flagged_tools_len(&self) -> usize {
684 self.global_flagged_tools.len()
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691
692 #[test]
693 fn test_session_creation() {
694 let store = SessionStore::new(Duration::from_secs(300), 100);
695 let id = store.get_or_create(None);
696 assert_eq!(id.len(), 36); assert_eq!(store.len(), 1);
698 }
699
700 #[test]
701 fn test_session_reuse() {
702 let store = SessionStore::new(Duration::from_secs(300), 100);
703 let id1 = store.get_or_create(None);
704 let id2 = store.get_or_create(Some(&id1));
705 assert_eq!(id1, id2);
706 assert_eq!(store.len(), 1);
707 }
708
709 #[test]
710 fn test_session_unknown_id_creates_new() {
711 let store = SessionStore::new(Duration::from_secs(300), 100);
712 let id = store.get_or_create(Some("nonexistent-id"));
713 assert_ne!(id, "nonexistent-id");
714 assert_eq!(store.len(), 1);
715 }
716
717 #[test]
718 fn test_max_sessions_enforced() {
719 let store = SessionStore::new(Duration::from_secs(300), 3);
720 store.get_or_create(None);
721 store.get_or_create(None);
722 store.get_or_create(None);
723 assert_eq!(store.len(), 3);
724 store.get_or_create(None);
726 assert_eq!(store.len(), 3);
727 }
728
729 #[test]
730 fn test_session_remove() {
731 let store = SessionStore::new(Duration::from_secs(300), 100);
732 let id = store.get_or_create(None);
733 assert!(store.remove(&id));
734 assert_eq!(store.len(), 0);
735 assert!(!store.remove(&id));
736 }
737
738 #[test]
739 fn test_session_touch_increments_count() {
740 let store = SessionStore::new(Duration::from_secs(300), 100);
741 let id = store.get_or_create(None);
742 store.get_or_create(Some(&id));
745 let session = store.get_mut(&id).unwrap();
746 assert_eq!(session.request_count, 1);
747 }
748
749 #[test]
750 fn test_flagged_tools_insert_and_contains() {
751 let store = SessionStore::new(Duration::from_secs(300), 100);
752 let id = store.get_or_create(None);
753
754 {
756 let mut session = store.get_mut(&id).unwrap();
757 session.flagged_tools.insert("evil_tool".to_string());
758 session.flagged_tools.insert("suspicious_tool".to_string());
759 }
760
761 let session = store.get_mut(&id).unwrap();
763 assert!(session.flagged_tools.contains("evil_tool"));
764 assert!(session.flagged_tools.contains("suspicious_tool"));
765 assert!(!session.flagged_tools.contains("safe_tool"));
766 assert_eq!(session.flagged_tools.len(), 2);
767 }
768
769 #[test]
770 fn test_flagged_tools_empty_by_default() {
771 let state = SessionState::new("test-session".to_string());
772 assert!(state.flagged_tools.is_empty());
773 assert!(state.pending_tool_calls.is_empty());
774 }
775
776 #[test]
777 fn test_oauth_subject_storage() {
778 let store = SessionStore::new(Duration::from_secs(300), 100);
779 let id = store.get_or_create(None);
780
781 {
783 let session = store.get_mut(&id).unwrap();
784 assert!(session.oauth_subject.is_none());
785 }
786
787 {
789 let mut session = store.get_mut(&id).unwrap();
790 session.oauth_subject = Some("user-42".to_string());
791 }
792
793 let session = store.get_mut(&id).unwrap();
795 assert_eq!(session.oauth_subject.as_deref(), Some("user-42"));
796 }
797
798 #[test]
799 fn test_protocol_version_tracking() {
800 let store = SessionStore::new(Duration::from_secs(300), 100);
801 let id = store.get_or_create(None);
802
803 {
804 let session = store.get_mut(&id).unwrap();
805 assert!(session.protocol_version.is_none());
806 }
807
808 {
809 let mut session = store.get_mut(&id).unwrap();
810 session.protocol_version = Some("2025-11-25".to_string());
811 }
812
813 let session = store.get_mut(&id).unwrap();
814 assert_eq!(session.protocol_version.as_deref(), Some("2025-11-25"));
815 }
816
817 #[test]
818 fn test_known_tools_mutations() {
819 let store = SessionStore::new(Duration::from_secs(300), 100);
820 let id = store.get_or_create(None);
821
822 {
823 let mut session = store.get_mut(&id).unwrap();
824 session.known_tools.insert(
825 "read_file".to_string(),
826 ToolAnnotations {
827 read_only_hint: true,
828 destructive_hint: false,
829 idempotent_hint: true,
830 open_world_hint: false,
831 input_schema_hash: None,
832 },
833 );
834 }
835
836 let session = store.get_mut(&id).unwrap();
837 assert_eq!(session.known_tools.len(), 1);
838 let ann = session.known_tools.get("read_file").unwrap();
839 assert!(ann.read_only_hint);
840 assert!(!ann.destructive_hint);
841 }
842
843 #[test]
844 fn test_tool_annotations_default() {
845 let ann = ToolAnnotations::default();
846 assert!(!ann.read_only_hint);
847 assert!(ann.destructive_hint);
848 assert!(!ann.idempotent_hint);
849 assert!(ann.open_world_hint);
850 }
851
852 #[test]
853 fn test_tool_annotations_equality() {
854 let a = ToolAnnotations {
855 read_only_hint: true,
856 destructive_hint: false,
857 idempotent_hint: true,
858 open_world_hint: false,
859 input_schema_hash: None,
860 };
861 let b = ToolAnnotations {
862 read_only_hint: true,
863 destructive_hint: false,
864 idempotent_hint: true,
865 open_world_hint: false,
866 input_schema_hash: None,
867 };
868 let c = ToolAnnotations::default();
869 assert_eq!(a, b);
870 assert_ne!(a, c);
871 }
872
873 #[test]
874 fn test_tools_list_seen_flag() {
875 let state = SessionState::new("test".to_string());
876 assert!(!state.tools_list_seen);
877 }
878
879 #[test]
882 fn test_inactivity_expiry_preserved() {
883 let state = SessionState::new("test-inactivity".to_string());
884 assert!(!state.is_expired(Duration::from_secs(300), None));
886 assert!(state.is_expired(Duration::from_nanos(0), None));
888 }
889
890 #[test]
891 fn test_absolute_lifetime_enforced() {
892 let state = SessionState::new("test-lifetime".to_string());
893 assert!(state.is_expired(Duration::from_secs(300), Some(Duration::from_nanos(0))));
895 assert!(!state.is_expired(Duration::from_secs(300), Some(Duration::from_secs(86400))));
897 }
898
899 #[test]
900 fn test_none_max_lifetime_no_absolute_limit() {
901 let state = SessionState::new("test-no-limit".to_string());
902 assert!(!state.is_expired(Duration::from_secs(300), None));
904 }
905
906 #[test]
907 fn test_eviction_checks_both_timeouts() {
908 let store = SessionStore::new(Duration::from_secs(300), 100)
910 .with_max_lifetime(Duration::from_nanos(0));
911
912 let _id = store.get_or_create(None);
913 assert_eq!(store.len(), 1);
914
915 store.evict_expired();
917 assert_eq!(store.len(), 0);
918 }
919
920 #[test]
921 fn test_with_max_lifetime_builder() {
922 let store = SessionStore::new(Duration::from_secs(300), 100)
923 .with_max_lifetime(Duration::from_secs(86400));
924 let id = store.get_or_create(None);
926 assert_eq!(store.len(), 1);
927 let id2 = store.get_or_create(Some(&id));
929 assert_eq!(id, id2);
930 }
931
932 #[test]
935 fn test_session_id_at_max_length_accepted() {
936 let store = SessionStore::new(Duration::from_secs(300), 100);
937 let long_id = "a".repeat(MAX_SESSION_ID_LEN);
939 let id = store.get_or_create(Some(&long_id));
941 assert_ne!(id, long_id); assert_eq!(store.len(), 1);
943
944 store
946 .sessions
947 .insert(long_id.clone(), SessionState::new(long_id.clone()));
948 let reused = store.get_or_create(Some(&long_id));
949 assert_eq!(reused, long_id);
950 }
951
952 #[test]
953 fn test_session_id_exceeding_max_length_rejected() {
954 let store = SessionStore::new(Duration::from_secs(300), 100);
955 let too_long = "b".repeat(MAX_SESSION_ID_LEN + 1);
957 store
958 .sessions
959 .insert(too_long.clone(), SessionState::new(too_long.clone()));
960
961 let id = store.get_or_create(Some(&too_long));
964 assert_ne!(id, too_long, "Oversized session ID must not be reused");
965 assert_eq!(id.len(), 36, "Should return a UUID-format session ID");
966 }
967
968 #[test]
969 fn test_session_id_empty_string_accepted() {
970 let store = SessionStore::new(Duration::from_secs(300), 100);
971 let id = store.get_or_create(Some(""));
973 assert_eq!(id.len(), 36); assert_eq!(store.len(), 1);
975 }
976
977 #[test]
978 fn test_session_id_exactly_128_chars_boundary() {
979 let store = SessionStore::new(Duration::from_secs(300), 100);
980 let exact = "x".repeat(128);
981 let id = store.get_or_create(Some(&exact));
983 assert_eq!(id.len(), 36);
986
987 let one_over = "x".repeat(129);
988 let id2 = store.get_or_create(Some(&one_over));
989 assert_eq!(id2.len(), 36);
990 assert_eq!(store.len(), 2);
992 }
993
994 #[test]
1000 fn test_stateful_context_implements_trait() {
1001 let session = SessionState::new("test-ctx".to_string());
1002 let ctx = StatefulContext::new(&session);
1003
1004 let _: &dyn RequestContext = &ctx;
1006 assert!(ctx.call_counts().is_empty());
1007 assert!(ctx.previous_actions().is_empty());
1008 assert!(ctx.call_chain().is_empty());
1009 assert!(ctx.agent_identity().is_none());
1010 assert!(ctx.session_guard_state().is_none());
1011 assert!(ctx.risk_score().is_none());
1012 }
1013
1014 #[test]
1016 fn test_stateful_context_call_counts() {
1017 let mut session = SessionState::new("test-counts".to_string());
1018 session.call_counts.insert("read_file".to_string(), 5);
1019 session.call_counts.insert("write_file".to_string(), 3);
1020
1021 let ctx = StatefulContext::new(&session);
1022 assert_eq!(ctx.call_counts().len(), 2);
1023 assert_eq!(ctx.call_counts()["read_file"], 5);
1024 assert_eq!(ctx.call_counts()["write_file"], 3);
1025 }
1026
1027 #[test]
1029 fn test_stateful_context_previous_actions() {
1030 let mut session = SessionState::new("test-actions".to_string());
1031 session.action_history.push_back("read_file".to_string());
1032 session.action_history.push_back("write_file".to_string());
1033 session.action_history.push_back("execute".to_string());
1034
1035 let ctx = StatefulContext::new(&session);
1036 let actions = ctx.previous_actions();
1037 assert_eq!(actions.len(), 3);
1038 assert_eq!(actions[0], "read_file");
1039 assert_eq!(actions[1], "write_file");
1040 assert_eq!(actions[2], "execute");
1041 }
1042
1043 #[test]
1048 fn test_discovered_tools_empty_by_default() {
1049 let state = SessionState::new("test".to_string());
1050 assert!(state.discovered_tools.is_empty());
1051 }
1052
1053 #[test]
1054 fn test_record_discovered_tools() {
1055 let mut state = SessionState::new("test".to_string());
1056 let tools = vec![
1057 "server:read_file".to_string(),
1058 "server:write_file".to_string(),
1059 ];
1060 state.record_discovered_tools(&tools, Duration::from_secs(300));
1061
1062 assert_eq!(state.discovered_tools.len(), 2);
1063 assert!(state.discovered_tools.contains_key("server:read_file"));
1064 assert!(state.discovered_tools.contains_key("server:write_file"));
1065 }
1066
1067 #[test]
1068 fn test_record_discovered_tools_sets_ttl() {
1069 let mut state = SessionState::new("test".to_string());
1070 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(60));
1071
1072 let entry = state.discovered_tools.get("server:tool1").unwrap();
1073 assert_eq!(entry.ttl, Duration::from_secs(60));
1074 assert!(!entry.used);
1075 }
1076
1077 #[test]
1078 fn test_record_discovered_tools_rediscovery_resets_ttl() {
1079 let mut state = SessionState::new("test".to_string());
1080 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(60));
1081
1082 state.mark_tool_used("server:tool1");
1084 assert!(state.discovered_tools.get("server:tool1").unwrap().used);
1085
1086 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(120));
1088
1089 let entry = state.discovered_tools.get("server:tool1").unwrap();
1090 assert_eq!(entry.ttl, Duration::from_secs(120));
1091 assert!(!entry.used); }
1093
1094 #[test]
1095 fn test_is_tool_discovery_expired_unknown_tool() {
1096 let state = SessionState::new("test".to_string());
1097 assert_eq!(state.is_tool_discovery_expired("unknown:tool"), None);
1098 }
1099
1100 #[test]
1101 fn test_is_tool_discovery_expired_fresh_tool() {
1102 let mut state = SessionState::new("test".to_string());
1103 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1104 assert_eq!(state.is_tool_discovery_expired("server:tool1"), Some(false));
1105 }
1106
1107 #[test]
1108 fn test_is_tool_discovery_expired_zero_ttl() {
1109 let mut state = SessionState::new("test".to_string());
1110 state.discovered_tools.insert(
1112 "server:tool1".to_string(),
1113 DiscoveredToolSession {
1114 tool_id: "server:tool1".to_string(),
1115 discovered_at: Instant::now() - Duration::from_secs(1),
1116 ttl: Duration::from_nanos(0),
1117 used: false,
1118 },
1119 );
1120 assert_eq!(state.is_tool_discovery_expired("server:tool1"), Some(true));
1121 }
1122
1123 #[test]
1124 fn test_mark_tool_used_existing() {
1125 let mut state = SessionState::new("test".to_string());
1126 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1127 assert!(!state.discovered_tools.get("server:tool1").unwrap().used);
1128
1129 assert!(state.mark_tool_used("server:tool1"));
1130 assert!(state.discovered_tools.get("server:tool1").unwrap().used);
1131 }
1132
1133 #[test]
1134 fn test_mark_tool_used_nonexistent() {
1135 let mut state = SessionState::new("test".to_string());
1136 assert!(!state.mark_tool_used("unknown:tool"));
1137 }
1138
1139 #[test]
1140 fn test_evict_expired_discoveries_none_expired() {
1141 let mut state = SessionState::new("test".to_string());
1142 state.record_discovered_tools(
1143 &["server:tool1".to_string(), "server:tool2".to_string()],
1144 Duration::from_secs(300),
1145 );
1146 assert_eq!(state.evict_expired_discoveries(), 0);
1147 assert_eq!(state.discovered_tools.len(), 2);
1148 }
1149
1150 #[test]
1151 fn test_evict_expired_discoveries_some_expired() {
1152 let mut state = SessionState::new("test".to_string());
1153
1154 state.record_discovered_tools(&["server:fresh".to_string()], Duration::from_secs(300));
1156
1157 state.discovered_tools.insert(
1159 "server:stale".to_string(),
1160 DiscoveredToolSession {
1161 tool_id: "server:stale".to_string(),
1162 discovered_at: Instant::now() - Duration::from_secs(10),
1163 ttl: Duration::from_secs(1),
1164 used: true,
1165 },
1166 );
1167
1168 assert_eq!(state.evict_expired_discoveries(), 1);
1169 assert_eq!(state.discovered_tools.len(), 1);
1170 assert!(state.discovered_tools.contains_key("server:fresh"));
1171 assert!(!state.discovered_tools.contains_key("server:stale"));
1172 }
1173
1174 #[test]
1175 fn test_evict_expired_discoveries_all_expired() {
1176 let mut state = SessionState::new("test".to_string());
1177 let past = Instant::now() - Duration::from_secs(10);
1178 for i in 0..5 {
1179 state.discovered_tools.insert(
1180 format!("server:tool{i}"),
1181 DiscoveredToolSession {
1182 tool_id: format!("server:tool{i}"),
1183 discovered_at: past,
1184 ttl: Duration::from_secs(1),
1185 used: false,
1186 },
1187 );
1188 }
1189
1190 assert_eq!(state.evict_expired_discoveries(), 5);
1191 assert!(state.discovered_tools.is_empty());
1192 }
1193
1194 #[test]
1195 fn test_discovered_tool_session_is_expired() {
1196 let fresh = DiscoveredToolSession {
1197 tool_id: "t".to_string(),
1198 discovered_at: Instant::now(),
1199 ttl: Duration::from_secs(300),
1200 used: false,
1201 };
1202 assert!(!fresh.is_expired());
1203
1204 let stale = DiscoveredToolSession {
1205 tool_id: "t".to_string(),
1206 discovered_at: Instant::now() - Duration::from_secs(10),
1207 ttl: Duration::from_secs(1),
1208 used: false,
1209 };
1210 assert!(stale.is_expired());
1211 }
1212
1213 #[test]
1214 fn test_discovered_tools_survive_session_touch() {
1215 let store = SessionStore::new(Duration::from_secs(300), 100);
1216 let id = store.get_or_create(None);
1217
1218 {
1220 let mut session = store.get_mut(&id).unwrap();
1221 session
1222 .record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1223 }
1224
1225 store.get_or_create(Some(&id));
1227
1228 let session = store.get_mut(&id).unwrap();
1230 assert_eq!(session.discovered_tools.len(), 1);
1231 assert!(session.discovered_tools.contains_key("server:tool1"));
1232 }
1233
1234 #[test]
1235 fn test_multiple_tools_independent_ttl() {
1236 let mut state = SessionState::new("test".to_string());
1237
1238 state.discovered_tools.insert(
1240 "server:short".to_string(),
1241 DiscoveredToolSession {
1242 tool_id: "server:short".to_string(),
1243 discovered_at: Instant::now() - Duration::from_secs(5),
1244 ttl: Duration::from_secs(1),
1245 used: false,
1246 },
1247 );
1248
1249 state.record_discovered_tools(&["server:long".to_string()], Duration::from_secs(3600));
1251
1252 assert_eq!(state.is_tool_discovery_expired("server:short"), Some(true));
1253 assert_eq!(state.is_tool_discovery_expired("server:long"), Some(false));
1254 }
1255
1256 #[test]
1258 fn test_evaluation_context_from_stateful() {
1259 let mut session = SessionState::new("test-eval".to_string());
1260 session.oauth_subject = Some("user-42".to_string());
1261 session.call_counts.insert("tool_a".to_string(), 7);
1262 session.action_history.push_back("tool_a".to_string());
1263 session.agent_identity = Some(AgentIdentity {
1264 issuer: Some("test-issuer".to_string()),
1265 subject: Some("agent-sub".to_string()),
1266 ..Default::default()
1267 });
1268
1269 let ctx = StatefulContext::new(&session);
1270 let eval = ctx.to_evaluation_context();
1271
1272 assert_eq!(eval.agent_id.as_deref(), Some("user-42"));
1273 assert_eq!(eval.call_counts["tool_a"], 7);
1274 assert_eq!(eval.previous_actions, vec!["tool_a".to_string()]);
1275 assert_eq!(
1276 eval.agent_identity.as_ref().unwrap().issuer.as_deref(),
1277 Some("test-issuer")
1278 );
1279 }
1280
1281 #[test]
1286 fn test_global_flagged_tool_basic() {
1287 let store = SessionStore::new(Duration::from_secs(300), 100);
1288 assert!(!store.is_tool_globally_flagged("evil_tool"));
1289 assert_eq!(store.global_flagged_tools_len(), 0);
1290
1291 store.flag_tool_globally("evil_tool".to_string());
1292 assert!(store.is_tool_globally_flagged("evil_tool"));
1293 assert!(!store.is_tool_globally_flagged("safe_tool"));
1294 assert_eq!(store.global_flagged_tools_len(), 1);
1295 }
1296
1297 #[test]
1298 fn test_global_flagged_tool_survives_session_eviction() {
1299 let store = SessionStore::new(Duration::from_secs(300), 2);
1302 let id1 = store.get_or_create(None);
1303
1304 if let Some(mut s) = store.get_mut(&id1) {
1306 s.insert_flagged_tool("rug_pulled_tool".to_string());
1307 }
1308 store.flag_tool_globally("rug_pulled_tool".to_string());
1310
1311 let is_flagged = store
1313 .get_mut(&id1)
1314 .map(|s| s.flagged_tools.contains("rug_pulled_tool"))
1315 .unwrap_or(false);
1316 assert!(is_flagged);
1317
1318 store.get_or_create(None);
1320 store.get_or_create(None); let session_gone = store.get_mut(&id1).is_none();
1324 assert!(session_gone, "session should have been evicted");
1325
1326 assert!(store.is_tool_globally_flagged("rug_pulled_tool"));
1328 }
1329
1330 #[test]
1331 fn test_global_flagged_tool_expiry() {
1332 let store = SessionStore::new(Duration::from_secs(300), 100);
1333
1334 store.global_flagged_tools.insert(
1336 "expired_tool".to_string(),
1337 GlobalFlaggedToolEntry {
1338 flagged_at: Instant::now() - Duration::from_secs(25 * 60 * 60), ttl: GLOBAL_FLAGGED_TOOL_TTL, },
1341 );
1342
1343 assert!(!store.is_tool_globally_flagged("expired_tool"));
1345
1346 let evicted = store.evict_expired_global_flags();
1348 assert_eq!(evicted, 1);
1349 assert_eq!(store.global_flagged_tools_len(), 0);
1350 }
1351
1352 #[test]
1353 fn test_global_flagged_tool_capacity_bound() {
1354 let store = SessionStore::new(Duration::from_secs(300), 100);
1355
1356 for i in 0..MAX_GLOBAL_FLAGGED_TOOLS {
1358 store.flag_tool_globally(format!("tool_{i}"));
1359 }
1360 assert_eq!(store.global_flagged_tools_len(), MAX_GLOBAL_FLAGGED_TOOLS);
1361
1362 store.flag_tool_globally("overflow_tool".to_string());
1364 assert!(!store.is_tool_globally_flagged("overflow_tool"));
1365 assert_eq!(store.global_flagged_tools_len(), MAX_GLOBAL_FLAGGED_TOOLS);
1366 }
1367
1368 #[test]
1369 fn test_global_flagged_tool_capacity_evicts_expired_first() {
1370 let store = SessionStore::new(Duration::from_secs(300), 100);
1371
1372 for i in 0..MAX_GLOBAL_FLAGGED_TOOLS {
1374 store.global_flagged_tools.insert(
1375 format!("old_tool_{i}"),
1376 GlobalFlaggedToolEntry {
1377 flagged_at: Instant::now() - Duration::from_secs(25 * 60 * 60),
1378 ttl: GLOBAL_FLAGGED_TOOL_TTL,
1379 },
1380 );
1381 }
1382 assert_eq!(store.global_flagged_tools_len(), MAX_GLOBAL_FLAGGED_TOOLS);
1383
1384 store.flag_tool_globally("fresh_tool".to_string());
1386 assert!(store.is_tool_globally_flagged("fresh_tool"));
1387 }
1388
1389 #[test]
1390 fn test_global_flagged_tool_no_ttl_reset_on_reflag() {
1391 let store = SessionStore::new(Duration::from_secs(300), 100);
1392
1393 let old_time = Instant::now() - Duration::from_secs(60 * 60); store.global_flagged_tools.insert(
1396 "tool_a".to_string(),
1397 GlobalFlaggedToolEntry {
1398 flagged_at: old_time,
1399 ttl: GLOBAL_FLAGGED_TOOL_TTL,
1400 },
1401 );
1402
1403 store.flag_tool_globally("tool_a".to_string());
1405 let entry = store.global_flagged_tools.get("tool_a").unwrap();
1406 assert_eq!(entry.flagged_at, old_time);
1407 }
1408
1409 #[test]
1410 fn test_global_flagged_tool_unwrap_or_else_fallback() {
1411 let store = SessionStore::new(Duration::from_secs(300), 100);
1414 store.flag_tool_globally("globally_flagged".to_string());
1415
1416 let is_flagged = store
1418 .get_mut("nonexistent-session")
1419 .map(|s| s.flagged_tools.contains("globally_flagged"))
1420 .unwrap_or_else(|| store.is_tool_globally_flagged("globally_flagged"));
1421
1422 assert!(is_flagged, "global fallback should catch flagged tool");
1423 }
1424}