use std::sync::Arc;
use crate::error::Result;
use crate::types::{
CanUseToolCallback, PermissionResult, PermissionResultAllow, ToolName, ToolPermissionContext,
};
pub struct PermissionManager {
callback: Option<CanUseToolCallback>,
allowed_tools: Option<Vec<ToolName>>,
disallowed_tools: Vec<ToolName>,
}
impl PermissionManager {
pub fn new() -> Self {
Self {
callback: None,
allowed_tools: None,
disallowed_tools: Vec::new(),
}
}
pub fn set_callback(&mut self, callback: CanUseToolCallback) {
self.callback = Some(callback);
}
pub fn set_allowed_tools(&mut self, tools: Option<Vec<ToolName>>) {
self.allowed_tools = tools;
}
pub fn set_disallowed_tools(&mut self, tools: Vec<ToolName>) {
self.disallowed_tools = tools;
}
pub async fn can_use_tool(
&self,
tool_name: ToolName,
tool_input: serde_json::Value,
context: ToolPermissionContext,
) -> Result<PermissionResult> {
if self.disallowed_tools.contains(&tool_name) {
return Ok(PermissionResult::Deny(crate::types::PermissionResultDeny {
message: format!("Tool {} is disallowed", tool_name.as_str()),
interrupt: false,
}));
}
if let Some(ref allowed) = self.allowed_tools {
if !allowed.contains(&tool_name) {
return Ok(PermissionResult::Deny(
crate::types::PermissionResultDeny {
message: format!("Tool {} is not in allowed list", tool_name.as_str()),
interrupt: false,
},
));
}
}
if let Some(ref callback) = self.callback {
callback(tool_name, tool_input, context).await
} else {
Ok(PermissionResult::Allow(PermissionResultAllow {
updated_input: None,
updated_permissions: None,
}))
}
}
pub fn callback<F, Fut>(f: F) -> CanUseToolCallback
where
F: Fn(ToolName, serde_json::Value, ToolPermissionContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<PermissionResult>> + Send + 'static,
{
Arc::new(move |tool_name, tool_input, context| Box::pin(f(tool_name, tool_input, context)))
}
}
impl Default for PermissionManager {
fn default() -> Self {
Self::new()
}
}
pub struct PermissionManagerBuilder {
callback: Option<CanUseToolCallback>,
allowed_tools: Option<Vec<ToolName>>,
disallowed_tools: Vec<ToolName>,
}
impl PermissionManagerBuilder {
pub fn new() -> Self {
Self {
callback: None,
allowed_tools: None,
disallowed_tools: Vec::new(),
}
}
pub fn callback(mut self, callback: CanUseToolCallback) -> Self {
self.callback = Some(callback);
self
}
pub fn allowed_tools(mut self, tools: Vec<ToolName>) -> Self {
self.allowed_tools = Some(tools);
self
}
pub fn disallowed_tools(mut self, tools: Vec<ToolName>) -> Self {
self.disallowed_tools = tools;
self
}
pub fn build(self) -> PermissionManager {
PermissionManager {
callback: self.callback,
allowed_tools: self.allowed_tools,
disallowed_tools: self.disallowed_tools,
}
}
}
impl Default for PermissionManagerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_permission_manager_default_allow() {
let manager = PermissionManager::new();
let result = manager
.can_use_tool(
ToolName::new("test_tool"),
serde_json::json!({}),
ToolPermissionContext {
suggestions: vec![],
},
)
.await
.unwrap();
match result {
PermissionResult::Allow(_) => {}
PermissionResult::Deny(_) => panic!("Expected allow"),
}
}
#[tokio::test]
async fn test_permission_manager_disallowed() {
let mut manager = PermissionManager::new();
manager.set_disallowed_tools(vec![ToolName::new("bad_tool")]);
let result = manager
.can_use_tool(
ToolName::new("bad_tool"),
serde_json::json!({}),
ToolPermissionContext {
suggestions: vec![],
},
)
.await
.unwrap();
match result {
PermissionResult::Allow(_) => panic!("Expected deny"),
PermissionResult::Deny(_) => {}
}
}
#[tokio::test]
async fn test_permission_manager_allowed_list() {
let mut manager = PermissionManager::new();
manager.set_allowed_tools(Some(vec![ToolName::new("good_tool")]));
let result = manager
.can_use_tool(
ToolName::new("good_tool"),
serde_json::json!({}),
ToolPermissionContext {
suggestions: vec![],
},
)
.await
.unwrap();
match result {
PermissionResult::Allow(_) => {}
PermissionResult::Deny(_) => panic!("Expected allow"),
}
let result = manager
.can_use_tool(
ToolName::new("other_tool"),
serde_json::json!({}),
ToolPermissionContext {
suggestions: vec![],
},
)
.await
.unwrap();
match result {
PermissionResult::Allow(_) => panic!("Expected deny"),
PermissionResult::Deny(_) => {}
}
}
}