Skip to main content

a3s_code_core/hooks/
matcher.rs

1//! Hook Matchers
2//!
3//! Matchers filter which events trigger a hook based on patterns.
4
5use super::events::HookEvent;
6use serde::{Deserialize, Serialize};
7
8/// Hook matcher for filtering events
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
10pub struct HookMatcher {
11    /// Match specific tool name (exact match)
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub tool: Option<String>,
14
15    /// Match file path pattern (glob)
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub path_pattern: Option<String>,
18
19    /// Match command pattern (regex for Bash commands)
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub command_pattern: Option<String>,
22
23    /// Match session ID (exact match)
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub session_id: Option<String>,
26
27    /// Match skill name (supports glob patterns)
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub skill: Option<String>,
30}
31
32impl HookMatcher {
33    /// Create an empty matcher (matches all)
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    /// Create a matcher for a specific tool
39    pub fn tool(name: impl Into<String>) -> Self {
40        Self {
41            tool: Some(name.into()),
42            ..Default::default()
43        }
44    }
45
46    /// Create a matcher for a file path pattern
47    pub fn path(pattern: impl Into<String>) -> Self {
48        Self {
49            path_pattern: Some(pattern.into()),
50            ..Default::default()
51        }
52    }
53
54    /// Create a matcher for a command pattern
55    pub fn command(pattern: impl Into<String>) -> Self {
56        Self {
57            command_pattern: Some(pattern.into()),
58            ..Default::default()
59        }
60    }
61
62    /// Create a matcher for a specific session
63    pub fn session(id: impl Into<String>) -> Self {
64        Self {
65            session_id: Some(id.into()),
66            ..Default::default()
67        }
68    }
69
70    /// Create a matcher for a specific skill (supports glob patterns)
71    pub fn skill(name: impl Into<String>) -> Self {
72        Self {
73            skill: Some(name.into()),
74            ..Default::default()
75        }
76    }
77
78    /// Add tool filter
79    pub fn with_tool(mut self, name: impl Into<String>) -> Self {
80        self.tool = Some(name.into());
81        self
82    }
83
84    /// Add path pattern filter
85    pub fn with_path(mut self, pattern: impl Into<String>) -> Self {
86        self.path_pattern = Some(pattern.into());
87        self
88    }
89
90    /// Add command pattern filter
91    pub fn with_command(mut self, pattern: impl Into<String>) -> Self {
92        self.command_pattern = Some(pattern.into());
93        self
94    }
95
96    /// Add session filter
97    pub fn with_session(mut self, id: impl Into<String>) -> Self {
98        self.session_id = Some(id.into());
99        self
100    }
101
102    /// Add skill filter (supports glob patterns)
103    pub fn with_skill(mut self, name: impl Into<String>) -> Self {
104        self.skill = Some(name.into());
105        self
106    }
107
108    /// Check if an event matches this matcher
109    pub fn matches(&self, event: &HookEvent) -> bool {
110        // Check session ID
111        if let Some(ref session_id) = self.session_id {
112            if event.session_id() != session_id {
113                return false;
114            }
115        }
116
117        // Check tool name
118        if let Some(ref tool_pattern) = self.tool {
119            if let Some(tool_name) = event.tool_name() {
120                if tool_name != tool_pattern {
121                    return false;
122                }
123            } else {
124                // Event doesn't have a tool, but we're filtering by tool
125                return false;
126            }
127        }
128
129        // Check path pattern (in tool args)
130        if let Some(ref path_pattern) = self.path_pattern {
131            if !self.matches_path_pattern(event, path_pattern) {
132                return false;
133            }
134        }
135
136        // Check command pattern (in Bash args)
137        if let Some(ref command_pattern) = self.command_pattern {
138            if !self.matches_command_pattern(event, command_pattern) {
139                return false;
140            }
141        }
142
143        // Check skill name (supports glob patterns)
144        if let Some(ref skill_pattern) = self.skill {
145            if let Some(skill_name) = event.skill_name() {
146                if !self.glob_match(skill_pattern, skill_name) {
147                    return false;
148                }
149            } else {
150                // Event doesn't have a skill, but we're filtering by skill
151                return false;
152            }
153        }
154
155        true
156    }
157
158    /// Check if event matches a path pattern (glob)
159    fn matches_path_pattern(&self, event: &HookEvent, pattern: &str) -> bool {
160        let args = match event.tool_args() {
161            Some(args) => args,
162            None => return false,
163        };
164
165        // Look for common path fields
166        let path = args
167            .get("file_path")
168            .or_else(|| args.get("path"))
169            .and_then(|v| v.as_str());
170
171        match path {
172            Some(p) => self.glob_match(pattern, p),
173            None => false,
174        }
175    }
176
177    /// Check if event matches a command pattern (regex)
178    fn matches_command_pattern(&self, event: &HookEvent, pattern: &str) -> bool {
179        // Only applies to Bash tool
180        if event.tool_name() != Some("Bash") && event.tool_name() != Some("bash") {
181            return false;
182        }
183
184        let args = match event.tool_args() {
185            Some(args) => args,
186            None => return false,
187        };
188
189        let command = args.get("command").and_then(|v| v.as_str());
190
191        match command {
192            Some(cmd) => {
193                // Use regex matching
194                if let Ok(re) = regex::Regex::new(pattern) {
195                    re.is_match(cmd)
196                } else {
197                    // Fallback to contains if regex is invalid
198                    cmd.contains(pattern)
199                }
200            }
201            None => false,
202        }
203    }
204
205    /// Simple glob matching (supports * and **)
206    ///
207    /// Matching rules:
208    /// - `*` matches any characters in filename (excluding `/`)
209    /// - `**` matches any path (including `/`)
210    /// - `*.ext` matches any file ending with `.ext` (any depth)
211    /// - `dir/**/*.ext` matches `.ext` files at any depth under dir
212    fn glob_match(&self, pattern: &str, text: &str) -> bool {
213        // Normalize Windows backslashes to forward slashes for consistent matching
214        let text = text.replace('\\', "/");
215
216        // Special handling: if pattern starts with * and has no /,
217        // match file suffix. e.g., "*.rs" should match "src/main.rs"
218        if pattern.starts_with('*') && !pattern.contains('/') {
219            let suffix = &pattern[1..]; // Remove leading *
220            return text.ends_with(suffix);
221        }
222
223        // Convert glob to regex
224        let regex_pattern = pattern
225            .replace('.', r"\.")
226            .replace("**/", "__DOUBLE_STAR_SLASH__")
227            .replace("**", "__DOUBLE_STAR__")
228            .replace('*', "[^/]*")
229            .replace("__DOUBLE_STAR_SLASH__", "(?:.*/)?") // **/ matches zero or more directories
230            .replace("__DOUBLE_STAR__", ".*");
231
232        let regex_pattern = format!("^{}$", regex_pattern);
233
234        if let Ok(re) = regex::Regex::new(&regex_pattern) {
235            re.is_match(&text)
236        } else {
237            text == pattern
238        }
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::hooks::events::PreToolUseEvent;
246
247    fn make_pre_tool_event(session_id: &str, tool: &str, args: serde_json::Value) -> HookEvent {
248        HookEvent::PreToolUse(PreToolUseEvent {
249            session_id: session_id.to_string(),
250            tool: tool.to_string(),
251            args,
252            working_directory: "/workspace".to_string(),
253            recent_tools: vec![],
254        })
255    }
256
257    #[test]
258    fn test_empty_matcher_matches_all() {
259        let matcher = HookMatcher::new();
260        let event = make_pre_tool_event("s1", "Bash", serde_json::json!({}));
261        assert!(matcher.matches(&event));
262    }
263
264    #[test]
265    fn test_tool_matcher() {
266        let matcher = HookMatcher::tool("Bash");
267
268        let bash_event = make_pre_tool_event("s1", "Bash", serde_json::json!({}));
269        let read_event = make_pre_tool_event("s1", "Read", serde_json::json!({}));
270
271        assert!(matcher.matches(&bash_event));
272        assert!(!matcher.matches(&read_event));
273    }
274
275    #[test]
276    fn test_session_matcher() {
277        let matcher = HookMatcher::session("session-1");
278
279        let s1_event = make_pre_tool_event("session-1", "Bash", serde_json::json!({}));
280        let s2_event = make_pre_tool_event("session-2", "Bash", serde_json::json!({}));
281
282        assert!(matcher.matches(&s1_event));
283        assert!(!matcher.matches(&s2_event));
284    }
285
286    #[test]
287    fn test_path_pattern_matcher() {
288        let matcher = HookMatcher::path("*.rs");
289
290        let rs_event = make_pre_tool_event(
291            "s1",
292            "Write",
293            serde_json::json!({"file_path": "src/main.rs"}),
294        );
295        let py_event = make_pre_tool_event(
296            "s1",
297            "Write",
298            serde_json::json!({"file_path": "src/main.py"}),
299        );
300
301        assert!(matcher.matches(&rs_event));
302        assert!(!matcher.matches(&py_event));
303    }
304
305    #[test]
306    fn test_path_pattern_double_star() {
307        let matcher = HookMatcher::path("src/**/*.rs");
308
309        let nested_event = make_pre_tool_event(
310            "s1",
311            "Write",
312            serde_json::json!({"file_path": "src/deep/nested/file.rs"}),
313        );
314        let root_event = make_pre_tool_event(
315            "s1",
316            "Write",
317            serde_json::json!({"file_path": "src/file.rs"}),
318        );
319
320        assert!(matcher.matches(&nested_event));
321        assert!(matcher.matches(&root_event));
322    }
323
324    #[test]
325    fn test_command_pattern_matcher() {
326        let matcher = HookMatcher::command(r"rm\s+-rf");
327
328        let rm_event = make_pre_tool_event(
329            "s1",
330            "Bash",
331            serde_json::json!({"command": "rm -rf /tmp/test"}),
332        );
333        let echo_event =
334            make_pre_tool_event("s1", "Bash", serde_json::json!({"command": "echo hello"}));
335
336        assert!(matcher.matches(&rm_event));
337        assert!(!matcher.matches(&echo_event));
338    }
339
340    #[test]
341    fn test_combined_matchers() {
342        let matcher = HookMatcher::new().with_tool("Bash").with_command("rm");
343
344        let bash_rm =
345            make_pre_tool_event("s1", "Bash", serde_json::json!({"command": "rm file.txt"}));
346        let bash_echo =
347            make_pre_tool_event("s1", "Bash", serde_json::json!({"command": "echo hello"}));
348        let read_event = make_pre_tool_event("s1", "Read", serde_json::json!({"path": "file.txt"}));
349
350        assert!(matcher.matches(&bash_rm));
351        assert!(!matcher.matches(&bash_echo)); // Bash but no rm
352        assert!(!matcher.matches(&read_event)); // Not Bash
353    }
354
355    #[test]
356    fn test_command_pattern_not_bash() {
357        // Command pattern should only apply to Bash tool
358        let matcher = HookMatcher::command("echo");
359
360        let read_event = make_pre_tool_event("s1", "Read", serde_json::json!({"path": "echo.txt"}));
361
362        assert!(!matcher.matches(&read_event));
363    }
364
365    #[test]
366    fn test_builder_pattern() {
367        let matcher = HookMatcher::tool("Write")
368            .with_path("*.env")
369            .with_session("secure-session");
370
371        assert_eq!(matcher.tool, Some("Write".to_string()));
372        assert_eq!(matcher.path_pattern, Some("*.env".to_string()));
373        assert_eq!(matcher.session_id, Some("secure-session".to_string()));
374    }
375
376    #[test]
377    fn test_matcher_serialization() {
378        let matcher = HookMatcher::tool("Bash").with_command("rm.*");
379
380        let json = serde_json::to_string(&matcher).unwrap();
381        assert!(json.contains("Bash"));
382        assert!(json.contains("rm.*"));
383
384        let parsed: HookMatcher = serde_json::from_str(&json).unwrap();
385        assert_eq!(parsed.tool, Some("Bash".to_string()));
386        assert_eq!(parsed.command_pattern, Some("rm.*".to_string()));
387    }
388
389    #[test]
390    fn test_path_with_alternative_field() {
391        // Test that "path" field also works (not just "file_path")
392        let matcher = HookMatcher::path("*.txt");
393
394        let event = make_pre_tool_event("s1", "Read", serde_json::json!({"path": "readme.txt"}));
395
396        assert!(matcher.matches(&event));
397    }
398
399    fn make_skill_load_event(skill_name: &str) -> HookEvent {
400        HookEvent::SkillLoad(crate::hooks::events::SkillLoadEvent {
401            skill_name: skill_name.to_string(),
402            tool_names: vec!["tool1".to_string()],
403            version: None,
404            description: None,
405            loaded_at: 0,
406        })
407    }
408
409    fn make_skill_unload_event(skill_name: &str) -> HookEvent {
410        HookEvent::SkillUnload(crate::hooks::events::SkillUnloadEvent {
411            skill_name: skill_name.to_string(),
412            tool_names: vec!["tool1".to_string()],
413            duration_ms: 1000,
414        })
415    }
416
417    #[test]
418    fn test_skill_matcher() {
419        let matcher = HookMatcher::skill("my-skill");
420
421        let matching_event = make_skill_load_event("my-skill");
422        let non_matching_event = make_skill_load_event("other-skill");
423
424        assert!(matcher.matches(&matching_event));
425        assert!(!matcher.matches(&non_matching_event));
426    }
427
428    #[test]
429    fn test_skill_matcher_pattern() {
430        // Test glob pattern matching for skill names
431        let matcher = HookMatcher::skill("test-*");
432
433        let test_skill = make_skill_load_event("test-skill");
434        let test_other = make_skill_load_event("test-other");
435        let no_match = make_skill_load_event("other-skill");
436
437        assert!(matcher.matches(&test_skill));
438        assert!(matcher.matches(&test_other));
439        assert!(!matcher.matches(&no_match));
440    }
441
442    #[test]
443    fn test_skill_matcher_unload_event() {
444        let matcher = HookMatcher::skill("my-skill");
445
446        let unload_event = make_skill_unload_event("my-skill");
447        assert!(matcher.matches(&unload_event));
448
449        let other_unload = make_skill_unload_event("other-skill");
450        assert!(!matcher.matches(&other_unload));
451    }
452
453    #[test]
454    fn test_skill_matcher_non_skill_event() {
455        // Skill matcher should not match non-skill events
456        let matcher = HookMatcher::skill("my-skill");
457
458        let tool_event = make_pre_tool_event("s1", "Bash", serde_json::json!({}));
459        assert!(!matcher.matches(&tool_event));
460    }
461
462    #[test]
463    fn test_skill_matcher_with_builder() {
464        let matcher = HookMatcher::new().with_skill("test-*");
465
466        assert_eq!(matcher.skill, Some("test-*".to_string()));
467
468        let event = make_skill_load_event("test-skill");
469        assert!(matcher.matches(&event));
470    }
471
472    #[test]
473    fn test_skill_matcher_serialization() {
474        let matcher = HookMatcher::skill("my-skill");
475
476        let json = serde_json::to_string(&matcher).unwrap();
477        assert!(json.contains("my-skill"));
478        assert!(json.contains("skill"));
479
480        let parsed: HookMatcher = serde_json::from_str(&json).unwrap();
481        assert_eq!(parsed.skill, Some("my-skill".to_string()));
482    }
483}