claude_wrapper/
tool_pattern.rs1use std::fmt;
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub struct ToolPattern {
44 repr: String,
45}
46
47#[derive(Debug, thiserror::Error, PartialEq, Eq)]
49pub enum PatternError {
50 #[error("tool pattern must not be empty")]
52 Empty,
53
54 #[error("tool pattern is missing a name before '('")]
56 MissingName,
57
58 #[error("tool pattern has unbalanced parentheses: {0:?}")]
60 UnbalancedParens(String),
61
62 #[error("tool pattern contains an illegal character: {0:?}")]
65 IllegalChar(String),
66}
67
68impl ToolPattern {
69 pub fn tool(name: impl Into<String>) -> Self {
74 Self {
75 repr: name.into().trim().to_string(),
76 }
77 }
78
79 pub fn tool_with_args(name: impl Into<String>, args: impl Into<String>) -> Self {
89 Self {
90 repr: format!("{}({})", name.into().trim(), args.into()),
91 }
92 }
93
94 pub fn all(name: impl Into<String>) -> Self {
102 Self::tool_with_args(name, "*")
103 }
104
105 pub fn mcp(server: impl Into<String>, tool: impl Into<String>) -> Self {
120 Self {
121 repr: format!("mcp__{}__{}", server.into(), tool.into()),
122 }
123 }
124
125 pub fn parse(s: impl AsRef<str>) -> Result<Self, PatternError> {
131 let trimmed = s.as_ref().trim();
132 if trimmed.is_empty() {
133 return Err(PatternError::Empty);
134 }
135
136 for ch in trimmed.chars() {
137 if ch == ',' || ch.is_control() {
138 return Err(PatternError::IllegalChar(trimmed.to_string()));
139 }
140 }
141
142 if let Some(open) = trimmed.find('(') {
143 if !trimmed.ends_with(')') {
144 return Err(PatternError::UnbalancedParens(trimmed.to_string()));
145 }
146 if trimmed.matches('(').count() != 1 || trimmed.matches(')').count() != 1 {
148 return Err(PatternError::UnbalancedParens(trimmed.to_string()));
149 }
150 if open == 0 {
151 return Err(PatternError::MissingName);
152 }
153 } else if trimmed.contains(')') {
154 return Err(PatternError::UnbalancedParens(trimmed.to_string()));
155 }
156
157 Ok(Self {
158 repr: trimmed.to_string(),
159 })
160 }
161
162 pub fn as_str(&self) -> &str {
164 &self.repr
165 }
166}
167
168impl fmt::Display for ToolPattern {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 f.write_str(&self.repr)
171 }
172}
173
174impl AsRef<str> for ToolPattern {
175 fn as_ref(&self) -> &str {
176 &self.repr
177 }
178}
179
180impl From<&str> for ToolPattern {
181 fn from(s: &str) -> Self {
182 Self {
183 repr: s.trim().to_string(),
184 }
185 }
186}
187
188impl From<String> for ToolPattern {
189 fn from(s: String) -> Self {
190 let trimmed = s.trim();
191 if trimmed.len() == s.len() {
192 Self { repr: s }
193 } else {
194 Self {
195 repr: trimmed.to_string(),
196 }
197 }
198 }
199}
200
201impl From<&String> for ToolPattern {
202 fn from(s: &String) -> Self {
203 Self::from(s.as_str())
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn tool_strips_whitespace() {
213 assert_eq!(ToolPattern::tool(" Bash ").as_str(), "Bash");
214 }
215
216 #[test]
217 fn tool_with_args_renders_parens() {
218 let p = ToolPattern::tool_with_args("Bash", "git log:*");
219 assert_eq!(p.as_str(), "Bash(git log:*)");
220 }
221
222 #[test]
223 fn all_wildcards_args() {
224 assert_eq!(ToolPattern::all("Write").as_str(), "Write(*)");
225 }
226
227 #[test]
228 fn mcp_patterns() {
229 assert_eq!(ToolPattern::mcp("srv", "do_it").as_str(), "mcp__srv__do_it");
230 assert_eq!(ToolPattern::mcp("srv", "*").as_str(), "mcp__srv__*");
231 }
232
233 #[test]
234 fn parse_accepts_bare_name() {
235 assert_eq!(ToolPattern::parse("Bash").unwrap().as_str(), "Bash");
236 }
237
238 #[test]
239 fn parse_accepts_name_with_args() {
240 assert_eq!(
241 ToolPattern::parse("Bash(git log:*)").unwrap().as_str(),
242 "Bash(git log:*)"
243 );
244 }
245
246 #[test]
247 fn parse_accepts_mcp() {
248 assert_eq!(
249 ToolPattern::parse("mcp__srv__*").unwrap().as_str(),
250 "mcp__srv__*"
251 );
252 }
253
254 #[test]
255 fn parse_trims_whitespace() {
256 assert_eq!(ToolPattern::parse(" Read ").unwrap().as_str(), "Read");
257 }
258
259 #[test]
260 fn parse_rejects_empty() {
261 assert_eq!(ToolPattern::parse("").unwrap_err(), PatternError::Empty);
262 assert_eq!(ToolPattern::parse(" ").unwrap_err(), PatternError::Empty);
263 }
264
265 #[test]
266 fn parse_rejects_unbalanced_parens() {
267 assert!(matches!(
268 ToolPattern::parse("Bash(git log"),
269 Err(PatternError::UnbalancedParens(_))
270 ));
271 assert!(matches!(
272 ToolPattern::parse("Bashgit log)"),
273 Err(PatternError::UnbalancedParens(_))
274 ));
275 assert!(matches!(
276 ToolPattern::parse("Bash((nested))"),
277 Err(PatternError::UnbalancedParens(_))
278 ));
279 }
280
281 #[test]
282 fn parse_rejects_missing_name() {
283 assert_eq!(
284 ToolPattern::parse("(args)").unwrap_err(),
285 PatternError::MissingName
286 );
287 }
288
289 #[test]
290 fn parse_rejects_comma() {
291 assert!(matches!(
292 ToolPattern::parse("Bash,Read"),
293 Err(PatternError::IllegalChar(_))
294 ));
295 }
296
297 #[test]
298 fn parse_rejects_control_chars() {
299 assert!(matches!(
301 ToolPattern::parse("Ba\nsh"),
302 Err(PatternError::IllegalChar(_))
303 ));
304 }
305
306 #[test]
307 fn from_str_is_loose() {
308 let p: ToolPattern = "anything goes".into();
310 assert_eq!(p.as_str(), "anything goes");
311 }
312
313 #[test]
314 fn display_matches_as_str() {
315 let p = ToolPattern::tool_with_args("Bash", "ls");
316 assert_eq!(format!("{p}"), p.as_str());
317 }
318}