1use crate::error::Result;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ToolPermissionLevel {
14 Allow,
16 Ask,
18 Deny,
20}
21
22impl ToolPermissionLevel {
23 pub fn from_permission_level(level: &str) -> Option<Self> {
25 match level {
26 "allow" => Some(ToolPermissionLevel::Allow),
27 "ask" => Some(ToolPermissionLevel::Ask),
28 "deny" => Some(ToolPermissionLevel::Deny),
29 _ => None,
30 }
31 }
32
33 pub fn as_str(&self) -> &'static str {
35 match self {
36 ToolPermissionLevel::Allow => "allow",
37 ToolPermissionLevel::Ask => "ask",
38 ToolPermissionLevel::Deny => "deny",
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct ToolPermissionDecision {
46 pub tool_id: String,
48 pub level: ToolPermissionLevel,
50 pub agent_id: Option<String>,
52 pub is_override: bool,
54 pub reason: String,
56}
57
58impl ToolPermissionDecision {
59 pub fn new(tool_id: String, level: ToolPermissionLevel, reason: String) -> Self {
61 Self {
62 tool_id,
63 level,
64 agent_id: None,
65 is_override: false,
66 reason,
67 }
68 }
69
70 pub fn with_agent(mut self, agent_id: String) -> Self {
72 self.agent_id = Some(agent_id);
73 self
74 }
75
76 pub fn as_override(mut self) -> Self {
78 self.is_override = true;
79 self
80 }
81}
82
83pub trait ToolPermissionChecker: Send + Sync {
87 fn check_permission(&self, tool_id: &str, agent_id: Option<&str>) -> Result<ToolPermissionDecision>;
98
99 fn is_allowed(&self, tool_id: &str, agent_id: Option<&str>) -> Result<bool> {
101 let decision = self.check_permission(tool_id, agent_id)?;
102 Ok(decision.level == ToolPermissionLevel::Allow)
103 }
104
105 fn requires_prompt(&self, tool_id: &str, agent_id: Option<&str>) -> Result<bool> {
107 let decision = self.check_permission(tool_id, agent_id)?;
108 Ok(decision.level == ToolPermissionLevel::Ask)
109 }
110
111 fn is_denied(&self, tool_id: &str, agent_id: Option<&str>) -> Result<bool> {
113 let decision = self.check_permission(tool_id, agent_id)?;
114 Ok(decision.level == ToolPermissionLevel::Deny)
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct ToolPermissionPrompt {
121 pub tool_id: String,
123 pub tool_name: String,
125 pub tool_description: String,
127 pub parameters: HashMap<String, Value>,
129 pub agent_id: Option<String>,
131}
132
133impl ToolPermissionPrompt {
134 pub fn new(
136 tool_id: String,
137 tool_name: String,
138 tool_description: String,
139 parameters: HashMap<String, Value>,
140 ) -> Self {
141 Self {
142 tool_id,
143 tool_name,
144 tool_description,
145 parameters,
146 agent_id: None,
147 }
148 }
149
150 pub fn with_agent(mut self, agent_id: String) -> Self {
152 self.agent_id = Some(agent_id);
153 self
154 }
155
156 pub fn format_message(&self) -> String {
158 let mut msg = format!(
159 "Tool Execution Request\n\
160 =====================\n\
161 Tool: {} ({})\n\
162 Description: {}\n",
163 self.tool_name, self.tool_id, self.tool_description
164 );
165
166 if let Some(agent_id) = &self.agent_id {
167 msg.push_str(&format!("Requested by: {}\n", agent_id));
168 }
169
170 if !self.parameters.is_empty() {
171 msg.push_str("\nParameters:\n");
172 for (key, value) in &self.parameters {
173 msg.push_str(&format!(" {}: {}\n", key, value));
174 }
175 }
176
177 msg.push_str("\nAllow execution? (yes/no)");
178 msg
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq)]
184pub enum UserPermissionDecision {
185 Approved,
187 Denied,
189 Cancelled,
191}
192
193pub struct ToolPermissionEnforcer {
198 checker: Arc<dyn ToolPermissionChecker>,
199}
200
201impl ToolPermissionEnforcer {
202 pub fn new(checker: Arc<dyn ToolPermissionChecker>) -> Self {
204 Self { checker }
205 }
206
207 pub fn can_execute(&self, tool_id: &str, agent_id: Option<&str>) -> Result<bool> {
209 self.checker.is_allowed(tool_id, agent_id)
210 }
211
212 pub fn requires_user_prompt(&self, tool_id: &str, agent_id: Option<&str>) -> Result<bool> {
214 self.checker.requires_prompt(tool_id, agent_id)
215 }
216
217 pub fn is_execution_denied(&self, tool_id: &str, agent_id: Option<&str>) -> Result<bool> {
219 self.checker.is_denied(tool_id, agent_id)
220 }
221
222 pub fn get_decision(&self, tool_id: &str, agent_id: Option<&str>) -> Result<ToolPermissionDecision> {
224 self.checker.check_permission(tool_id, agent_id)
225 }
226
227 pub fn log_decision(&self, decision: &ToolPermissionDecision) {
229 let agent_info = decision
230 .agent_id
231 .as_ref()
232 .map(|id| format!(" (agent: {})", id))
233 .unwrap_or_default();
234
235 let override_info = if decision.is_override {
236 " [OVERRIDE]"
237 } else {
238 ""
239 };
240
241 tracing::info!(
242 "Tool permission: {} - {} for tool {}{}{}",
243 decision.level.as_str(),
244 decision.reason,
245 decision.tool_id,
246 agent_info,
247 override_info
248 );
249 }
250
251 pub fn log_denial(&self, tool_id: &str, agent_id: Option<&str>, reason: &str) {
253 let agent_info = agent_id
254 .map(|id| format!(" (agent: {})", id))
255 .unwrap_or_default();
256
257 tracing::warn!(
258 "Tool execution denied: {} for tool {}{}",
259 reason,
260 tool_id,
261 agent_info
262 );
263 }
264}
265
266pub struct PermissionAwareToolExecution {
270 enforcer: Arc<ToolPermissionEnforcer>,
271}
272
273impl PermissionAwareToolExecution {
274 pub fn new(enforcer: Arc<ToolPermissionEnforcer>) -> Self {
276 Self { enforcer }
277 }
278
279 pub async fn check_and_execute<F, T>(
281 &self,
282 tool_id: &str,
283 agent_id: Option<&str>,
284 execute_fn: F,
285 ) -> Result<T>
286 where
287 F: FnOnce() -> Result<T>,
288 {
289 let decision = self.enforcer.get_decision(tool_id, agent_id)?;
291 self.enforcer.log_decision(&decision);
292
293 match decision.level {
294 ToolPermissionLevel::Allow => {
295 execute_fn()
297 }
298 ToolPermissionLevel::Ask => {
299 Err(crate::error::Error::PermissionDenied(format!(
302 "Tool execution requires user approval: {}",
303 tool_id
304 )))
305 }
306 ToolPermissionLevel::Deny => {
307 self.enforcer.log_denial(tool_id, agent_id, "Permission denied");
308 Err(crate::error::Error::PermissionDenied(format!(
309 "Tool execution denied: {}",
310 tool_id
311 )))
312 }
313 }
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 struct MockPermissionChecker;
322
323 impl ToolPermissionChecker for MockPermissionChecker {
324 fn check_permission(
325 &self,
326 tool_id: &str,
327 agent_id: Option<&str>,
328 ) -> Result<ToolPermissionDecision> {
329 let level = if tool_id.contains("allowed") {
330 ToolPermissionLevel::Allow
331 } else if tool_id.contains("ask") {
332 ToolPermissionLevel::Ask
333 } else {
334 ToolPermissionLevel::Deny
335 };
336
337 Ok(ToolPermissionDecision::new(
338 tool_id.to_string(),
339 level,
340 "Mock decision".to_string(),
341 )
342 .with_agent(agent_id.unwrap_or("unknown").to_string()))
343 }
344 }
345
346 #[test]
347 fn test_tool_permission_level_conversion() {
348 assert_eq!(
349 ToolPermissionLevel::from_permission_level("allow"),
350 Some(ToolPermissionLevel::Allow)
351 );
352 assert_eq!(
353 ToolPermissionLevel::from_permission_level("ask"),
354 Some(ToolPermissionLevel::Ask)
355 );
356 assert_eq!(
357 ToolPermissionLevel::from_permission_level("deny"),
358 Some(ToolPermissionLevel::Deny)
359 );
360 assert_eq!(ToolPermissionLevel::from_permission_level("invalid"), None);
361 }
362
363 #[test]
364 fn test_tool_permission_level_as_str() {
365 assert_eq!(ToolPermissionLevel::Allow.as_str(), "allow");
366 assert_eq!(ToolPermissionLevel::Ask.as_str(), "ask");
367 assert_eq!(ToolPermissionLevel::Deny.as_str(), "deny");
368 }
369
370 #[test]
371 fn test_tool_permission_decision_creation() {
372 let decision = ToolPermissionDecision::new(
373 "tool-1".to_string(),
374 ToolPermissionLevel::Allow,
375 "Test reason".to_string(),
376 );
377
378 assert_eq!(decision.tool_id, "tool-1");
379 assert_eq!(decision.level, ToolPermissionLevel::Allow);
380 assert_eq!(decision.reason, "Test reason");
381 assert!(decision.agent_id.is_none());
382 assert!(!decision.is_override);
383 }
384
385 #[test]
386 fn test_tool_permission_decision_with_agent() {
387 let decision = ToolPermissionDecision::new(
388 "tool-1".to_string(),
389 ToolPermissionLevel::Allow,
390 "Test reason".to_string(),
391 )
392 .with_agent("agent-1".to_string())
393 .as_override();
394
395 assert_eq!(decision.agent_id, Some("agent-1".to_string()));
396 assert!(decision.is_override);
397 }
398
399 #[test]
400 fn test_tool_permission_prompt_creation() {
401 let mut params = HashMap::new();
402 params.insert("param1".to_string(), serde_json::json!("value1"));
403
404 let prompt = ToolPermissionPrompt::new(
405 "tool-1".to_string(),
406 "Tool 1".to_string(),
407 "A test tool".to_string(),
408 params,
409 );
410
411 assert_eq!(prompt.tool_id, "tool-1");
412 assert_eq!(prompt.tool_name, "Tool 1");
413 assert_eq!(prompt.tool_description, "A test tool");
414 assert!(prompt.agent_id.is_none());
415 }
416
417 #[test]
418 fn test_tool_permission_prompt_format_message() {
419 let mut params = HashMap::new();
420 params.insert("param1".to_string(), serde_json::json!("value1"));
421
422 let prompt = ToolPermissionPrompt::new(
423 "tool-1".to_string(),
424 "Tool 1".to_string(),
425 "A test tool".to_string(),
426 params,
427 )
428 .with_agent("agent-1".to_string());
429
430 let message = prompt.format_message();
431 assert!(message.contains("Tool 1"));
432 assert!(message.contains("tool-1"));
433 assert!(message.contains("A test tool"));
434 assert!(message.contains("agent-1"));
435 assert!(message.contains("param1"));
436 }
437
438 #[test]
439 fn test_tool_permission_enforcer_creation() {
440 let checker: Arc<dyn ToolPermissionChecker> = Arc::new(MockPermissionChecker);
441 let enforcer = ToolPermissionEnforcer::new(checker);
442
443 assert!(enforcer.can_execute("allowed-tool", None).is_ok());
444 }
445
446 #[test]
447 fn test_tool_permission_enforcer_can_execute() {
448 let checker: Arc<dyn ToolPermissionChecker> = Arc::new(MockPermissionChecker);
449 let enforcer = ToolPermissionEnforcer::new(checker);
450
451 assert!(enforcer.can_execute("allowed-tool", None).unwrap());
452 assert!(!enforcer.can_execute("denied-tool", None).unwrap());
453 }
454
455 #[test]
456 fn test_tool_permission_enforcer_requires_prompt() {
457 let checker: Arc<dyn ToolPermissionChecker> = Arc::new(MockPermissionChecker);
458 let enforcer = ToolPermissionEnforcer::new(checker);
459
460 assert!(enforcer.requires_user_prompt("ask-tool", None).unwrap());
461 assert!(!enforcer.requires_user_prompt("allowed-tool", None).unwrap());
462 }
463
464 #[test]
465 fn test_tool_permission_enforcer_is_denied() {
466 let checker: Arc<dyn ToolPermissionChecker> = Arc::new(MockPermissionChecker);
467 let enforcer = ToolPermissionEnforcer::new(checker);
468
469 assert!(enforcer.is_execution_denied("denied-tool", None).unwrap());
470 assert!(!enforcer.is_execution_denied("allowed-tool", None).unwrap());
471 }
472
473 #[tokio::test]
474 async fn test_permission_aware_tool_execution_allowed() {
475 let checker: Arc<dyn ToolPermissionChecker> = Arc::new(MockPermissionChecker);
476 let enforcer = Arc::new(ToolPermissionEnforcer::new(checker));
477 let execution = PermissionAwareToolExecution::new(enforcer);
478
479 let result = execution
480 .check_and_execute("allowed-tool", None, || Ok(42))
481 .await;
482
483 assert!(result.is_ok());
484 assert_eq!(result.unwrap(), 42);
485 }
486
487 #[tokio::test]
488 async fn test_permission_aware_tool_execution_denied() {
489 let checker: Arc<dyn ToolPermissionChecker> = Arc::new(MockPermissionChecker);
490 let enforcer = Arc::new(ToolPermissionEnforcer::new(checker));
491 let execution = PermissionAwareToolExecution::new(enforcer);
492
493 let result = execution
494 .check_and_execute("denied-tool", None, || Ok(42))
495 .await;
496
497 assert!(result.is_err());
498 }
499}