Skip to main content

claude_cli_sdk/
hooks.rs

1//! Hook system for intercepting agent lifecycle events.
2//!
3//! Hooks allow SDK consumers to observe and modify tool executions by
4//! registering callbacks for specific lifecycle events (e.g., pre/post tool
5//! use). The hook system is optional — sessions without hooks behave normally.
6//!
7//! # Architecture
8//!
9//! Hooks are registered on [`ClientConfig`](crate::config::ClientConfig) as a list of
10//! [`HookMatcher`] entries. Each matcher targets a specific [`HookEvent`] and
11//! optionally filters by tool name. When a matching event occurs, the
12//! [`HookCallback`] is invoked with the event details.
13//!
14//! # Timeout Enforcement
15//!
16//! Each hook callback is subject to a timeout. The effective timeout is
17//! `HookMatcher::timeout` if set, otherwise `ClientConfig::default_hook_timeout`
18//! (default: 30s). If a callback exceeds its timeout, the SDK logs a warning
19//! and defaults to [`HookOutput::allow()`] (fail-open). This ensures a
20//! misbehaving hook never permanently blocks a session.
21//!
22//! # Internal Protocol
23//!
24//! When hooks are registered, the SDK exchanges structured JSON messages with
25//! the CLI via the control protocol. `HookRequest` and `HookResponse` are
26//! the internal wire types (not exposed publicly).
27
28use std::sync::Arc;
29use std::time::Duration;
30
31use serde::{Deserialize, Serialize};
32
33// ── Hook Events ──────────────────────────────────────────────────────────────
34
35/// Lifecycle events that hooks can intercept.
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
37#[serde(rename_all = "snake_case")]
38pub enum HookEvent {
39    /// Before a tool is executed. Can modify input or deny execution.
40    PreToolUse,
41    /// After a tool completes successfully.
42    PostToolUse,
43    /// After a tool fails with an error.
44    PostToolUseFailure,
45    /// When a user prompt is submitted (before processing).
46    UserPromptSubmit,
47    /// When the agent session stops.
48    Stop,
49    /// When a subagent session stops.
50    SubagentStop,
51    /// Before context compaction occurs.
52    PreCompact,
53    /// A general notification event.
54    Notification,
55}
56
57// ── Hook Matcher ─────────────────────────────────────────────────────────────
58
59/// A hook registration that pairs a lifecycle event with a callback.
60pub struct HookMatcher {
61    /// The event type this hook matches.
62    pub event: HookEvent,
63    /// Optional tool name filter. If `Some`, only matches the named tool.
64    /// If `None`, matches all tools for the given event.
65    pub tool_name: Option<String>,
66    /// The callback to invoke when the hook matches.
67    pub callback: HookCallback,
68    /// Optional timeout for the callback. If `None`, uses the default.
69    pub timeout: Option<Duration>,
70}
71
72impl std::fmt::Debug for HookMatcher {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_struct("HookMatcher")
75            .field("event", &self.event)
76            .field("tool_name", &self.tool_name)
77            .field("timeout", &self.timeout)
78            .finish()
79    }
80}
81
82impl HookMatcher {
83    /// Create a new hook matcher for the given event with the given callback.
84    pub fn new(event: HookEvent, callback: HookCallback) -> Self {
85        Self {
86            event,
87            tool_name: None,
88            callback,
89            timeout: None,
90        }
91    }
92
93    /// Filter this hook to only match a specific tool name.
94    #[must_use]
95    pub fn for_tool(mut self, name: impl Into<String>) -> Self {
96        self.tool_name = Some(name.into());
97        self
98    }
99
100    /// Set a timeout for this hook's callback.
101    #[must_use]
102    pub fn with_timeout(mut self, timeout: Duration) -> Self {
103        self.timeout = Some(timeout);
104        self
105    }
106
107    /// Returns `true` if this matcher matches the given event and tool name.
108    #[must_use]
109    pub fn matches(&self, event: HookEvent, tool_name: Option<&str>) -> bool {
110        if self.event != event {
111            return false;
112        }
113        match (&self.tool_name, tool_name) {
114            (Some(filter), Some(name)) => filter == name,
115            (Some(_), None) => false,
116            (None, _) => true,
117        }
118    }
119}
120
121// ── Callback types ───────────────────────────────────────────────────────────
122
123use crate::util::BoxFuture;
124
125/// The callback function invoked when a hook matches.
126///
127/// Arguments:
128/// - `HookInput` — details about the event
129/// - `Option<String>` — session ID (if available)
130/// - `HookContext` — additional context
131///
132/// Returns a [`HookOutput`] with the decision and optional modifications.
133pub type HookCallback =
134    Arc<dyn Fn(HookInput, Option<String>, HookContext) -> BoxFuture<HookOutput> + Send + Sync>;
135
136/// Input data provided to a hook callback.
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct HookInput {
139    /// The lifecycle event that triggered this hook.
140    pub hook_event: HookEvent,
141    /// The tool name (if applicable).
142    #[serde(default, skip_serializing_if = "Option::is_none")]
143    pub tool_name: Option<String>,
144    /// The tool's input parameters (if applicable).
145    #[serde(default, skip_serializing_if = "Option::is_none")]
146    pub tool_input: Option<serde_json::Value>,
147    /// The tool's result (for PostToolUse/PostToolUseFailure).
148    #[serde(default, skip_serializing_if = "Option::is_none")]
149    pub tool_result: Option<serde_json::Value>,
150    /// The tool_use_id (if applicable).
151    #[serde(default, skip_serializing_if = "Option::is_none")]
152    pub tool_use_id: Option<String>,
153    /// Extra data for future extension.
154    #[serde(default, skip_serializing_if = "Option::is_none")]
155    pub extra: Option<serde_json::Value>,
156}
157
158/// Additional context provided to hook callbacks.
159#[derive(Debug, Clone)]
160pub struct HookContext {
161    /// The active session ID.
162    pub session_id: Option<String>,
163}
164
165/// The output returned by a hook callback.
166#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
167pub struct HookOutput {
168    /// The hook's decision.
169    pub decision: HookDecision,
170    /// Human-readable reason for the decision.
171    #[serde(default, skip_serializing_if = "Option::is_none")]
172    pub reason: Option<String>,
173    /// Updated tool input (only meaningful for `PreToolUse`).
174    #[serde(default, skip_serializing_if = "Option::is_none")]
175    pub updated_input: Option<serde_json::Value>,
176    /// Extra data for future extension.
177    #[serde(default, skip_serializing_if = "Option::is_none")]
178    pub extra: Option<serde_json::Value>,
179}
180
181/// The decision a hook can make about the intercepted event.
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
183#[serde(rename_all = "snake_case")]
184pub enum HookDecision {
185    /// Allow the event to proceed as normal.
186    Allow,
187    /// Block the tool execution (PreToolUse only).
188    Block,
189    /// Modify the input and continue (PreToolUse only).
190    Modify,
191    /// Abort the entire session.
192    Abort,
193}
194
195impl HookOutput {
196    /// Convenience: allow the event to proceed.
197    #[must_use]
198    pub fn allow() -> Self {
199        Self {
200            decision: HookDecision::Allow,
201            reason: None,
202            updated_input: None,
203            extra: None,
204        }
205    }
206
207    /// Convenience: block the tool execution.
208    #[must_use]
209    pub fn block(reason: impl Into<String>) -> Self {
210        Self {
211            decision: HookDecision::Block,
212            reason: Some(reason.into()),
213            updated_input: None,
214            extra: None,
215        }
216    }
217
218    /// Convenience: modify the tool input.
219    #[must_use]
220    pub fn modify(updated_input: serde_json::Value) -> Self {
221        Self {
222            decision: HookDecision::Modify,
223            reason: None,
224            updated_input: Some(updated_input),
225            extra: None,
226        }
227    }
228
229    /// Convenience: abort the session.
230    #[must_use]
231    pub fn abort(reason: impl Into<String>) -> Self {
232        Self {
233            decision: HookDecision::Abort,
234            reason: Some(reason.into()),
235            updated_input: None,
236            extra: None,
237        }
238    }
239}
240
241// ── Internal protocol messages ───────────────────────────────────────────────
242
243/// A hook request received from the CLI.
244#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
245pub(crate) struct HookRequest {
246    /// Unique request ID for correlation.
247    pub request_id: String,
248    /// The hook event.
249    pub hook_event: HookEvent,
250    /// Tool name (if applicable).
251    #[serde(default, skip_serializing_if = "Option::is_none")]
252    pub tool_name: Option<String>,
253    /// Tool input (if applicable).
254    #[serde(default, skip_serializing_if = "Option::is_none")]
255    pub tool_input: Option<serde_json::Value>,
256    /// Tool result (if applicable).
257    #[serde(default, skip_serializing_if = "Option::is_none")]
258    pub tool_result: Option<serde_json::Value>,
259    /// Tool use ID (if applicable).
260    #[serde(default, skip_serializing_if = "Option::is_none")]
261    pub tool_use_id: Option<String>,
262}
263
264impl HookRequest {
265    /// Convert this wire request into a [`HookInput`] suitable for callbacks.
266    #[cfg(test)]
267    pub fn into_hook_input(self) -> HookInput {
268        HookInput {
269            hook_event: self.hook_event,
270            tool_name: self.tool_name,
271            tool_input: self.tool_input,
272            tool_result: self.tool_result,
273            tool_use_id: self.tool_use_id,
274            extra: None,
275        }
276    }
277
278    /// Borrow this wire request as a [`HookInput`] suitable for callbacks.
279    pub(crate) fn to_hook_input(&self) -> HookInput {
280        HookInput {
281            hook_event: self.hook_event,
282            tool_name: self.tool_name.clone(),
283            tool_input: self.tool_input.clone(),
284            tool_result: self.tool_result.clone(),
285            tool_use_id: self.tool_use_id.clone(),
286            extra: None,
287        }
288    }
289}
290
291/// A hook response sent back to the CLI.
292///
293/// Shares the same `{kind, request_id, result}` wire envelope pattern as
294/// [`ControlResponse`](crate::permissions::ControlResponse), but they are
295/// kept separate because they carry different result types and serve different
296/// protocol flows (hook lifecycle vs permission handshake). A generic
297/// `ControlEnvelope<T>` was considered but rejected as over-abstraction for
298/// two types.
299#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
300pub(crate) struct HookResponse {
301    /// The kind of response.
302    pub kind: String,
303    /// Must match the `request_id` from the corresponding request.
304    pub request_id: String,
305    /// The result payload.
306    pub result: HookOutput,
307}
308
309impl HookResponse {
310    /// Create a response from a hook output and the originating request ID.
311    pub fn from_output(request_id: String, output: HookOutput) -> Self {
312        Self {
313            kind: "hook_response".into(),
314            request_id,
315            result: output,
316        }
317    }
318}
319
320// ── Hook dispatch ────────────────────────────────────────────────────────────
321
322/// Dispatch a hook request to matching callbacks with timeout enforcement.
323///
324/// If a matching callback times out, the hook defaults to `HookOutput::allow()`
325/// (fail-open). This ensures a misbehaving hook never permanently blocks a
326/// session.
327pub(crate) async fn dispatch_hook(
328    req: &HookRequest,
329    hooks: &[HookMatcher],
330    default_hook_timeout: Duration,
331    session_id: Option<String>,
332) -> HookOutput {
333    let input = req.to_hook_input();
334
335    for matcher in hooks {
336        if !matcher.matches(req.hook_event, req.tool_name.as_deref()) {
337            continue;
338        }
339
340        let effective_timeout = matcher.timeout.unwrap_or(default_hook_timeout);
341        let ctx = HookContext {
342            session_id: session_id.clone(),
343        };
344
345        let fut = (matcher.callback)(input.clone(), session_id.clone(), ctx);
346        match tokio::time::timeout(effective_timeout, fut).await {
347            Ok(output) => return output,
348            Err(_) => {
349                tracing::warn!(
350                    event = ?req.hook_event,
351                    tool = ?req.tool_name,
352                    timeout_secs = effective_timeout.as_secs_f64(),
353                    "hook callback timed out, defaulting to allow (fail-open)"
354                );
355                return HookOutput::allow();
356            }
357        }
358    }
359
360    // No matching hook — allow by default.
361    HookOutput::allow()
362}
363
364// ── Tests ────────────────────────────────────────────────────────────────────
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn hook_event_round_trip() {
372        let events = [
373            HookEvent::PreToolUse,
374            HookEvent::PostToolUse,
375            HookEvent::PostToolUseFailure,
376            HookEvent::UserPromptSubmit,
377            HookEvent::Stop,
378            HookEvent::SubagentStop,
379            HookEvent::PreCompact,
380            HookEvent::Notification,
381        ];
382        for event in events {
383            let json = serde_json::to_string(&event).unwrap();
384            let decoded: HookEvent = serde_json::from_str(&json).unwrap();
385            assert_eq!(event, decoded, "round-trip failed for {event:?}");
386        }
387    }
388
389    #[test]
390    fn hook_matcher_matches_any_tool() {
391        let cb: HookCallback = Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
392        let matcher = HookMatcher::new(HookEvent::PreToolUse, cb);
393        assert!(matcher.matches(HookEvent::PreToolUse, Some("bash")));
394        assert!(matcher.matches(HookEvent::PreToolUse, Some("read_file")));
395        assert!(matcher.matches(HookEvent::PreToolUse, None));
396        assert!(!matcher.matches(HookEvent::PostToolUse, Some("bash")));
397    }
398
399    #[test]
400    fn hook_matcher_matches_specific_tool() {
401        let cb: HookCallback = Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
402        let matcher = HookMatcher::new(HookEvent::PreToolUse, cb).for_tool("bash");
403        assert!(matcher.matches(HookEvent::PreToolUse, Some("bash")));
404        assert!(!matcher.matches(HookEvent::PreToolUse, Some("read_file")));
405        assert!(!matcher.matches(HookEvent::PreToolUse, None));
406    }
407
408    #[test]
409    fn hook_matcher_with_timeout() {
410        let cb: HookCallback = Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
411        let matcher = HookMatcher::new(HookEvent::Stop, cb).with_timeout(Duration::from_secs(5));
412        assert_eq!(matcher.timeout, Some(Duration::from_secs(5)));
413    }
414
415    #[test]
416    fn hook_output_allow() {
417        let output = HookOutput::allow();
418        assert_eq!(output.decision, HookDecision::Allow);
419        assert!(output.reason.is_none());
420    }
421
422    #[test]
423    fn hook_output_block() {
424        let output = HookOutput::block("dangerous command");
425        assert_eq!(output.decision, HookDecision::Block);
426        assert_eq!(output.reason.as_deref(), Some("dangerous command"));
427    }
428
429    #[test]
430    fn hook_output_modify() {
431        let output = HookOutput::modify(serde_json::json!({"safe": true}));
432        assert_eq!(output.decision, HookDecision::Modify);
433        assert!(output.updated_input.is_some());
434    }
435
436    #[test]
437    fn hook_output_abort() {
438        let output = HookOutput::abort("critical failure");
439        assert_eq!(output.decision, HookDecision::Abort);
440        assert_eq!(output.reason.as_deref(), Some("critical failure"));
441    }
442
443    #[test]
444    fn hook_output_round_trip() {
445        let output = HookOutput {
446            decision: HookDecision::Modify,
447            reason: Some("safety".into()),
448            updated_input: Some(serde_json::json!({"command": "ls"})),
449            extra: None,
450        };
451        let json = serde_json::to_string(&output).unwrap();
452        let decoded: HookOutput = serde_json::from_str(&json).unwrap();
453        assert_eq!(output.decision, decoded.decision);
454        assert_eq!(output.reason, decoded.reason);
455        assert_eq!(output.updated_input, decoded.updated_input);
456    }
457
458    #[test]
459    fn hook_request_round_trip() {
460        let req = HookRequest {
461            request_id: "hr-1".into(),
462            hook_event: HookEvent::PreToolUse,
463            tool_name: Some("bash".into()),
464            tool_input: Some(serde_json::json!({"command": "echo hello"})),
465            tool_result: None,
466            tool_use_id: Some("tu-1".into()),
467        };
468        let json = serde_json::to_string(&req).unwrap();
469        let decoded: HookRequest = serde_json::from_str(&json).unwrap();
470        assert_eq!(req, decoded);
471    }
472
473    #[test]
474    fn hook_request_into_hook_input() {
475        let req = HookRequest {
476            request_id: "hr-1".into(),
477            hook_event: HookEvent::PostToolUse,
478            tool_name: Some("bash".into()),
479            tool_input: None,
480            tool_result: Some(serde_json::json!("output")),
481            tool_use_id: Some("tu-1".into()),
482        };
483        let input = req.into_hook_input();
484        assert_eq!(input.hook_event, HookEvent::PostToolUse);
485        assert_eq!(input.tool_name.as_deref(), Some("bash"));
486        assert!(input.tool_result.is_some());
487    }
488
489    #[test]
490    fn hook_response_from_output() {
491        let output = HookOutput::allow();
492        let resp = HookResponse::from_output("req-1".into(), output);
493        assert_eq!(resp.kind, "hook_response");
494        assert_eq!(resp.request_id, "req-1");
495        assert_eq!(resp.result.decision, HookDecision::Allow);
496    }
497
498    #[test]
499    fn hook_response_round_trip() {
500        let resp = HookResponse {
501            kind: "hook_response".into(),
502            request_id: "hr-1".into(),
503            result: HookOutput::block("no"),
504        };
505        let json = serde_json::to_string(&resp).unwrap();
506        let decoded: HookResponse = serde_json::from_str(&json).unwrap();
507        assert_eq!(resp, decoded);
508    }
509
510    #[test]
511    fn hook_decision_serde() {
512        let decisions = [
513            (HookDecision::Allow, r#""allow""#),
514            (HookDecision::Block, r#""block""#),
515            (HookDecision::Modify, r#""modify""#),
516            (HookDecision::Abort, r#""abort""#),
517        ];
518        for (decision, expected_json) in decisions {
519            let json = serde_json::to_string(&decision).unwrap();
520            assert_eq!(json, expected_json);
521            let decoded: HookDecision = serde_json::from_str(&json).unwrap();
522            assert_eq!(decision, decoded);
523        }
524    }
525
526    #[test]
527    fn hook_input_optional_fields() {
528        // Minimal input with only required fields.
529        let json = r#"{"hook_event":"stop"}"#;
530        let input: HookInput = serde_json::from_str(json).unwrap();
531        assert_eq!(input.hook_event, HookEvent::Stop);
532        assert!(input.tool_name.is_none());
533        assert!(input.tool_input.is_none());
534        assert!(input.tool_result.is_none());
535    }
536
537    // ── Hook dispatch tests ──────────────────────────────────────────────
538
539    #[tokio::test]
540    async fn hook_timeout_defaults_to_config_value() {
541        let cb: HookCallback =
542            Arc::new(|_, _, _| Box::pin(async { HookOutput::block("should arrive") }));
543        let matchers = vec![HookMatcher::new(HookEvent::PreToolUse, cb)];
544
545        let req = HookRequest {
546            request_id: "r1".into(),
547            hook_event: HookEvent::PreToolUse,
548            tool_name: Some("Bash".into()),
549            tool_input: None,
550            tool_result: None,
551            tool_use_id: None,
552        };
553
554        let output = dispatch_hook(&req, &matchers, Duration::from_secs(30), None).await;
555        assert_eq!(output.decision, HookDecision::Block);
556    }
557
558    #[tokio::test]
559    async fn hook_timeout_override() {
560        let cb: HookCallback =
561            Arc::new(|_, _, _| Box::pin(async { HookOutput::block("custom timeout") }));
562        let matchers =
563            vec![HookMatcher::new(HookEvent::PreToolUse, cb).with_timeout(Duration::from_secs(60))];
564
565        let req = HookRequest {
566            request_id: "r1".into(),
567            hook_event: HookEvent::PreToolUse,
568            tool_name: None,
569            tool_input: None,
570            tool_result: None,
571            tool_use_id: None,
572        };
573
574        // Should use the per-matcher timeout (60s), not default (1ms).
575        let output = dispatch_hook(&req, &matchers, Duration::from_millis(1), None).await;
576        assert_eq!(output.decision, HookDecision::Block);
577    }
578
579    #[tokio::test]
580    async fn hook_timeout_fires_returns_allow() {
581        // Hook that sleeps forever.
582        let cb: HookCallback = Arc::new(|_, _, _| {
583            Box::pin(async {
584                tokio::time::sleep(Duration::from_secs(3600)).await;
585                HookOutput::block("never reached")
586            })
587        });
588        let matchers = vec![HookMatcher::new(HookEvent::PreToolUse, cb)];
589
590        let req = HookRequest {
591            request_id: "r1".into(),
592            hook_event: HookEvent::PreToolUse,
593            tool_name: Some("Bash".into()),
594            tool_input: None,
595            tool_result: None,
596            tool_use_id: None,
597        };
598
599        let output = dispatch_hook(&req, &matchers, Duration::from_millis(10), None).await;
600        // Fail-open: timed out hook should default to Allow.
601        assert_eq!(output.decision, HookDecision::Allow);
602    }
603}