1use llmtrace_core::{AgentAction, AgentActionType, SecurityFinding, SecuritySeverity};
30use std::collections::{HashMap, VecDeque};
31use std::fmt;
32use std::sync::RwLock;
33use std::time::{Duration, Instant};
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub enum ToolCategory {
45 DataRetrieval,
47 WebAccess,
49 FileSystem,
51 Database,
53 CodeExecution,
55 Communication,
57 SystemAdmin,
59 Custom(String),
61}
62
63impl fmt::Display for ToolCategory {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 match self {
66 Self::DataRetrieval => write!(f, "data_retrieval"),
67 Self::WebAccess => write!(f, "web_access"),
68 Self::FileSystem => write!(f, "file_system"),
69 Self::Database => write!(f, "database"),
70 Self::CodeExecution => write!(f, "code_execution"),
71 Self::Communication => write!(f, "communication"),
72 Self::SystemAdmin => write!(f, "system_admin"),
73 Self::Custom(name) => write!(f, "custom:{}", name),
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
88pub struct ToolDefinition {
89 pub id: String,
91 pub name: String,
93 pub category: ToolCategory,
95 pub risk_score: f64,
97 pub requires_approval: bool,
99 pub rate_limit: Option<u32>,
101 pub required_permissions: Vec<String>,
103 pub description: String,
105}
106
107impl ToolDefinition {
108 pub fn new(id: &str, name: &str, category: ToolCategory) -> Self {
113 Self {
114 id: id.to_string(),
115 name: name.to_string(),
116 category,
117 risk_score: 0.5,
118 requires_approval: false,
119 rate_limit: None,
120 required_permissions: Vec::new(),
121 description: String::new(),
122 }
123 }
124
125 pub fn with_risk_score(mut self, score: f64) -> Self {
127 self.risk_score = score.clamp(0.0, 1.0);
128 self
129 }
130
131 pub fn with_requires_approval(mut self, requires: bool) -> Self {
133 self.requires_approval = requires;
134 self
135 }
136
137 pub fn with_rate_limit(mut self, limit: u32) -> Self {
139 self.rate_limit = Some(limit);
140 self
141 }
142
143 pub fn with_permission(mut self, permission: String) -> Self {
145 self.required_permissions.push(permission);
146 self
147 }
148
149 pub fn with_description(mut self, description: String) -> Self {
151 self.description = description;
152 self
153 }
154}
155
156pub struct ToolRegistry {
169 tools: RwLock<HashMap<String, ToolDefinition>>,
171}
172
173impl ToolRegistry {
174 pub fn new() -> Self {
176 Self {
177 tools: RwLock::new(HashMap::new()),
178 }
179 }
180
181 pub fn with_defaults() -> Self {
184 let registry = Self::new();
185 let defaults = vec![
186 ToolDefinition::new("web_search", "Web Search", ToolCategory::WebAccess)
187 .with_risk_score(0.3)
188 .with_description("Search the web for information".to_string()),
189 ToolDefinition::new("web_browse", "Web Browse", ToolCategory::WebAccess)
190 .with_risk_score(0.4)
191 .with_description("Browse a web page and extract content".to_string()),
192 ToolDefinition::new("file_read", "File Read", ToolCategory::FileSystem)
193 .with_risk_score(0.3)
194 .with_description("Read contents of a file".to_string()),
195 ToolDefinition::new("file_write", "File Write", ToolCategory::FileSystem)
196 .with_risk_score(0.6)
197 .with_requires_approval(true)
198 .with_description("Write content to a file".to_string())
199 .with_permission("file:write".to_string()),
200 ToolDefinition::new("file_delete", "File Delete", ToolCategory::FileSystem)
201 .with_risk_score(0.8)
202 .with_requires_approval(true)
203 .with_description("Delete a file from the filesystem".to_string())
204 .with_permission("file:delete".to_string()),
205 ToolDefinition::new("shell_exec", "Shell Execute", ToolCategory::CodeExecution)
206 .with_risk_score(0.9)
207 .with_requires_approval(true)
208 .with_rate_limit(10)
209 .with_description("Execute a shell command".to_string())
210 .with_permission("exec:shell".to_string()),
211 ToolDefinition::new("code_exec", "Code Execute", ToolCategory::CodeExecution)
212 .with_risk_score(0.85)
213 .with_requires_approval(true)
214 .with_rate_limit(20)
215 .with_description("Execute code in a sandboxed environment".to_string())
216 .with_permission("exec:code".to_string()),
217 ToolDefinition::new("send_email", "Send Email", ToolCategory::Communication)
218 .with_risk_score(0.7)
219 .with_requires_approval(true)
220 .with_rate_limit(5)
221 .with_description("Send an email message".to_string())
222 .with_permission("comms:email".to_string()),
223 ToolDefinition::new("send_message", "Send Message", ToolCategory::Communication)
224 .with_risk_score(0.6)
225 .with_requires_approval(true)
226 .with_rate_limit(10)
227 .with_description("Send a chat or messaging platform message".to_string())
228 .with_permission("comms:message".to_string()),
229 ToolDefinition::new("database_query", "Database Query", ToolCategory::Database)
230 .with_risk_score(0.5)
231 .with_description("Execute a read-only database query".to_string())
232 .with_permission("db:read".to_string()),
233 ToolDefinition::new("database_write", "Database Write", ToolCategory::Database)
234 .with_risk_score(0.7)
235 .with_requires_approval(true)
236 .with_description("Execute a database write operation".to_string())
237 .with_permission("db:write".to_string()),
238 ToolDefinition::new("api_call", "API Call", ToolCategory::WebAccess)
239 .with_risk_score(0.4)
240 .with_description("Make an HTTP API call".to_string()),
241 ToolDefinition::new("data_lookup", "Data Lookup", ToolCategory::DataRetrieval)
242 .with_risk_score(0.1)
243 .with_description("Look up data from a knowledge base".to_string()),
244 ToolDefinition::new(
245 "system_config",
246 "System Configuration",
247 ToolCategory::SystemAdmin,
248 )
249 .with_risk_score(0.95)
250 .with_requires_approval(true)
251 .with_rate_limit(5)
252 .with_description("Modify system configuration".to_string())
253 .with_permission("admin:config".to_string()),
254 ];
255 for tool in defaults {
256 registry.register(tool);
257 }
258 registry
259 }
260
261 pub fn register(&self, tool: ToolDefinition) {
265 let mut tools = self.tools.write().expect("tool registry lock poisoned");
266 tools.insert(tool.id.clone(), tool);
267 }
268
269 pub fn unregister(&self, id: &str) -> bool {
273 let mut tools = self.tools.write().expect("tool registry lock poisoned");
274 tools.remove(id).is_some()
275 }
276
277 pub fn get(&self, id: &str) -> Option<ToolDefinition> {
279 let tools = self.tools.read().expect("tool registry lock poisoned");
280 tools.get(id).cloned()
281 }
282
283 pub fn is_registered(&self, id: &str) -> bool {
285 let tools = self.tools.read().expect("tool registry lock poisoned");
286 tools.contains_key(id)
287 }
288
289 pub fn lookup_by_category(&self, category: &ToolCategory) -> Vec<ToolDefinition> {
294 let tools = self.tools.read().expect("tool registry lock poisoned");
295 tools
296 .values()
297 .filter(|t| &t.category == category)
298 .cloned()
299 .collect()
300 }
301
302 pub fn len(&self) -> usize {
304 let tools = self.tools.read().expect("tool registry lock poisoned");
305 tools.len()
306 }
307
308 pub fn is_empty(&self) -> bool {
310 self.len() == 0
311 }
312
313 pub fn validate_action(&self, action: &AgentAction) -> Vec<SecurityFinding> {
320 let mut findings = Vec::new();
321 let tool_name = &action.name;
322
323 let tools = self.tools.read().expect("tool registry lock poisoned");
324
325 match tools.get(tool_name) {
326 None => {
327 if action.action_type == AgentActionType::ToolCall
331 || action.action_type == AgentActionType::SkillInvocation
332 {
333 findings.push(
334 SecurityFinding::new(
335 SecuritySeverity::High,
336 "unregistered_tool".to_string(),
337 format!("Unregistered tool used: {}", tool_name),
338 0.9,
339 )
340 .with_location("agent_action.tool_call".to_string())
341 .with_metadata("tool_name".to_string(), tool_name.clone()),
342 );
343 }
344 }
345 Some(tool) => {
346 if tool.risk_score > 0.8 {
347 findings.push(
348 SecurityFinding::new(
349 SecuritySeverity::High,
350 "high_risk_tool".to_string(),
351 format!(
352 "High-risk tool used: {} (risk score: {:.2})",
353 tool_name, tool.risk_score
354 ),
355 tool.risk_score,
356 )
357 .with_location("agent_action.tool_call".to_string())
358 .with_metadata("tool_name".to_string(), tool_name.clone())
359 .with_metadata("risk_score".to_string(), format!("{:.2}", tool.risk_score))
360 .with_metadata("category".to_string(), tool.category.to_string()),
361 );
362 }
363 if tool.requires_approval {
364 findings.push(
365 SecurityFinding::new(
366 SecuritySeverity::Info,
367 "tool_requires_approval".to_string(),
368 format!("Tool requires user approval: {}", tool_name),
369 1.0,
370 )
371 .with_location("agent_action.tool_call".to_string())
372 .with_metadata("tool_name".to_string(), tool_name.clone())
373 .with_alert_required(false),
374 );
375 }
376 }
377 }
378
379 findings
380 }
381}
382
383impl Default for ToolRegistry {
384 fn default() -> Self {
385 Self::new()
386 }
387}
388
389impl fmt::Debug for ToolRegistry {
390 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391 let tools = self.tools.read().expect("tool registry lock poisoned");
392 f.debug_struct("ToolRegistry")
393 .field("tool_count", &tools.len())
394 .field("tool_ids", &tools.keys().collect::<Vec<_>>())
395 .finish()
396 }
397}
398
399#[derive(Debug, Clone)]
405pub struct RateLimitExceeded {
406 pub action_type: String,
408 pub limit: u32,
410 pub window: Duration,
412}
413
414impl fmt::Display for RateLimitExceeded {
415 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416 write!(
417 f,
418 "rate limit exceeded for action type '{}': {} calls per {:?}",
419 self.action_type, self.limit, self.window
420 )
421 }
422}
423
424impl std::error::Error for RateLimitExceeded {}
425
426impl RateLimitExceeded {
427 pub fn to_security_finding(&self) -> SecurityFinding {
429 SecurityFinding::new(
430 SecuritySeverity::Medium,
431 "action_rate_limit_exceeded".to_string(),
432 format!(
433 "Rate limit exceeded for action type '{}': limit is {} calls per {:?}",
434 self.action_type, self.limit, self.window
435 ),
436 0.95,
437 )
438 .with_metadata("action_type".to_string(), self.action_type.clone())
439 .with_metadata("limit".to_string(), self.limit.to_string())
440 .with_metadata("window_secs".to_string(), self.window.as_secs().to_string())
441 }
442}
443
444pub struct ActionRateLimiter {
459 windows: RwLock<HashMap<String, VecDeque<Instant>>>,
461 default_limit: u32,
463 limits: HashMap<String, u32>,
465 window: Duration,
467}
468
469impl ActionRateLimiter {
470 pub fn new(default_limit: u32, window: Duration) -> Self {
472 Self {
473 windows: RwLock::new(HashMap::new()),
474 default_limit,
475 limits: HashMap::new(),
476 window,
477 }
478 }
479
480 pub fn with_limits(default_limit: u32, window: Duration, limits: HashMap<String, u32>) -> Self {
482 Self {
483 windows: RwLock::new(HashMap::new()),
484 default_limit,
485 limits,
486 window,
487 }
488 }
489
490 fn effective_limit(&self, action_type: &str) -> u32 {
492 self.limits
493 .get(action_type)
494 .copied()
495 .unwrap_or(self.default_limit)
496 }
497
498 fn prune(deque: &mut VecDeque<Instant>, cutoff: Instant) {
501 while let Some(&front) = deque.front() {
502 if front < cutoff {
503 deque.pop_front();
504 } else {
505 break;
506 }
507 }
508 }
509
510 pub fn check_rate_limit(
515 &self,
516 action_type: &str,
517 ) -> std::result::Result<(), RateLimitExceeded> {
518 let limit = self.effective_limit(action_type);
519 let now = Instant::now();
520 let cutoff = now - self.window;
521
522 let mut windows = self.windows.write().expect("rate limiter lock poisoned");
523 let deque = windows.entry(action_type.to_string()).or_default();
524 Self::prune(deque, cutoff);
525
526 if deque.len() >= limit as usize {
527 Err(RateLimitExceeded {
528 action_type: action_type.to_string(),
529 limit,
530 window: self.window,
531 })
532 } else {
533 deque.push_back(now);
534 Ok(())
535 }
536 }
537
538 pub fn record_action(&self, action_type: &str) {
543 let now = Instant::now();
544 let cutoff = now - self.window;
545
546 let mut windows = self.windows.write().expect("rate limiter lock poisoned");
547 let deque = windows.entry(action_type.to_string()).or_default();
548 Self::prune(deque, cutoff);
549 deque.push_back(now);
550 }
551
552 pub fn remaining(&self, action_type: &str) -> u32 {
555 let limit = self.effective_limit(action_type);
556 let now = Instant::now();
557 let cutoff = now - self.window;
558
559 let mut windows = self.windows.write().expect("rate limiter lock poisoned");
560 let deque = windows.entry(action_type.to_string()).or_default();
561 Self::prune(deque, cutoff);
562
563 limit.saturating_sub(deque.len() as u32)
564 }
565
566 pub fn reset(&self, action_type: &str) {
568 let mut windows = self.windows.write().expect("rate limiter lock poisoned");
569 windows.remove(action_type);
570 }
571
572 pub fn window_duration(&self) -> Duration {
574 self.window
575 }
576
577 pub fn default_limit(&self) -> u32 {
579 self.default_limit
580 }
581}
582
583impl fmt::Debug for ActionRateLimiter {
584 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
585 f.debug_struct("ActionRateLimiter")
586 .field("default_limit", &self.default_limit)
587 .field("window", &self.window)
588 .field("overrides", &self.limits)
589 .finish()
590 }
591}
592
593#[cfg(test)]
598mod tests {
599 use super::*;
600 use llmtrace_core::{AgentAction, AgentActionType};
601 use std::thread;
602 use std::time::Duration;
603
604 #[test]
609 fn test_tool_category_display() {
610 assert_eq!(ToolCategory::DataRetrieval.to_string(), "data_retrieval");
611 assert_eq!(ToolCategory::WebAccess.to_string(), "web_access");
612 assert_eq!(ToolCategory::FileSystem.to_string(), "file_system");
613 assert_eq!(ToolCategory::Database.to_string(), "database");
614 assert_eq!(ToolCategory::CodeExecution.to_string(), "code_execution");
615 assert_eq!(ToolCategory::Communication.to_string(), "communication");
616 assert_eq!(ToolCategory::SystemAdmin.to_string(), "system_admin");
617 assert_eq!(
618 ToolCategory::Custom("my_cat".to_string()).to_string(),
619 "custom:my_cat"
620 );
621 }
622
623 #[test]
624 fn test_tool_category_equality() {
625 assert_eq!(ToolCategory::WebAccess, ToolCategory::WebAccess);
626 assert_ne!(ToolCategory::WebAccess, ToolCategory::FileSystem);
627 assert_eq!(
628 ToolCategory::Custom("x".to_string()),
629 ToolCategory::Custom("x".to_string())
630 );
631 assert_ne!(
632 ToolCategory::Custom("x".to_string()),
633 ToolCategory::Custom("y".to_string())
634 );
635 }
636
637 #[test]
642 fn test_tool_definition_new_defaults() {
643 let tool = ToolDefinition::new("test", "Test", ToolCategory::DataRetrieval);
644 assert_eq!(tool.id, "test");
645 assert_eq!(tool.name, "Test");
646 assert_eq!(tool.category, ToolCategory::DataRetrieval);
647 assert!((tool.risk_score - 0.5).abs() < f64::EPSILON);
648 assert!(!tool.requires_approval);
649 assert!(tool.rate_limit.is_none());
650 assert!(tool.required_permissions.is_empty());
651 assert!(tool.description.is_empty());
652 }
653
654 #[test]
655 fn test_tool_definition_builder() {
656 let tool = ToolDefinition::new("exec", "Execute", ToolCategory::CodeExecution)
657 .with_risk_score(0.95)
658 .with_requires_approval(true)
659 .with_rate_limit(10)
660 .with_permission("exec:shell".to_string())
661 .with_permission("exec:code".to_string())
662 .with_description("Run shell commands".to_string());
663
664 assert_eq!(tool.id, "exec");
665 assert!((tool.risk_score - 0.95).abs() < f64::EPSILON);
666 assert!(tool.requires_approval);
667 assert_eq!(tool.rate_limit, Some(10));
668 assert_eq!(tool.required_permissions.len(), 2);
669 assert_eq!(tool.description, "Run shell commands");
670 }
671
672 #[test]
673 fn test_tool_definition_risk_score_clamped() {
674 let tool_high =
675 ToolDefinition::new("t", "T", ToolCategory::DataRetrieval).with_risk_score(1.5);
676 assert!((tool_high.risk_score - 1.0).abs() < f64::EPSILON);
677
678 let tool_low =
679 ToolDefinition::new("t", "T", ToolCategory::DataRetrieval).with_risk_score(-0.5);
680 assert!(tool_low.risk_score.abs() < f64::EPSILON);
681 }
682
683 #[test]
688 fn test_registry_new_is_empty() {
689 let reg = ToolRegistry::new();
690 assert!(reg.is_empty());
691 assert_eq!(reg.len(), 0);
692 }
693
694 #[test]
695 fn test_registry_default_is_empty() {
696 let reg = ToolRegistry::default();
697 assert!(reg.is_empty());
698 }
699
700 #[test]
701 fn test_registry_register_and_get() {
702 let reg = ToolRegistry::new();
703 let tool =
704 ToolDefinition::new("search", "Search", ToolCategory::WebAccess).with_risk_score(0.3);
705 reg.register(tool);
706
707 assert!(reg.is_registered("search"));
708 assert!(!reg.is_registered("unknown"));
709 assert_eq!(reg.len(), 1);
710
711 let got = reg.get("search").unwrap();
712 assert_eq!(got.id, "search");
713 assert!((got.risk_score - 0.3).abs() < f64::EPSILON);
714 }
715
716 #[test]
717 fn test_registry_get_nonexistent() {
718 let reg = ToolRegistry::new();
719 assert!(reg.get("nope").is_none());
720 }
721
722 #[test]
723 fn test_registry_register_overwrites() {
724 let reg = ToolRegistry::new();
725 let tool_v1 =
726 ToolDefinition::new("t", "V1", ToolCategory::DataRetrieval).with_risk_score(0.1);
727 reg.register(tool_v1);
728 assert_eq!(reg.get("t").unwrap().name, "V1");
729
730 let tool_v2 = ToolDefinition::new("t", "V2", ToolCategory::Database).with_risk_score(0.9);
731 reg.register(tool_v2);
732 assert_eq!(reg.get("t").unwrap().name, "V2");
733 assert_eq!(reg.len(), 1);
734 }
735
736 #[test]
737 fn test_registry_unregister() {
738 let reg = ToolRegistry::new();
739 reg.register(ToolDefinition::new("a", "A", ToolCategory::DataRetrieval));
740 reg.register(ToolDefinition::new("b", "B", ToolCategory::DataRetrieval));
741 assert_eq!(reg.len(), 2);
742
743 assert!(reg.unregister("a"));
744 assert!(!reg.is_registered("a"));
745 assert!(reg.is_registered("b"));
746 assert_eq!(reg.len(), 1);
747
748 assert!(!reg.unregister("nonexistent"));
749 }
750
751 #[test]
752 fn test_registry_lookup_by_category() {
753 let reg = ToolRegistry::new();
754 reg.register(
755 ToolDefinition::new("web1", "Web 1", ToolCategory::WebAccess).with_risk_score(0.3),
756 );
757 reg.register(
758 ToolDefinition::new("web2", "Web 2", ToolCategory::WebAccess).with_risk_score(0.4),
759 );
760 reg.register(
761 ToolDefinition::new("file1", "File 1", ToolCategory::FileSystem).with_risk_score(0.5),
762 );
763
764 let web_tools = reg.lookup_by_category(&ToolCategory::WebAccess);
765 assert_eq!(web_tools.len(), 2);
766 assert!(web_tools
767 .iter()
768 .all(|t| t.category == ToolCategory::WebAccess));
769
770 let fs_tools = reg.lookup_by_category(&ToolCategory::FileSystem);
771 assert_eq!(fs_tools.len(), 1);
772
773 let db_tools = reg.lookup_by_category(&ToolCategory::Database);
774 assert!(db_tools.is_empty());
775 }
776
777 #[test]
782 fn test_registry_with_defaults_populated() {
783 let reg = ToolRegistry::with_defaults();
784 assert!(!reg.is_empty());
785 assert!(reg.is_registered("web_search"));
786 assert!(reg.is_registered("file_read"));
787 assert!(reg.is_registered("file_write"));
788 assert!(reg.is_registered("shell_exec"));
789 assert!(reg.is_registered("send_email"));
790 assert!(reg.is_registered("database_query"));
791 assert!(reg.is_registered("system_config"));
792 }
793
794 #[test]
795 fn test_registry_defaults_risk_scores() {
796 let reg = ToolRegistry::with_defaults();
797 let search = reg.get("web_search").unwrap();
798 assert!(search.risk_score <= 0.5, "web_search should be low risk");
799
800 let shell = reg.get("shell_exec").unwrap();
801 assert!(shell.risk_score > 0.8, "shell_exec should be high risk");
802 assert!(
803 shell.requires_approval,
804 "shell_exec should require approval"
805 );
806 }
807
808 #[test]
809 fn test_registry_defaults_categories() {
810 let reg = ToolRegistry::with_defaults();
811 assert_eq!(
812 reg.get("web_search").unwrap().category,
813 ToolCategory::WebAccess
814 );
815 assert_eq!(
816 reg.get("file_read").unwrap().category,
817 ToolCategory::FileSystem
818 );
819 assert_eq!(
820 reg.get("shell_exec").unwrap().category,
821 ToolCategory::CodeExecution
822 );
823 assert_eq!(
824 reg.get("send_email").unwrap().category,
825 ToolCategory::Communication
826 );
827 assert_eq!(
828 reg.get("database_query").unwrap().category,
829 ToolCategory::Database
830 );
831 assert_eq!(
832 reg.get("data_lookup").unwrap().category,
833 ToolCategory::DataRetrieval
834 );
835 assert_eq!(
836 reg.get("system_config").unwrap().category,
837 ToolCategory::SystemAdmin
838 );
839 }
840
841 #[test]
846 fn test_validate_unregistered_tool_call() {
847 let reg = ToolRegistry::new();
848 let action = AgentAction::new(AgentActionType::ToolCall, "unknown_tool".to_string());
849 let findings = reg.validate_action(&action);
850 assert_eq!(findings.len(), 1);
851 assert_eq!(findings[0].finding_type, "unregistered_tool");
852 assert_eq!(findings[0].severity, SecuritySeverity::High);
853 }
854
855 #[test]
856 fn test_validate_unregistered_skill_invocation() {
857 let reg = ToolRegistry::new();
858 let action = AgentAction::new(
859 AgentActionType::SkillInvocation,
860 "unknown_skill".to_string(),
861 );
862 let findings = reg.validate_action(&action);
863 assert_eq!(findings.len(), 1);
864 assert_eq!(findings[0].finding_type, "unregistered_tool");
865 }
866
867 #[test]
868 fn test_validate_unregistered_command_not_flagged() {
869 let reg = ToolRegistry::new();
870 let action = AgentAction::new(AgentActionType::CommandExecution, "ls -la".to_string());
871 let findings = reg.validate_action(&action);
872 assert!(
873 findings.is_empty(),
874 "CommandExecution should not trigger unregistered_tool"
875 );
876 }
877
878 #[test]
879 fn test_validate_unregistered_web_access_not_flagged() {
880 let reg = ToolRegistry::new();
881 let action = AgentAction::new(
882 AgentActionType::WebAccess,
883 "https://example.com".to_string(),
884 );
885 let findings = reg.validate_action(&action);
886 assert!(findings.is_empty());
887 }
888
889 #[test]
890 fn test_validate_unregistered_file_access_not_flagged() {
891 let reg = ToolRegistry::new();
892 let action = AgentAction::new(AgentActionType::FileAccess, "/tmp/file.txt".to_string());
893 let findings = reg.validate_action(&action);
894 assert!(findings.is_empty());
895 }
896
897 #[test]
898 fn test_validate_registered_low_risk_tool() {
899 let reg = ToolRegistry::new();
900 reg.register(
901 ToolDefinition::new("safe_tool", "Safe", ToolCategory::DataRetrieval)
902 .with_risk_score(0.1),
903 );
904 let action = AgentAction::new(AgentActionType::ToolCall, "safe_tool".to_string());
905 let findings = reg.validate_action(&action);
906 assert!(
907 findings.is_empty(),
908 "Low-risk registered tool should produce no findings"
909 );
910 }
911
912 #[test]
913 fn test_validate_high_risk_tool() {
914 let reg = ToolRegistry::new();
915 reg.register(
916 ToolDefinition::new("danger", "Danger", ToolCategory::CodeExecution)
917 .with_risk_score(0.85),
918 );
919 let action = AgentAction::new(AgentActionType::ToolCall, "danger".to_string());
920 let findings = reg.validate_action(&action);
921 assert!(findings.iter().any(|f| f.finding_type == "high_risk_tool"));
922 assert!(findings
923 .iter()
924 .any(|f| f.severity == SecuritySeverity::High));
925 }
926
927 #[test]
928 fn test_validate_tool_at_boundary_risk_score() {
929 let reg = ToolRegistry::new();
930 reg.register(
931 ToolDefinition::new("border", "Border", ToolCategory::FileSystem).with_risk_score(0.8),
932 );
933 let action = AgentAction::new(AgentActionType::ToolCall, "border".to_string());
934 let findings = reg.validate_action(&action);
935 assert!(
936 !findings.iter().any(|f| f.finding_type == "high_risk_tool"),
937 "risk_score == 0.8 is NOT > 0.8, should not trigger"
938 );
939 }
940
941 #[test]
942 fn test_validate_tool_requires_approval() {
943 let reg = ToolRegistry::new();
944 reg.register(
945 ToolDefinition::new("approve_me", "Approve", ToolCategory::Communication)
946 .with_risk_score(0.5)
947 .with_requires_approval(true),
948 );
949 let action = AgentAction::new(AgentActionType::ToolCall, "approve_me".to_string());
950 let findings = reg.validate_action(&action);
951 assert!(findings
952 .iter()
953 .any(|f| f.finding_type == "tool_requires_approval"));
954 let approval_finding = findings
955 .iter()
956 .find(|f| f.finding_type == "tool_requires_approval")
957 .unwrap();
958 assert_eq!(approval_finding.severity, SecuritySeverity::Info);
959 assert!(!approval_finding.requires_alert);
960 }
961
962 #[test]
963 fn test_validate_high_risk_and_requires_approval() {
964 let reg = ToolRegistry::new();
965 reg.register(
966 ToolDefinition::new("risky_approval", "Risky", ToolCategory::SystemAdmin)
967 .with_risk_score(0.95)
968 .with_requires_approval(true),
969 );
970 let action = AgentAction::new(AgentActionType::ToolCall, "risky_approval".to_string());
971 let findings = reg.validate_action(&action);
972 assert!(findings.iter().any(|f| f.finding_type == "high_risk_tool"));
973 assert!(findings
974 .iter()
975 .any(|f| f.finding_type == "tool_requires_approval"));
976 assert_eq!(findings.len(), 2);
977 }
978
979 #[test]
980 fn test_validate_with_defaults_known_tool() {
981 let reg = ToolRegistry::with_defaults();
982 let action = AgentAction::new(AgentActionType::ToolCall, "web_search".to_string());
983 let findings = reg.validate_action(&action);
984 assert!(
986 !findings
987 .iter()
988 .any(|f| f.finding_type == "unregistered_tool"),
989 "web_search is registered in defaults"
990 );
991 assert!(
992 !findings.iter().any(|f| f.finding_type == "high_risk_tool"),
993 "web_search is low risk"
994 );
995 }
996
997 #[test]
998 fn test_validate_with_defaults_shell_exec() {
999 let reg = ToolRegistry::with_defaults();
1000 let action = AgentAction::new(AgentActionType::ToolCall, "shell_exec".to_string());
1001 let findings = reg.validate_action(&action);
1002 assert!(
1003 findings.iter().any(|f| f.finding_type == "high_risk_tool"),
1004 "shell_exec has risk > 0.8"
1005 );
1006 assert!(findings
1007 .iter()
1008 .any(|f| f.finding_type == "tool_requires_approval"));
1009 }
1010
1011 #[test]
1016 fn test_registry_concurrent_access() {
1017 let reg = std::sync::Arc::new(ToolRegistry::new());
1018 let mut handles = Vec::new();
1019
1020 for i in 0..10 {
1021 let reg_clone = reg.clone();
1022 handles.push(thread::spawn(move || {
1023 let id = format!("tool_{}", i);
1024 reg_clone.register(ToolDefinition::new(&id, &id, ToolCategory::DataRetrieval));
1025 assert!(reg_clone.is_registered(&id));
1026 }));
1027 }
1028
1029 for handle in handles {
1030 handle.join().unwrap();
1031 }
1032
1033 assert_eq!(reg.len(), 10);
1034 }
1035
1036 #[test]
1041 fn test_rate_limit_exceeded_display() {
1042 let err = RateLimitExceeded {
1043 action_type: "tool_call".to_string(),
1044 limit: 10,
1045 window: Duration::from_secs(60),
1046 };
1047 let msg = err.to_string();
1048 assert!(msg.contains("tool_call"));
1049 assert!(msg.contains("10"));
1050 }
1051
1052 #[test]
1053 fn test_rate_limit_exceeded_to_security_finding() {
1054 let err = RateLimitExceeded {
1055 action_type: "web_access".to_string(),
1056 limit: 5,
1057 window: Duration::from_secs(60),
1058 };
1059 let finding = err.to_security_finding();
1060 assert_eq!(finding.finding_type, "action_rate_limit_exceeded");
1061 assert_eq!(finding.severity, SecuritySeverity::Medium);
1062 assert_eq!(
1063 finding.metadata.get("action_type"),
1064 Some(&"web_access".to_string())
1065 );
1066 assert_eq!(finding.metadata.get("limit"), Some(&"5".to_string()));
1067 }
1068
1069 #[test]
1074 fn test_rate_limiter_allows_within_limit() {
1075 let limiter = ActionRateLimiter::new(5, Duration::from_secs(60));
1076 for _ in 0..5 {
1077 assert!(limiter.check_rate_limit("tool_call").is_ok());
1078 }
1079 }
1080
1081 #[test]
1082 fn test_rate_limiter_denies_over_limit() {
1083 let limiter = ActionRateLimiter::new(3, Duration::from_secs(60));
1084 for _ in 0..3 {
1085 assert!(limiter.check_rate_limit("tool_call").is_ok());
1086 }
1087 let result = limiter.check_rate_limit("tool_call");
1088 assert!(result.is_err());
1089 let err = result.unwrap_err();
1090 assert_eq!(err.action_type, "tool_call");
1091 assert_eq!(err.limit, 3);
1092 }
1093
1094 #[test]
1095 fn test_rate_limiter_independent_action_types() {
1096 let limiter = ActionRateLimiter::new(2, Duration::from_secs(60));
1097 assert!(limiter.check_rate_limit("type_a").is_ok());
1098 assert!(limiter.check_rate_limit("type_a").is_ok());
1099 assert!(limiter.check_rate_limit("type_a").is_err());
1100
1101 assert!(limiter.check_rate_limit("type_b").is_ok());
1103 assert!(limiter.check_rate_limit("type_b").is_ok());
1104 assert!(limiter.check_rate_limit("type_b").is_err());
1105 }
1106
1107 #[test]
1108 fn test_rate_limiter_remaining() {
1109 let limiter = ActionRateLimiter::new(5, Duration::from_secs(60));
1110 assert_eq!(limiter.remaining("test"), 5);
1111
1112 limiter.check_rate_limit("test").unwrap();
1113 assert_eq!(limiter.remaining("test"), 4);
1114
1115 limiter.check_rate_limit("test").unwrap();
1116 limiter.check_rate_limit("test").unwrap();
1117 assert_eq!(limiter.remaining("test"), 2);
1118 }
1119
1120 #[test]
1121 fn test_rate_limiter_reset() {
1122 let limiter = ActionRateLimiter::new(3, Duration::from_secs(60));
1123 for _ in 0..3 {
1124 limiter.check_rate_limit("test").unwrap();
1125 }
1126 assert!(limiter.check_rate_limit("test").is_err());
1127 assert_eq!(limiter.remaining("test"), 0);
1128
1129 limiter.reset("test");
1130 assert_eq!(limiter.remaining("test"), 3);
1131 assert!(limiter.check_rate_limit("test").is_ok());
1132 }
1133
1134 #[test]
1135 fn test_rate_limiter_record_action() {
1136 let limiter = ActionRateLimiter::new(3, Duration::from_secs(60));
1137 limiter.record_action("test");
1138 limiter.record_action("test");
1139 assert_eq!(limiter.remaining("test"), 1);
1140
1141 limiter.record_action("test");
1142 assert!(limiter.check_rate_limit("test").is_err());
1144 }
1145
1146 #[test]
1147 fn test_rate_limiter_with_overrides() {
1148 let mut limits = HashMap::new();
1149 limits.insert("strict".to_string(), 1);
1150 limits.insert("relaxed".to_string(), 100);
1151
1152 let limiter = ActionRateLimiter::with_limits(10, Duration::from_secs(60), limits);
1153
1154 assert!(limiter.check_rate_limit("strict").is_ok());
1155 assert!(limiter.check_rate_limit("strict").is_err());
1156
1157 assert_eq!(limiter.remaining("relaxed"), 100);
1158
1159 assert_eq!(limiter.remaining("other"), 10);
1161 }
1162
1163 #[test]
1164 fn test_rate_limiter_sliding_window_expiry() {
1165 let limiter = ActionRateLimiter::new(2, Duration::from_millis(50));
1167 assert!(limiter.check_rate_limit("test").is_ok());
1168 assert!(limiter.check_rate_limit("test").is_ok());
1169 assert!(limiter.check_rate_limit("test").is_err());
1170
1171 thread::sleep(Duration::from_millis(60));
1173
1174 assert!(limiter.check_rate_limit("test").is_ok());
1176 }
1177
1178 #[test]
1179 fn test_rate_limiter_accessors() {
1180 let limiter = ActionRateLimiter::new(42, Duration::from_secs(120));
1181 assert_eq!(limiter.default_limit(), 42);
1182 assert_eq!(limiter.window_duration(), Duration::from_secs(120));
1183 }
1184
1185 #[test]
1186 fn test_rate_limiter_debug() {
1187 let limiter = ActionRateLimiter::new(10, Duration::from_secs(60));
1188 let debug_str = format!("{:?}", limiter);
1189 assert!(debug_str.contains("ActionRateLimiter"));
1190 assert!(debug_str.contains("10"));
1191 }
1192
1193 #[test]
1198 fn test_rate_limiter_concurrent_access() {
1199 let limiter = std::sync::Arc::new(ActionRateLimiter::new(100, Duration::from_secs(60)));
1200 let mut handles = Vec::new();
1201
1202 for _ in 0..10 {
1203 let limiter_clone = limiter.clone();
1204 handles.push(thread::spawn(move || {
1205 for _ in 0..10 {
1206 let _ = limiter_clone.check_rate_limit("concurrent");
1207 }
1208 }));
1209 }
1210
1211 for handle in handles {
1212 handle.join().unwrap();
1213 }
1214
1215 assert_eq!(limiter.remaining("concurrent"), 0);
1217 }
1218
1219 #[test]
1224 fn test_registry_debug() {
1225 let reg = ToolRegistry::new();
1226 reg.register(ToolDefinition::new("a", "A", ToolCategory::DataRetrieval));
1227 let debug_str = format!("{:?}", reg);
1228 assert!(debug_str.contains("ToolRegistry"));
1229 assert!(debug_str.contains("tool_count"));
1230 }
1231
1232 #[test]
1233 fn test_end_to_end_validation_and_rate_limit() {
1234 let registry = ToolRegistry::with_defaults();
1235 let limiter = ActionRateLimiter::new(2, Duration::from_secs(60));
1236
1237 let action = AgentAction::new(AgentActionType::ToolCall, "shell_exec".to_string());
1239 let mut all_findings = registry.validate_action(&action);
1240
1241 match limiter.check_rate_limit(&action.action_type.to_string()) {
1243 Ok(()) => {}
1244 Err(err) => all_findings.push(err.to_security_finding()),
1245 }
1246
1247 assert!(all_findings
1248 .iter()
1249 .any(|f| f.finding_type == "high_risk_tool"));
1250 assert!(all_findings
1251 .iter()
1252 .any(|f| f.finding_type == "tool_requires_approval"));
1253 assert!(
1254 !all_findings
1255 .iter()
1256 .any(|f| f.finding_type == "action_rate_limit_exceeded"),
1257 "First call should not be rate limited"
1258 );
1259
1260 let _ = limiter.check_rate_limit(&action.action_type.to_string());
1262
1263 let result = limiter.check_rate_limit(&action.action_type.to_string());
1265 assert!(result.is_err());
1266 let rate_finding = result.unwrap_err().to_security_finding();
1267 assert_eq!(rate_finding.finding_type, "action_rate_limit_exceeded");
1268 assert_eq!(rate_finding.severity, SecuritySeverity::Medium);
1269 }
1270}