mod builtin;
mod registry;
pub use builtin::{LoggingHook, SecurityHook};
pub use registry::HookRegistry;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HookResult {
Continue,
ContinueWith(serde_json::Value),
Skip { reason: String },
Deny { reason: String },
Abort { reason: String },
}
impl HookResult {
pub fn should_continue(&self) -> bool {
matches!(self, HookResult::Continue | HookResult::ContinueWith(_))
}
pub fn is_denied(&self) -> bool {
matches!(self, HookResult::Deny { .. })
}
pub fn is_aborted(&self) -> bool {
matches!(self, HookResult::Abort { .. })
}
}
#[derive(Debug, Clone)]
pub struct HookContext {
pub agent_id: String,
pub session_id: Option<String>,
pub task_id: Option<String>,
pub working_directory: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl HookContext {
pub fn new(agent_id: impl Into<String>) -> Self {
Self {
agent_id: agent_id.into(),
session_id: None,
task_id: None,
working_directory: None,
metadata: HashMap::new(),
}
}
pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
pub fn with_task(mut self, task_id: impl Into<String>) -> Self {
self.task_id = Some(task_id.into());
self
}
pub fn with_working_directory(mut self, dir: impl Into<String>) -> Self {
self.working_directory = Some(dir.into());
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreExecutionInput {
pub task_description: String,
pub task_type: String,
pub priority: String,
pub details: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PostExecutionInput {
pub task_description: String,
pub success: bool,
pub output: serde_json::Value,
pub error: Option<String>,
pub duration_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnErrorInput {
pub error_message: String,
pub error_type: String,
pub is_recoverable: bool,
pub stack_trace: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreToolUseInput {
pub tool_name: String,
pub arguments: serde_json::Value,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PostToolUseInput {
pub tool_name: String,
pub arguments: serde_json::Value,
pub success: bool,
pub result: serde_json::Value,
pub error: Option<String>,
pub duration_ms: u64,
}
#[async_trait]
pub trait ExecutionHooks: Send + Sync {
async fn pre_execution(&self, input: PreExecutionInput, ctx: HookContext) -> HookResult;
async fn post_execution(&self, input: PostExecutionInput, ctx: HookContext) -> HookResult;
async fn on_error(&self, input: OnErrorInput, ctx: HookContext) -> HookResult;
fn name(&self) -> &str;
fn priority(&self) -> i32 {
0
}
}
#[async_trait]
pub trait ToolHooks: Send + Sync {
async fn pre_tool_use(&self, input: PreToolUseInput, ctx: HookContext) -> HookResult;
async fn post_tool_use(&self, input: PostToolUseInput, ctx: HookContext) -> HookResult;
fn name(&self) -> &str;
fn priority(&self) -> i32 {
0
}
}
#[async_trait]
pub trait AllHooks: ExecutionHooks + ToolHooks {}
impl<T: ExecutionHooks + ToolHooks> AllHooks for T {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_result_checks() {
assert!(HookResult::Continue.should_continue());
assert!(HookResult::ContinueWith(serde_json::json!({})).should_continue());
assert!(
!HookResult::Skip {
reason: "test".into()
}
.should_continue()
);
assert!(
HookResult::Deny {
reason: "test".into()
}
.is_denied()
);
assert!(
HookResult::Abort {
reason: "test".into()
}
.is_aborted()
);
}
#[test]
fn test_hook_context_builder() {
let ctx = HookContext::new("agent-1")
.with_session("session-1")
.with_task("task-1")
.with_working_directory("/tmp")
.with_metadata("key", serde_json::json!("value"));
assert_eq!(ctx.agent_id, "agent-1");
assert_eq!(ctx.session_id, Some("session-1".to_string()));
assert_eq!(ctx.task_id, Some("task-1".to_string()));
assert_eq!(ctx.working_directory, Some("/tmp".to_string()));
assert!(ctx.metadata.contains_key("key"));
}
}