Skip to main content

a3s_code_core/hooks/
matcher.rs

1//! Hook Matchers
2//!
3//! Matchers filter which events trigger a hook based on patterns.
4
5use super::events::HookEvent;
6use serde::{Deserialize, Serialize};
7
8/// Hook matcher for filtering events
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
10pub struct HookMatcher {
11    /// Match specific tool name (exact match)
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub tool: Option<String>,
14
15    /// Match file path pattern (glob)
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub path_pattern: Option<String>,
18
19    /// Match command pattern (regex for Bash commands)
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub command_pattern: Option<String>,
22
23    /// Match session ID (exact match)
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub session_id: Option<String>,
26
27    /// Match skill name (supports glob patterns)
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub skill: Option<String>,
30}
31
32impl HookMatcher {
33    /// Create an empty matcher (matches all)
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    /// Create a matcher for a specific tool
39    pub fn tool(name: impl Into<String>) -> Self {
40        Self {
41            tool: Some(name.into()),
42            ..Default::default()
43        }
44    }
45
46    /// Create a matcher for a file path pattern
47    pub fn path(pattern: impl Into<String>) -> Self {
48        Self {
49            path_pattern: Some(pattern.into()),
50            ..Default::default()
51        }
52    }
53
54    /// Create a matcher for a command pattern
55    pub fn command(pattern: impl Into<String>) -> Self {
56        Self {
57            command_pattern: Some(pattern.into()),
58            ..Default::default()
59        }
60    }
61
62    /// Create a matcher for a specific session
63    pub fn session(id: impl Into<String>) -> Self {
64        Self {
65            session_id: Some(id.into()),
66            ..Default::default()
67        }
68    }
69
70    /// Create a matcher for a specific skill (supports glob patterns)
71    pub fn skill(name: impl Into<String>) -> Self {
72        Self {
73            skill: Some(name.into()),
74            ..Default::default()
75        }
76    }
77
78    /// Add tool filter
79    pub fn with_tool(mut self, name: impl Into<String>) -> Self {
80        self.tool = Some(name.into());
81        self
82    }
83
84    /// Add path pattern filter
85    pub fn with_path(mut self, pattern: impl Into<String>) -> Self {
86        self.path_pattern = Some(pattern.into());
87        self
88    }
89
90    /// Add command pattern filter
91    pub fn with_command(mut self, pattern: impl Into<String>) -> Self {
92        self.command_pattern = Some(pattern.into());
93        self
94    }
95
96    /// Add session filter
97    pub fn with_session(mut self, id: impl Into<String>) -> Self {
98        self.session_id = Some(id.into());
99        self
100    }
101
102    /// Add skill filter (supports glob patterns)
103    pub fn with_skill(mut self, name: impl Into<String>) -> Self {
104        self.skill = Some(name.into());
105        self
106    }
107
108    /// Check if an event matches this matcher
109    pub fn matches(&self, event: &HookEvent) -> bool {
110        // Check session ID
111        if let Some(ref session_id) = self.session_id {
112            if event.session_id() != session_id {
113                return false;
114            }
115        }
116
117        // Check tool name
118        if let Some(ref tool_pattern) = self.tool {
119            if let Some(tool_name) = event.tool_name() {
120                if tool_name != tool_pattern {
121                    return false;
122                }
123            } else {
124                // Event doesn't have a tool, but we're filtering by tool
125                return false;
126            }
127        }
128
129        // Check path pattern (in tool args)
130        if let Some(ref path_pattern) = self.path_pattern {
131            if !self.matches_path_pattern(event, path_pattern) {
132                return false;
133            }
134        }
135
136        // Check command pattern (in Bash args)
137        if let Some(ref command_pattern) = self.command_pattern {
138            if !self.matches_command_pattern(event, command_pattern) {
139                return false;
140            }
141        }
142
143        // Check skill name (supports glob patterns)
144        if let Some(ref skill_pattern) = self.skill {
145            if let Some(skill_name) = event.skill_name() {
146                if !self.glob_match(skill_pattern, skill_name) {
147                    return false;
148                }
149            } else {
150                // Event doesn't have a skill, but we're filtering by skill
151                return false;
152            }
153        }
154
155        true
156    }
157
158    /// Check if event matches a path pattern (glob)
159    fn matches_path_pattern(&self, event: &HookEvent, pattern: &str) -> bool {
160        let args = match event.tool_args() {
161            Some(args) => args,
162            None => return false,
163        };
164
165        // Look for common path fields
166        let path = args
167            .get("file_path")
168            .or_else(|| args.get("path"))
169            .and_then(|v| v.as_str());
170
171        match path {
172            Some(p) => self.glob_match(pattern, p),
173            None => false,
174        }
175    }
176
177    /// Check if event matches a command pattern (regex)
178    fn matches_command_pattern(&self, event: &HookEvent, pattern: &str) -> bool {
179        // Only applies to Bash tool
180        if event.tool_name() != Some("Bash") && event.tool_name() != Some("bash") {
181            return false;
182        }
183
184        let args = match event.tool_args() {
185            Some(args) => args,
186            None => return false,
187        };
188
189        let command = args.get("command").and_then(|v| v.as_str());
190
191        match command {
192            Some(cmd) => {
193                // Use regex matching
194                if let Ok(re) = regex::Regex::new(pattern) {
195                    re.is_match(cmd)
196                } else {
197                    // Fallback to contains if regex is invalid
198                    cmd.contains(pattern)
199                }
200            }
201            None => false,
202        }
203    }
204
205    /// Simple glob matching (supports * and **)
206    ///
207    /// Matching rules:
208    /// - `*` matches any characters in filename (excluding `/`)
209    /// - `**` matches any path (including `/`)
210    /// - `*.ext` matches any file ending with `.ext` (any depth)
211    /// - `dir/**/*.ext` matches `.ext` files at any depth under dir
212    fn glob_match(&self, pattern: &str, text: &str) -> bool {
213        // Special handling: if pattern starts with * and has no /,
214        // match file suffix. e.g., "*.rs" should match "src/main.rs"
215        if pattern.starts_with('*') && !pattern.contains('/') {
216            let suffix = &pattern[1..]; // Remove leading *
217            return text.ends_with(suffix);
218        }
219
220        // Convert glob to regex
221        let regex_pattern = pattern
222            .replace('.', r"\.")
223            .replace("**/", "__DOUBLE_STAR_SLASH__")
224            .replace("**", "__DOUBLE_STAR__")
225            .replace('*', "[^/]*")
226            .replace("__DOUBLE_STAR_SLASH__", "(?:.*/)?") // **/ matches zero or more directories
227            .replace("__DOUBLE_STAR__", ".*");
228
229        let regex_pattern = format!("^{}$", regex_pattern);
230
231        if let Ok(re) = regex::Regex::new(&regex_pattern) {
232            re.is_match(text)
233        } else {
234            text == pattern
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::hooks::events::PreToolUseEvent;
243
244    fn make_pre_tool_event(session_id: &str, tool: &str, args: serde_json::Value) -> HookEvent {
245        HookEvent::PreToolUse(PreToolUseEvent {
246            session_id: session_id.to_string(),
247            tool: tool.to_string(),
248            args,
249            working_directory: "/workspace".to_string(),
250            recent_tools: vec![],
251        })
252    }
253
254    #[test]
255    fn test_empty_matcher_matches_all() {
256        let matcher = HookMatcher::new();
257        let event = make_pre_tool_event("s1", "Bash", serde_json::json!({}));
258        assert!(matcher.matches(&event));
259    }
260
261    #[test]
262    fn test_tool_matcher() {
263        let matcher = HookMatcher::tool("Bash");
264
265        let bash_event = make_pre_tool_event("s1", "Bash", serde_json::json!({}));
266        let read_event = make_pre_tool_event("s1", "Read", serde_json::json!({}));
267
268        assert!(matcher.matches(&bash_event));
269        assert!(!matcher.matches(&read_event));
270    }
271
272    #[test]
273    fn test_session_matcher() {
274        let matcher = HookMatcher::session("session-1");
275
276        let s1_event = make_pre_tool_event("session-1", "Bash", serde_json::json!({}));
277        let s2_event = make_pre_tool_event("session-2", "Bash", serde_json::json!({}));
278
279        assert!(matcher.matches(&s1_event));
280        assert!(!matcher.matches(&s2_event));
281    }
282
283    #[test]
284    fn test_path_pattern_matcher() {
285        let matcher = HookMatcher::path("*.rs");
286
287        let rs_event = make_pre_tool_event(
288            "s1",
289            "Write",
290            serde_json::json!({"file_path": "src/main.rs"}),
291        );
292        let py_event = make_pre_tool_event(
293            "s1",
294            "Write",
295            serde_json::json!({"file_path": "src/main.py"}),
296        );
297
298        assert!(matcher.matches(&rs_event));
299        assert!(!matcher.matches(&py_event));
300    }
301
302    #[test]
303    fn test_path_pattern_double_star() {
304        let matcher = HookMatcher::path("src/**/*.rs");
305
306        let nested_event = make_pre_tool_event(
307            "s1",
308            "Write",
309            serde_json::json!({"file_path": "src/deep/nested/file.rs"}),
310        );
311        let root_event = make_pre_tool_event(
312            "s1",
313            "Write",
314            serde_json::json!({"file_path": "src/file.rs"}),
315        );
316
317        assert!(matcher.matches(&nested_event));
318        assert!(matcher.matches(&root_event));
319    }
320
321    #[test]
322    fn test_command_pattern_matcher() {
323        let matcher = HookMatcher::command(r"rm\s+-rf");
324
325        let rm_event = make_pre_tool_event(
326            "s1",
327            "Bash",
328            serde_json::json!({"command": "rm -rf /tmp/test"}),
329        );
330        let echo_event =
331            make_pre_tool_event("s1", "Bash", serde_json::json!({"command": "echo hello"}));
332
333        assert!(matcher.matches(&rm_event));
334        assert!(!matcher.matches(&echo_event));
335    }
336
337    #[test]
338    fn test_combined_matchers() {
339        let matcher = HookMatcher::new().with_tool("Bash").with_command("rm");
340
341        let bash_rm =
342            make_pre_tool_event("s1", "Bash", serde_json::json!({"command": "rm file.txt"}));
343        let bash_echo =
344            make_pre_tool_event("s1", "Bash", serde_json::json!({"command": "echo hello"}));
345        let read_event = make_pre_tool_event("s1", "Read", serde_json::json!({"path": "file.txt"}));
346
347        assert!(matcher.matches(&bash_rm));
348        assert!(!matcher.matches(&bash_echo)); // Bash but no rm
349        assert!(!matcher.matches(&read_event)); // Not Bash
350    }
351
352    #[test]
353    fn test_command_pattern_not_bash() {
354        // Command pattern should only apply to Bash tool
355        let matcher = HookMatcher::command("echo");
356
357        let read_event = make_pre_tool_event("s1", "Read", serde_json::json!({"path": "echo.txt"}));
358
359        assert!(!matcher.matches(&read_event));
360    }
361
362    #[test]
363    fn test_builder_pattern() {
364        let matcher = HookMatcher::tool("Write")
365            .with_path("*.env")
366            .with_session("secure-session");
367
368        assert_eq!(matcher.tool, Some("Write".to_string()));
369        assert_eq!(matcher.path_pattern, Some("*.env".to_string()));
370        assert_eq!(matcher.session_id, Some("secure-session".to_string()));
371    }
372
373    #[test]
374    fn test_matcher_serialization() {
375        let matcher = HookMatcher::tool("Bash").with_command("rm.*");
376
377        let json = serde_json::to_string(&matcher).unwrap();
378        assert!(json.contains("Bash"));
379        assert!(json.contains("rm.*"));
380
381        let parsed: HookMatcher = serde_json::from_str(&json).unwrap();
382        assert_eq!(parsed.tool, Some("Bash".to_string()));
383        assert_eq!(parsed.command_pattern, Some("rm.*".to_string()));
384    }
385
386    #[test]
387    fn test_path_with_alternative_field() {
388        // Test that "path" field also works (not just "file_path")
389        let matcher = HookMatcher::path("*.txt");
390
391        let event = make_pre_tool_event("s1", "Read", serde_json::json!({"path": "readme.txt"}));
392
393        assert!(matcher.matches(&event));
394    }
395
396    fn make_skill_load_event(skill_name: &str) -> HookEvent {
397        HookEvent::SkillLoad(crate::hooks::events::SkillLoadEvent {
398            skill_name: skill_name.to_string(),
399            tool_names: vec!["tool1".to_string()],
400            version: None,
401            description: None,
402            loaded_at: 0,
403        })
404    }
405
406    fn make_skill_unload_event(skill_name: &str) -> HookEvent {
407        HookEvent::SkillUnload(crate::hooks::events::SkillUnloadEvent {
408            skill_name: skill_name.to_string(),
409            tool_names: vec!["tool1".to_string()],
410            duration_ms: 1000,
411        })
412    }
413
414    #[test]
415    fn test_skill_matcher() {
416        let matcher = HookMatcher::skill("my-skill");
417
418        let matching_event = make_skill_load_event("my-skill");
419        let non_matching_event = make_skill_load_event("other-skill");
420
421        assert!(matcher.matches(&matching_event));
422        assert!(!matcher.matches(&non_matching_event));
423    }
424
425    #[test]
426    fn test_skill_matcher_pattern() {
427        // Test glob pattern matching for skill names
428        let matcher = HookMatcher::skill("test-*");
429
430        let test_skill = make_skill_load_event("test-skill");
431        let test_other = make_skill_load_event("test-other");
432        let no_match = make_skill_load_event("other-skill");
433
434        assert!(matcher.matches(&test_skill));
435        assert!(matcher.matches(&test_other));
436        assert!(!matcher.matches(&no_match));
437    }
438
439    #[test]
440    fn test_skill_matcher_unload_event() {
441        let matcher = HookMatcher::skill("my-skill");
442
443        let unload_event = make_skill_unload_event("my-skill");
444        assert!(matcher.matches(&unload_event));
445
446        let other_unload = make_skill_unload_event("other-skill");
447        assert!(!matcher.matches(&other_unload));
448    }
449
450    #[test]
451    fn test_skill_matcher_non_skill_event() {
452        // Skill matcher should not match non-skill events
453        let matcher = HookMatcher::skill("my-skill");
454
455        let tool_event = make_pre_tool_event("s1", "Bash", serde_json::json!({}));
456        assert!(!matcher.matches(&tool_event));
457    }
458
459    #[test]
460    fn test_skill_matcher_with_builder() {
461        let matcher = HookMatcher::new().with_skill("test-*");
462
463        assert_eq!(matcher.skill, Some("test-*".to_string()));
464
465        let event = make_skill_load_event("test-skill");
466        assert!(matcher.matches(&event));
467    }
468
469    #[test]
470    fn test_skill_matcher_serialization() {
471        let matcher = HookMatcher::skill("my-skill");
472
473        let json = serde_json::to_string(&matcher).unwrap();
474        assert!(json.contains("my-skill"));
475        assert!(json.contains("skill"));
476
477        let parsed: HookMatcher = serde_json::from_str(&json).unwrap();
478        assert_eq!(parsed.skill, Some("my-skill".to_string()));
479    }
480}