use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use crate::error::ClawError;
use crate::options::HookEvent;
use crate::permissions::PermissionDecision;
#[async_trait]
pub trait CanUseToolHandler: Send + Sync {
async fn can_use_tool(
&self,
tool_name: &str,
tool_input: &Value,
) -> Result<PermissionDecision, ClawError>;
}
#[async_trait]
pub trait HookHandler: Send + Sync {
async fn call(&self, hook_event: HookEvent, hook_input: Value) -> Result<Value, ClawError>;
}
#[async_trait]
pub trait McpMessageHandler: Send + Sync {
async fn handle(&self, server_name: &str, message: Value) -> Result<Value, ClawError>;
}
#[derive(Default)]
pub struct ControlHandlers {
pub(crate) can_use_tool: Option<Arc<dyn CanUseToolHandler>>,
pub(crate) hook_callbacks: HashMap<String, Arc<dyn HookHandler>>,
pub(crate) mcp_message: Option<Arc<dyn McpMessageHandler>>,
}
impl ControlHandlers {
pub fn new() -> Self {
Self::default()
}
pub fn register_can_use_tool(&mut self, handler: Arc<dyn CanUseToolHandler>) {
self.can_use_tool = Some(handler);
}
pub fn register_hook(&mut self, hook_id: String, handler: Arc<dyn HookHandler>) {
self.hook_callbacks.insert(hook_id, handler);
}
pub fn register_mcp_message(&mut self, handler: Arc<dyn McpMessageHandler>) {
self.mcp_message = Some(handler);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::permissions::PermissionDecision;
use serde_json::json;
#[derive(Debug)]
struct MockCanUseToolHandler;
#[async_trait]
impl CanUseToolHandler for MockCanUseToolHandler {
async fn can_use_tool(
&self,
tool_name: &str,
_tool_input: &Value,
) -> Result<PermissionDecision, ClawError> {
if tool_name == "Read" {
Ok(PermissionDecision::Allow {
updated_input: None,
})
} else {
Ok(PermissionDecision::Deny { interrupt: false })
}
}
}
struct MockHookHandler;
#[async_trait]
impl HookHandler for MockHookHandler {
async fn call(
&self,
_hook_event: HookEvent,
hook_input: Value,
) -> Result<Value, ClawError> {
Ok(json!({ "echo": hook_input }))
}
}
struct MockMcpHandler;
#[async_trait]
impl McpMessageHandler for MockMcpHandler {
async fn handle(&self, server_name: &str, _message: Value) -> Result<Value, ClawError> {
Ok(json!({ "server": server_name }))
}
}
#[tokio::test]
async fn test_can_use_tool_handler() {
let handler = MockCanUseToolHandler;
assert!(matches!(
handler.can_use_tool("Read", &json!({})).await.unwrap(),
PermissionDecision::Allow { .. }
));
assert!(matches!(
handler.can_use_tool("Bash", &json!({})).await.unwrap(),
PermissionDecision::Deny { .. }
));
}
#[tokio::test]
async fn test_hook_handler() {
let handler = MockHookHandler;
let result = handler
.call(
crate::options::HookEvent::PreToolUse,
json!({ "foo": "bar" }),
)
.await
.unwrap();
assert_eq!(result["echo"]["foo"], "bar");
}
#[tokio::test]
async fn test_mcp_handler() {
let handler = MockMcpHandler;
let result = handler.handle("test_server", json!({})).await.unwrap();
assert_eq!(result["server"], "test_server");
}
#[test]
fn test_handlers_registry_default() {
let handlers = ControlHandlers::new();
assert!(handlers.can_use_tool.is_none());
assert!(handlers.hook_callbacks.is_empty());
assert!(handlers.mcp_message.is_none());
}
#[test]
fn test_handlers_register_can_use_tool() {
let mut handlers = ControlHandlers::new();
handlers.register_can_use_tool(Arc::new(MockCanUseToolHandler));
assert!(handlers.can_use_tool.is_some());
}
#[test]
fn test_handlers_register_hook() {
let mut handlers = ControlHandlers::new();
handlers.register_hook("hook1".to_string(), Arc::new(MockHookHandler));
handlers.register_hook("hook2".to_string(), Arc::new(MockHookHandler));
assert_eq!(handlers.hook_callbacks.len(), 2);
assert!(handlers.hook_callbacks.contains_key("hook1"));
assert!(handlers.hook_callbacks.contains_key("hook2"));
}
#[test]
fn test_handlers_register_mcp_message() {
let mut handlers = ControlHandlers::new();
handlers.register_mcp_message(Arc::new(MockMcpHandler));
assert!(handlers.mcp_message.is_some());
}
#[tokio::test]
async fn test_permission_decision_allow_with_updated_input() {
#[derive(Debug)]
struct SanitizingHandler;
#[async_trait]
impl CanUseToolHandler for SanitizingHandler {
async fn can_use_tool(
&self,
_tool_name: &str,
_tool_input: &Value,
) -> Result<PermissionDecision, ClawError> {
Ok(PermissionDecision::Allow {
updated_input: Some(json!({ "command": "echo safe" })),
})
}
}
let handler = SanitizingHandler;
let decision = handler
.can_use_tool("Bash", &json!({ "command": "rm -rf /" }))
.await
.unwrap();
match decision {
PermissionDecision::Allow { updated_input } => {
assert!(updated_input.is_some());
assert_eq!(updated_input.unwrap()["command"], "echo safe");
}
PermissionDecision::Deny { .. } => panic!("Expected Allow"),
}
}
#[tokio::test]
async fn test_permission_decision_deny_with_interrupt() {
#[derive(Debug)]
struct InterruptingHandler;
#[async_trait]
impl CanUseToolHandler for InterruptingHandler {
async fn can_use_tool(
&self,
_tool_name: &str,
_tool_input: &Value,
) -> Result<PermissionDecision, ClawError> {
Ok(PermissionDecision::Deny { interrupt: true })
}
}
let handler = InterruptingHandler;
let decision = handler.can_use_tool("Bash", &json!({})).await.unwrap();
match decision {
PermissionDecision::Deny { interrupt } => {
assert!(interrupt);
}
PermissionDecision::Allow { .. } => panic!("Expected Deny"),
}
}
}