use crate::error::ClawError;
use crate::hooks::{HookContext, HookInput, HookResponse};
use async_trait::async_trait;
#[async_trait]
pub trait HookCallback: Send + Sync {
async fn call(
&self,
input: HookInput,
tool_use_id: Option<&str>,
context: &HookContext,
) -> Result<HookResponse, ClawError>;
}
#[async_trait]
impl<F, Fut> HookCallback for F
where
F: Fn(HookInput, Option<&str>, &HookContext) -> Fut + Send + Sync,
Fut: Future<Output = Result<HookResponse, ClawError>> + Send,
{
async fn call(
&self,
input: HookInput,
tool_use_id: Option<&str>,
context: &HookContext,
) -> Result<HookResponse, ClawError> {
self(input, tool_use_id, context).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hooks::PermissionDecision;
use serde_json::json;
struct TestHook {
allowed: bool,
}
#[async_trait]
impl HookCallback for TestHook {
async fn call(
&self,
_input: HookInput,
_tool_use_id: Option<&str>,
_context: &HookContext,
) -> Result<HookResponse, ClawError> {
if self.allowed {
Ok(HookResponse::allow("Test allowed"))
} else {
Ok(HookResponse::deny("Test denied"))
}
}
}
#[tokio::test]
async fn test_struct_implementation() {
let hook = TestHook { allowed: true };
let input = HookInput::default();
let context = HookContext::default();
let response = hook.call(input, None, &context).await.unwrap();
assert!(matches!(
response.permission_decision,
Some(PermissionDecision::Allow)
));
}
#[tokio::test]
async fn test_closure_implementation() {
async fn my_hook(
input: HookInput,
_tool_use_id: Option<&str>,
_context: &HookContext,
) -> Result<HookResponse, ClawError> {
if let Some(tool_name) = &input.tool_name
&& tool_name == "Bash"
{
return Ok(HookResponse::deny("Bash not allowed"));
}
Ok(HookResponse::allow("OK"))
}
let input = HookInput::tool_use("Bash", json!({}));
let context = HookContext::default();
let response = my_hook(input, None, &context).await.unwrap();
assert!(matches!(
response.permission_decision,
Some(PermissionDecision::Deny)
));
}
#[tokio::test]
async fn test_hook_with_tool_use_id() {
async fn check_tool_id(
_input: HookInput,
tool_use_id: Option<&str>,
_context: &HookContext,
) -> Result<HookResponse, ClawError> {
if let Some(id) = tool_use_id
&& id == "dangerous-id"
{
return Ok(HookResponse::deny("Dangerous tool use ID"));
}
Ok(HookResponse::allow("Safe"))
}
let input = HookInput::default();
let context = HookContext::default();
let response = check_tool_id(input.clone(), Some("dangerous-id"), &context)
.await
.unwrap();
assert!(matches!(
response.permission_decision,
Some(PermissionDecision::Deny)
));
let response = check_tool_id(input, Some("safe-id"), &context)
.await
.unwrap();
assert!(matches!(
response.permission_decision,
Some(PermissionDecision::Allow)
));
}
#[tokio::test]
async fn test_hook_with_context() {
async fn check_tools(
_input: HookInput,
_tool_use_id: Option<&str>,
context: &HookContext,
) -> Result<HookResponse, ClawError> {
if let Some(tools) = &context.available_tools
&& tools.contains(&"Bash".to_string())
{
return Ok(HookResponse::allow("Bash is available"));
}
Ok(HookResponse::deny("Bash not available"))
}
let input = HookInput::default();
let context = HookContext::with_session("test").with_tools(vec!["Bash".to_string()]);
let response = check_tools(input, None, &context).await.unwrap();
assert!(matches!(
response.permission_decision,
Some(PermissionDecision::Allow)
));
}
}