Skip to main content

llmtrace_security/
tool_registry.rs

1//! Tool registry and action-type rate limiting for agent security.
2//!
3//! Provides a thread-safe [`ToolRegistry`] that classifies tools by security level
4//! (category, risk score, required permissions) and an [`ActionRateLimiter`] that
5//! enforces per-action-type sliding-window rate limits.
6//!
7//! # Example
8//!
9//! ```
10//! use llmtrace_security::tool_registry::{ToolRegistry, ActionRateLimiter, ToolDefinition, ToolCategory};
11//!
12//! // Pre-populated registry with sensible defaults
13//! let registry = ToolRegistry::with_defaults();
14//! assert!(registry.is_registered("web_search"));
15//!
16//! // Custom tool
17//! let tool = ToolDefinition::new("my_tool", "My Tool", ToolCategory::DataRetrieval)
18//!     .with_risk_score(0.1)
19//!     .with_description("A safe read-only tool".to_string());
20//! let mut registry = ToolRegistry::new();
21//! registry.register(tool);
22//! assert!(registry.is_registered("my_tool"));
23//!
24//! // Rate limiter
25//! let limiter = ActionRateLimiter::new(60, std::time::Duration::from_secs(60));
26//! assert!(limiter.check_rate_limit("tool_call").is_ok());
27//! ```
28
29use llmtrace_core::{AgentAction, AgentActionType, SecurityFinding, SecuritySeverity};
30use std::collections::{HashMap, VecDeque};
31use std::fmt;
32use std::sync::RwLock;
33use std::time::{Duration, Instant};
34
35// ---------------------------------------------------------------------------
36// ToolCategory
37// ---------------------------------------------------------------------------
38
39/// Security category for a tool.
40///
41/// Categories group tools by the type of operation they perform, which
42/// directly correlates with their inherent risk level.
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub enum ToolCategory {
45    /// Read-only data retrieval (low risk).
46    DataRetrieval,
47    /// Web browsing, HTTP requests (medium risk).
48    WebAccess,
49    /// File system operations (medium-high risk).
50    FileSystem,
51    /// Database operations (medium-high risk).
52    Database,
53    /// Code execution, shell commands (high risk).
54    CodeExecution,
55    /// Communication — sending emails, messages (high risk).
56    Communication,
57    /// System administration (critical risk).
58    SystemAdmin,
59    /// Custom user-defined category.
60    Custom(String),
61}
62
63impl fmt::Display for ToolCategory {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        match self {
66            Self::DataRetrieval => write!(f, "data_retrieval"),
67            Self::WebAccess => write!(f, "web_access"),
68            Self::FileSystem => write!(f, "file_system"),
69            Self::Database => write!(f, "database"),
70            Self::CodeExecution => write!(f, "code_execution"),
71            Self::Communication => write!(f, "communication"),
72            Self::SystemAdmin => write!(f, "system_admin"),
73            Self::Custom(name) => write!(f, "custom:{}", name),
74        }
75    }
76}
77
78// ---------------------------------------------------------------------------
79// ToolDefinition
80// ---------------------------------------------------------------------------
81
82/// Definition of a tool with security metadata.
83///
84/// Each tool registered in the [`ToolRegistry`] carries metadata that the
85/// security engine uses to assess risk, enforce rate limits, and decide
86/// whether user approval is required.
87#[derive(Debug, Clone)]
88pub struct ToolDefinition {
89    /// Unique tool identifier (e.g., `"web_search"`, `"file_read"`).
90    pub id: String,
91    /// Human-readable name.
92    pub name: String,
93    /// Security category.
94    pub category: ToolCategory,
95    /// Risk score 0.0–1.0 (0 = safe, 1 = dangerous).
96    pub risk_score: f64,
97    /// Whether this tool requires explicit user approval before execution.
98    pub requires_approval: bool,
99    /// Maximum calls per minute (`None` = unlimited).
100    pub rate_limit: Option<u32>,
101    /// List of permission strings required to use this tool.
102    pub required_permissions: Vec<String>,
103    /// Description of what this tool does.
104    pub description: String,
105}
106
107impl ToolDefinition {
108    /// Create a new tool definition with sensible defaults.
109    ///
110    /// The risk score defaults to `0.5`, approval is not required, and
111    /// no per-tool rate limit is set.
112    pub fn new(id: &str, name: &str, category: ToolCategory) -> Self {
113        Self {
114            id: id.to_string(),
115            name: name.to_string(),
116            category,
117            risk_score: 0.5,
118            requires_approval: false,
119            rate_limit: None,
120            required_permissions: Vec::new(),
121            description: String::new(),
122        }
123    }
124
125    /// Set the risk score (clamped to 0.0–1.0).
126    pub fn with_risk_score(mut self, score: f64) -> Self {
127        self.risk_score = score.clamp(0.0, 1.0);
128        self
129    }
130
131    /// Set whether this tool requires user approval.
132    pub fn with_requires_approval(mut self, requires: bool) -> Self {
133        self.requires_approval = requires;
134        self
135    }
136
137    /// Set a per-tool rate limit (calls per minute).
138    pub fn with_rate_limit(mut self, limit: u32) -> Self {
139        self.rate_limit = Some(limit);
140        self
141    }
142
143    /// Add a required permission string.
144    pub fn with_permission(mut self, permission: String) -> Self {
145        self.required_permissions.push(permission);
146        self
147    }
148
149    /// Set the tool description.
150    pub fn with_description(mut self, description: String) -> Self {
151        self.description = description;
152        self
153    }
154}
155
156// ---------------------------------------------------------------------------
157// ToolRegistry
158// ---------------------------------------------------------------------------
159
160/// Thread-safe registry of tool definitions.
161///
162/// Tools are stored in a `RwLock<HashMap<String, ToolDefinition>>` keyed by
163/// the tool's unique identifier. The registry supports concurrent reads and
164/// exclusive writes, making it safe to share across async tasks.
165///
166/// Use [`ToolRegistry::with_defaults`] to create a registry pre-populated
167/// with common agent tools.
168pub struct ToolRegistry {
169    /// Tool definitions keyed by tool ID.
170    tools: RwLock<HashMap<String, ToolDefinition>>,
171}
172
173impl ToolRegistry {
174    /// Create an empty tool registry.
175    pub fn new() -> Self {
176        Self {
177            tools: RwLock::new(HashMap::new()),
178        }
179    }
180
181    /// Create a tool registry pre-populated with sensible defaults for
182    /// common agent tool names.
183    pub fn with_defaults() -> Self {
184        let registry = Self::new();
185        let defaults = vec![
186            ToolDefinition::new("web_search", "Web Search", ToolCategory::WebAccess)
187                .with_risk_score(0.3)
188                .with_description("Search the web for information".to_string()),
189            ToolDefinition::new("web_browse", "Web Browse", ToolCategory::WebAccess)
190                .with_risk_score(0.4)
191                .with_description("Browse a web page and extract content".to_string()),
192            ToolDefinition::new("file_read", "File Read", ToolCategory::FileSystem)
193                .with_risk_score(0.3)
194                .with_description("Read contents of a file".to_string()),
195            ToolDefinition::new("file_write", "File Write", ToolCategory::FileSystem)
196                .with_risk_score(0.6)
197                .with_requires_approval(true)
198                .with_description("Write content to a file".to_string())
199                .with_permission("file:write".to_string()),
200            ToolDefinition::new("file_delete", "File Delete", ToolCategory::FileSystem)
201                .with_risk_score(0.8)
202                .with_requires_approval(true)
203                .with_description("Delete a file from the filesystem".to_string())
204                .with_permission("file:delete".to_string()),
205            ToolDefinition::new("shell_exec", "Shell Execute", ToolCategory::CodeExecution)
206                .with_risk_score(0.9)
207                .with_requires_approval(true)
208                .with_rate_limit(10)
209                .with_description("Execute a shell command".to_string())
210                .with_permission("exec:shell".to_string()),
211            ToolDefinition::new("code_exec", "Code Execute", ToolCategory::CodeExecution)
212                .with_risk_score(0.85)
213                .with_requires_approval(true)
214                .with_rate_limit(20)
215                .with_description("Execute code in a sandboxed environment".to_string())
216                .with_permission("exec:code".to_string()),
217            ToolDefinition::new("send_email", "Send Email", ToolCategory::Communication)
218                .with_risk_score(0.7)
219                .with_requires_approval(true)
220                .with_rate_limit(5)
221                .with_description("Send an email message".to_string())
222                .with_permission("comms:email".to_string()),
223            ToolDefinition::new("send_message", "Send Message", ToolCategory::Communication)
224                .with_risk_score(0.6)
225                .with_requires_approval(true)
226                .with_rate_limit(10)
227                .with_description("Send a chat or messaging platform message".to_string())
228                .with_permission("comms:message".to_string()),
229            ToolDefinition::new("database_query", "Database Query", ToolCategory::Database)
230                .with_risk_score(0.5)
231                .with_description("Execute a read-only database query".to_string())
232                .with_permission("db:read".to_string()),
233            ToolDefinition::new("database_write", "Database Write", ToolCategory::Database)
234                .with_risk_score(0.7)
235                .with_requires_approval(true)
236                .with_description("Execute a database write operation".to_string())
237                .with_permission("db:write".to_string()),
238            ToolDefinition::new("api_call", "API Call", ToolCategory::WebAccess)
239                .with_risk_score(0.4)
240                .with_description("Make an HTTP API call".to_string()),
241            ToolDefinition::new("data_lookup", "Data Lookup", ToolCategory::DataRetrieval)
242                .with_risk_score(0.1)
243                .with_description("Look up data from a knowledge base".to_string()),
244            ToolDefinition::new(
245                "system_config",
246                "System Configuration",
247                ToolCategory::SystemAdmin,
248            )
249            .with_risk_score(0.95)
250            .with_requires_approval(true)
251            .with_rate_limit(5)
252            .with_description("Modify system configuration".to_string())
253            .with_permission("admin:config".to_string()),
254        ];
255        for tool in defaults {
256            registry.register(tool);
257        }
258        registry
259    }
260
261    /// Register a tool definition.
262    ///
263    /// If a tool with the same ID already exists, it is replaced.
264    pub fn register(&self, tool: ToolDefinition) {
265        let mut tools = self.tools.write().expect("tool registry lock poisoned");
266        tools.insert(tool.id.clone(), tool);
267    }
268
269    /// Unregister a tool by its ID.
270    ///
271    /// Returns `true` if the tool was present and removed.
272    pub fn unregister(&self, id: &str) -> bool {
273        let mut tools = self.tools.write().expect("tool registry lock poisoned");
274        tools.remove(id).is_some()
275    }
276
277    /// Get a cloned copy of a tool definition by ID.
278    pub fn get(&self, id: &str) -> Option<ToolDefinition> {
279        let tools = self.tools.read().expect("tool registry lock poisoned");
280        tools.get(id).cloned()
281    }
282
283    /// Check whether a tool is registered.
284    pub fn is_registered(&self, id: &str) -> bool {
285        let tools = self.tools.read().expect("tool registry lock poisoned");
286        tools.contains_key(id)
287    }
288
289    /// Look up all tools in a given category.
290    ///
291    /// Returns cloned definitions because the internal lock cannot be held
292    /// across the returned references.
293    pub fn lookup_by_category(&self, category: &ToolCategory) -> Vec<ToolDefinition> {
294        let tools = self.tools.read().expect("tool registry lock poisoned");
295        tools
296            .values()
297            .filter(|t| &t.category == category)
298            .cloned()
299            .collect()
300    }
301
302    /// Return the number of registered tools.
303    pub fn len(&self) -> usize {
304        let tools = self.tools.read().expect("tool registry lock poisoned");
305        tools.len()
306    }
307
308    /// Return whether the registry is empty.
309    pub fn is_empty(&self) -> bool {
310        self.len() == 0
311    }
312
313    /// Validate an agent action against the registry.
314    ///
315    /// Produces [`SecurityFinding`]s for:
316    /// - **Unregistered tool usage** (`"unregistered_tool"`, severity `High`)
317    /// - **High-risk tool** with `risk_score > 0.8` (`"high_risk_tool"`, severity `High`)
318    /// - **Tool requiring approval** (`"tool_requires_approval"`, severity `Info`)
319    pub fn validate_action(&self, action: &AgentAction) -> Vec<SecurityFinding> {
320        let mut findings = Vec::new();
321        let tool_name = &action.name;
322
323        let tools = self.tools.read().expect("tool registry lock poisoned");
324
325        match tools.get(tool_name) {
326            None => {
327                // Only flag tool calls and skill invocations as unregistered.
328                // Other action types (command, web, file) use their own name
329                // semantics and are validated by the regex analyzer.
330                if action.action_type == AgentActionType::ToolCall
331                    || action.action_type == AgentActionType::SkillInvocation
332                {
333                    findings.push(
334                        SecurityFinding::new(
335                            SecuritySeverity::High,
336                            "unregistered_tool".to_string(),
337                            format!("Unregistered tool used: {}", tool_name),
338                            0.9,
339                        )
340                        .with_location("agent_action.tool_call".to_string())
341                        .with_metadata("tool_name".to_string(), tool_name.clone()),
342                    );
343                }
344            }
345            Some(tool) => {
346                if tool.risk_score > 0.8 {
347                    findings.push(
348                        SecurityFinding::new(
349                            SecuritySeverity::High,
350                            "high_risk_tool".to_string(),
351                            format!(
352                                "High-risk tool used: {} (risk score: {:.2})",
353                                tool_name, tool.risk_score
354                            ),
355                            tool.risk_score,
356                        )
357                        .with_location("agent_action.tool_call".to_string())
358                        .with_metadata("tool_name".to_string(), tool_name.clone())
359                        .with_metadata("risk_score".to_string(), format!("{:.2}", tool.risk_score))
360                        .with_metadata("category".to_string(), tool.category.to_string()),
361                    );
362                }
363                if tool.requires_approval {
364                    findings.push(
365                        SecurityFinding::new(
366                            SecuritySeverity::Info,
367                            "tool_requires_approval".to_string(),
368                            format!("Tool requires user approval: {}", tool_name),
369                            1.0,
370                        )
371                        .with_location("agent_action.tool_call".to_string())
372                        .with_metadata("tool_name".to_string(), tool_name.clone())
373                        .with_alert_required(false),
374                    );
375                }
376            }
377        }
378
379        findings
380    }
381}
382
383impl Default for ToolRegistry {
384    fn default() -> Self {
385        Self::new()
386    }
387}
388
389impl fmt::Debug for ToolRegistry {
390    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391        let tools = self.tools.read().expect("tool registry lock poisoned");
392        f.debug_struct("ToolRegistry")
393            .field("tool_count", &tools.len())
394            .field("tool_ids", &tools.keys().collect::<Vec<_>>())
395            .finish()
396    }
397}
398
399// ---------------------------------------------------------------------------
400// RateLimitExceeded
401// ---------------------------------------------------------------------------
402
403/// Error returned when an action-type rate limit is exceeded.
404#[derive(Debug, Clone)]
405pub struct RateLimitExceeded {
406    /// The action type that exceeded its rate limit.
407    pub action_type: String,
408    /// The configured limit (calls per window).
409    pub limit: u32,
410    /// Duration of the rate-limit window.
411    pub window: Duration,
412}
413
414impl fmt::Display for RateLimitExceeded {
415    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416        write!(
417            f,
418            "rate limit exceeded for action type '{}': {} calls per {:?}",
419            self.action_type, self.limit, self.window
420        )
421    }
422}
423
424impl std::error::Error for RateLimitExceeded {}
425
426impl RateLimitExceeded {
427    /// Convert this rate limit violation into a [`SecurityFinding`].
428    pub fn to_security_finding(&self) -> SecurityFinding {
429        SecurityFinding::new(
430            SecuritySeverity::Medium,
431            "action_rate_limit_exceeded".to_string(),
432            format!(
433                "Rate limit exceeded for action type '{}': limit is {} calls per {:?}",
434                self.action_type, self.limit, self.window
435            ),
436            0.95,
437        )
438        .with_metadata("action_type".to_string(), self.action_type.clone())
439        .with_metadata("limit".to_string(), self.limit.to_string())
440        .with_metadata("window_secs".to_string(), self.window.as_secs().to_string())
441    }
442}
443
444// ---------------------------------------------------------------------------
445// ActionRateLimiter
446// ---------------------------------------------------------------------------
447
448/// Per-action-type sliding-window rate limiter.
449///
450/// Tracks timestamps of recent action invocations and enforces a maximum
451/// number of calls within a configurable time window. Each action type
452/// (e.g. `"tool_call"`, `"web_access"`) has its own independent window.
453///
454/// # Thread Safety
455///
456/// All methods acquire the internal `RwLock` and are safe to call from
457/// multiple threads or async tasks.
458pub struct ActionRateLimiter {
459    /// Per-action-type sliding windows: `action_type -> VecDeque<Instant>`.
460    windows: RwLock<HashMap<String, VecDeque<Instant>>>,
461    /// Default rate limit (calls per window) for action types without an override.
462    default_limit: u32,
463    /// Per-action-type limit overrides.
464    limits: HashMap<String, u32>,
465    /// Duration of the sliding window.
466    window: Duration,
467}
468
469impl ActionRateLimiter {
470    /// Create a new rate limiter with the given default limit and window duration.
471    pub fn new(default_limit: u32, window: Duration) -> Self {
472        Self {
473            windows: RwLock::new(HashMap::new()),
474            default_limit,
475            limits: HashMap::new(),
476            window,
477        }
478    }
479
480    /// Create a rate limiter with per-action-type overrides.
481    pub fn with_limits(default_limit: u32, window: Duration, limits: HashMap<String, u32>) -> Self {
482        Self {
483            windows: RwLock::new(HashMap::new()),
484            default_limit,
485            limits,
486            window,
487        }
488    }
489
490    /// Get the effective limit for an action type.
491    fn effective_limit(&self, action_type: &str) -> u32 {
492        self.limits
493            .get(action_type)
494            .copied()
495            .unwrap_or(self.default_limit)
496    }
497
498    /// Prune expired entries from a deque, keeping only timestamps within
499    /// the current window relative to `now`.
500    fn prune(deque: &mut VecDeque<Instant>, cutoff: Instant) {
501        while let Some(&front) = deque.front() {
502            if front < cutoff {
503                deque.pop_front();
504            } else {
505                break;
506            }
507        }
508    }
509
510    /// Check the rate limit for an action type and record the action if allowed.
511    ///
512    /// Returns `Ok(())` if the action is within the limit, or
513    /// `Err(RateLimitExceeded)` if the limit has been reached.
514    pub fn check_rate_limit(
515        &self,
516        action_type: &str,
517    ) -> std::result::Result<(), RateLimitExceeded> {
518        let limit = self.effective_limit(action_type);
519        let now = Instant::now();
520        let cutoff = now - self.window;
521
522        let mut windows = self.windows.write().expect("rate limiter lock poisoned");
523        let deque = windows.entry(action_type.to_string()).or_default();
524        Self::prune(deque, cutoff);
525
526        if deque.len() >= limit as usize {
527            Err(RateLimitExceeded {
528                action_type: action_type.to_string(),
529                limit,
530                window: self.window,
531            })
532        } else {
533            deque.push_back(now);
534            Ok(())
535        }
536    }
537
538    /// Record an action without checking the rate limit.
539    ///
540    /// Useful for tracking actions that have already been validated by
541    /// other means.
542    pub fn record_action(&self, action_type: &str) {
543        let now = Instant::now();
544        let cutoff = now - self.window;
545
546        let mut windows = self.windows.write().expect("rate limiter lock poisoned");
547        let deque = windows.entry(action_type.to_string()).or_default();
548        Self::prune(deque, cutoff);
549        deque.push_back(now);
550    }
551
552    /// Return the number of remaining calls allowed for an action type
553    /// within the current window.
554    pub fn remaining(&self, action_type: &str) -> u32 {
555        let limit = self.effective_limit(action_type);
556        let now = Instant::now();
557        let cutoff = now - self.window;
558
559        let mut windows = self.windows.write().expect("rate limiter lock poisoned");
560        let deque = windows.entry(action_type.to_string()).or_default();
561        Self::prune(deque, cutoff);
562
563        limit.saturating_sub(deque.len() as u32)
564    }
565
566    /// Reset the sliding window for a specific action type.
567    pub fn reset(&self, action_type: &str) {
568        let mut windows = self.windows.write().expect("rate limiter lock poisoned");
569        windows.remove(action_type);
570    }
571
572    /// Return the configured window duration.
573    pub fn window_duration(&self) -> Duration {
574        self.window
575    }
576
577    /// Return the default rate limit.
578    pub fn default_limit(&self) -> u32 {
579        self.default_limit
580    }
581}
582
583impl fmt::Debug for ActionRateLimiter {
584    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
585        f.debug_struct("ActionRateLimiter")
586            .field("default_limit", &self.default_limit)
587            .field("window", &self.window)
588            .field("overrides", &self.limits)
589            .finish()
590    }
591}
592
593// ===========================================================================
594// Tests
595// ===========================================================================
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600    use llmtrace_core::{AgentAction, AgentActionType};
601    use std::thread;
602    use std::time::Duration;
603
604    // ---------------------------------------------------------------
605    // ToolCategory
606    // ---------------------------------------------------------------
607
608    #[test]
609    fn test_tool_category_display() {
610        assert_eq!(ToolCategory::DataRetrieval.to_string(), "data_retrieval");
611        assert_eq!(ToolCategory::WebAccess.to_string(), "web_access");
612        assert_eq!(ToolCategory::FileSystem.to_string(), "file_system");
613        assert_eq!(ToolCategory::Database.to_string(), "database");
614        assert_eq!(ToolCategory::CodeExecution.to_string(), "code_execution");
615        assert_eq!(ToolCategory::Communication.to_string(), "communication");
616        assert_eq!(ToolCategory::SystemAdmin.to_string(), "system_admin");
617        assert_eq!(
618            ToolCategory::Custom("my_cat".to_string()).to_string(),
619            "custom:my_cat"
620        );
621    }
622
623    #[test]
624    fn test_tool_category_equality() {
625        assert_eq!(ToolCategory::WebAccess, ToolCategory::WebAccess);
626        assert_ne!(ToolCategory::WebAccess, ToolCategory::FileSystem);
627        assert_eq!(
628            ToolCategory::Custom("x".to_string()),
629            ToolCategory::Custom("x".to_string())
630        );
631        assert_ne!(
632            ToolCategory::Custom("x".to_string()),
633            ToolCategory::Custom("y".to_string())
634        );
635    }
636
637    // ---------------------------------------------------------------
638    // ToolDefinition
639    // ---------------------------------------------------------------
640
641    #[test]
642    fn test_tool_definition_new_defaults() {
643        let tool = ToolDefinition::new("test", "Test", ToolCategory::DataRetrieval);
644        assert_eq!(tool.id, "test");
645        assert_eq!(tool.name, "Test");
646        assert_eq!(tool.category, ToolCategory::DataRetrieval);
647        assert!((tool.risk_score - 0.5).abs() < f64::EPSILON);
648        assert!(!tool.requires_approval);
649        assert!(tool.rate_limit.is_none());
650        assert!(tool.required_permissions.is_empty());
651        assert!(tool.description.is_empty());
652    }
653
654    #[test]
655    fn test_tool_definition_builder() {
656        let tool = ToolDefinition::new("exec", "Execute", ToolCategory::CodeExecution)
657            .with_risk_score(0.95)
658            .with_requires_approval(true)
659            .with_rate_limit(10)
660            .with_permission("exec:shell".to_string())
661            .with_permission("exec:code".to_string())
662            .with_description("Run shell commands".to_string());
663
664        assert_eq!(tool.id, "exec");
665        assert!((tool.risk_score - 0.95).abs() < f64::EPSILON);
666        assert!(tool.requires_approval);
667        assert_eq!(tool.rate_limit, Some(10));
668        assert_eq!(tool.required_permissions.len(), 2);
669        assert_eq!(tool.description, "Run shell commands");
670    }
671
672    #[test]
673    fn test_tool_definition_risk_score_clamped() {
674        let tool_high =
675            ToolDefinition::new("t", "T", ToolCategory::DataRetrieval).with_risk_score(1.5);
676        assert!((tool_high.risk_score - 1.0).abs() < f64::EPSILON);
677
678        let tool_low =
679            ToolDefinition::new("t", "T", ToolCategory::DataRetrieval).with_risk_score(-0.5);
680        assert!(tool_low.risk_score.abs() < f64::EPSILON);
681    }
682
683    // ---------------------------------------------------------------
684    // ToolRegistry — basic operations
685    // ---------------------------------------------------------------
686
687    #[test]
688    fn test_registry_new_is_empty() {
689        let reg = ToolRegistry::new();
690        assert!(reg.is_empty());
691        assert_eq!(reg.len(), 0);
692    }
693
694    #[test]
695    fn test_registry_default_is_empty() {
696        let reg = ToolRegistry::default();
697        assert!(reg.is_empty());
698    }
699
700    #[test]
701    fn test_registry_register_and_get() {
702        let reg = ToolRegistry::new();
703        let tool =
704            ToolDefinition::new("search", "Search", ToolCategory::WebAccess).with_risk_score(0.3);
705        reg.register(tool);
706
707        assert!(reg.is_registered("search"));
708        assert!(!reg.is_registered("unknown"));
709        assert_eq!(reg.len(), 1);
710
711        let got = reg.get("search").unwrap();
712        assert_eq!(got.id, "search");
713        assert!((got.risk_score - 0.3).abs() < f64::EPSILON);
714    }
715
716    #[test]
717    fn test_registry_get_nonexistent() {
718        let reg = ToolRegistry::new();
719        assert!(reg.get("nope").is_none());
720    }
721
722    #[test]
723    fn test_registry_register_overwrites() {
724        let reg = ToolRegistry::new();
725        let tool_v1 =
726            ToolDefinition::new("t", "V1", ToolCategory::DataRetrieval).with_risk_score(0.1);
727        reg.register(tool_v1);
728        assert_eq!(reg.get("t").unwrap().name, "V1");
729
730        let tool_v2 = ToolDefinition::new("t", "V2", ToolCategory::Database).with_risk_score(0.9);
731        reg.register(tool_v2);
732        assert_eq!(reg.get("t").unwrap().name, "V2");
733        assert_eq!(reg.len(), 1);
734    }
735
736    #[test]
737    fn test_registry_unregister() {
738        let reg = ToolRegistry::new();
739        reg.register(ToolDefinition::new("a", "A", ToolCategory::DataRetrieval));
740        reg.register(ToolDefinition::new("b", "B", ToolCategory::DataRetrieval));
741        assert_eq!(reg.len(), 2);
742
743        assert!(reg.unregister("a"));
744        assert!(!reg.is_registered("a"));
745        assert!(reg.is_registered("b"));
746        assert_eq!(reg.len(), 1);
747
748        assert!(!reg.unregister("nonexistent"));
749    }
750
751    #[test]
752    fn test_registry_lookup_by_category() {
753        let reg = ToolRegistry::new();
754        reg.register(
755            ToolDefinition::new("web1", "Web 1", ToolCategory::WebAccess).with_risk_score(0.3),
756        );
757        reg.register(
758            ToolDefinition::new("web2", "Web 2", ToolCategory::WebAccess).with_risk_score(0.4),
759        );
760        reg.register(
761            ToolDefinition::new("file1", "File 1", ToolCategory::FileSystem).with_risk_score(0.5),
762        );
763
764        let web_tools = reg.lookup_by_category(&ToolCategory::WebAccess);
765        assert_eq!(web_tools.len(), 2);
766        assert!(web_tools
767            .iter()
768            .all(|t| t.category == ToolCategory::WebAccess));
769
770        let fs_tools = reg.lookup_by_category(&ToolCategory::FileSystem);
771        assert_eq!(fs_tools.len(), 1);
772
773        let db_tools = reg.lookup_by_category(&ToolCategory::Database);
774        assert!(db_tools.is_empty());
775    }
776
777    // ---------------------------------------------------------------
778    // ToolRegistry — with_defaults
779    // ---------------------------------------------------------------
780
781    #[test]
782    fn test_registry_with_defaults_populated() {
783        let reg = ToolRegistry::with_defaults();
784        assert!(!reg.is_empty());
785        assert!(reg.is_registered("web_search"));
786        assert!(reg.is_registered("file_read"));
787        assert!(reg.is_registered("file_write"));
788        assert!(reg.is_registered("shell_exec"));
789        assert!(reg.is_registered("send_email"));
790        assert!(reg.is_registered("database_query"));
791        assert!(reg.is_registered("system_config"));
792    }
793
794    #[test]
795    fn test_registry_defaults_risk_scores() {
796        let reg = ToolRegistry::with_defaults();
797        let search = reg.get("web_search").unwrap();
798        assert!(search.risk_score <= 0.5, "web_search should be low risk");
799
800        let shell = reg.get("shell_exec").unwrap();
801        assert!(shell.risk_score > 0.8, "shell_exec should be high risk");
802        assert!(
803            shell.requires_approval,
804            "shell_exec should require approval"
805        );
806    }
807
808    #[test]
809    fn test_registry_defaults_categories() {
810        let reg = ToolRegistry::with_defaults();
811        assert_eq!(
812            reg.get("web_search").unwrap().category,
813            ToolCategory::WebAccess
814        );
815        assert_eq!(
816            reg.get("file_read").unwrap().category,
817            ToolCategory::FileSystem
818        );
819        assert_eq!(
820            reg.get("shell_exec").unwrap().category,
821            ToolCategory::CodeExecution
822        );
823        assert_eq!(
824            reg.get("send_email").unwrap().category,
825            ToolCategory::Communication
826        );
827        assert_eq!(
828            reg.get("database_query").unwrap().category,
829            ToolCategory::Database
830        );
831        assert_eq!(
832            reg.get("data_lookup").unwrap().category,
833            ToolCategory::DataRetrieval
834        );
835        assert_eq!(
836            reg.get("system_config").unwrap().category,
837            ToolCategory::SystemAdmin
838        );
839    }
840
841    // ---------------------------------------------------------------
842    // ToolRegistry — validate_action
843    // ---------------------------------------------------------------
844
845    #[test]
846    fn test_validate_unregistered_tool_call() {
847        let reg = ToolRegistry::new();
848        let action = AgentAction::new(AgentActionType::ToolCall, "unknown_tool".to_string());
849        let findings = reg.validate_action(&action);
850        assert_eq!(findings.len(), 1);
851        assert_eq!(findings[0].finding_type, "unregistered_tool");
852        assert_eq!(findings[0].severity, SecuritySeverity::High);
853    }
854
855    #[test]
856    fn test_validate_unregistered_skill_invocation() {
857        let reg = ToolRegistry::new();
858        let action = AgentAction::new(
859            AgentActionType::SkillInvocation,
860            "unknown_skill".to_string(),
861        );
862        let findings = reg.validate_action(&action);
863        assert_eq!(findings.len(), 1);
864        assert_eq!(findings[0].finding_type, "unregistered_tool");
865    }
866
867    #[test]
868    fn test_validate_unregistered_command_not_flagged() {
869        let reg = ToolRegistry::new();
870        let action = AgentAction::new(AgentActionType::CommandExecution, "ls -la".to_string());
871        let findings = reg.validate_action(&action);
872        assert!(
873            findings.is_empty(),
874            "CommandExecution should not trigger unregistered_tool"
875        );
876    }
877
878    #[test]
879    fn test_validate_unregistered_web_access_not_flagged() {
880        let reg = ToolRegistry::new();
881        let action = AgentAction::new(
882            AgentActionType::WebAccess,
883            "https://example.com".to_string(),
884        );
885        let findings = reg.validate_action(&action);
886        assert!(findings.is_empty());
887    }
888
889    #[test]
890    fn test_validate_unregistered_file_access_not_flagged() {
891        let reg = ToolRegistry::new();
892        let action = AgentAction::new(AgentActionType::FileAccess, "/tmp/file.txt".to_string());
893        let findings = reg.validate_action(&action);
894        assert!(findings.is_empty());
895    }
896
897    #[test]
898    fn test_validate_registered_low_risk_tool() {
899        let reg = ToolRegistry::new();
900        reg.register(
901            ToolDefinition::new("safe_tool", "Safe", ToolCategory::DataRetrieval)
902                .with_risk_score(0.1),
903        );
904        let action = AgentAction::new(AgentActionType::ToolCall, "safe_tool".to_string());
905        let findings = reg.validate_action(&action);
906        assert!(
907            findings.is_empty(),
908            "Low-risk registered tool should produce no findings"
909        );
910    }
911
912    #[test]
913    fn test_validate_high_risk_tool() {
914        let reg = ToolRegistry::new();
915        reg.register(
916            ToolDefinition::new("danger", "Danger", ToolCategory::CodeExecution)
917                .with_risk_score(0.85),
918        );
919        let action = AgentAction::new(AgentActionType::ToolCall, "danger".to_string());
920        let findings = reg.validate_action(&action);
921        assert!(findings.iter().any(|f| f.finding_type == "high_risk_tool"));
922        assert!(findings
923            .iter()
924            .any(|f| f.severity == SecuritySeverity::High));
925    }
926
927    #[test]
928    fn test_validate_tool_at_boundary_risk_score() {
929        let reg = ToolRegistry::new();
930        reg.register(
931            ToolDefinition::new("border", "Border", ToolCategory::FileSystem).with_risk_score(0.8),
932        );
933        let action = AgentAction::new(AgentActionType::ToolCall, "border".to_string());
934        let findings = reg.validate_action(&action);
935        assert!(
936            !findings.iter().any(|f| f.finding_type == "high_risk_tool"),
937            "risk_score == 0.8 is NOT > 0.8, should not trigger"
938        );
939    }
940
941    #[test]
942    fn test_validate_tool_requires_approval() {
943        let reg = ToolRegistry::new();
944        reg.register(
945            ToolDefinition::new("approve_me", "Approve", ToolCategory::Communication)
946                .with_risk_score(0.5)
947                .with_requires_approval(true),
948        );
949        let action = AgentAction::new(AgentActionType::ToolCall, "approve_me".to_string());
950        let findings = reg.validate_action(&action);
951        assert!(findings
952            .iter()
953            .any(|f| f.finding_type == "tool_requires_approval"));
954        let approval_finding = findings
955            .iter()
956            .find(|f| f.finding_type == "tool_requires_approval")
957            .unwrap();
958        assert_eq!(approval_finding.severity, SecuritySeverity::Info);
959        assert!(!approval_finding.requires_alert);
960    }
961
962    #[test]
963    fn test_validate_high_risk_and_requires_approval() {
964        let reg = ToolRegistry::new();
965        reg.register(
966            ToolDefinition::new("risky_approval", "Risky", ToolCategory::SystemAdmin)
967                .with_risk_score(0.95)
968                .with_requires_approval(true),
969        );
970        let action = AgentAction::new(AgentActionType::ToolCall, "risky_approval".to_string());
971        let findings = reg.validate_action(&action);
972        assert!(findings.iter().any(|f| f.finding_type == "high_risk_tool"));
973        assert!(findings
974            .iter()
975            .any(|f| f.finding_type == "tool_requires_approval"));
976        assert_eq!(findings.len(), 2);
977    }
978
979    #[test]
980    fn test_validate_with_defaults_known_tool() {
981        let reg = ToolRegistry::with_defaults();
982        let action = AgentAction::new(AgentActionType::ToolCall, "web_search".to_string());
983        let findings = reg.validate_action(&action);
984        // web_search is low risk, no approval needed
985        assert!(
986            !findings
987                .iter()
988                .any(|f| f.finding_type == "unregistered_tool"),
989            "web_search is registered in defaults"
990        );
991        assert!(
992            !findings.iter().any(|f| f.finding_type == "high_risk_tool"),
993            "web_search is low risk"
994        );
995    }
996
997    #[test]
998    fn test_validate_with_defaults_shell_exec() {
999        let reg = ToolRegistry::with_defaults();
1000        let action = AgentAction::new(AgentActionType::ToolCall, "shell_exec".to_string());
1001        let findings = reg.validate_action(&action);
1002        assert!(
1003            findings.iter().any(|f| f.finding_type == "high_risk_tool"),
1004            "shell_exec has risk > 0.8"
1005        );
1006        assert!(findings
1007            .iter()
1008            .any(|f| f.finding_type == "tool_requires_approval"));
1009    }
1010
1011    // ---------------------------------------------------------------
1012    // ToolRegistry — thread safety
1013    // ---------------------------------------------------------------
1014
1015    #[test]
1016    fn test_registry_concurrent_access() {
1017        let reg = std::sync::Arc::new(ToolRegistry::new());
1018        let mut handles = Vec::new();
1019
1020        for i in 0..10 {
1021            let reg_clone = reg.clone();
1022            handles.push(thread::spawn(move || {
1023                let id = format!("tool_{}", i);
1024                reg_clone.register(ToolDefinition::new(&id, &id, ToolCategory::DataRetrieval));
1025                assert!(reg_clone.is_registered(&id));
1026            }));
1027        }
1028
1029        for handle in handles {
1030            handle.join().unwrap();
1031        }
1032
1033        assert_eq!(reg.len(), 10);
1034    }
1035
1036    // ---------------------------------------------------------------
1037    // RateLimitExceeded
1038    // ---------------------------------------------------------------
1039
1040    #[test]
1041    fn test_rate_limit_exceeded_display() {
1042        let err = RateLimitExceeded {
1043            action_type: "tool_call".to_string(),
1044            limit: 10,
1045            window: Duration::from_secs(60),
1046        };
1047        let msg = err.to_string();
1048        assert!(msg.contains("tool_call"));
1049        assert!(msg.contains("10"));
1050    }
1051
1052    #[test]
1053    fn test_rate_limit_exceeded_to_security_finding() {
1054        let err = RateLimitExceeded {
1055            action_type: "web_access".to_string(),
1056            limit: 5,
1057            window: Duration::from_secs(60),
1058        };
1059        let finding = err.to_security_finding();
1060        assert_eq!(finding.finding_type, "action_rate_limit_exceeded");
1061        assert_eq!(finding.severity, SecuritySeverity::Medium);
1062        assert_eq!(
1063            finding.metadata.get("action_type"),
1064            Some(&"web_access".to_string())
1065        );
1066        assert_eq!(finding.metadata.get("limit"), Some(&"5".to_string()));
1067    }
1068
1069    // ---------------------------------------------------------------
1070    // ActionRateLimiter — basic
1071    // ---------------------------------------------------------------
1072
1073    #[test]
1074    fn test_rate_limiter_allows_within_limit() {
1075        let limiter = ActionRateLimiter::new(5, Duration::from_secs(60));
1076        for _ in 0..5 {
1077            assert!(limiter.check_rate_limit("tool_call").is_ok());
1078        }
1079    }
1080
1081    #[test]
1082    fn test_rate_limiter_denies_over_limit() {
1083        let limiter = ActionRateLimiter::new(3, Duration::from_secs(60));
1084        for _ in 0..3 {
1085            assert!(limiter.check_rate_limit("tool_call").is_ok());
1086        }
1087        let result = limiter.check_rate_limit("tool_call");
1088        assert!(result.is_err());
1089        let err = result.unwrap_err();
1090        assert_eq!(err.action_type, "tool_call");
1091        assert_eq!(err.limit, 3);
1092    }
1093
1094    #[test]
1095    fn test_rate_limiter_independent_action_types() {
1096        let limiter = ActionRateLimiter::new(2, Duration::from_secs(60));
1097        assert!(limiter.check_rate_limit("type_a").is_ok());
1098        assert!(limiter.check_rate_limit("type_a").is_ok());
1099        assert!(limiter.check_rate_limit("type_a").is_err());
1100
1101        // type_b should be independent
1102        assert!(limiter.check_rate_limit("type_b").is_ok());
1103        assert!(limiter.check_rate_limit("type_b").is_ok());
1104        assert!(limiter.check_rate_limit("type_b").is_err());
1105    }
1106
1107    #[test]
1108    fn test_rate_limiter_remaining() {
1109        let limiter = ActionRateLimiter::new(5, Duration::from_secs(60));
1110        assert_eq!(limiter.remaining("test"), 5);
1111
1112        limiter.check_rate_limit("test").unwrap();
1113        assert_eq!(limiter.remaining("test"), 4);
1114
1115        limiter.check_rate_limit("test").unwrap();
1116        limiter.check_rate_limit("test").unwrap();
1117        assert_eq!(limiter.remaining("test"), 2);
1118    }
1119
1120    #[test]
1121    fn test_rate_limiter_reset() {
1122        let limiter = ActionRateLimiter::new(3, Duration::from_secs(60));
1123        for _ in 0..3 {
1124            limiter.check_rate_limit("test").unwrap();
1125        }
1126        assert!(limiter.check_rate_limit("test").is_err());
1127        assert_eq!(limiter.remaining("test"), 0);
1128
1129        limiter.reset("test");
1130        assert_eq!(limiter.remaining("test"), 3);
1131        assert!(limiter.check_rate_limit("test").is_ok());
1132    }
1133
1134    #[test]
1135    fn test_rate_limiter_record_action() {
1136        let limiter = ActionRateLimiter::new(3, Duration::from_secs(60));
1137        limiter.record_action("test");
1138        limiter.record_action("test");
1139        assert_eq!(limiter.remaining("test"), 1);
1140
1141        limiter.record_action("test");
1142        // Now at limit; check should fail
1143        assert!(limiter.check_rate_limit("test").is_err());
1144    }
1145
1146    #[test]
1147    fn test_rate_limiter_with_overrides() {
1148        let mut limits = HashMap::new();
1149        limits.insert("strict".to_string(), 1);
1150        limits.insert("relaxed".to_string(), 100);
1151
1152        let limiter = ActionRateLimiter::with_limits(10, Duration::from_secs(60), limits);
1153
1154        assert!(limiter.check_rate_limit("strict").is_ok());
1155        assert!(limiter.check_rate_limit("strict").is_err());
1156
1157        assert_eq!(limiter.remaining("relaxed"), 100);
1158
1159        // default (no override)
1160        assert_eq!(limiter.remaining("other"), 10);
1161    }
1162
1163    #[test]
1164    fn test_rate_limiter_sliding_window_expiry() {
1165        // Use a very short window to test expiry
1166        let limiter = ActionRateLimiter::new(2, Duration::from_millis(50));
1167        assert!(limiter.check_rate_limit("test").is_ok());
1168        assert!(limiter.check_rate_limit("test").is_ok());
1169        assert!(limiter.check_rate_limit("test").is_err());
1170
1171        // Wait for the window to expire
1172        thread::sleep(Duration::from_millis(60));
1173
1174        // Should be allowed again
1175        assert!(limiter.check_rate_limit("test").is_ok());
1176    }
1177
1178    #[test]
1179    fn test_rate_limiter_accessors() {
1180        let limiter = ActionRateLimiter::new(42, Duration::from_secs(120));
1181        assert_eq!(limiter.default_limit(), 42);
1182        assert_eq!(limiter.window_duration(), Duration::from_secs(120));
1183    }
1184
1185    #[test]
1186    fn test_rate_limiter_debug() {
1187        let limiter = ActionRateLimiter::new(10, Duration::from_secs(60));
1188        let debug_str = format!("{:?}", limiter);
1189        assert!(debug_str.contains("ActionRateLimiter"));
1190        assert!(debug_str.contains("10"));
1191    }
1192
1193    // ---------------------------------------------------------------
1194    // ActionRateLimiter — thread safety
1195    // ---------------------------------------------------------------
1196
1197    #[test]
1198    fn test_rate_limiter_concurrent_access() {
1199        let limiter = std::sync::Arc::new(ActionRateLimiter::new(100, Duration::from_secs(60)));
1200        let mut handles = Vec::new();
1201
1202        for _ in 0..10 {
1203            let limiter_clone = limiter.clone();
1204            handles.push(thread::spawn(move || {
1205                for _ in 0..10 {
1206                    let _ = limiter_clone.check_rate_limit("concurrent");
1207                }
1208            }));
1209        }
1210
1211        for handle in handles {
1212            handle.join().unwrap();
1213        }
1214
1215        // All 100 calls should have been recorded
1216        assert_eq!(limiter.remaining("concurrent"), 0);
1217    }
1218
1219    // ---------------------------------------------------------------
1220    // Integration: registry + rate limiter
1221    // ---------------------------------------------------------------
1222
1223    #[test]
1224    fn test_registry_debug() {
1225        let reg = ToolRegistry::new();
1226        reg.register(ToolDefinition::new("a", "A", ToolCategory::DataRetrieval));
1227        let debug_str = format!("{:?}", reg);
1228        assert!(debug_str.contains("ToolRegistry"));
1229        assert!(debug_str.contains("tool_count"));
1230    }
1231
1232    #[test]
1233    fn test_end_to_end_validation_and_rate_limit() {
1234        let registry = ToolRegistry::with_defaults();
1235        let limiter = ActionRateLimiter::new(2, Duration::from_secs(60));
1236
1237        // First call to shell_exec — high risk + requires approval
1238        let action = AgentAction::new(AgentActionType::ToolCall, "shell_exec".to_string());
1239        let mut all_findings = registry.validate_action(&action);
1240
1241        // Check rate limit
1242        match limiter.check_rate_limit(&action.action_type.to_string()) {
1243            Ok(()) => {}
1244            Err(err) => all_findings.push(err.to_security_finding()),
1245        }
1246
1247        assert!(all_findings
1248            .iter()
1249            .any(|f| f.finding_type == "high_risk_tool"));
1250        assert!(all_findings
1251            .iter()
1252            .any(|f| f.finding_type == "tool_requires_approval"));
1253        assert!(
1254            !all_findings
1255                .iter()
1256                .any(|f| f.finding_type == "action_rate_limit_exceeded"),
1257            "First call should not be rate limited"
1258        );
1259
1260        // Second call
1261        let _ = limiter.check_rate_limit(&action.action_type.to_string());
1262
1263        // Third call — should be rate limited
1264        let result = limiter.check_rate_limit(&action.action_type.to_string());
1265        assert!(result.is_err());
1266        let rate_finding = result.unwrap_err().to_security_finding();
1267        assert_eq!(rate_finding.finding_type, "action_rate_limit_exceeded");
1268        assert_eq!(rate_finding.severity, SecuritySeverity::Medium);
1269    }
1270}