Skip to main content

gemini_cli_sdk/
hooks.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Duration;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8/// Hook events that can be intercepted.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum HookEvent {
12    /// Before a tool is executed. Can modify input or deny.
13    PreToolUse,
14    /// After a tool executes successfully.
15    PostToolUse,
16    /// After a tool execution fails.
17    PostToolUseFailure,
18    /// Before a user prompt is sent to the agent.
19    UserPromptSubmit,
20    /// When the agent stops (turn ends or process exits).
21    Stop,
22    /// Subagent stopped (not supported by Gemini CLI).
23    SubagentStop,
24    /// Before context compaction (not supported by Gemini CLI).
25    PreCompact,
26    /// Generic notification (not supported by Gemini CLI).
27    Notification,
28}
29
30impl HookEvent {
31    /// Returns true if this event is supported by Gemini CLI.
32    pub fn is_supported(&self) -> bool {
33        matches!(
34            self,
35            HookEvent::PreToolUse
36            | HookEvent::PostToolUse
37            | HookEvent::PostToolUseFailure
38            | HookEvent::UserPromptSubmit
39            | HookEvent::Stop
40        )
41    }
42}
43
44/// A registered hook — matches events by type and optional tool name pattern.
45#[derive(Clone)]
46pub struct HookMatcher {
47    pub event: HookEvent,
48    pub tool_name: Option<String>,
49    pub callback: HookCallback,
50    pub timeout: Option<Duration>,
51}
52
53/// Input provided to a hook callback.
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct HookInput {
56    pub event: HookEvent,
57    pub tool_name: Option<String>,
58    pub tool_input: Option<Value>,
59    pub tool_output: Option<Value>,
60    pub prompt: Option<String>,
61    pub session_id: String,
62    #[serde(flatten)]
63    pub extra: Value,
64}
65
66/// Context for hook execution.
67#[derive(Debug, Clone)]
68pub struct HookContext {
69    pub session_id: String,
70    pub cwd: String,
71}
72
73/// Output from a hook callback — can modify the execution flow.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct HookOutput {
76    pub decision: HookDecision,
77    #[serde(default)]
78    pub updated_input: Option<Value>,
79    #[serde(default)]
80    pub message: Option<String>,
81}
82
83/// Hook decision — continue, modify, or block.
84#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
85#[serde(rename_all = "snake_case")]
86pub enum HookDecision {
87    /// Continue with the operation (optionally with modified input).
88    Continue,
89    /// Block the operation.
90    Block,
91    /// Skip hook processing for this event.
92    Skip,
93}
94
95impl Default for HookOutput {
96    fn default() -> Self {
97        Self {
98            decision: HookDecision::Continue,
99            updated_input: None,
100            message: None,
101        }
102    }
103}
104
105/// Callback type for hooks.
106pub type HookCallback = Arc<
107    dyn Fn(HookInput, HookContext) -> Pin<Box<dyn Future<Output = HookOutput> + Send>>
108        + Send
109        + Sync,
110>;
111
112/// Execute matching hooks for an event, in registration order.
113pub(crate) async fn execute_hooks(
114    hooks: &[HookMatcher],
115    input: HookInput,
116    context: &HookContext,
117    default_timeout: Duration,
118) -> HookOutput {
119    for hook in hooks {
120        if hook.event != input.event {
121            continue;
122        }
123
124        // Check tool name filter
125        if let Some(pattern) = &hook.tool_name {
126            if let Some(tool_name) = &input.tool_name {
127                if !tool_name_matches(tool_name, pattern) {
128                    continue;
129                }
130            } else {
131                // Pattern set but no tool name — skip
132                continue;
133            }
134        }
135
136        let timeout = hook.timeout.unwrap_or(default_timeout);
137        let result = tokio::time::timeout(
138            timeout,
139            (hook.callback)(input.clone(), context.clone()),
140        )
141        .await;
142
143        match result {
144            Ok(output) => {
145                if output.decision != HookDecision::Skip {
146                    return output;
147                }
148            }
149            Err(_) => {
150                tracing::warn!("Hook timed out for event {:?}", input.event);
151            }
152        }
153    }
154
155    HookOutput::default()
156}
157
158/// Simple tool name matching — supports exact match and glob-like "*" suffix.
159#[allow(dead_code)]
160fn tool_name_matches(name: &str, pattern: &str) -> bool {
161    if pattern.ends_with('*') {
162        name.starts_with(pattern.strip_suffix('*').unwrap_or(pattern))
163    } else {
164        name == pattern
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_hook_event_is_supported() {
174        assert!(HookEvent::PreToolUse.is_supported());
175        assert!(HookEvent::PostToolUse.is_supported());
176        assert!(HookEvent::PostToolUseFailure.is_supported());
177        assert!(HookEvent::UserPromptSubmit.is_supported());
178        assert!(HookEvent::Stop.is_supported());
179        assert!(!HookEvent::SubagentStop.is_supported());
180        assert!(!HookEvent::PreCompact.is_supported());
181        assert!(!HookEvent::Notification.is_supported());
182    }
183
184    #[test]
185    fn test_tool_name_exact_match() {
186        assert!(tool_name_matches("EditFile", "EditFile"));
187        assert!(!tool_name_matches("EditFile", "ReadFile"));
188    }
189
190    #[test]
191    fn test_tool_name_glob_match() {
192        assert!(tool_name_matches("EditFile", "Edit*"));
193        assert!(tool_name_matches("EditBlock", "Edit*"));
194        assert!(!tool_name_matches("ReadFile", "Edit*"));
195    }
196
197    fn make_input(event: HookEvent) -> HookInput {
198        HookInput {
199            event,
200            tool_name: None,
201            tool_input: None,
202            tool_output: None,
203            prompt: None,
204            session_id: "test-session".to_string(),
205            extra: serde_json::Value::Null,
206        }
207    }
208
209    fn make_context() -> HookContext {
210        HookContext {
211            session_id: "test-session".to_string(),
212            cwd: "/tmp".to_string(),
213        }
214    }
215
216    #[tokio::test]
217    async fn test_execute_hooks_no_match() {
218        let hooks = vec![];
219        let input = make_input(HookEvent::PreToolUse);
220        let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
221        assert_eq!(output.decision, HookDecision::Continue);
222    }
223
224    #[tokio::test]
225    async fn test_execute_hooks_matching() {
226        let hooks = vec![HookMatcher {
227            event: HookEvent::PreToolUse,
228            tool_name: None,
229            callback: Arc::new(|_input, _ctx| {
230                Box::pin(async {
231                    HookOutput {
232                        decision: HookDecision::Block,
233                        updated_input: None,
234                        message: Some("blocked".to_string()),
235                    }
236                })
237            }),
238            timeout: None,
239        }];
240        let input = make_input(HookEvent::PreToolUse);
241        let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
242        assert_eq!(output.decision, HookDecision::Block);
243    }
244
245    #[tokio::test]
246    async fn test_execute_hooks_wrong_event() {
247        let hooks = vec![HookMatcher {
248            event: HookEvent::PostToolUse,
249            tool_name: None,
250            callback: Arc::new(|_input, _ctx| {
251                Box::pin(async {
252                    HookOutput {
253                        decision: HookDecision::Block,
254                        updated_input: None,
255                        message: None,
256                    }
257                })
258            }),
259            timeout: None,
260        }];
261        let input = make_input(HookEvent::PreToolUse);
262        let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
263        assert_eq!(output.decision, HookDecision::Continue);
264    }
265
266    #[tokio::test]
267    async fn test_execute_hooks_tool_name_filter() {
268        let hooks = vec![HookMatcher {
269            event: HookEvent::PreToolUse,
270            tool_name: Some("EditFile".to_string()),
271            callback: Arc::new(|_input, _ctx| {
272                Box::pin(async {
273                    HookOutput {
274                        decision: HookDecision::Block,
275                        updated_input: None,
276                        message: None,
277                    }
278                })
279            }),
280            timeout: None,
281        }];
282
283        // Matching tool name
284        let mut input = make_input(HookEvent::PreToolUse);
285        input.tool_name = Some("EditFile".to_string());
286        let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
287        assert_eq!(output.decision, HookDecision::Block);
288
289        // Non-matching tool name
290        let mut input2 = make_input(HookEvent::PreToolUse);
291        input2.tool_name = Some("ReadFile".to_string());
292        let output2 = execute_hooks(&hooks, input2, &make_context(), Duration::from_secs(5)).await;
293        assert_eq!(output2.decision, HookDecision::Continue);
294    }
295
296    #[tokio::test]
297    async fn test_execute_hooks_glob_filter() {
298        let hooks = vec![HookMatcher {
299            event: HookEvent::PreToolUse,
300            tool_name: Some("Edit*".to_string()),
301            callback: Arc::new(|_input, _ctx| {
302                Box::pin(async {
303                    HookOutput {
304                        decision: HookDecision::Block,
305                        updated_input: None,
306                        message: None,
307                    }
308                })
309            }),
310            timeout: None,
311        }];
312
313        let mut input = make_input(HookEvent::PreToolUse);
314        input.tool_name = Some("EditBlock".to_string());
315        let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
316        assert_eq!(output.decision, HookDecision::Block);
317    }
318
319    #[tokio::test]
320    async fn test_execute_hooks_timeout() {
321        let hooks = vec![HookMatcher {
322            event: HookEvent::PreToolUse,
323            tool_name: None,
324            callback: Arc::new(|_input, _ctx| {
325                Box::pin(async {
326                    tokio::time::sleep(Duration::from_secs(10)).await;
327                    HookOutput::default()
328                })
329            }),
330            timeout: Some(Duration::from_millis(10)),
331        }];
332        let input = make_input(HookEvent::PreToolUse);
333        let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
334        // Timed out hook skipped → default Continue
335        assert_eq!(output.decision, HookDecision::Continue);
336    }
337
338    #[tokio::test]
339    async fn test_execute_hooks_skip_advances() {
340        // First hook returns Skip, second returns Block
341        let hooks = vec![
342            HookMatcher {
343                event: HookEvent::PreToolUse,
344                tool_name: None,
345                callback: Arc::new(|_input, _ctx| {
346                    Box::pin(async {
347                        HookOutput {
348                            decision: HookDecision::Skip,
349                            updated_input: None,
350                            message: None,
351                        }
352                    })
353                }),
354                timeout: None,
355            },
356            HookMatcher {
357                event: HookEvent::PreToolUse,
358                tool_name: None,
359                callback: Arc::new(|_input, _ctx| {
360                    Box::pin(async {
361                        HookOutput {
362                            decision: HookDecision::Block,
363                            updated_input: None,
364                            message: None,
365                        }
366                    })
367                }),
368                timeout: None,
369            },
370        ];
371        let input = make_input(HookEvent::PreToolUse);
372        let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
373        assert_eq!(output.decision, HookDecision::Block);
374    }
375}