Skip to main content

ai_agents_tools/security/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ToolSecurityConfig {
6    #[serde(default)]
7    pub enabled: bool,
8    #[serde(default = "default_tool_timeout")]
9    pub default_timeout_ms: u64,
10    #[serde(default)]
11    pub tools: HashMap<String, ToolPolicyConfig>,
12}
13
14impl Default for ToolSecurityConfig {
15    fn default() -> Self {
16        Self {
17            enabled: false,
18            default_timeout_ms: default_tool_timeout(),
19            tools: HashMap::new(),
20        }
21    }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ToolPolicyConfig {
26    #[serde(default = "default_true")]
27    pub enabled: bool,
28    #[serde(default)]
29    pub require_confirmation: bool,
30    #[serde(default)]
31    pub confirmation_message: Option<String>,
32    #[serde(default)]
33    pub rate_limit: Option<u32>,
34    #[serde(default)]
35    pub timeout_ms: Option<u64>,
36    #[serde(default)]
37    pub allowed_domains: Vec<String>,
38    #[serde(default)]
39    pub blocked_domains: Vec<String>,
40    #[serde(default)]
41    pub allowed_paths: Vec<String>,
42}
43
44impl Default for ToolPolicyConfig {
45    fn default() -> Self {
46        Self {
47            enabled: true,
48            require_confirmation: false,
49            confirmation_message: None,
50            rate_limit: None,
51            timeout_ms: None,
52            allowed_domains: Vec::new(),
53            blocked_domains: Vec::new(),
54            allowed_paths: Vec::new(),
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub enum SecurityCheckResult {
61    Allow,
62    Block { reason: String },
63    Warn { message: String },
64    RequireConfirmation { message: String },
65}
66
67impl SecurityCheckResult {
68    pub fn is_allowed(&self) -> bool {
69        matches!(
70            self,
71            SecurityCheckResult::Allow | SecurityCheckResult::Warn { .. }
72        )
73    }
74
75    pub fn is_blocked(&self) -> bool {
76        matches!(self, SecurityCheckResult::Block { .. })
77    }
78}
79
80fn default_tool_timeout() -> u64 {
81    30000
82}
83
84fn default_true() -> bool {
85    true
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn test_default_config() {
94        let config = ToolSecurityConfig::default();
95        assert!(!config.enabled);
96        assert_eq!(config.default_timeout_ms, 30000);
97        assert!(config.tools.is_empty());
98    }
99
100    #[test]
101    fn test_yaml_parsing() {
102        let yaml = r#"
103enabled: true
104default_timeout_ms: 10000
105tools:
106  http:
107    rate_limit: 10
108    blocked_domains:
109      - evil.com
110    allowed_domains:
111      - api.example.com
112  file_write:
113    require_confirmation: true
114    confirmation_message: "Are you sure you want to write this file?"
115    allowed_paths:
116      - /tmp/
117"#;
118        let config: ToolSecurityConfig = serde_yaml::from_str(yaml).unwrap();
119        assert!(config.enabled);
120        assert_eq!(config.default_timeout_ms, 10000);
121        assert!(config.tools.contains_key("http"));
122        assert!(config.tools.contains_key("file_write"));
123
124        let http = config.tools.get("http").unwrap();
125        assert_eq!(http.rate_limit, Some(10));
126        assert_eq!(http.blocked_domains, vec!["evil.com"]);
127
128        let file_write = config.tools.get("file_write").unwrap();
129        assert!(file_write.require_confirmation);
130    }
131
132    #[test]
133    fn test_security_check_result() {
134        let allow = SecurityCheckResult::Allow;
135        assert!(allow.is_allowed());
136        assert!(!allow.is_blocked());
137
138        let block = SecurityCheckResult::Block {
139            reason: "test".into(),
140        };
141        assert!(!block.is_allowed());
142        assert!(block.is_blocked());
143
144        let warn = SecurityCheckResult::Warn {
145            message: "warning".into(),
146        };
147        assert!(warn.is_allowed());
148        assert!(!warn.is_blocked());
149    }
150
151    #[test]
152    fn test_tool_policy_defaults() {
153        let policy = ToolPolicyConfig::default();
154        assert!(policy.enabled);
155        assert!(!policy.require_confirmation);
156        assert!(policy.rate_limit.is_none());
157    }
158}