Skip to main content

awaken_tool_pattern/
tool_id.rs

1use thiserror::Error;
2
3#[derive(Debug, Error, PartialEq, Eq)]
4pub enum ToolIdPatternError {
5    #[error("pattern is empty")]
6    Empty,
7    #[error("pattern ends with a dangling escape (`\\`)")]
8    DanglingEscape,
9}
10
11/// Match a tool-id glob pattern against a literal tool id.
12///
13/// Grammar (anchored full match):
14/// - `*` matches any sequence of characters (including `/`, `:`, `_`).
15/// - `\` escapes the next character (`\*` is a literal `*`; `\\` a literal `\`).
16/// - Every other character is a literal.
17#[must_use]
18pub fn tool_id_match(pattern: &str, tool_id: &str) -> bool {
19    let p = pattern.as_bytes();
20    let v = tool_id.as_bytes();
21    let mut pi = 0usize;
22    let mut vi = 0usize;
23    let mut star_pi: Option<usize> = None;
24    let mut star_vi = 0usize;
25
26    while vi < v.len() {
27        if pi < p.len() {
28            let c = p[pi];
29            if c == b'\\' && pi + 1 < p.len() {
30                if p[pi + 1] == v[vi] {
31                    pi += 2;
32                    vi += 1;
33                    continue;
34                }
35            } else if c == b'*' {
36                star_pi = Some(pi);
37                star_vi = vi;
38                pi += 1;
39                continue;
40            } else if c == v[vi] {
41                pi += 1;
42                vi += 1;
43                continue;
44            }
45        }
46        if let Some(sp) = star_pi {
47            pi = sp + 1;
48            star_vi += 1;
49            vi = star_vi;
50        } else {
51            return false;
52        }
53    }
54    while pi < p.len() && p[pi] == b'*' {
55        pi += 1;
56    }
57    pi == p.len()
58}
59
60/// Validate that a tool-id pattern string is syntactically well-formed.
61pub fn validate_tool_id_pattern(pattern: &str) -> Result<(), ToolIdPatternError> {
62    if pattern.is_empty() {
63        return Err(ToolIdPatternError::Empty);
64    }
65    let bytes = pattern.as_bytes();
66    let mut i = 0;
67    while i < bytes.len() {
68        if bytes[i] == b'\\' {
69            if i + 1 >= bytes.len() {
70                return Err(ToolIdPatternError::DanglingEscape);
71            }
72            i += 2;
73        } else {
74            i += 1;
75        }
76    }
77    Ok(())
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn literal_matches_exact() {
86        assert!(tool_id_match("Bash", "Bash"));
87        assert!(!tool_id_match("Bash", "bash"));
88        assert!(!tool_id_match("Bash", "Bashx"));
89    }
90
91    #[test]
92    fn star_matches_anything() {
93        assert!(tool_id_match("*", ""));
94        assert!(tool_id_match("*", "Bash"));
95        assert!(tool_id_match("*", "mcp:weather/forecast"));
96    }
97
98    #[test]
99    fn star_prefix_and_suffix() {
100        assert!(tool_id_match("mcp:*", "mcp:weather"));
101        assert!(tool_id_match("mcp:*", "mcp:fs/read"));
102        assert!(!tool_id_match("mcp:*", "Bash"));
103        assert!(tool_id_match("*Tool", "BashTool"));
104        assert!(tool_id_match("*Tool", "Tool"));
105    }
106
107    #[test]
108    fn star_in_middle() {
109        assert!(tool_id_match("mcp:*/read", "mcp:fs/read"));
110        assert!(!tool_id_match("mcp:*/read", "mcp:fs/write"));
111    }
112
113    #[test]
114    fn escape_literal_star() {
115        assert!(tool_id_match(r"foo\*bar", "foo*bar"));
116        assert!(!tool_id_match(r"foo\*bar", "foobar"));
117        assert!(!tool_id_match(r"foo\*bar", "fooXbar"));
118    }
119
120    #[test]
121    fn escape_literal_backslash() {
122        assert!(tool_id_match(r"foo\\bar", r"foo\bar"));
123        assert!(!tool_id_match(r"foo\\bar", "foobar"));
124    }
125
126    #[test]
127    fn slash_colon_underscore_are_literal() {
128        assert!(tool_id_match("a/b:c_d", "a/b:c_d"));
129        assert!(!tool_id_match("a/b:c_d", "a/b:c-d"));
130    }
131
132    #[test]
133    fn validate_rejects_empty() {
134        assert_eq!(validate_tool_id_pattern(""), Err(ToolIdPatternError::Empty));
135    }
136
137    #[test]
138    fn validate_rejects_dangling_escape() {
139        assert_eq!(
140            validate_tool_id_pattern(r"foo\"),
141            Err(ToolIdPatternError::DanglingEscape)
142        );
143    }
144
145    #[test]
146    fn validate_accepts_well_formed() {
147        for p in ["*", "Bash", "mcp:*", r"foo\*bar", r"foo\\bar"] {
148            assert!(validate_tool_id_pattern(p).is_ok(), "should accept {p}");
149        }
150    }
151}