awaken_tool_pattern/
tool_id.rs1use thiserror::Error;
2
3#[derive(Debug, Error, PartialEq, Eq)]
4pub enum ToolIdPatternError {
5 #[error("pattern is empty")]
6 Empty,
7 #[error("pattern ends with a dangling escape (`\\`)")]
8 DanglingEscape,
9}
10
11#[must_use]
18pub fn tool_id_match(pattern: &str, tool_id: &str) -> bool {
19 let p = pattern.as_bytes();
20 let v = tool_id.as_bytes();
21 let mut pi = 0usize;
22 let mut vi = 0usize;
23 let mut star_pi: Option<usize> = None;
24 let mut star_vi = 0usize;
25
26 while vi < v.len() {
27 if pi < p.len() {
28 let c = p[pi];
29 if c == b'\\' && pi + 1 < p.len() {
30 if p[pi + 1] == v[vi] {
31 pi += 2;
32 vi += 1;
33 continue;
34 }
35 } else if c == b'*' {
36 star_pi = Some(pi);
37 star_vi = vi;
38 pi += 1;
39 continue;
40 } else if c == v[vi] {
41 pi += 1;
42 vi += 1;
43 continue;
44 }
45 }
46 if let Some(sp) = star_pi {
47 pi = sp + 1;
48 star_vi += 1;
49 vi = star_vi;
50 } else {
51 return false;
52 }
53 }
54 while pi < p.len() && p[pi] == b'*' {
55 pi += 1;
56 }
57 pi == p.len()
58}
59
60pub fn validate_tool_id_pattern(pattern: &str) -> Result<(), ToolIdPatternError> {
62 if pattern.is_empty() {
63 return Err(ToolIdPatternError::Empty);
64 }
65 let bytes = pattern.as_bytes();
66 let mut i = 0;
67 while i < bytes.len() {
68 if bytes[i] == b'\\' {
69 if i + 1 >= bytes.len() {
70 return Err(ToolIdPatternError::DanglingEscape);
71 }
72 i += 2;
73 } else {
74 i += 1;
75 }
76 }
77 Ok(())
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[test]
85 fn literal_matches_exact() {
86 assert!(tool_id_match("Bash", "Bash"));
87 assert!(!tool_id_match("Bash", "bash"));
88 assert!(!tool_id_match("Bash", "Bashx"));
89 }
90
91 #[test]
92 fn star_matches_anything() {
93 assert!(tool_id_match("*", ""));
94 assert!(tool_id_match("*", "Bash"));
95 assert!(tool_id_match("*", "mcp:weather/forecast"));
96 }
97
98 #[test]
99 fn star_prefix_and_suffix() {
100 assert!(tool_id_match("mcp:*", "mcp:weather"));
101 assert!(tool_id_match("mcp:*", "mcp:fs/read"));
102 assert!(!tool_id_match("mcp:*", "Bash"));
103 assert!(tool_id_match("*Tool", "BashTool"));
104 assert!(tool_id_match("*Tool", "Tool"));
105 }
106
107 #[test]
108 fn star_in_middle() {
109 assert!(tool_id_match("mcp:*/read", "mcp:fs/read"));
110 assert!(!tool_id_match("mcp:*/read", "mcp:fs/write"));
111 }
112
113 #[test]
114 fn escape_literal_star() {
115 assert!(tool_id_match(r"foo\*bar", "foo*bar"));
116 assert!(!tool_id_match(r"foo\*bar", "foobar"));
117 assert!(!tool_id_match(r"foo\*bar", "fooXbar"));
118 }
119
120 #[test]
121 fn escape_literal_backslash() {
122 assert!(tool_id_match(r"foo\\bar", r"foo\bar"));
123 assert!(!tool_id_match(r"foo\\bar", "foobar"));
124 }
125
126 #[test]
127 fn slash_colon_underscore_are_literal() {
128 assert!(tool_id_match("a/b:c_d", "a/b:c_d"));
129 assert!(!tool_id_match("a/b:c_d", "a/b:c-d"));
130 }
131
132 #[test]
133 fn validate_rejects_empty() {
134 assert_eq!(validate_tool_id_pattern(""), Err(ToolIdPatternError::Empty));
135 }
136
137 #[test]
138 fn validate_rejects_dangling_escape() {
139 assert_eq!(
140 validate_tool_id_pattern(r"foo\"),
141 Err(ToolIdPatternError::DanglingEscape)
142 );
143 }
144
145 #[test]
146 fn validate_accepts_well_formed() {
147 for p in ["*", "Bash", "mcp:*", r"foo\*bar", r"foo\\bar"] {
148 assert!(validate_tool_id_pattern(p).is_ok(), "should accept {p}");
149 }
150 }
151}