Skip to main content

ai_agent/
hooks.rs

1// Source: /data/home/swei/claudecode/openclaudecode/src/commands/hooks/hooks.tsx
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::process::Command;
5use tokio::time::{timeout, Duration};
6
7/// All supported hook events.
8pub const HOOK_EVENTS: &[&str] = &[
9    "PreToolUse",
10    "PostToolUse",
11    "PostToolUseFailure",
12    "Notification",
13    "UserPromptSubmit",
14    "SessionStart",
15    "SessionEnd",
16    "Stop",
17    "StopFailure",
18    "SubagentStart",
19    "SubagentStop",
20    "PreCompact",
21    "PostCompact",
22    "PermissionRequest",
23    "PermissionDenied",
24    "Setup",
25    "TeammateIdle",
26    "TaskCreated",
27    "TaskCompleted",
28    "Elicitation",
29    "ElicitationResult",
30    "ConfigChange",
31    "WorktreeCreate",
32    "WorktreeRemove",
33    "InstructionsLoaded",
34    "CwdChanged",
35    "FileChanged",
36];
37
38/// Reasons for session end.
39pub const EXIT_REASONS: &[&str] = &[
40    "clear",
41    "resume",
42    "logout",
43    "prompt_input_exit",
44    "other",
45    "bypass_permissions_disabled",
46];
47
48/// Reasons for loading instructions.
49pub const INSTRUCTIONS_LOAD_REASONS: &[&str] = &[
50    "session_start",
51    "nested_traversal",
52    "path_glob_match",
53    "include",
54    "compact",
55];
56
57/// Types of instructions memory.
58pub const INSTRUCTIONS_MEMORY_TYPES: &[&str] = &["User", "Project", "Local", "Managed"];
59
60/// Sources of config changes.
61pub const CONFIG_CHANGE_SOURCES: &[&str] = &[
62    "user_settings",
63    "project_settings",
64    "local_settings",
65    "policy_settings",
66    "skills",
67];
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub enum HookEvent {
72    PreToolUse,
73    PostToolUse,
74    PostToolUseFailure,
75    Notification,
76    UserPromptSubmit,
77    SessionStart,
78    SessionEnd,
79    Stop,
80    StopFailure,
81    SubagentStart,
82    SubagentStop,
83    PreCompact,
84    PostCompact,
85    PermissionRequest,
86    PermissionDenied,
87    Setup,
88    TeammateIdle,
89    TaskCreated,
90    TaskCompleted,
91    Elicitation,
92    ElicitationResult,
93    ConfigChange,
94    WorktreeCreate,
95    WorktreeRemove,
96    InstructionsLoaded,
97    CwdChanged,
98    FileChanged,
99}
100
101impl HookEvent {
102    pub fn as_str(&self) -> &'static str {
103        match self {
104            HookEvent::PreToolUse => "PreToolUse",
105            HookEvent::PostToolUse => "PostToolUse",
106            HookEvent::PostToolUseFailure => "PostToolUseFailure",
107            HookEvent::Notification => "Notification",
108            HookEvent::UserPromptSubmit => "UserPromptSubmit",
109            HookEvent::SessionStart => "SessionStart",
110            HookEvent::SessionEnd => "SessionEnd",
111            HookEvent::Stop => "Stop",
112            HookEvent::StopFailure => "StopFailure",
113            HookEvent::SubagentStart => "SubagentStart",
114            HookEvent::SubagentStop => "SubagentStop",
115            HookEvent::PreCompact => "PreCompact",
116            HookEvent::PostCompact => "PostCompact",
117            HookEvent::PermissionRequest => "PermissionRequest",
118            HookEvent::PermissionDenied => "PermissionDenied",
119            HookEvent::Setup => "Setup",
120            HookEvent::TeammateIdle => "TeammateIdle",
121            HookEvent::TaskCreated => "TaskCreated",
122            HookEvent::TaskCompleted => "TaskCompleted",
123            HookEvent::Elicitation => "Elicitation",
124            HookEvent::ElicitationResult => "ElicitationResult",
125            HookEvent::ConfigChange => "ConfigChange",
126            HookEvent::WorktreeCreate => "WorktreeCreate",
127            HookEvent::WorktreeRemove => "WorktreeRemove",
128            HookEvent::InstructionsLoaded => "InstructionsLoaded",
129            HookEvent::CwdChanged => "CwdChanged",
130            HookEvent::FileChanged => "FileChanged",
131        }
132    }
133
134    pub fn from_str(s: &str) -> Option<Self> {
135        match s {
136            "PreToolUse" => Some(HookEvent::PreToolUse),
137            "PostToolUse" => Some(HookEvent::PostToolUse),
138            "PostToolUseFailure" => Some(HookEvent::PostToolUseFailure),
139            "Notification" => Some(HookEvent::Notification),
140            "UserPromptSubmit" => Some(HookEvent::UserPromptSubmit),
141            "SessionStart" => Some(HookEvent::SessionStart),
142            "SessionEnd" => Some(HookEvent::SessionEnd),
143            "Stop" => Some(HookEvent::Stop),
144            "StopFailure" => Some(HookEvent::StopFailure),
145            "SubagentStart" => Some(HookEvent::SubagentStart),
146            "SubagentStop" => Some(HookEvent::SubagentStop),
147            "PreCompact" => Some(HookEvent::PreCompact),
148            "PostCompact" => Some(HookEvent::PostCompact),
149            "PermissionRequest" => Some(HookEvent::PermissionRequest),
150            "PermissionDenied" => Some(HookEvent::PermissionDenied),
151            "Setup" => Some(HookEvent::Setup),
152            "TeammateIdle" => Some(HookEvent::TeammateIdle),
153            "TaskCreated" => Some(HookEvent::TaskCreated),
154            "TaskCompleted" => Some(HookEvent::TaskCompleted),
155            "Elicitation" => Some(HookEvent::Elicitation),
156            "ElicitationResult" => Some(HookEvent::ElicitationResult),
157            "ConfigChange" => Some(HookEvent::ConfigChange),
158            "WorktreeCreate" => Some(HookEvent::WorktreeCreate),
159            "WorktreeRemove" => Some(HookEvent::WorktreeRemove),
160            "InstructionsLoaded" => Some(HookEvent::InstructionsLoaded),
161            "CwdChanged" => Some(HookEvent::CwdChanged),
162            "FileChanged" => Some(HookEvent::FileChanged),
163            _ => None,
164        }
165    }
166}
167
168/// Hook definition.
169#[derive(Debug, Clone)]
170pub struct HookDefinition {
171    /// Shell command to execute
172    pub command: Option<String>,
173    /// Function handler (stored as async fn pointer)
174    pub timeout: Option<u64>,
175    /// Tool name matcher (regex pattern)
176    pub matcher: Option<String>,
177}
178
179impl<'de> Deserialize<'de> for HookDefinition {
180    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
181    where
182        D: serde::Deserializer<'de>,
183    {
184        #[derive(Deserialize)]
185        #[serde(rename_all = "camelCase")]
186        struct HookDef {
187            command: Option<String>,
188            timeout: Option<u64>,
189            matcher: Option<String>,
190        }
191
192        let def = HookDef::deserialize(deserializer)?;
193        Ok(HookDefinition {
194            command: def.command,
195            timeout: def.timeout.or(Some(30000)),
196            matcher: def.matcher,
197        })
198    }
199}
200
201/// Hook input passed to handlers.
202#[derive(Debug, Clone, Serialize, Deserialize)]
203#[serde(rename_all = "camelCase")]
204pub struct HookInput {
205    pub event: String,
206    #[serde(skip_serializing_if = "Option::is_none")]
207    pub tool_name: Option<String>,
208    #[serde(skip_serializing_if = "Option::is_none")]
209    pub tool_input: Option<serde_json::Value>,
210    #[serde(skip_serializing_if = "Option::is_none")]
211    pub tool_output: Option<serde_json::Value>,
212    #[serde(skip_serializing_if = "Option::is_none")]
213    pub tool_use_id: Option<String>,
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub session_id: Option<String>,
216    #[serde(skip_serializing_if = "Option::is_none")]
217    pub cwd: Option<String>,
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub error: Option<String>,
220}
221
222impl HookInput {
223    pub fn new(event: &str) -> Self {
224        Self {
225            event: event.to_string(),
226            tool_name: None,
227            tool_input: None,
228            tool_output: None,
229            tool_use_id: None,
230            session_id: None,
231            cwd: None,
232            error: None,
233        }
234    }
235}
236
237/// Hook output returned by handlers.
238#[derive(Debug, Clone, Serialize, Deserialize)]
239#[serde(rename_all = "camelCase")]
240pub struct HookOutput {
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub message: Option<String>,
243    #[serde(skip_serializing_if = "Option::is_none")]
244    pub permission_update: Option<PermissionUpdate>,
245    #[serde(skip_serializing_if = "Option::is_none")]
246    pub block: Option<bool>,
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub notification: Option<Notification>,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
252#[serde(rename_all = "camelCase")]
253pub struct PermissionUpdate {
254    pub tool: String,
255    pub behavior: String,
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize)]
259#[serde(rename_all = "camelCase")]
260pub struct Notification {
261    pub title: String,
262    pub body: String,
263    #[serde(skip_serializing_if = "Option::is_none")]
264    pub level: Option<String>,
265}
266
267/// Hook configuration (from settings).
268pub type HookConfig = HashMap<String, Vec<HookDefinition>>;
269
270/// Hook registry for managing and executing hooks.
271#[derive(Debug, Default, Clone)]
272pub struct HookRegistry {
273    hooks: HashMap<String, Vec<HookDefinition>>,
274}
275
276impl HookRegistry {
277    /// Create a new empty registry.
278    pub fn new() -> Self {
279        Self {
280            hooks: HashMap::new(),
281        }
282    }
283
284    /// Register hooks from configuration.
285    pub fn register_from_config(&mut self, config: HookConfig) {
286        for (event, definitions) in config {
287            if !HOOK_EVENTS.contains(&event.as_str()) {
288                continue;
289            }
290            let existing = self.hooks.entry(event).or_insert_with(Vec::new);
291            existing.extend(definitions);
292        }
293    }
294
295    /// Register a single hook.
296    pub fn register(&mut self, event: &str, definition: HookDefinition) {
297        if !HOOK_EVENTS.contains(&event) {
298            return;
299        }
300        let existing = self.hooks.entry(event.to_string()).or_insert_with(Vec::new);
301        existing.push(definition);
302    }
303
304    /// Execute hooks for an event.
305    pub async fn execute(&self, event: &str, mut input: HookInput) -> Vec<HookOutput> {
306        let definitions = match self.hooks.get(event) {
307            Some(d) => d,
308            None => return vec![],
309        };
310
311        input.event = event.to_string();
312        let mut results = Vec::new();
313
314        for def in definitions {
315            // Check matcher for tool-specific hooks
316            if let Some(matcher) = &def.matcher {
317                if let Some(tool_name) = &input.tool_name {
318                    if let Ok(re) = regex::Regex::new(matcher) {
319                        if !re.is_match(tool_name) {
320                            continue;
321                        }
322                    }
323                }
324            }
325
326            if let Some(command) = &def.command {
327                match execute_shell_hook(command, &input, def.timeout.unwrap_or(30000)).await {
328                    Ok(output) => {
329                        if let Some(o) = output {
330                            results.push(o);
331                        }
332                    }
333                    Err(e) => {
334                        eprintln!("[Hook] {} hook failed: {}", event, e);
335                    }
336                }
337            }
338            // Note: Function handlers would require storing function pointers,
339            // which is complex in Rust. Shell commands are the primary mechanism.
340        }
341
342        results
343    }
344
345    /// Check if any hooks are registered for an event.
346    pub fn has_hooks(&self, event: &str) -> bool {
347        self.hooks
348            .get(event)
349            .map(|h| !h.is_empty())
350            .unwrap_or(false)
351    }
352
353    /// Clear all hooks.
354    pub fn clear(&mut self) {
355        self.hooks.clear();
356    }
357}
358
359/// Execute a shell command as a hook.
360async fn execute_shell_hook(
361    command: &str,
362    input: &HookInput,
363    timeout_ms: u64,
364) -> Result<Option<HookOutput>, crate::error::AgentError> {
365    let input_json = serde_json::to_string(input).map_err(crate::error::AgentError::Json)?;
366
367    // Clone data needed in the blocking task
368    let cmd_str = command.to_string();
369    let event = input.event.clone();
370    let tool_name = input.tool_name.clone();
371    let session_id = input.session_id.clone();
372    let cwd = input.cwd.clone();
373
374    let result = timeout(
375        Duration::from_millis(timeout_ms),
376        tokio::task::spawn_blocking(move || {
377            let mut cmd = Command::new("bash");
378            cmd.args(["-c", &cmd_str])
379                .env("HOOK_EVENT", &event)
380                .env("HOOK_TOOL_NAME", tool_name.as_deref().unwrap_or(""))
381                .env("HOOK_SESSION_ID", session_id.as_deref().unwrap_or(""))
382                .env("HOOK_CWD", cwd.as_deref().unwrap_or(""))
383                .stdin(std::process::Stdio::piped())
384                .stdout(std::process::Stdio::piped())
385                .stderr(std::process::Stdio::piped());
386
387            let mut child = cmd.spawn()?;
388
389            use std::io::Write;
390            if let Some(mut stdin) = child.stdin.take() {
391                stdin.write_all(input_json.as_bytes())?;
392            }
393
394            let output = child.wait_with_output()?;
395
396            if !output.status.success() {
397                return Ok(None);
398            }
399
400            let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
401            if stdout.is_empty() {
402                return Ok(None);
403            }
404
405            // Try to parse as JSON
406            if let Ok(hook_output) = serde_json::from_str::<HookOutput>(&stdout) {
407                Ok(Some(hook_output))
408            } else {
409                // Non-JSON output treated as message
410                Ok(Some(HookOutput {
411                    message: Some(stdout),
412                    permission_update: None,
413                    block: None,
414                    notification: None,
415                }))
416            }
417        }),
418    )
419    .await;
420
421    match result {
422        Ok(Ok(r)) => r,
423        Ok(Err(e)) => {
424            let err = std::io::Error::new(std::io::ErrorKind::Other, e.to_string());
425            Err(crate::error::AgentError::Io(err))
426        }
427        Err(_) => {
428            let err = std::io::Error::new(std::io::ErrorKind::TimedOut, "Hook timeout");
429            Err(crate::error::AgentError::Io(err))
430        }
431    }
432}
433
434/// Create a default hook registry.
435pub fn create_hook_registry(config: Option<HookConfig>) -> HookRegistry {
436    let mut registry = HookRegistry::new();
437    if let Some(c) = config {
438        registry.register_from_config(c);
439    }
440    registry
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn test_hook_event_as_str() {
449        assert_eq!(HookEvent::PreToolUse.as_str(), "PreToolUse");
450        assert_eq!(HookEvent::PostToolUse.as_str(), "PostToolUse");
451        assert_eq!(HookEvent::SessionStart.as_str(), "SessionStart");
452    }
453
454    #[test]
455    fn test_hook_event_from_str() {
456        assert_eq!(
457            HookEvent::from_str("PreToolUse"),
458            Some(HookEvent::PreToolUse)
459        );
460        assert_eq!(HookEvent::from_str("Invalid"), None);
461    }
462
463    #[test]
464    fn test_hook_events_constant() {
465        assert!(HOOK_EVENTS.contains(&"PreToolUse"));
466        assert!(HOOK_EVENTS.contains(&"PostToolUse"));
467        assert!(HOOK_EVENTS.contains(&"SessionStart"));
468    }
469
470    #[test]
471    fn test_hook_registry_new() {
472        let registry = HookRegistry::new();
473        assert!(!registry.has_hooks("PreToolUse"));
474    }
475
476    #[test]
477    fn test_hook_registry_register() {
478        let mut registry = HookRegistry::new();
479        registry.register(
480            "PreToolUse",
481            HookDefinition {
482                command: Some("echo test".to_string()),
483                timeout: Some(5000),
484                matcher: Some("Read.*".to_string()),
485            },
486        );
487        assert!(registry.has_hooks("PreToolUse"));
488    }
489
490    #[test]
491    fn test_hook_registry_clear() {
492        let mut registry = HookRegistry::new();
493        registry.register(
494            "PreToolUse",
495            HookDefinition {
496                command: Some("echo test".to_string()),
497                timeout: None,
498                matcher: None,
499            },
500        );
501        registry.clear();
502        assert!(!registry.has_hooks("PreToolUse"));
503    }
504
505    #[test]
506    fn test_hook_input_new() {
507        let input = HookInput::new("PreToolUse");
508        assert_eq!(input.event, "PreToolUse");
509    }
510
511    #[test]
512    fn test_hook_output_serialization() {
513        let output = HookOutput {
514            message: Some("test message".to_string()),
515            permission_update: None,
516            block: Some(true),
517            notification: None,
518        };
519        let json = serde_json::to_string(&output).unwrap();
520        assert!(json.contains("test message"));
521    }
522
523    #[test]
524    fn test_create_hook_registry() {
525        let registry = create_hook_registry(None);
526        assert!(!registry.has_hooks("PreToolUse"));
527    }
528
529    #[tokio::test]
530    async fn test_execute_no_hooks() {
531        let registry = HookRegistry::new();
532        let input = HookInput::new("PreToolUse");
533        let results = registry.execute("PreToolUse", input).await;
534        assert!(results.is_empty());
535    }
536
537    #[tokio::test]
538    async fn test_execute_with_invalid_event() {
539        let registry = HookRegistry::new();
540        let input = HookInput::new("InvalidEvent");
541        let results = registry.execute("InvalidEvent", input).await;
542        assert!(results.is_empty());
543    }
544}