use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookEvent {
BeforeAgentStart,
AfterAgentEnd,
BeforeAgentStep,
AfterAgentStep,
BeforeToolCall,
AfterToolCall,
BeforeLlmRequest,
AfterLlmResponse,
BeforeKnowledgeSearch,
AfterKnowledgeSearch,
}
impl std::fmt::Display for HookEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
HookEvent::BeforeAgentStart => "before_agent_start",
HookEvent::AfterAgentEnd => "after_agent_end",
HookEvent::BeforeAgentStep => "before_agent_step",
HookEvent::AfterAgentStep => "after_agent_step",
HookEvent::BeforeToolCall => "before_tool_call",
HookEvent::AfterToolCall => "after_tool_call",
HookEvent::BeforeLlmRequest => "before_llm_request",
HookEvent::AfterLlmResponse => "after_llm_response",
HookEvent::BeforeKnowledgeSearch => "before_knowledge_search",
HookEvent::AfterKnowledgeSearch => "after_knowledge_search",
};
write!(f, "{}", s)
}
}
#[derive(Debug, Clone)]
pub struct HookContext {
pub event: HookEvent,
pub data: serde_json::Value,
pub correlation_id: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl HookContext {
pub fn new(event: HookEvent) -> Self {
Self {
event,
data: serde_json::Value::Null,
correlation_id: None,
metadata: HashMap::new(),
}
}
pub fn data(mut self, data: serde_json::Value) -> Self {
self.data = data;
self
}
pub fn correlation_id(mut self, id: impl Into<String>) -> Self {
self.correlation_id = Some(id.into());
self
}
pub fn metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn get_str(&self, key: &str) -> Option<&str> {
self.data.get(key).and_then(|v| v.as_str())
}
pub fn get_i64(&self, key: &str) -> Option<i64> {
self.data.get(key).and_then(|v| v.as_i64())
}
}
#[derive(Debug, Clone, Default)]
pub enum HookResult {
#[default]
Continue,
ContinueWith(serde_json::Value),
Skip,
SkipWith(serde_json::Value),
Abort(String),
}
impl HookResult {
pub fn should_continue(&self) -> bool {
matches!(self, HookResult::Continue | HookResult::ContinueWith(_))
}
pub fn should_skip(&self) -> bool {
matches!(self, HookResult::Skip | HookResult::SkipWith(_))
}
pub fn is_abort(&self) -> bool {
matches!(self, HookResult::Abort(_))
}
pub fn modified_data(&self) -> Option<&serde_json::Value> {
match self {
HookResult::ContinueWith(data) | HookResult::SkipWith(data) => Some(data),
_ => None,
}
}
pub fn error_message(&self) -> Option<&str> {
match self {
HookResult::Abort(msg) => Some(msg),
_ => None,
}
}
}
pub trait Hook: Send + Sync {
fn name(&self) -> &str;
fn priority(&self) -> i32 {
100
}
fn on_before_agent_start(&self, _ctx: &HookContext) -> HookResult {
HookResult::Continue
}
fn on_after_agent_end(&self, _ctx: &HookContext) -> HookResult {
HookResult::Continue
}
fn on_before_agent_step(&self, _ctx: &HookContext) -> HookResult {
HookResult::Continue
}
fn on_after_agent_step(&self, _ctx: &HookContext) -> HookResult {
HookResult::Continue
}
fn on_before_tool_call(&self, _ctx: &HookContext) -> HookResult {
HookResult::Continue
}
fn on_after_tool_call(&self, _ctx: &HookContext) -> HookResult {
HookResult::Continue
}
fn on_before_llm_request(&self, _ctx: &HookContext) -> HookResult {
HookResult::Continue
}
fn on_after_llm_response(&self, _ctx: &HookContext) -> HookResult {
HookResult::Continue
}
fn on_before_knowledge_search(&self, _ctx: &HookContext) -> HookResult {
HookResult::Continue
}
fn on_after_knowledge_search(&self, _ctx: &HookContext) -> HookResult {
HookResult::Continue
}
fn on_event(&self, ctx: &HookContext) -> HookResult {
match ctx.event {
HookEvent::BeforeAgentStart => self.on_before_agent_start(ctx),
HookEvent::AfterAgentEnd => self.on_after_agent_end(ctx),
HookEvent::BeforeAgentStep => self.on_before_agent_step(ctx),
HookEvent::AfterAgentStep => self.on_after_agent_step(ctx),
HookEvent::BeforeToolCall => self.on_before_tool_call(ctx),
HookEvent::AfterToolCall => self.on_after_tool_call(ctx),
HookEvent::BeforeLlmRequest => self.on_before_llm_request(ctx),
HookEvent::AfterLlmResponse => self.on_after_llm_response(ctx),
HookEvent::BeforeKnowledgeSearch => self.on_before_knowledge_search(ctx),
HookEvent::AfterKnowledgeSearch => self.on_after_knowledge_search(ctx),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestHook;
impl Hook for TestHook {
fn name(&self) -> &str {
"test"
}
fn on_before_tool_call(&self, ctx: &HookContext) -> HookResult {
if ctx.get_str("tool") == Some("blocked") {
HookResult::Skip
} else {
HookResult::Continue
}
}
}
#[test]
fn test_hook_context() {
let ctx = HookContext::new(HookEvent::BeforeToolCall)
.data(serde_json::json!({"tool": "calculator"}))
.correlation_id("corr-123");
assert_eq!(ctx.event, HookEvent::BeforeToolCall);
assert_eq!(ctx.get_str("tool"), Some("calculator"));
assert_eq!(ctx.correlation_id, Some("corr-123".to_string()));
}
#[test]
fn test_hook_result() {
assert!(HookResult::Continue.should_continue());
assert!(HookResult::ContinueWith(serde_json::json!({})).should_continue());
assert!(HookResult::Skip.should_skip());
assert!(HookResult::Abort("error".to_string()).is_abort());
}
#[test]
fn test_hook_dispatch() {
let hook = TestHook;
let ctx = HookContext::new(HookEvent::BeforeToolCall)
.data(serde_json::json!({"tool": "calculator"}));
assert!(hook.on_event(&ctx).should_continue());
let ctx = HookContext::new(HookEvent::BeforeToolCall)
.data(serde_json::json!({"tool": "blocked"}));
assert!(hook.on_event(&ctx).should_skip());
}
}