Skip to main content

awaken_tool_pattern/
parser.rs

1//! Pattern string parser for [`ToolCallPattern`].
2//!
3//! Syntax overview:
4//! ```text
5//! Bash                            exact tool, any args
6//! Bash(*)                         explicit any args
7//! Bash(npm *)                     primary arg glob
8//! Edit(file_path ~ "src/**")      named field glob
9//! Bash(command =~ "(?i)rm")       named field regex
10//! mcp__github__*                  glob tool name
11//! /mcp__(gh|gl)__.*/              regex tool name
12//! Tool(a.b[*].c ~ "pat")         nested field path
13//! Tool(f1 ~ "a", f2 = "b")       multi-field AND
14//! ```
15
16use std::fmt;
17
18use crate::types::{
19    ArgMatcher, FieldCondition, MatchOp, PathSegment, ToolCallPattern, ToolMatcher,
20};
21
22/// Error returned when a pattern string cannot be parsed.
23#[derive(Debug, Clone)]
24pub struct PatternParseError {
25    pub message: String,
26    pub position: usize,
27}
28
29impl fmt::Display for PatternParseError {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        write!(f, "parse error at {}: {}", self.position, self.message)
32    }
33}
34
35impl std::error::Error for PatternParseError {}
36
37struct Cursor<'a> {
38    input: &'a str,
39    pos: usize,
40}
41
42impl<'a> Cursor<'a> {
43    fn new(input: &'a str) -> Self {
44        Self { input, pos: 0 }
45    }
46
47    fn remaining(&self) -> &'a str {
48        &self.input[self.pos..]
49    }
50
51    fn is_empty(&self) -> bool {
52        self.pos >= self.input.len()
53    }
54
55    fn peek(&self) -> Option<char> {
56        self.remaining().chars().next()
57    }
58
59    fn advance(&mut self, n: usize) {
60        self.pos += n;
61    }
62
63    fn skip_whitespace(&mut self) {
64        while let Some(c) = self.peek() {
65            if c.is_ascii_whitespace() {
66                self.advance(c.len_utf8());
67            } else {
68                break;
69            }
70        }
71    }
72
73    fn expect(&mut self, ch: char) -> Result<(), PatternParseError> {
74        self.skip_whitespace();
75        match self.peek() {
76            Some(c) if c == ch => {
77                self.advance(c.len_utf8());
78                Ok(())
79            }
80            other => Err(self.error(format!(
81                "expected '{}', found {}",
82                ch,
83                match other {
84                    Some(c) => format!("'{c}'"),
85                    None => "end of input".to_string(),
86                }
87            ))),
88        }
89    }
90
91    fn error(&self, message: impl Into<String>) -> PatternParseError {
92        PatternParseError {
93            message: message.into(),
94            position: self.pos,
95        }
96    }
97}
98
99/// Parse a pattern string into a [`ToolCallPattern`].
100pub fn parse_pattern(input: &str) -> Result<ToolCallPattern, PatternParseError> {
101    let mut cursor = Cursor::new(input.trim());
102
103    let tool = parse_tool_part(&mut cursor)?;
104    cursor.skip_whitespace();
105
106    let args = if cursor.peek() == Some('(') {
107        cursor.advance(1);
108        let args = parse_arg_part(&mut cursor)?;
109        cursor.expect(')')?;
110        args
111    } else {
112        ArgMatcher::Any
113    };
114
115    cursor.skip_whitespace();
116    if !cursor.is_empty() {
117        return Err(cursor.error(format!("unexpected trailing: '{}'", cursor.remaining())));
118    }
119
120    Ok(ToolCallPattern { tool, args })
121}
122
123fn parse_tool_part(cursor: &mut Cursor<'_>) -> Result<ToolMatcher, PatternParseError> {
124    cursor.skip_whitespace();
125    if cursor.peek() == Some('/') {
126        cursor.advance(1);
127        let start = cursor.pos;
128        let mut depth = 0u32;
129        while let Some(c) = cursor.peek() {
130            match c {
131                '\\' => {
132                    cursor.advance(1);
133                    if cursor.peek().is_some() {
134                        cursor.advance(1);
135                    }
136                }
137                '(' => {
138                    depth += 1;
139                    cursor.advance(1);
140                }
141                ')' => {
142                    depth = depth.saturating_sub(1);
143                    cursor.advance(1);
144                }
145                '/' if depth == 0 => break,
146                _ => cursor.advance(c.len_utf8()),
147            }
148        }
149        let body = &cursor.input[start..cursor.pos];
150        if body.is_empty() {
151            return Err(cursor.error("empty regex pattern"));
152        }
153        cursor.expect('/')?;
154        let re =
155            regex::Regex::new(body).map_err(|e| cursor.error(format!("invalid regex: {e}")))?;
156        Ok(ToolMatcher::Regex(re))
157    } else {
158        let start = cursor.pos;
159        while let Some(c) = cursor.peek() {
160            if c == '(' || c.is_ascii_whitespace() {
161                break;
162            }
163            cursor.advance(c.len_utf8());
164        }
165        let name = &cursor.input[start..cursor.pos];
166        if name.is_empty() {
167            return Err(cursor.error("empty tool name"));
168        }
169        if has_glob_chars(name) {
170            Ok(ToolMatcher::Glob(name.to_string()))
171        } else {
172            Ok(ToolMatcher::Exact(name.to_string()))
173        }
174    }
175}
176
177fn has_glob_chars(s: &str) -> bool {
178    s.contains('*') || s.contains('?') || s.contains('[')
179}
180
181fn parse_arg_part(cursor: &mut Cursor<'_>) -> Result<ArgMatcher, PatternParseError> {
182    cursor.skip_whitespace();
183
184    if cursor.peek() == Some('*') {
185        let after = cursor.remaining().get(1..2);
186        if after.is_none_or(|s| {
187            let c = s.chars().next().unwrap_or(')');
188            c == ')' || c.is_ascii_whitespace()
189        }) {
190            cursor.advance(1);
191            cursor.skip_whitespace();
192            return Ok(ArgMatcher::Any);
193        }
194    }
195
196    if looks_like_field_conditions(cursor.remaining()) {
197        parse_field_conditions(cursor)
198    } else {
199        parse_primary_value(cursor)
200    }
201}
202
203fn looks_like_field_conditions(s: &str) -> bool {
204    let s = s.trim();
205    let bytes = s.as_bytes();
206    let mut i = 0;
207    while i < bytes.len() {
208        let c = bytes[i] as char;
209        if c.is_ascii_alphanumeric() || c == '_' || c == '.' {
210            i += 1;
211        } else if c == '[' {
212            i += 1;
213            while i < bytes.len() && bytes[i] != b']' {
214                i += 1;
215            }
216            if i < bytes.len() {
217                i += 1;
218            }
219        } else if c == '*' {
220            i += 1;
221            if i < bytes.len() && (bytes[i] == b'.' || bytes[i] == b'[') {
222                continue;
223            }
224            break;
225        } else {
226            break;
227        }
228    }
229    while i < bytes.len() && (bytes[i] as char).is_ascii_whitespace() {
230        i += 1;
231    }
232    let remaining = &s[i..];
233    remaining.starts_with("~")
234        || remaining.starts_with("=")
235        || remaining.starts_with("!~")
236        || remaining.starts_with("!=")
237}
238
239fn parse_field_conditions(cursor: &mut Cursor<'_>) -> Result<ArgMatcher, PatternParseError> {
240    let mut conditions = Vec::new();
241    loop {
242        cursor.skip_whitespace();
243        conditions.push(parse_single_field_condition(cursor)?);
244        cursor.skip_whitespace();
245        if cursor.peek() == Some(',') {
246            cursor.advance(1);
247        } else {
248            break;
249        }
250    }
251    Ok(ArgMatcher::Fields(conditions))
252}
253
254fn parse_single_field_condition(
255    cursor: &mut Cursor<'_>,
256) -> Result<FieldCondition, PatternParseError> {
257    cursor.skip_whitespace();
258    let path = parse_field_path(cursor)?;
259    cursor.skip_whitespace();
260    let op = parse_match_op(cursor)?;
261    cursor.skip_whitespace();
262    let value = parse_quoted_value(cursor)?;
263    Ok(FieldCondition { path, op, value })
264}
265
266fn parse_field_path(cursor: &mut Cursor<'_>) -> Result<Vec<PathSegment>, PatternParseError> {
267    let mut segments = Vec::new();
268    loop {
269        cursor.skip_whitespace();
270        if cursor.peek() == Some('*') {
271            cursor.advance(1);
272            segments.push(PathSegment::Wildcard);
273        } else {
274            let ident = parse_identifier(cursor)?;
275            segments.push(PathSegment::Field(ident));
276        }
277
278        while cursor.peek() == Some('[') {
279            cursor.advance(1);
280            cursor.skip_whitespace();
281            if cursor.peek() == Some('*') {
282                cursor.advance(1);
283                segments.push(PathSegment::AnyIndex);
284            } else {
285                let idx = parse_usize(cursor)?;
286                segments.push(PathSegment::Index(idx));
287            }
288            cursor.expect(']')?;
289        }
290
291        if cursor.peek() == Some('.') {
292            cursor.advance(1);
293        } else {
294            break;
295        }
296    }
297    Ok(segments)
298}
299
300fn parse_identifier(cursor: &mut Cursor<'_>) -> Result<String, PatternParseError> {
301    let start = cursor.pos;
302    while let Some(c) = cursor.peek() {
303        if c.is_ascii_alphanumeric() || c == '_' || c == '-' {
304            cursor.advance(1);
305        } else {
306            break;
307        }
308    }
309    let ident = &cursor.input[start..cursor.pos];
310    if ident.is_empty() {
311        return Err(cursor.error("expected identifier"));
312    }
313    Ok(ident.to_string())
314}
315
316fn parse_usize(cursor: &mut Cursor<'_>) -> Result<usize, PatternParseError> {
317    let start = cursor.pos;
318    while let Some(c) = cursor.peek() {
319        if c.is_ascii_digit() {
320            cursor.advance(1);
321        } else {
322            break;
323        }
324    }
325    let digits = &cursor.input[start..cursor.pos];
326    digits
327        .parse::<usize>()
328        .map_err(|_| cursor.error(format!("invalid index: '{digits}'")))
329}
330
331fn parse_match_op(cursor: &mut Cursor<'_>) -> Result<MatchOp, PatternParseError> {
332    let remaining = cursor.remaining();
333    if remaining.starts_with("!=~") {
334        cursor.advance(3);
335        Ok(MatchOp::NotRegex)
336    } else if remaining.starts_with("!=") {
337        cursor.advance(2);
338        Ok(MatchOp::NotExact)
339    } else if remaining.starts_with("!~") {
340        cursor.advance(2);
341        Ok(MatchOp::NotGlob)
342    } else if remaining.starts_with("=~") {
343        cursor.advance(2);
344        Ok(MatchOp::Regex)
345    } else if remaining.starts_with('~') {
346        cursor.advance(1);
347        Ok(MatchOp::Glob)
348    } else if remaining.starts_with('=') {
349        cursor.advance(1);
350        Ok(MatchOp::Exact)
351    } else {
352        Err(cursor.error("expected operator: ~, =, =~, !~, !=, or !=~"))
353    }
354}
355
356fn parse_quoted_value(cursor: &mut Cursor<'_>) -> Result<String, PatternParseError> {
357    cursor.skip_whitespace();
358    if cursor.peek() != Some('"') {
359        return Err(cursor.error("expected '\"' to start value"));
360    }
361    cursor.advance(1);
362    let mut value = String::new();
363    loop {
364        match cursor.peek() {
365            None => return Err(cursor.error("unterminated string literal")),
366            Some('"') => {
367                cursor.advance(1);
368                break;
369            }
370            Some('\\') => {
371                cursor.advance(1);
372                match cursor.peek() {
373                    Some(c @ ('"' | '\\')) => {
374                        value.push(c);
375                        cursor.advance(1);
376                    }
377                    Some(c) => {
378                        value.push('\\');
379                        value.push(c);
380                        cursor.advance(c.len_utf8());
381                    }
382                    None => return Err(cursor.error("unterminated escape sequence")),
383                }
384            }
385            Some(c) => {
386                value.push(c);
387                cursor.advance(c.len_utf8());
388            }
389        }
390    }
391    Ok(value)
392}
393
394fn parse_primary_value(cursor: &mut Cursor<'_>) -> Result<ArgMatcher, PatternParseError> {
395    cursor.skip_whitespace();
396    let start = cursor.pos;
397
398    let mut depth = 0u32;
399    while let Some(c) = cursor.peek() {
400        match c {
401            '(' => {
402                depth += 1;
403                cursor.advance(1);
404            }
405            ')' if depth > 0 => {
406                depth -= 1;
407                cursor.advance(1);
408            }
409            ')' => break,
410            _ => cursor.advance(c.len_utf8()),
411        }
412    }
413
414    let value = cursor.input[start..cursor.pos].trim();
415    if value.is_empty() {
416        return Err(cursor.error("empty primary pattern"));
417    }
418    Ok(ArgMatcher::Primary {
419        op: MatchOp::Glob,
420        value: value.to_string(),
421    })
422}
423
424// ---------------------------------------------------------------------------
425// Serde for ToolCallPattern (uses the parser)
426// ---------------------------------------------------------------------------
427
428impl serde::Serialize for ToolCallPattern {
429    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
430        serializer.serialize_str(&self.to_string())
431    }
432}
433
434impl<'de> serde::Deserialize<'de> for ToolCallPattern {
435    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
436        struct PatternVisitor;
437
438        impl<'de> serde::de::Visitor<'de> for PatternVisitor {
439            type Value = ToolCallPattern;
440
441            fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
442                f.write_str("a tool call pattern string like \"Bash(npm *)\"")
443            }
444
445            fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
446                parse_pattern(v).map_err(serde::de::Error::custom)
447            }
448        }
449
450        deserializer.deserialize_str(PatternVisitor)
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[test]
459    fn parse_exact_tool_only() {
460        let p = parse_pattern("Bash").unwrap();
461        assert_eq!(p.tool, ToolMatcher::Exact("Bash".into()));
462        assert_eq!(p.args, ArgMatcher::Any);
463    }
464
465    #[test]
466    fn parse_glob_tool_only() {
467        let p = parse_pattern("mcp__github__*").unwrap();
468        assert_eq!(p.tool, ToolMatcher::Glob("mcp__github__*".into()));
469    }
470
471    #[test]
472    fn parse_primary_glob() {
473        let p = parse_pattern("Bash(npm *)").unwrap();
474        assert_eq!(
475            p.args,
476            ArgMatcher::Primary {
477                op: MatchOp::Glob,
478                value: "npm *".into()
479            }
480        );
481    }
482
483    #[test]
484    fn parse_named_field_glob() {
485        let p = parse_pattern(r#"Edit(file_path ~ "src/**/*.rs")"#).unwrap();
486        if let ArgMatcher::Fields(conditions) = &p.args {
487            assert_eq!(conditions[0].op, MatchOp::Glob);
488            assert_eq!(conditions[0].value, "src/**/*.rs");
489        } else {
490            panic!("expected Fields");
491        }
492    }
493
494    #[test]
495    fn serde_round_trip() {
496        let p = ToolCallPattern::tool_with_primary("Bash", "npm *");
497        let json_val = serde_json::to_string(&p).unwrap();
498        assert_eq!(json_val, r#""Bash(npm *)""#);
499        let decoded: ToolCallPattern = serde_json::from_str(&json_val).unwrap();
500        assert_eq!(decoded, p);
501    }
502
503    #[test]
504    fn error_empty_input() {
505        assert!(parse_pattern("").is_err());
506    }
507
508    #[test]
509    fn error_unmatched_paren() {
510        assert!(parse_pattern("Bash(npm *").is_err());
511    }
512
513    #[test]
514    fn parse_regex_tool() {
515        let p = parse_pattern(r"/mcp__(github|gitlab)__.*/").unwrap();
516        assert!(matches!(p.tool, ToolMatcher::Regex(_)));
517        assert_eq!(p.args, ArgMatcher::Any);
518    }
519
520    #[test]
521    fn parse_regex_tool_with_escape() {
522        let p = parse_pattern(r"/foo\/bar/").unwrap();
523        if let ToolMatcher::Regex(re) = &p.tool {
524            assert_eq!(re.as_str(), r"foo\/bar");
525        } else {
526            panic!("expected Regex");
527        }
528    }
529
530    #[test]
531    fn error_empty_regex() {
532        assert!(parse_pattern("//").is_err());
533    }
534
535    #[test]
536    fn error_invalid_regex() {
537        assert!(parse_pattern("/[invalid/").is_err());
538    }
539
540    #[test]
541    fn parse_explicit_any_args() {
542        let p = parse_pattern("Bash(*)").unwrap();
543        assert_eq!(p.args, ArgMatcher::Any);
544    }
545
546    #[test]
547    fn parse_named_field_exact() {
548        let p = parse_pattern(r#"Bash(command = "ls")"#).unwrap();
549        if let ArgMatcher::Fields(conditions) = &p.args {
550            assert_eq!(conditions[0].op, MatchOp::Exact);
551            assert_eq!(conditions[0].value, "ls");
552        } else {
553            panic!("expected Fields");
554        }
555    }
556
557    #[test]
558    fn parse_named_field_regex() {
559        let p = parse_pattern(r#"Bash(command =~ "(?i)rm")"#).unwrap();
560        if let ArgMatcher::Fields(conditions) = &p.args {
561            assert_eq!(conditions[0].op, MatchOp::Regex);
562        } else {
563            panic!("expected Fields");
564        }
565    }
566
567    #[test]
568    fn parse_negated_operators() {
569        let p1 = parse_pattern(r#"T(f !~ "pat")"#).unwrap();
570        let p2 = parse_pattern(r#"T(f != "val")"#).unwrap();
571        let p3 = parse_pattern(r#"T(f !=~ "re")"#).unwrap();
572        if let ArgMatcher::Fields(c) = &p1.args {
573            assert_eq!(c[0].op, MatchOp::NotGlob);
574        }
575        if let ArgMatcher::Fields(c) = &p2.args {
576            assert_eq!(c[0].op, MatchOp::NotExact);
577        }
578        if let ArgMatcher::Fields(c) = &p3.args {
579            assert_eq!(c[0].op, MatchOp::NotRegex);
580        }
581    }
582
583    #[test]
584    fn parse_multi_field_conditions() {
585        let p = parse_pattern(r#"Tool(f1 ~ "a", f2 = "b")"#).unwrap();
586        if let ArgMatcher::Fields(conditions) = &p.args {
587            assert_eq!(conditions.len(), 2);
588            assert_eq!(conditions[0].op, MatchOp::Glob);
589            assert_eq!(conditions[1].op, MatchOp::Exact);
590        } else {
591            panic!("expected Fields");
592        }
593    }
594
595    #[test]
596    fn parse_nested_field_path() {
597        let p = parse_pattern(r#"Tool(a.b[*].c ~ "pat")"#).unwrap();
598        if let ArgMatcher::Fields(conditions) = &p.args {
599            let path = &conditions[0].path;
600            assert_eq!(path.len(), 4);
601            assert_eq!(path[0], PathSegment::Field("a".into()));
602            assert_eq!(path[1], PathSegment::Field("b".into()));
603            assert_eq!(path[2], PathSegment::AnyIndex);
604            assert_eq!(path[3], PathSegment::Field("c".into()));
605        } else {
606            panic!("expected Fields");
607        }
608    }
609
610    #[test]
611    fn parse_specific_index_path() {
612        let p = parse_pattern(r#"Tool(items[0] = "val")"#).unwrap();
613        if let ArgMatcher::Fields(conditions) = &p.args {
614            assert_eq!(conditions[0].path[1], PathSegment::Index(0));
615        } else {
616            panic!("expected Fields");
617        }
618    }
619
620    #[test]
621    fn parse_wildcard_path_segment() {
622        let p = parse_pattern(r#"Tool(*.id = "val")"#).unwrap();
623        if let ArgMatcher::Fields(conditions) = &p.args {
624            assert_eq!(conditions[0].path[0], PathSegment::Wildcard);
625            assert_eq!(conditions[0].path[1], PathSegment::Field("id".into()));
626        } else {
627            panic!("expected Fields");
628        }
629    }
630
631    #[test]
632    fn parse_escaped_quote_in_value() {
633        let p = parse_pattern(r#"T(f = "say \"hello\"")"#).unwrap();
634        if let ArgMatcher::Fields(c) = &p.args {
635            assert_eq!(c[0].value, r#"say "hello""#);
636        } else {
637            panic!("expected Fields");
638        }
639    }
640
641    #[test]
642    fn parse_escaped_backslash_in_value() {
643        let p = parse_pattern(r#"T(f = "path\\file")"#).unwrap();
644        if let ArgMatcher::Fields(c) = &p.args {
645            assert_eq!(c[0].value, r"path\file");
646        } else {
647            panic!("expected Fields");
648        }
649    }
650
651    #[test]
652    fn parse_non_special_escape_in_value() {
653        let p = parse_pattern(r#"T(f = "hello\nworld")"#).unwrap();
654        if let ArgMatcher::Fields(c) = &p.args {
655            // Non-special escapes preserve the backslash
656            assert_eq!(c[0].value, "hello\\nworld");
657        } else {
658            panic!("expected Fields");
659        }
660    }
661
662    #[test]
663    fn error_trailing_chars() {
664        assert!(parse_pattern("Bash extra").is_err());
665    }
666
667    #[test]
668    fn error_unterminated_string() {
669        assert!(parse_pattern(r#"T(f = "unterminated)"#).is_err());
670    }
671
672    #[test]
673    fn error_unterminated_escape() {
674        assert!(parse_pattern(r#"T(f = "end\"#).is_err());
675    }
676
677    #[test]
678    fn error_missing_quote() {
679        assert!(parse_pattern(r#"T(f = noquote)"#).is_err());
680    }
681
682    #[test]
683    fn error_bad_operator() {
684        // f = needs a quoted value; unquoted triggers an error
685        assert!(parse_pattern(r#"T(f = unquoted)"#).is_err());
686    }
687
688    #[test]
689    fn parse_pattern_error_display() {
690        let err = parse_pattern("").unwrap_err();
691        assert!(err.to_string().contains("parse error at"));
692    }
693
694    #[test]
695    fn serde_deserialize_invalid() {
696        let result: Result<ToolCallPattern, _> = serde_json::from_str(r#""""#);
697        assert!(result.is_err());
698    }
699
700    #[test]
701    fn parse_glob_tool_question_mark() {
702        let p = parse_pattern("Bas?").unwrap();
703        assert_eq!(p.tool, ToolMatcher::Glob("Bas?".into()));
704    }
705
706    #[test]
707    fn parse_glob_tool_bracket() {
708        let p = parse_pattern("Bas[hH]").unwrap();
709        assert_eq!(p.tool, ToolMatcher::Glob("Bas[hH]".into()));
710    }
711}