1use async_trait::async_trait;
7use cersei_types::Message;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11#[async_trait]
14pub trait Hook: Send + Sync {
15 fn events(&self) -> &[HookEvent];
17
18 async fn on_event(&self, ctx: &HookContext) -> HookAction;
20
21 fn name(&self) -> &str {
23 "unnamed-hook"
24 }
25}
26
27#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
30#[serde(rename_all = "PascalCase")]
31pub enum HookEvent {
32 PreToolUse,
33 PostToolUse,
34 PreModelTurn,
35 PostModelTurn,
36 Stop,
37 Error,
38 TurnsElapsed,
43}
44
45#[derive(Debug, Clone)]
48pub struct HookContext {
49 pub event: HookEvent,
50 pub tool_name: Option<String>,
51 pub tool_input: Option<Value>,
52 pub tool_result: Option<String>,
53 pub tool_is_error: Option<bool>,
54 pub turn: u32,
55 pub cumulative_cost_usd: f64,
56 pub message_count: usize,
57}
58
59impl HookContext {
60 pub fn cumulative_cost_usd(&self) -> f64 {
61 self.cumulative_cost_usd
62 }
63}
64
65#[derive(Debug, Clone)]
68pub enum HookAction {
69 Continue,
71 Block(String),
73 ModifyInput(Value),
75 InjectMessage(Message),
77}
78
79pub struct ShellHook {
83 pub command: String,
84 pub hook_events: Vec<HookEvent>,
85 pub blocking: bool,
86 hook_name: String,
87}
88
89impl ShellHook {
90 pub fn new(command: impl Into<String>, events: &[HookEvent], blocking: bool) -> Self {
91 let cmd = command.into();
92 let name = format!("shell:{}", cmd.chars().take(40).collect::<String>());
93 Self {
94 command: cmd,
95 hook_events: events.to_vec(),
96 blocking,
97 hook_name: name,
98 }
99 }
100}
101
102#[async_trait]
103impl Hook for ShellHook {
104 fn events(&self) -> &[HookEvent] {
105 &self.hook_events
106 }
107
108 fn name(&self) -> &str {
109 &self.hook_name
110 }
111
112 async fn on_event(&self, ctx: &HookContext) -> HookAction {
113 let sh = if cfg!(windows) { "cmd" } else { "sh" };
114 let flag = if cfg!(windows) { "/C" } else { "-c" };
115
116 let ctx_json = serde_json::to_string(&serde_json::json!({
117 "event": format!("{:?}", ctx.event),
118 "tool_name": ctx.tool_name,
119 "turn": ctx.turn,
120 }))
121 .unwrap_or_default();
122
123 let output = match std::process::Command::new(sh)
124 .args([flag, &self.command])
125 .env("CERSEI_HOOK_CONTEXT", &ctx_json)
126 .stdin(std::process::Stdio::null())
127 .stdout(std::process::Stdio::piped())
128 .stderr(std::process::Stdio::piped())
129 .output()
130 {
131 Ok(o) => o,
132 Err(e) => {
133 tracing::warn!(command = %self.command, error = %e, "Shell hook failed to spawn");
134 return HookAction::Continue;
135 }
136 };
137
138 if output.status.success() {
139 return HookAction::Continue;
140 }
141
142 let stderr = String::from_utf8_lossy(&output.stderr);
143 let stdout = String::from_utf8_lossy(&output.stdout);
144 let body = if !stderr.trim().is_empty() {
145 stderr.to_string()
146 } else {
147 stdout.to_string()
148 };
149
150 if self.blocking {
151 HookAction::Block(format!("Hook '{}' failed: {}", self.command, body.trim()))
152 } else {
153 tracing::warn!(command = %self.command, body = %body.trim(), "Shell hook returned non-zero");
154 HookAction::Continue
155 }
156 }
157}
158
159pub async fn run_hooks(hooks: &[std::sync::Arc<dyn Hook>], ctx: &HookContext) -> HookAction {
163 for hook in hooks {
164 if hook.events().contains(&ctx.event) {
165 let action = hook.on_event(ctx).await;
166 match &action {
167 HookAction::Continue => continue,
168 _ => return action,
169 }
170 }
171 }
172 HookAction::Continue
173}