ai_agents_tools/security/
config.rs1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ToolSecurityConfig {
6 #[serde(default)]
7 pub enabled: bool,
8 #[serde(default = "default_tool_timeout")]
9 pub default_timeout_ms: u64,
10 #[serde(default)]
11 pub tools: HashMap<String, ToolPolicyConfig>,
12}
13
14impl Default for ToolSecurityConfig {
15 fn default() -> Self {
16 Self {
17 enabled: false,
18 default_timeout_ms: default_tool_timeout(),
19 tools: HashMap::new(),
20 }
21 }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ToolPolicyConfig {
26 #[serde(default = "default_true")]
27 pub enabled: bool,
28 #[serde(default)]
29 pub require_confirmation: bool,
30 #[serde(default)]
31 pub confirmation_message: Option<String>,
32 #[serde(default)]
33 pub rate_limit: Option<u32>,
34 #[serde(default)]
35 pub timeout_ms: Option<u64>,
36 #[serde(default)]
37 pub allowed_domains: Vec<String>,
38 #[serde(default)]
39 pub blocked_domains: Vec<String>,
40 #[serde(default)]
41 pub allowed_paths: Vec<String>,
42}
43
44impl Default for ToolPolicyConfig {
45 fn default() -> Self {
46 Self {
47 enabled: true,
48 require_confirmation: false,
49 confirmation_message: None,
50 rate_limit: None,
51 timeout_ms: None,
52 allowed_domains: Vec::new(),
53 blocked_domains: Vec::new(),
54 allowed_paths: Vec::new(),
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
60pub enum SecurityCheckResult {
61 Allow,
62 Block { reason: String },
63 Warn { message: String },
64 RequireConfirmation { message: String },
65}
66
67impl SecurityCheckResult {
68 pub fn is_allowed(&self) -> bool {
69 matches!(
70 self,
71 SecurityCheckResult::Allow | SecurityCheckResult::Warn { .. }
72 )
73 }
74
75 pub fn is_blocked(&self) -> bool {
76 matches!(self, SecurityCheckResult::Block { .. })
77 }
78}
79
80fn default_tool_timeout() -> u64 {
81 30000
82}
83
84fn default_true() -> bool {
85 true
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 #[test]
93 fn test_default_config() {
94 let config = ToolSecurityConfig::default();
95 assert!(!config.enabled);
96 assert_eq!(config.default_timeout_ms, 30000);
97 assert!(config.tools.is_empty());
98 }
99
100 #[test]
101 fn test_yaml_parsing() {
102 let yaml = r#"
103enabled: true
104default_timeout_ms: 10000
105tools:
106 http:
107 rate_limit: 10
108 blocked_domains:
109 - evil.com
110 allowed_domains:
111 - api.example.com
112 file_write:
113 require_confirmation: true
114 confirmation_message: "Are you sure you want to write this file?"
115 allowed_paths:
116 - /tmp/
117"#;
118 let config: ToolSecurityConfig = serde_yaml::from_str(yaml).unwrap();
119 assert!(config.enabled);
120 assert_eq!(config.default_timeout_ms, 10000);
121 assert!(config.tools.contains_key("http"));
122 assert!(config.tools.contains_key("file_write"));
123
124 let http = config.tools.get("http").unwrap();
125 assert_eq!(http.rate_limit, Some(10));
126 assert_eq!(http.blocked_domains, vec!["evil.com"]);
127
128 let file_write = config.tools.get("file_write").unwrap();
129 assert!(file_write.require_confirmation);
130 }
131
132 #[test]
133 fn test_security_check_result() {
134 let allow = SecurityCheckResult::Allow;
135 assert!(allow.is_allowed());
136 assert!(!allow.is_blocked());
137
138 let block = SecurityCheckResult::Block {
139 reason: "test".into(),
140 };
141 assert!(!block.is_allowed());
142 assert!(block.is_blocked());
143
144 let warn = SecurityCheckResult::Warn {
145 message: "warning".into(),
146 };
147 assert!(warn.is_allowed());
148 assert!(!warn.is_blocked());
149 }
150
151 #[test]
152 fn test_tool_policy_defaults() {
153 let policy = ToolPolicyConfig::default();
154 assert!(policy.enabled);
155 assert!(!policy.require_confirmation);
156 assert!(policy.rate_limit.is_none());
157 }
158}