use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
pub trait HookEventKind: Send + Sync + 'static {
type Input;
type Output;
}
pub struct OnPromptSubmit;
pub struct PreLlmRequest;
pub struct PreToolCall;
pub struct PostToolCall;
pub struct OnTurnEnd;
pub struct OnAbort;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OnPromptSubmitResult {
Continue,
Cancel(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PreLlmRequestResult {
Continue,
Cancel(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PreToolCallResult {
Continue,
Skip,
Abort(String),
Pause,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PostToolCallResult {
Continue,
Abort(String),
}
#[derive(Debug, Clone)]
pub enum OnTurnEndResult {
Finish,
ContinueWithMessages(Vec<crate::Message>),
Paused,
}
use std::sync::Arc;
use crate::tool::{Tool, ToolMeta};
pub struct ToolCallContext {
pub call: ToolCall,
pub meta: ToolMeta,
pub tool: Arc<dyn Tool>,
}
pub struct PostToolCallContext {
pub call: ToolCall,
pub result: ToolResult,
pub meta: ToolMeta,
pub tool: Arc<dyn Tool>,
}
impl HookEventKind for OnPromptSubmit {
type Input = crate::Message;
type Output = OnPromptSubmitResult;
}
impl HookEventKind for PreLlmRequest {
type Input = Vec<crate::Message>;
type Output = PreLlmRequestResult;
}
impl HookEventKind for PreToolCall {
type Input = ToolCallContext;
type Output = PreToolCallResult;
}
impl HookEventKind for PostToolCall {
type Input = PostToolCallContext;
type Output = PostToolCallResult;
}
impl HookEventKind for OnTurnEnd {
type Input = Vec<crate::Message>;
type Output = OnTurnEndResult;
}
impl HookEventKind for OnAbort {
type Input = String;
type Output = ();
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub input: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub tool_use_id: String,
pub content: String,
#[serde(default)]
pub is_error: bool,
}
impl ToolResult {
pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
tool_use_id: tool_use_id.into(),
content: content.into(),
is_error: false,
}
}
pub fn error(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
tool_use_id: tool_use_id.into(),
content: content.into(),
is_error: true,
}
}
}
#[derive(Debug, Error)]
pub enum HookError {
#[error("Aborted: {0}")]
Aborted(String),
#[error("Hook error: {0}")]
Internal(String),
}
#[async_trait]
pub trait Hook<E: HookEventKind>: Send + Sync {
async fn call(&self, input: &mut E::Input) -> Result<E::Output, HookError>;
}
pub struct HookRegistry {
pub(crate) on_prompt_submit: Vec<Box<dyn Hook<OnPromptSubmit>>>,
pub(crate) pre_llm_request: Vec<Box<dyn Hook<PreLlmRequest>>>,
pub(crate) pre_tool_call: Vec<Box<dyn Hook<PreToolCall>>>,
pub(crate) post_tool_call: Vec<Box<dyn Hook<PostToolCall>>>,
pub(crate) on_turn_end: Vec<Box<dyn Hook<OnTurnEnd>>>,
pub(crate) on_abort: Vec<Box<dyn Hook<OnAbort>>>,
}
impl Default for HookRegistry {
fn default() -> Self {
Self::new()
}
}
impl HookRegistry {
pub fn new() -> Self {
Self {
on_prompt_submit: Vec::new(),
pre_llm_request: Vec::new(),
pre_tool_call: Vec::new(),
post_tool_call: Vec::new(),
on_turn_end: Vec::new(),
on_abort: Vec::new(),
}
}
}