Skip to main content

codex_cli_sdk/
hooks.rs

1//! Hook system for observing and reacting to SDK events.
2//!
3//! Hooks fire on stream events **after** the Codex CLI has already acted.
4//! `Block`/`Abort` affect what the SDK consumer sees, not what the CLI executes.
5//! Pre-execution gating requires [`ApprovalPolicy::OnRequest`](crate::ApprovalPolicy)
6//! combined with an [`ApprovalCallback`](crate::ApprovalCallback).
7//!
8//! # Design
9//!
10//! - **First-match dispatch**: hooks are evaluated in order; the first matching hook handles the event.
11//! - **Configurable timeout behavior**: if a hook callback exceeds its timeout, the behavior is
12//!   controlled per-hook via [`HookTimeoutBehavior`]. Defaults to `FailOpen` (event passes through).
13//!   Use `FailClosed` for security/gating hooks where a timeout must not silently allow the event.
14//! - **Async callbacks**: hooks can perform I/O (logging, webhooks, etc.).
15//!
16//! # Processing order
17//!
18//! Hooks run **before** the [`EventCallback`](crate::callback::EventCallback).
19//! For each event the pipeline is:
20//!
21//! 1. Hooks (async, semantically classified, first-match)
22//! 2. `EventCallback` (sync, raw event, observe/filter/transform)
23//!
24//! If a hook returns [`HookDecision::Block`] or [`HookDecision::Abort`], the
25//! `EventCallback` is **not** invoked for that event.
26//!
27//! # Example
28//!
29//! ```rust
30//! use std::sync::Arc;
31//! use std::time::Duration;
32//! use codex_cli_sdk::hooks::{HookMatcher, HookEvent, HookDecision, HookOutput, HookTimeoutBehavior};
33//!
34//! let hook = HookMatcher {
35//!     event: HookEvent::CommandStarted,
36//!     command_filter: Some("rm".into()),
37//!     callback: Arc::new(|input, _ctx| {
38//!         Box::pin(async move {
39//!             eprintln!("Blocked dangerous command: {:?}", input.command);
40//!             HookOutput {
41//!                 decision: HookDecision::Block,
42//!                 reason: Some("dangerous command".into()),
43//!                 replacement_event: None,
44//!             }
45//!         })
46//!     }),
47//!     timeout: Some(Duration::from_secs(5)),
48//!     on_timeout: HookTimeoutBehavior::FailClosed, // block on timeout for security hooks
49//! };
50//! ```
51
52use crate::types::events::ThreadEvent;
53use crate::types::items::ThreadItem;
54use std::future::Future;
55use std::pin::Pin;
56use std::sync::Arc;
57use std::time::Duration;
58
59// ── Hook event classification ─────────────────────────────────
60
61/// The semantic event type that hooks match against.
62#[derive(Debug, Clone, PartialEq, Eq, Hash)]
63pub enum HookEvent {
64    /// A command execution item started (shell command about to run or running).
65    CommandStarted,
66    /// A command execution item completed successfully.
67    CommandCompleted,
68    /// A command execution item failed (non-zero exit or error status).
69    CommandFailed,
70    /// A file change item completed.
71    FileChangeCompleted,
72    /// An agent message item completed.
73    AgentMessage,
74    /// The turn completed successfully.
75    TurnCompleted,
76    /// The turn failed.
77    TurnFailed,
78}
79
80// ── Hook callback types ───────────────────────────────────────
81
82/// Input provided to a hook callback.
83#[derive(Debug, Clone)]
84pub struct HookInput {
85    /// The classified hook event.
86    pub hook_event: HookEvent,
87    /// The shell command (if applicable).
88    pub command: Option<String>,
89    /// The exit code (if applicable).
90    pub exit_code: Option<i32>,
91    /// The message text (if applicable, e.g. agent message or error).
92    pub message_text: Option<String>,
93    /// The raw event that triggered this hook.
94    pub raw_event: ThreadEvent,
95}
96
97/// Context provided alongside the hook input.
98#[derive(Debug, Clone)]
99pub struct HookContext {
100    /// The current thread ID, if known.
101    pub thread_id: Option<String>,
102    /// Number of turns completed so far.
103    pub turn_count: u32,
104}
105
106/// Output returned by a hook callback.
107#[derive(Debug, Clone)]
108pub struct HookOutput {
109    /// The decision for how to handle the event.
110    pub decision: HookDecision,
111    /// Optional human-readable reason for the decision.
112    pub reason: Option<String>,
113    /// Replacement event (only used with `HookDecision::Modify`).
114    pub replacement_event: Option<ThreadEvent>,
115}
116
117impl Default for HookOutput {
118    fn default() -> Self {
119        Self {
120            decision: HookDecision::Allow,
121            reason: None,
122            replacement_event: None,
123        }
124    }
125}
126
127/// What happens when a hook callback exceeds its timeout.
128#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
129pub enum HookTimeoutBehavior {
130    /// Let the event pass through unchanged (default).
131    ///
132    /// Safe for observability hooks (logging, metrics) where a slow hook
133    /// should not block the event stream.
134    #[default]
135    FailOpen,
136
137    /// Block (suppress) the event, as if the hook returned [`HookDecision::Block`].
138    ///
139    /// Use this for security or gating hooks where a timeout must not silently
140    /// allow a potentially dangerous event through.
141    FailClosed,
142}
143
144/// Decision made by a hook.
145#[derive(Debug, Clone, PartialEq, Eq)]
146pub enum HookDecision {
147    /// Allow the event to pass through unchanged.
148    Allow,
149    /// Block (suppress) the event — consumer won't see it.
150    Block,
151    /// Replace the event with `HookOutput::replacement_event`.
152    Modify,
153    /// Abort — terminate the stream entirely.
154    Abort,
155}
156
157/// Async hook callback.
158pub type HookCallback = Arc<
159    dyn Fn(HookInput, HookContext) -> Pin<Box<dyn Future<Output = HookOutput> + Send>>
160        + Send
161        + Sync,
162>;
163
164// ── Hook matcher ──────────────────────────────────────────────
165
166/// A hook registration: matches events and invokes a callback.
167#[derive(Clone)]
168pub struct HookMatcher {
169    /// Which event type this hook matches.
170    pub event: HookEvent,
171    /// Optional command substring filter (only relevant for Command* events).
172    pub command_filter: Option<String>,
173    /// The async callback to invoke when matched.
174    pub callback: HookCallback,
175    /// Per-hook timeout override. Falls back to `ThreadOptions::default_hook_timeout`.
176    pub timeout: Option<Duration>,
177    /// What to do if the callback exceeds its timeout.
178    ///
179    /// Defaults to [`HookTimeoutBehavior::FailOpen`] (event passes through).
180    /// Set to [`HookTimeoutBehavior::FailClosed`] for security/gating hooks.
181    pub on_timeout: HookTimeoutBehavior,
182}
183
184impl std::fmt::Debug for HookMatcher {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        f.debug_struct("HookMatcher")
187            .field("event", &self.event)
188            .field("command_filter", &self.command_filter)
189            .field("timeout", &self.timeout)
190            .field("on_timeout", &self.on_timeout)
191            .finish()
192    }
193}
194
195// ── Classification ────────────────────────────────────────────
196
197/// Classify a `ThreadEvent` into a `HookEvent`, if applicable.
198///
199/// Returns `None` for events that don't map to any hook (e.g. `ThreadStarted`,
200/// `TurnStarted`, approval requests, streaming deltas).
201pub fn classify_hook_event(event: &ThreadEvent) -> Option<HookEvent> {
202    use crate::types::items::CommandExecutionStatus;
203
204    match event {
205        ThreadEvent::ItemStarted {
206            item: ThreadItem::CommandExecution { .. },
207        } => Some(HookEvent::CommandStarted),
208
209        ThreadEvent::ItemCompleted {
210            item: ThreadItem::CommandExecution { status, .. },
211        } => match status {
212            CommandExecutionStatus::Completed => Some(HookEvent::CommandCompleted),
213            CommandExecutionStatus::Failed => Some(HookEvent::CommandFailed),
214            CommandExecutionStatus::InProgress => None,
215        },
216
217        ThreadEvent::ItemCompleted {
218            item: ThreadItem::FileChange { .. },
219        } => Some(HookEvent::FileChangeCompleted),
220
221        ThreadEvent::ItemCompleted {
222            item: ThreadItem::AgentMessage { .. },
223        } => Some(HookEvent::AgentMessage),
224
225        ThreadEvent::TurnCompleted { .. } => Some(HookEvent::TurnCompleted),
226        ThreadEvent::TurnFailed { .. } => Some(HookEvent::TurnFailed),
227
228        _ => None,
229    }
230}
231
232/// Build a `HookInput` from a classified event.
233pub fn build_hook_input(hook_event: HookEvent, event: &ThreadEvent) -> HookInput {
234    let (command, exit_code, message_text) = match event {
235        ThreadEvent::ItemStarted {
236            item: ThreadItem::CommandExecution { command, .. },
237        }
238        | ThreadEvent::ItemCompleted {
239            item: ThreadItem::CommandExecution { command, .. },
240        } => {
241            let exit_code = match event {
242                ThreadEvent::ItemCompleted {
243                    item: ThreadItem::CommandExecution { exit_code, .. },
244                } => *exit_code,
245                _ => None,
246            };
247            (Some(command.clone()), exit_code, None)
248        }
249
250        ThreadEvent::ItemCompleted {
251            item: ThreadItem::AgentMessage { text, .. },
252        } => (None, None, Some(text.clone())),
253
254        ThreadEvent::TurnFailed { error } => (None, None, Some(error.message.clone())),
255
256        _ => (None, None, None),
257    };
258
259    HookInput {
260        hook_event,
261        command,
262        exit_code,
263        message_text,
264        raw_event: event.clone(),
265    }
266}
267
268// ── Dispatch ──────────────────────────────────────────────────
269
270/// Dispatch an event through the hook chain (first-match).
271///
272/// Returns the `HookOutput` from the first matching hook, or `None` if no hook matched.
273/// On timeout, the hook is skipped (fail-open) and dispatch continues to the next hook.
274pub async fn dispatch_hook(
275    event: &ThreadEvent,
276    hooks: &[HookMatcher],
277    context: &HookContext,
278    default_timeout: Duration,
279) -> Option<HookOutput> {
280    let hook_event = classify_hook_event(event)?;
281    let input = build_hook_input(hook_event.clone(), event);
282
283    for hook in hooks {
284        if hook.event != hook_event {
285            continue;
286        }
287
288        // Apply command filter if present.
289        if let Some(ref filter) = hook.command_filter {
290            match &input.command {
291                Some(cmd) if cmd.contains(filter.as_str()) => {}
292                _ => continue,
293            }
294        }
295
296        let timeout = hook.timeout.unwrap_or(default_timeout);
297        let future = (hook.callback)(input.clone(), context.clone());
298
299        match tokio::time::timeout(timeout, future).await {
300            Ok(output) => return Some(output),
301            Err(_) => {
302                tracing::warn!(
303                    "Hook timed out after {:?} for {:?} — {:?}",
304                    timeout,
305                    hook.event,
306                    hook.on_timeout,
307                );
308                match hook.on_timeout {
309                    HookTimeoutBehavior::FailOpen => continue,
310                    HookTimeoutBehavior::FailClosed => {
311                        return Some(HookOutput {
312                            decision: HookDecision::Block,
313                            reason: Some(format!("hook timeout after {timeout:?} (fail-closed)")),
314                            replacement_event: None,
315                        });
316                    }
317                }
318            }
319        }
320    }
321
322    None
323}
324
325// ── Tests ─────────────────────────────────────────────────────
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use crate::types::events::Usage;
331
332    fn make_command_started(cmd: &str) -> ThreadEvent {
333        ThreadEvent::ItemStarted {
334            item: ThreadItem::CommandExecution {
335                id: "cmd-1".into(),
336                command: cmd.into(),
337                aggregated_output: String::new(),
338                exit_code: None,
339                status: crate::types::items::CommandExecutionStatus::InProgress,
340            },
341        }
342    }
343
344    fn make_command_completed(cmd: &str, code: i32) -> ThreadEvent {
345        ThreadEvent::ItemCompleted {
346            item: ThreadItem::CommandExecution {
347                id: "cmd-1".into(),
348                command: cmd.into(),
349                aggregated_output: "output".into(),
350                exit_code: Some(code),
351                status: crate::types::items::CommandExecutionStatus::Completed,
352            },
353        }
354    }
355
356    fn make_turn_completed() -> ThreadEvent {
357        ThreadEvent::TurnCompleted {
358            usage: Usage {
359                input_tokens: 100,
360                cached_input_tokens: 0,
361                output_tokens: 50,
362            },
363        }
364    }
365
366    fn make_context() -> HookContext {
367        HookContext {
368            thread_id: Some("thread-1".into()),
369            turn_count: 0,
370        }
371    }
372
373    #[test]
374    fn classify_command_started() {
375        let event = make_command_started("ls -la");
376        assert_eq!(classify_hook_event(&event), Some(HookEvent::CommandStarted));
377    }
378
379    #[test]
380    fn classify_command_completed() {
381        let event = make_command_completed("ls", 0);
382        assert_eq!(
383            classify_hook_event(&event),
384            Some(HookEvent::CommandCompleted)
385        );
386    }
387
388    #[test]
389    fn classify_turn_completed() {
390        let event = make_turn_completed();
391        assert_eq!(classify_hook_event(&event), Some(HookEvent::TurnCompleted));
392    }
393
394    #[test]
395    fn classify_unmatched_returns_none() {
396        let event = ThreadEvent::TurnStarted;
397        assert_eq!(classify_hook_event(&event), None);
398    }
399
400    #[test]
401    fn build_input_extracts_command() {
402        let event = make_command_started("git status");
403        let input = build_hook_input(HookEvent::CommandStarted, &event);
404        assert_eq!(input.command, Some("git status".into()));
405        assert_eq!(input.exit_code, None);
406    }
407
408    #[test]
409    fn build_input_extracts_exit_code() {
410        let event = make_command_completed("ls", 1);
411        let input = build_hook_input(HookEvent::CommandCompleted, &event);
412        assert_eq!(input.exit_code, Some(1));
413    }
414
415    #[tokio::test]
416    async fn dispatch_first_match() {
417        let hook = HookMatcher {
418            event: HookEvent::CommandStarted,
419            command_filter: None,
420            callback: Arc::new(|_input, _ctx| {
421                Box::pin(async {
422                    HookOutput {
423                        decision: HookDecision::Block,
424                        reason: Some("blocked".into()),
425                        replacement_event: None,
426                    }
427                })
428            }),
429            timeout: None,
430            on_timeout: Default::default(),
431        };
432
433        let event = make_command_started("ls");
434        let ctx = make_context();
435        let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
436
437        assert!(result.is_some());
438        let output = result.unwrap();
439        assert_eq!(output.decision, HookDecision::Block);
440    }
441
442    #[tokio::test]
443    async fn dispatch_command_filter() {
444        let hook = HookMatcher {
445            event: HookEvent::CommandStarted,
446            command_filter: Some("rm".into()),
447            callback: Arc::new(|_input, _ctx| {
448                Box::pin(async {
449                    HookOutput {
450                        decision: HookDecision::Block,
451                        reason: None,
452                        replacement_event: None,
453                    }
454                })
455            }),
456            timeout: None,
457            on_timeout: Default::default(),
458        };
459
460        let ctx = make_context();
461
462        // Should NOT match "ls"
463        let ls_event = make_command_started("ls -la");
464        let result = dispatch_hook(&ls_event, &[hook], &ctx, Duration::from_secs(5)).await;
465        assert!(result.is_none());
466    }
467
468    #[tokio::test]
469    async fn dispatch_command_filter_matches() {
470        let hook = HookMatcher {
471            event: HookEvent::CommandStarted,
472            command_filter: Some("rm".into()),
473            callback: Arc::new(|_input, _ctx| {
474                Box::pin(async {
475                    HookOutput {
476                        decision: HookDecision::Block,
477                        reason: None,
478                        replacement_event: None,
479                    }
480                })
481            }),
482            timeout: None,
483            on_timeout: Default::default(),
484        };
485
486        let ctx = make_context();
487
488        let rm_event = make_command_started("rm -rf /tmp/test");
489        let result = dispatch_hook(&rm_event, &[hook], &ctx, Duration::from_secs(5)).await;
490        assert!(result.is_some());
491        assert_eq!(result.unwrap().decision, HookDecision::Block);
492    }
493
494    #[tokio::test]
495    async fn dispatch_no_match_returns_none() {
496        let hook = HookMatcher {
497            event: HookEvent::TurnCompleted,
498            command_filter: None,
499            callback: Arc::new(|_input, _ctx| Box::pin(async { HookOutput::default() })),
500            timeout: None,
501            on_timeout: Default::default(),
502        };
503
504        let event = make_command_started("ls");
505        let ctx = make_context();
506        let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
507        assert!(result.is_none());
508    }
509
510    #[tokio::test]
511    async fn dispatch_timeout_fails_open() {
512        let hook = HookMatcher {
513            event: HookEvent::CommandStarted,
514            command_filter: None,
515            callback: Arc::new(|_input, _ctx| {
516                Box::pin(async {
517                    // Simulate a slow hook
518                    tokio::time::sleep(Duration::from_secs(10)).await;
519                    HookOutput {
520                        decision: HookDecision::Block,
521                        reason: None,
522                        replacement_event: None,
523                    }
524                })
525            }),
526            timeout: Some(Duration::from_millis(10)),
527            on_timeout: HookTimeoutBehavior::FailOpen,
528        };
529
530        let event = make_command_started("ls");
531        let ctx = make_context();
532        let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
533
534        // Timeout → fail-open → no match (since it's the only hook)
535        assert!(result.is_none());
536    }
537
538    #[tokio::test]
539    async fn dispatch_timeout_fail_closed_blocks() {
540        let hook = HookMatcher {
541            event: HookEvent::CommandStarted,
542            command_filter: None,
543            callback: Arc::new(|_input, _ctx| {
544                Box::pin(async {
545                    tokio::time::sleep(Duration::from_secs(10)).await;
546                    HookOutput::default()
547                })
548            }),
549            timeout: Some(Duration::from_millis(10)),
550            on_timeout: HookTimeoutBehavior::FailClosed,
551        };
552
553        let event = make_command_started("dangerous-cmd");
554        let ctx = make_context();
555        let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
556
557        // Timeout + fail-closed → Block decision returned
558        assert!(result.is_some());
559        let output = result.unwrap();
560        assert_eq!(output.decision, HookDecision::Block);
561        assert!(output.reason.as_deref().unwrap_or("").contains("timeout"));
562    }
563
564    #[tokio::test]
565    async fn dispatch_all_four_decisions() {
566        for decision in [
567            HookDecision::Allow,
568            HookDecision::Block,
569            HookDecision::Modify,
570            HookDecision::Abort,
571        ] {
572            let d = decision.clone();
573            let hook = HookMatcher {
574                event: HookEvent::TurnCompleted,
575                command_filter: None,
576                callback: Arc::new(move |_input, _ctx| {
577                    let d = d.clone();
578                    Box::pin(async move {
579                        HookOutput {
580                            decision: d,
581                            reason: None,
582                            replacement_event: None,
583                        }
584                    })
585                }),
586                timeout: None,
587                on_timeout: Default::default(),
588            };
589
590            let event = make_turn_completed();
591            let ctx = make_context();
592            let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
593            assert!(result.is_some());
594            assert_eq!(result.unwrap().decision, decision);
595        }
596    }
597}