use crate::error::Result;
use crate::types::{
AsyncHookOutput, HookContext, HookInput, HookOutput, PermissionResult, SyncHookOutput,
ToolPermissionContext,
};
use async_trait::async_trait;
use serde_json::Value;
use std::future::Future;
use std::pin::Pin;
#[async_trait]
pub trait HookCallback: Send + Sync {
async fn call(
&self,
input: HookInput,
tool_use_id: Option<String>,
context: HookContext,
) -> Result<HookOutput>;
}
#[async_trait]
pub trait PermissionCallback: Send + Sync {
async fn call(
&self,
tool_name: String,
input: Value,
context: ToolPermissionContext,
) -> Result<PermissionResult>;
}
pub struct ClosureHook<F>
where
F: Fn(
HookInput,
Option<String>,
HookContext,
) -> Pin<Box<dyn Future<Output = Result<HookOutput>> + Send>>
+ Send
+ Sync,
{
func: F,
}
impl<F> ClosureHook<F>
where
F: Fn(
HookInput,
Option<String>,
HookContext,
) -> Pin<Box<dyn Future<Output = Result<HookOutput>> + Send>>
+ Send
+ Sync,
{
pub fn new(func: F) -> Self {
Self { func }
}
}
#[async_trait]
impl<F> HookCallback for ClosureHook<F>
where
F: Fn(
HookInput,
Option<String>,
HookContext,
) -> Pin<Box<dyn Future<Output = Result<HookOutput>> + Send>>
+ Send
+ Sync,
{
async fn call(
&self,
input: HookInput,
tool_use_id: Option<String>,
context: HookContext,
) -> Result<HookOutput> {
(self.func)(input, tool_use_id, context).await
}
}
pub struct ClosurePermission<F>
where
F: Fn(
String,
Value,
ToolPermissionContext,
) -> Pin<Box<dyn Future<Output = Result<PermissionResult>> + Send>>
+ Send
+ Sync,
{
func: F,
}
impl<F> ClosurePermission<F>
where
F: Fn(
String,
Value,
ToolPermissionContext,
) -> Pin<Box<dyn Future<Output = Result<PermissionResult>> + Send>>
+ Send
+ Sync,
{
pub fn new(func: F) -> Self {
Self { func }
}
}
#[async_trait]
impl<F> PermissionCallback for ClosurePermission<F>
where
F: Fn(
String,
Value,
ToolPermissionContext,
) -> Pin<Box<dyn Future<Output = Result<PermissionResult>> + Send>>
+ Send
+ Sync,
{
async fn call(
&self,
tool_name: String,
input: Value,
context: ToolPermissionContext,
) -> Result<PermissionResult> {
(self.func)(tool_name, input, context).await
}
}
pub mod hooks {
use super::*;
pub fn allow() -> HookOutput {
HookOutput::Sync(Box::new(SyncHookOutput {
continue_: Some(true),
suppress_output: None,
stop_reason: None,
decision: None,
system_message: None,
reason: None,
hook_specific_output: None,
}))
}
pub fn block(reason: impl Into<String>) -> HookOutput {
HookOutput::Sync(Box::new(SyncHookOutput {
continue_: Some(false),
suppress_output: None,
stop_reason: Some(reason.into()),
decision: Some("block".to_string()),
system_message: None,
reason: None,
hook_specific_output: None,
}))
}
pub fn allow_with_message(message: impl Into<String>) -> HookOutput {
HookOutput::Sync(Box::new(SyncHookOutput {
continue_: Some(true),
suppress_output: None,
stop_reason: None,
decision: None,
system_message: Some(message.into()),
reason: None,
hook_specific_output: None,
}))
}
pub fn defer(timeout_ms: Option<u32>) -> HookOutput {
HookOutput::Async(AsyncHookOutput {
async_: true,
async_timeout: timeout_ms,
})
}
}
pub mod permissions {
use super::*;
use crate::types::{PermissionResultAllow, PermissionResultDeny};
pub fn allow() -> PermissionResult {
PermissionResult::Allow(PermissionResultAllow {
behavior: "allow".to_string(),
updated_input: None,
updated_permissions: None,
})
}
pub fn allow_with_input(updated_input: Value) -> PermissionResult {
PermissionResult::Allow(PermissionResultAllow {
behavior: "allow".to_string(),
updated_input: Some(updated_input),
updated_permissions: None,
})
}
pub fn deny(message: impl Into<String>) -> PermissionResult {
PermissionResult::Deny(PermissionResultDeny {
behavior: "deny".to_string(),
message: message.into(),
interrupt: false,
})
}
pub fn deny_and_interrupt(message: impl Into<String>) -> PermissionResult {
PermissionResult::Deny(PermissionResultDeny {
behavior: "deny".to_string(),
message: message.into(),
interrupt: true,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_helpers() {
let allow = hooks::allow();
assert!(matches!(allow, HookOutput::Sync(_)));
let block = hooks::block("Dangerous operation");
assert!(matches!(block, HookOutput::Sync(_)));
let defer = hooks::defer(Some(5000));
assert!(matches!(defer, HookOutput::Async(_)));
}
#[test]
fn test_permission_helpers() {
let allow = permissions::allow();
assert!(matches!(allow, PermissionResult::Allow(_)));
let deny = permissions::deny("Not allowed");
assert!(matches!(deny, PermissionResult::Deny(_)));
}
}