use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum HookPoint {
BeforeInbound,
BeforeToolCall,
BeforeOutbound,
OnSessionStart,
OnSessionEnd,
TransformResponse,
}
impl HookPoint {
pub fn as_str(&self) -> &'static str {
match self {
HookPoint::BeforeInbound => "beforeInbound",
HookPoint::BeforeToolCall => "beforeToolCall",
HookPoint::BeforeOutbound => "beforeOutbound",
HookPoint::OnSessionStart => "onSessionStart",
HookPoint::OnSessionEnd => "onSessionEnd",
HookPoint::TransformResponse => "transformResponse",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HookEvent {
Inbound {
user_id: String,
channel: String,
content: String,
thread_id: Option<String>,
},
ToolCall {
tool_name: String,
parameters: serde_json::Value,
user_id: String,
context: String,
},
Outbound {
user_id: String,
channel: String,
content: String,
thread_id: Option<String>,
},
SessionStart { user_id: String, session_id: String },
SessionEnd { user_id: String, session_id: String },
ResponseTransform {
user_id: String,
thread_id: String,
response: String,
},
}
impl HookEvent {
pub fn hook_point(&self) -> HookPoint {
match self {
HookEvent::Inbound { .. } => HookPoint::BeforeInbound,
HookEvent::ToolCall { .. } => HookPoint::BeforeToolCall,
HookEvent::Outbound { .. } => HookPoint::BeforeOutbound,
HookEvent::SessionStart { .. } => HookPoint::OnSessionStart,
HookEvent::SessionEnd { .. } => HookPoint::OnSessionEnd,
HookEvent::ResponseTransform { .. } => HookPoint::TransformResponse,
}
}
pub fn apply_modification(&mut self, modified: &str) {
match self {
HookEvent::Inbound { content, .. } | HookEvent::Outbound { content, .. } => {
*content = modified.to_string();
}
HookEvent::ToolCall { parameters, .. } => match serde_json::from_str(modified) {
Ok(parsed) => *parameters = parsed,
Err(e) => {
tracing::warn!(
"Hook returned non-JSON modification for ToolCall, ignoring: {}",
e
);
}
},
HookEvent::ResponseTransform { response, .. } => {
*response = modified.to_string();
}
HookEvent::SessionStart { .. } | HookEvent::SessionEnd { .. } => {
}
}
}
}
#[derive(Debug, Clone)]
pub enum HookOutcome {
Continue {
modified: Option<String>,
},
Reject {
reason: String,
},
}
impl HookOutcome {
pub fn ok() -> Self {
HookOutcome::Continue { modified: None }
}
pub fn modify(value: String) -> Self {
HookOutcome::Continue {
modified: Some(value),
}
}
pub fn reject(reason: impl Into<String>) -> Self {
HookOutcome::Reject {
reason: reason.into(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookFailureMode {
FailOpen,
FailClosed,
}
#[derive(Debug, thiserror::Error)]
pub enum HookError {
#[error("Hook execution failed: {reason}")]
ExecutionFailed { reason: String },
#[error("Hook timed out after {timeout:?}")]
Timeout { timeout: Duration },
#[error("Hook rejected: {reason}")]
Rejected { reason: String },
}
pub struct HookContext {
pub metadata: serde_json::Value,
}
impl Default for HookContext {
fn default() -> Self {
Self {
metadata: serde_json::Value::Null,
}
}
}
#[async_trait]
pub trait Hook: Send + Sync {
fn name(&self) -> &str;
fn hook_points(&self) -> &[HookPoint];
fn failure_mode(&self) -> HookFailureMode {
HookFailureMode::FailOpen
}
fn timeout(&self) -> Duration {
Duration::from_secs(5)
}
async fn execute(&self, event: &HookEvent, ctx: &HookContext)
-> Result<HookOutcome, HookError>;
}