1use std::fmt;
17
18use crate::types::{
19 ArgMatcher, FieldCondition, MatchOp, PathSegment, ToolCallPattern, ToolMatcher,
20};
21
22#[derive(Debug, Clone)]
24pub struct PatternParseError {
25 pub message: String,
26 pub position: usize,
27}
28
29impl fmt::Display for PatternParseError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 write!(f, "parse error at {}: {}", self.position, self.message)
32 }
33}
34
35impl std::error::Error for PatternParseError {}
36
37struct Cursor<'a> {
38 input: &'a str,
39 pos: usize,
40}
41
42impl<'a> Cursor<'a> {
43 fn new(input: &'a str) -> Self {
44 Self { input, pos: 0 }
45 }
46
47 fn remaining(&self) -> &'a str {
48 &self.input[self.pos..]
49 }
50
51 fn is_empty(&self) -> bool {
52 self.pos >= self.input.len()
53 }
54
55 fn peek(&self) -> Option<char> {
56 self.remaining().chars().next()
57 }
58
59 fn advance(&mut self, n: usize) {
60 self.pos += n;
61 }
62
63 fn skip_whitespace(&mut self) {
64 while let Some(c) = self.peek() {
65 if c.is_ascii_whitespace() {
66 self.advance(c.len_utf8());
67 } else {
68 break;
69 }
70 }
71 }
72
73 fn expect(&mut self, ch: char) -> Result<(), PatternParseError> {
74 self.skip_whitespace();
75 match self.peek() {
76 Some(c) if c == ch => {
77 self.advance(c.len_utf8());
78 Ok(())
79 }
80 other => Err(self.error(format!(
81 "expected '{}', found {}",
82 ch,
83 match other {
84 Some(c) => format!("'{c}'"),
85 None => "end of input".to_string(),
86 }
87 ))),
88 }
89 }
90
91 fn error(&self, message: impl Into<String>) -> PatternParseError {
92 PatternParseError {
93 message: message.into(),
94 position: self.pos,
95 }
96 }
97}
98
99pub fn parse_pattern(input: &str) -> Result<ToolCallPattern, PatternParseError> {
101 let mut cursor = Cursor::new(input.trim());
102
103 let tool = parse_tool_part(&mut cursor)?;
104 cursor.skip_whitespace();
105
106 let args = if cursor.peek() == Some('(') {
107 cursor.advance(1);
108 let args = parse_arg_part(&mut cursor)?;
109 cursor.expect(')')?;
110 args
111 } else {
112 ArgMatcher::Any
113 };
114
115 cursor.skip_whitespace();
116 if !cursor.is_empty() {
117 return Err(cursor.error(format!("unexpected trailing: '{}'", cursor.remaining())));
118 }
119
120 Ok(ToolCallPattern { tool, args })
121}
122
123fn parse_tool_part(cursor: &mut Cursor<'_>) -> Result<ToolMatcher, PatternParseError> {
124 cursor.skip_whitespace();
125 if cursor.peek() == Some('/') {
126 cursor.advance(1);
127 let start = cursor.pos;
128 let mut depth = 0u32;
129 while let Some(c) = cursor.peek() {
130 match c {
131 '\\' => {
132 cursor.advance(1);
133 if cursor.peek().is_some() {
134 cursor.advance(1);
135 }
136 }
137 '(' => {
138 depth += 1;
139 cursor.advance(1);
140 }
141 ')' => {
142 depth = depth.saturating_sub(1);
143 cursor.advance(1);
144 }
145 '/' if depth == 0 => break,
146 _ => cursor.advance(c.len_utf8()),
147 }
148 }
149 let body = &cursor.input[start..cursor.pos];
150 if body.is_empty() {
151 return Err(cursor.error("empty regex pattern"));
152 }
153 cursor.expect('/')?;
154 let re =
155 regex::Regex::new(body).map_err(|e| cursor.error(format!("invalid regex: {e}")))?;
156 Ok(ToolMatcher::Regex(re))
157 } else {
158 let start = cursor.pos;
159 while let Some(c) = cursor.peek() {
160 if c == '(' || c.is_ascii_whitespace() {
161 break;
162 }
163 cursor.advance(c.len_utf8());
164 }
165 let name = &cursor.input[start..cursor.pos];
166 if name.is_empty() {
167 return Err(cursor.error("empty tool name"));
168 }
169 if has_glob_chars(name) {
170 Ok(ToolMatcher::Glob(name.to_string()))
171 } else {
172 Ok(ToolMatcher::Exact(name.to_string()))
173 }
174 }
175}
176
177fn has_glob_chars(s: &str) -> bool {
178 s.contains('*') || s.contains('?') || s.contains('[')
179}
180
181fn parse_arg_part(cursor: &mut Cursor<'_>) -> Result<ArgMatcher, PatternParseError> {
182 cursor.skip_whitespace();
183
184 if cursor.peek() == Some('*') {
185 let after = cursor.remaining().get(1..2);
186 if after.is_none_or(|s| {
187 let c = s.chars().next().unwrap_or(')');
188 c == ')' || c.is_ascii_whitespace()
189 }) {
190 cursor.advance(1);
191 cursor.skip_whitespace();
192 return Ok(ArgMatcher::Any);
193 }
194 }
195
196 if looks_like_field_conditions(cursor.remaining()) {
197 parse_field_conditions(cursor)
198 } else {
199 parse_primary_value(cursor)
200 }
201}
202
203fn looks_like_field_conditions(s: &str) -> bool {
204 let s = s.trim();
205 let bytes = s.as_bytes();
206 let mut i = 0;
207 while i < bytes.len() {
208 let c = bytes[i] as char;
209 if c.is_ascii_alphanumeric() || c == '_' || c == '.' {
210 i += 1;
211 } else if c == '[' {
212 i += 1;
213 while i < bytes.len() && bytes[i] != b']' {
214 i += 1;
215 }
216 if i < bytes.len() {
217 i += 1;
218 }
219 } else if c == '*' {
220 i += 1;
221 if i < bytes.len() && (bytes[i] == b'.' || bytes[i] == b'[') {
222 continue;
223 }
224 break;
225 } else {
226 break;
227 }
228 }
229 while i < bytes.len() && (bytes[i] as char).is_ascii_whitespace() {
230 i += 1;
231 }
232 let remaining = &s[i..];
233 remaining.starts_with("~")
234 || remaining.starts_with("=")
235 || remaining.starts_with("!~")
236 || remaining.starts_with("!=")
237}
238
239fn parse_field_conditions(cursor: &mut Cursor<'_>) -> Result<ArgMatcher, PatternParseError> {
240 let mut conditions = Vec::new();
241 loop {
242 cursor.skip_whitespace();
243 conditions.push(parse_single_field_condition(cursor)?);
244 cursor.skip_whitespace();
245 if cursor.peek() == Some(',') {
246 cursor.advance(1);
247 } else {
248 break;
249 }
250 }
251 Ok(ArgMatcher::Fields(conditions))
252}
253
254fn parse_single_field_condition(
255 cursor: &mut Cursor<'_>,
256) -> Result<FieldCondition, PatternParseError> {
257 cursor.skip_whitespace();
258 let path = parse_field_path(cursor)?;
259 cursor.skip_whitespace();
260 let op = parse_match_op(cursor)?;
261 cursor.skip_whitespace();
262 let value = parse_quoted_value(cursor)?;
263 Ok(FieldCondition { path, op, value })
264}
265
266fn parse_field_path(cursor: &mut Cursor<'_>) -> Result<Vec<PathSegment>, PatternParseError> {
267 let mut segments = Vec::new();
268 loop {
269 cursor.skip_whitespace();
270 if cursor.peek() == Some('*') {
271 cursor.advance(1);
272 segments.push(PathSegment::Wildcard);
273 } else {
274 let ident = parse_identifier(cursor)?;
275 segments.push(PathSegment::Field(ident));
276 }
277
278 while cursor.peek() == Some('[') {
279 cursor.advance(1);
280 cursor.skip_whitespace();
281 if cursor.peek() == Some('*') {
282 cursor.advance(1);
283 segments.push(PathSegment::AnyIndex);
284 } else {
285 let idx = parse_usize(cursor)?;
286 segments.push(PathSegment::Index(idx));
287 }
288 cursor.expect(']')?;
289 }
290
291 if cursor.peek() == Some('.') {
292 cursor.advance(1);
293 } else {
294 break;
295 }
296 }
297 Ok(segments)
298}
299
300fn parse_identifier(cursor: &mut Cursor<'_>) -> Result<String, PatternParseError> {
301 let start = cursor.pos;
302 while let Some(c) = cursor.peek() {
303 if c.is_ascii_alphanumeric() || c == '_' || c == '-' {
304 cursor.advance(1);
305 } else {
306 break;
307 }
308 }
309 let ident = &cursor.input[start..cursor.pos];
310 if ident.is_empty() {
311 return Err(cursor.error("expected identifier"));
312 }
313 Ok(ident.to_string())
314}
315
316fn parse_usize(cursor: &mut Cursor<'_>) -> Result<usize, PatternParseError> {
317 let start = cursor.pos;
318 while let Some(c) = cursor.peek() {
319 if c.is_ascii_digit() {
320 cursor.advance(1);
321 } else {
322 break;
323 }
324 }
325 let digits = &cursor.input[start..cursor.pos];
326 digits
327 .parse::<usize>()
328 .map_err(|_| cursor.error(format!("invalid index: '{digits}'")))
329}
330
331fn parse_match_op(cursor: &mut Cursor<'_>) -> Result<MatchOp, PatternParseError> {
332 let remaining = cursor.remaining();
333 if remaining.starts_with("!=~") {
334 cursor.advance(3);
335 Ok(MatchOp::NotRegex)
336 } else if remaining.starts_with("!=") {
337 cursor.advance(2);
338 Ok(MatchOp::NotExact)
339 } else if remaining.starts_with("!~") {
340 cursor.advance(2);
341 Ok(MatchOp::NotGlob)
342 } else if remaining.starts_with("=~") {
343 cursor.advance(2);
344 Ok(MatchOp::Regex)
345 } else if remaining.starts_with('~') {
346 cursor.advance(1);
347 Ok(MatchOp::Glob)
348 } else if remaining.starts_with('=') {
349 cursor.advance(1);
350 Ok(MatchOp::Exact)
351 } else {
352 Err(cursor.error("expected operator: ~, =, =~, !~, !=, or !=~"))
353 }
354}
355
356fn parse_quoted_value(cursor: &mut Cursor<'_>) -> Result<String, PatternParseError> {
357 cursor.skip_whitespace();
358 if cursor.peek() != Some('"') {
359 return Err(cursor.error("expected '\"' to start value"));
360 }
361 cursor.advance(1);
362 let mut value = String::new();
363 loop {
364 match cursor.peek() {
365 None => return Err(cursor.error("unterminated string literal")),
366 Some('"') => {
367 cursor.advance(1);
368 break;
369 }
370 Some('\\') => {
371 cursor.advance(1);
372 match cursor.peek() {
373 Some(c @ ('"' | '\\')) => {
374 value.push(c);
375 cursor.advance(1);
376 }
377 Some(c) => {
378 value.push('\\');
379 value.push(c);
380 cursor.advance(c.len_utf8());
381 }
382 None => return Err(cursor.error("unterminated escape sequence")),
383 }
384 }
385 Some(c) => {
386 value.push(c);
387 cursor.advance(c.len_utf8());
388 }
389 }
390 }
391 Ok(value)
392}
393
394fn parse_primary_value(cursor: &mut Cursor<'_>) -> Result<ArgMatcher, PatternParseError> {
395 cursor.skip_whitespace();
396 let start = cursor.pos;
397
398 let mut depth = 0u32;
399 while let Some(c) = cursor.peek() {
400 match c {
401 '(' => {
402 depth += 1;
403 cursor.advance(1);
404 }
405 ')' if depth > 0 => {
406 depth -= 1;
407 cursor.advance(1);
408 }
409 ')' => break,
410 _ => cursor.advance(c.len_utf8()),
411 }
412 }
413
414 let value = cursor.input[start..cursor.pos].trim();
415 if value.is_empty() {
416 return Err(cursor.error("empty primary pattern"));
417 }
418 Ok(ArgMatcher::Primary {
419 op: MatchOp::Glob,
420 value: value.to_string(),
421 })
422}
423
424impl serde::Serialize for ToolCallPattern {
429 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
430 serializer.serialize_str(&self.to_string())
431 }
432}
433
434impl<'de> serde::Deserialize<'de> for ToolCallPattern {
435 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
436 struct PatternVisitor;
437
438 impl<'de> serde::de::Visitor<'de> for PatternVisitor {
439 type Value = ToolCallPattern;
440
441 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
442 f.write_str("a tool call pattern string like \"Bash(npm *)\"")
443 }
444
445 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
446 parse_pattern(v).map_err(serde::de::Error::custom)
447 }
448 }
449
450 deserializer.deserialize_str(PatternVisitor)
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn parse_exact_tool_only() {
460 let p = parse_pattern("Bash").unwrap();
461 assert_eq!(p.tool, ToolMatcher::Exact("Bash".into()));
462 assert_eq!(p.args, ArgMatcher::Any);
463 }
464
465 #[test]
466 fn parse_glob_tool_only() {
467 let p = parse_pattern("mcp__github__*").unwrap();
468 assert_eq!(p.tool, ToolMatcher::Glob("mcp__github__*".into()));
469 }
470
471 #[test]
472 fn parse_primary_glob() {
473 let p = parse_pattern("Bash(npm *)").unwrap();
474 assert_eq!(
475 p.args,
476 ArgMatcher::Primary {
477 op: MatchOp::Glob,
478 value: "npm *".into()
479 }
480 );
481 }
482
483 #[test]
484 fn parse_named_field_glob() {
485 let p = parse_pattern(r#"Edit(file_path ~ "src/**/*.rs")"#).unwrap();
486 if let ArgMatcher::Fields(conditions) = &p.args {
487 assert_eq!(conditions[0].op, MatchOp::Glob);
488 assert_eq!(conditions[0].value, "src/**/*.rs");
489 } else {
490 panic!("expected Fields");
491 }
492 }
493
494 #[test]
495 fn serde_round_trip() {
496 let p = ToolCallPattern::tool_with_primary("Bash", "npm *");
497 let json_val = serde_json::to_string(&p).unwrap();
498 assert_eq!(json_val, r#""Bash(npm *)""#);
499 let decoded: ToolCallPattern = serde_json::from_str(&json_val).unwrap();
500 assert_eq!(decoded, p);
501 }
502
503 #[test]
504 fn error_empty_input() {
505 assert!(parse_pattern("").is_err());
506 }
507
508 #[test]
509 fn error_unmatched_paren() {
510 assert!(parse_pattern("Bash(npm *").is_err());
511 }
512
513 #[test]
514 fn parse_regex_tool() {
515 let p = parse_pattern(r"/mcp__(github|gitlab)__.*/").unwrap();
516 assert!(matches!(p.tool, ToolMatcher::Regex(_)));
517 assert_eq!(p.args, ArgMatcher::Any);
518 }
519
520 #[test]
521 fn parse_regex_tool_with_escape() {
522 let p = parse_pattern(r"/foo\/bar/").unwrap();
523 if let ToolMatcher::Regex(re) = &p.tool {
524 assert_eq!(re.as_str(), r"foo\/bar");
525 } else {
526 panic!("expected Regex");
527 }
528 }
529
530 #[test]
531 fn error_empty_regex() {
532 assert!(parse_pattern("//").is_err());
533 }
534
535 #[test]
536 fn error_invalid_regex() {
537 assert!(parse_pattern("/[invalid/").is_err());
538 }
539
540 #[test]
541 fn parse_explicit_any_args() {
542 let p = parse_pattern("Bash(*)").unwrap();
543 assert_eq!(p.args, ArgMatcher::Any);
544 }
545
546 #[test]
547 fn parse_named_field_exact() {
548 let p = parse_pattern(r#"Bash(command = "ls")"#).unwrap();
549 if let ArgMatcher::Fields(conditions) = &p.args {
550 assert_eq!(conditions[0].op, MatchOp::Exact);
551 assert_eq!(conditions[0].value, "ls");
552 } else {
553 panic!("expected Fields");
554 }
555 }
556
557 #[test]
558 fn parse_named_field_regex() {
559 let p = parse_pattern(r#"Bash(command =~ "(?i)rm")"#).unwrap();
560 if let ArgMatcher::Fields(conditions) = &p.args {
561 assert_eq!(conditions[0].op, MatchOp::Regex);
562 } else {
563 panic!("expected Fields");
564 }
565 }
566
567 #[test]
568 fn parse_negated_operators() {
569 let p1 = parse_pattern(r#"T(f !~ "pat")"#).unwrap();
570 let p2 = parse_pattern(r#"T(f != "val")"#).unwrap();
571 let p3 = parse_pattern(r#"T(f !=~ "re")"#).unwrap();
572 if let ArgMatcher::Fields(c) = &p1.args {
573 assert_eq!(c[0].op, MatchOp::NotGlob);
574 }
575 if let ArgMatcher::Fields(c) = &p2.args {
576 assert_eq!(c[0].op, MatchOp::NotExact);
577 }
578 if let ArgMatcher::Fields(c) = &p3.args {
579 assert_eq!(c[0].op, MatchOp::NotRegex);
580 }
581 }
582
583 #[test]
584 fn parse_multi_field_conditions() {
585 let p = parse_pattern(r#"Tool(f1 ~ "a", f2 = "b")"#).unwrap();
586 if let ArgMatcher::Fields(conditions) = &p.args {
587 assert_eq!(conditions.len(), 2);
588 assert_eq!(conditions[0].op, MatchOp::Glob);
589 assert_eq!(conditions[1].op, MatchOp::Exact);
590 } else {
591 panic!("expected Fields");
592 }
593 }
594
595 #[test]
596 fn parse_nested_field_path() {
597 let p = parse_pattern(r#"Tool(a.b[*].c ~ "pat")"#).unwrap();
598 if let ArgMatcher::Fields(conditions) = &p.args {
599 let path = &conditions[0].path;
600 assert_eq!(path.len(), 4);
601 assert_eq!(path[0], PathSegment::Field("a".into()));
602 assert_eq!(path[1], PathSegment::Field("b".into()));
603 assert_eq!(path[2], PathSegment::AnyIndex);
604 assert_eq!(path[3], PathSegment::Field("c".into()));
605 } else {
606 panic!("expected Fields");
607 }
608 }
609
610 #[test]
611 fn parse_specific_index_path() {
612 let p = parse_pattern(r#"Tool(items[0] = "val")"#).unwrap();
613 if let ArgMatcher::Fields(conditions) = &p.args {
614 assert_eq!(conditions[0].path[1], PathSegment::Index(0));
615 } else {
616 panic!("expected Fields");
617 }
618 }
619
620 #[test]
621 fn parse_wildcard_path_segment() {
622 let p = parse_pattern(r#"Tool(*.id = "val")"#).unwrap();
623 if let ArgMatcher::Fields(conditions) = &p.args {
624 assert_eq!(conditions[0].path[0], PathSegment::Wildcard);
625 assert_eq!(conditions[0].path[1], PathSegment::Field("id".into()));
626 } else {
627 panic!("expected Fields");
628 }
629 }
630
631 #[test]
632 fn parse_escaped_quote_in_value() {
633 let p = parse_pattern(r#"T(f = "say \"hello\"")"#).unwrap();
634 if let ArgMatcher::Fields(c) = &p.args {
635 assert_eq!(c[0].value, r#"say "hello""#);
636 } else {
637 panic!("expected Fields");
638 }
639 }
640
641 #[test]
642 fn parse_escaped_backslash_in_value() {
643 let p = parse_pattern(r#"T(f = "path\\file")"#).unwrap();
644 if let ArgMatcher::Fields(c) = &p.args {
645 assert_eq!(c[0].value, r"path\file");
646 } else {
647 panic!("expected Fields");
648 }
649 }
650
651 #[test]
652 fn parse_non_special_escape_in_value() {
653 let p = parse_pattern(r#"T(f = "hello\nworld")"#).unwrap();
654 if let ArgMatcher::Fields(c) = &p.args {
655 assert_eq!(c[0].value, "hello\\nworld");
657 } else {
658 panic!("expected Fields");
659 }
660 }
661
662 #[test]
663 fn error_trailing_chars() {
664 assert!(parse_pattern("Bash extra").is_err());
665 }
666
667 #[test]
668 fn error_unterminated_string() {
669 assert!(parse_pattern(r#"T(f = "unterminated)"#).is_err());
670 }
671
672 #[test]
673 fn error_unterminated_escape() {
674 assert!(parse_pattern(r#"T(f = "end\"#).is_err());
675 }
676
677 #[test]
678 fn error_missing_quote() {
679 assert!(parse_pattern(r#"T(f = noquote)"#).is_err());
680 }
681
682 #[test]
683 fn error_bad_operator() {
684 assert!(parse_pattern(r#"T(f = unquoted)"#).is_err());
686 }
687
688 #[test]
689 fn parse_pattern_error_display() {
690 let err = parse_pattern("").unwrap_err();
691 assert!(err.to_string().contains("parse error at"));
692 }
693
694 #[test]
695 fn serde_deserialize_invalid() {
696 let result: Result<ToolCallPattern, _> = serde_json::from_str(r#""""#);
697 assert!(result.is_err());
698 }
699
700 #[test]
701 fn parse_glob_tool_question_mark() {
702 let p = parse_pattern("Bas?").unwrap();
703 assert_eq!(p.tool, ToolMatcher::Glob("Bas?".into()));
704 }
705
706 #[test]
707 fn parse_glob_tool_bracket() {
708 let p = parse_pattern("Bas[hH]").unwrap();
709 assert_eq!(p.tool, ToolMatcher::Glob("Bas[hH]".into()));
710 }
711}