Skip to main content

cersei_hooks/
lib.rs

1//! cersei-hooks: Hook/middleware system for the Cersei SDK.
2//!
3//! Hooks intercept events in the agent lifecycle (pre/post tool use, model turns, etc.)
4//! and can block, modify, or inject messages.
5
6use async_trait::async_trait;
7use cersei_types::Message;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11// ─── Hook trait ──────────────────────────────────────────────────────────────
12
13#[async_trait]
14pub trait Hook: Send + Sync {
15    /// Which events this hook handles.
16    fn events(&self) -> &[HookEvent];
17
18    /// Called when a matching event fires. Returns an action to control flow.
19    async fn on_event(&self, ctx: &HookContext) -> HookAction;
20
21    /// Optional name for logging/debugging.
22    fn name(&self) -> &str {
23        "unnamed-hook"
24    }
25}
26
27// ─── Hook events ─────────────────────────────────────────────────────────────
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
30#[serde(rename_all = "PascalCase")]
31pub enum HookEvent {
32    PreToolUse,
33    PostToolUse,
34    PreModelTurn,
35    PostModelTurn,
36    Stop,
37    Error,
38    /// Fires every N turns (agent-configured cadence, default 10). Used by
39    /// the skills-nudge hook to trigger background skill review/creation
40    /// without blocking the agent loop. `HookContext::turn` carries the
41    /// current turn counter.
42    TurnsElapsed,
43}
44
45// ─── Hook context ────────────────────────────────────────────────────────────
46
47#[derive(Debug, Clone)]
48pub struct HookContext {
49    pub event: HookEvent,
50    pub tool_name: Option<String>,
51    pub tool_input: Option<Value>,
52    pub tool_result: Option<String>,
53    pub tool_is_error: Option<bool>,
54    pub turn: u32,
55    pub cumulative_cost_usd: f64,
56    pub message_count: usize,
57}
58
59impl HookContext {
60    pub fn cumulative_cost_usd(&self) -> f64 {
61        self.cumulative_cost_usd
62    }
63}
64
65// ─── Hook actions ────────────────────────────────────────────────────────────
66
67#[derive(Debug, Clone)]
68pub enum HookAction {
69    /// Continue normally.
70    Continue,
71    /// Block the operation (PreToolUse only). Includes reason.
72    Block(String),
73    /// Replace the tool input with modified data (PreToolUse only).
74    ModifyInput(Value),
75    /// Inject a message into the conversation.
76    InjectMessage(Message),
77}
78
79// ─── Shell hook (compat with cc-core HookEntry) ──────────────────────────────
80
81/// A hook that runs a shell command from settings.json configuration.
82pub struct ShellHook {
83    pub command: String,
84    pub hook_events: Vec<HookEvent>,
85    pub blocking: bool,
86    hook_name: String,
87}
88
89impl ShellHook {
90    pub fn new(command: impl Into<String>, events: &[HookEvent], blocking: bool) -> Self {
91        let cmd = command.into();
92        let name = format!("shell:{}", cmd.chars().take(40).collect::<String>());
93        Self {
94            command: cmd,
95            hook_events: events.to_vec(),
96            blocking,
97            hook_name: name,
98        }
99    }
100}
101
102#[async_trait]
103impl Hook for ShellHook {
104    fn events(&self) -> &[HookEvent] {
105        &self.hook_events
106    }
107
108    fn name(&self) -> &str {
109        &self.hook_name
110    }
111
112    async fn on_event(&self, ctx: &HookContext) -> HookAction {
113        let sh = if cfg!(windows) { "cmd" } else { "sh" };
114        let flag = if cfg!(windows) { "/C" } else { "-c" };
115
116        let ctx_json = serde_json::to_string(&serde_json::json!({
117            "event": format!("{:?}", ctx.event),
118            "tool_name": ctx.tool_name,
119            "turn": ctx.turn,
120        }))
121        .unwrap_or_default();
122
123        let output = match std::process::Command::new(sh)
124            .args([flag, &self.command])
125            .env("CERSEI_HOOK_CONTEXT", &ctx_json)
126            .stdin(std::process::Stdio::null())
127            .stdout(std::process::Stdio::piped())
128            .stderr(std::process::Stdio::piped())
129            .output()
130        {
131            Ok(o) => o,
132            Err(e) => {
133                tracing::warn!(command = %self.command, error = %e, "Shell hook failed to spawn");
134                return HookAction::Continue;
135            }
136        };
137
138        if output.status.success() {
139            return HookAction::Continue;
140        }
141
142        let stderr = String::from_utf8_lossy(&output.stderr);
143        let stdout = String::from_utf8_lossy(&output.stdout);
144        let body = if !stderr.trim().is_empty() {
145            stderr.to_string()
146        } else {
147            stdout.to_string()
148        };
149
150        if self.blocking {
151            HookAction::Block(format!("Hook '{}' failed: {}", self.command, body.trim()))
152        } else {
153            tracing::warn!(command = %self.command, body = %body.trim(), "Shell hook returned non-zero");
154            HookAction::Continue
155        }
156    }
157}
158
159// ─── Hook runner ─────────────────────────────────────────────────────────────
160
161/// Execute all matching hooks for a given event, returning the first non-Continue action.
162pub async fn run_hooks(hooks: &[std::sync::Arc<dyn Hook>], ctx: &HookContext) -> HookAction {
163    for hook in hooks {
164        if hook.events().contains(&ctx.event) {
165            let action = hook.on_event(ctx).await;
166            match &action {
167                HookAction::Continue => continue,
168                _ => return action,
169            }
170        }
171    }
172    HookAction::Continue
173}