Skip to main content

claude_wrapper/
tool_pattern.rs

1//! Tool permission patterns for `--allowed-tools` / `--disallowed-tools`.
2//!
3//! The Claude CLI accepts three pattern shapes in its tool lists:
4//!
5//! - Bare name: `Bash`, `Read`, `Write`.
6//! - Name with argument glob: `Bash(git log:*)`, `Write(src/*.rs)`.
7//! - MCP pattern: `mcp__server__tool` or `mcp__server__*`.
8//!
9//! [`ToolPattern`] models all three. The typed constructors
10//! ([`ToolPattern::tool`], [`ToolPattern::tool_with_args`],
11//! [`ToolPattern::all`], [`ToolPattern::mcp`]) always produce
12//! well-formed output. [`ToolPattern::parse`] validates shape of a
13//! raw string and returns [`PatternError`] on malformed input.
14//!
15//! For back-compat, `From<&str>` / `From<String>` accept any string
16//! and store it verbatim -- callers passing raw CLI strings through
17//! [`QueryCommand::allowed_tool`](crate::QueryCommand::allowed_tool)
18//! keep working without changes. Use [`ToolPattern::parse`] directly
19//! when you want to catch typos before the CLI invocation.
20//!
21//! # Example
22//!
23//! ```
24//! use claude_wrapper::ToolPattern;
25//!
26//! let p = ToolPattern::tool_with_args("Bash", "git log:*");
27//! assert_eq!(p.as_str(), "Bash(git log:*)");
28//!
29//! let p = ToolPattern::all("Write");
30//! assert_eq!(p.as_str(), "Write(*)");
31//!
32//! let p = ToolPattern::mcp("my-server", "*");
33//! assert_eq!(p.as_str(), "mcp__my-server__*");
34//! ```
35
36use std::fmt;
37
38/// A tool permission pattern, ready to be emitted in a comma-joined
39/// `--allowed-tools` / `--disallowed-tools` value.
40///
41/// See the [module docs](crate::tool_pattern) for the accepted shapes.
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub struct ToolPattern {
44    repr: String,
45}
46
47/// Errors from parsing a raw string with [`ToolPattern::parse`].
48#[derive(Debug, thiserror::Error, PartialEq, Eq)]
49pub enum PatternError {
50    /// Input was empty or all whitespace.
51    #[error("tool pattern must not be empty")]
52    Empty,
53
54    /// The tool name part was empty (e.g. `(args)` with nothing before).
55    #[error("tool pattern is missing a name before '('")]
56    MissingName,
57
58    /// Parentheses were unbalanced or appeared out of order.
59    #[error("tool pattern has unbalanced parentheses: {0:?}")]
60    UnbalancedParens(String),
61
62    /// Contained a character that the CLI disallows in pattern values
63    /// (comma splits the argv, control chars can break the shell).
64    #[error("tool pattern contains an illegal character: {0:?}")]
65    IllegalChar(String),
66}
67
68impl ToolPattern {
69    /// A bare tool name, e.g. `ToolPattern::tool("Bash")` -> `Bash`.
70    ///
71    /// No validation beyond trimming whitespace; the CLI is the
72    /// source of truth for which tool names exist.
73    pub fn tool(name: impl Into<String>) -> Self {
74        Self {
75            repr: name.into().trim().to_string(),
76        }
77    }
78
79    /// A tool with an argument glob, rendered `Name(args)`.
80    ///
81    /// ```
82    /// # use claude_wrapper::ToolPattern;
83    /// assert_eq!(
84    ///     ToolPattern::tool_with_args("Bash", "git log:*").as_str(),
85    ///     "Bash(git log:*)"
86    /// );
87    /// ```
88    pub fn tool_with_args(name: impl Into<String>, args: impl Into<String>) -> Self {
89        Self {
90            repr: format!("{}({})", name.into().trim(), args.into()),
91        }
92    }
93
94    /// Shorthand for [`ToolPattern::tool_with_args`] with `*` as the
95    /// argument pattern -- "any args to this tool."
96    ///
97    /// ```
98    /// # use claude_wrapper::ToolPattern;
99    /// assert_eq!(ToolPattern::all("Write").as_str(), "Write(*)");
100    /// ```
101    pub fn all(name: impl Into<String>) -> Self {
102        Self::tool_with_args(name, "*")
103    }
104
105    /// An MCP pattern: `mcp__{server}__{tool}`. Pass `"*"` as the tool
106    /// to match any tool from the server.
107    ///
108    /// ```
109    /// # use claude_wrapper::ToolPattern;
110    /// assert_eq!(
111    ///     ToolPattern::mcp("my-server", "do_thing").as_str(),
112    ///     "mcp__my-server__do_thing"
113    /// );
114    /// assert_eq!(
115    ///     ToolPattern::mcp("my-server", "*").as_str(),
116    ///     "mcp__my-server__*"
117    /// );
118    /// ```
119    pub fn mcp(server: impl Into<String>, tool: impl Into<String>) -> Self {
120        Self {
121            repr: format!("mcp__{}__{}", server.into(), tool.into()),
122        }
123    }
124
125    /// Parse and validate a raw CLI-format pattern string.
126    ///
127    /// Validation is shape-level only (non-empty, balanced parens, no
128    /// comma or control chars). Tool names are not checked against any
129    /// allowlist because the CLI's tool inventory evolves independently.
130    pub fn parse(s: impl AsRef<str>) -> Result<Self, PatternError> {
131        let trimmed = s.as_ref().trim();
132        if trimmed.is_empty() {
133            return Err(PatternError::Empty);
134        }
135
136        for ch in trimmed.chars() {
137            if ch == ',' || ch.is_control() {
138                return Err(PatternError::IllegalChar(trimmed.to_string()));
139            }
140        }
141
142        if let Some(open) = trimmed.find('(') {
143            if !trimmed.ends_with(')') {
144                return Err(PatternError::UnbalancedParens(trimmed.to_string()));
145            }
146            // Exactly one '(' and one ')'.
147            if trimmed.matches('(').count() != 1 || trimmed.matches(')').count() != 1 {
148                return Err(PatternError::UnbalancedParens(trimmed.to_string()));
149            }
150            if open == 0 {
151                return Err(PatternError::MissingName);
152            }
153        } else if trimmed.contains(')') {
154            return Err(PatternError::UnbalancedParens(trimmed.to_string()));
155        }
156
157        Ok(Self {
158            repr: trimmed.to_string(),
159        })
160    }
161
162    /// The rendered pattern string, as it will appear in the CLI arg.
163    pub fn as_str(&self) -> &str {
164        &self.repr
165    }
166}
167
168impl fmt::Display for ToolPattern {
169    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170        f.write_str(&self.repr)
171    }
172}
173
174impl AsRef<str> for ToolPattern {
175    fn as_ref(&self) -> &str {
176        &self.repr
177    }
178}
179
180impl From<&str> for ToolPattern {
181    fn from(s: &str) -> Self {
182        Self {
183            repr: s.trim().to_string(),
184        }
185    }
186}
187
188impl From<String> for ToolPattern {
189    fn from(s: String) -> Self {
190        let trimmed = s.trim();
191        if trimmed.len() == s.len() {
192            Self { repr: s }
193        } else {
194            Self {
195                repr: trimmed.to_string(),
196            }
197        }
198    }
199}
200
201impl From<&String> for ToolPattern {
202    fn from(s: &String) -> Self {
203        Self::from(s.as_str())
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn tool_strips_whitespace() {
213        assert_eq!(ToolPattern::tool("  Bash  ").as_str(), "Bash");
214    }
215
216    #[test]
217    fn tool_with_args_renders_parens() {
218        let p = ToolPattern::tool_with_args("Bash", "git log:*");
219        assert_eq!(p.as_str(), "Bash(git log:*)");
220    }
221
222    #[test]
223    fn all_wildcards_args() {
224        assert_eq!(ToolPattern::all("Write").as_str(), "Write(*)");
225    }
226
227    #[test]
228    fn mcp_patterns() {
229        assert_eq!(ToolPattern::mcp("srv", "do_it").as_str(), "mcp__srv__do_it");
230        assert_eq!(ToolPattern::mcp("srv", "*").as_str(), "mcp__srv__*");
231    }
232
233    #[test]
234    fn parse_accepts_bare_name() {
235        assert_eq!(ToolPattern::parse("Bash").unwrap().as_str(), "Bash");
236    }
237
238    #[test]
239    fn parse_accepts_name_with_args() {
240        assert_eq!(
241            ToolPattern::parse("Bash(git log:*)").unwrap().as_str(),
242            "Bash(git log:*)"
243        );
244    }
245
246    #[test]
247    fn parse_accepts_mcp() {
248        assert_eq!(
249            ToolPattern::parse("mcp__srv__*").unwrap().as_str(),
250            "mcp__srv__*"
251        );
252    }
253
254    #[test]
255    fn parse_trims_whitespace() {
256        assert_eq!(ToolPattern::parse("  Read  ").unwrap().as_str(), "Read");
257    }
258
259    #[test]
260    fn parse_rejects_empty() {
261        assert_eq!(ToolPattern::parse("").unwrap_err(), PatternError::Empty);
262        assert_eq!(ToolPattern::parse("   ").unwrap_err(), PatternError::Empty);
263    }
264
265    #[test]
266    fn parse_rejects_unbalanced_parens() {
267        assert!(matches!(
268            ToolPattern::parse("Bash(git log"),
269            Err(PatternError::UnbalancedParens(_))
270        ));
271        assert!(matches!(
272            ToolPattern::parse("Bashgit log)"),
273            Err(PatternError::UnbalancedParens(_))
274        ));
275        assert!(matches!(
276            ToolPattern::parse("Bash((nested))"),
277            Err(PatternError::UnbalancedParens(_))
278        ));
279    }
280
281    #[test]
282    fn parse_rejects_missing_name() {
283        assert_eq!(
284            ToolPattern::parse("(args)").unwrap_err(),
285            PatternError::MissingName
286        );
287    }
288
289    #[test]
290    fn parse_rejects_comma() {
291        assert!(matches!(
292            ToolPattern::parse("Bash,Read"),
293            Err(PatternError::IllegalChar(_))
294        ));
295    }
296
297    #[test]
298    fn parse_rejects_control_chars() {
299        // Must be mid-string: leading/trailing whitespace is trimmed.
300        assert!(matches!(
301            ToolPattern::parse("Ba\nsh"),
302            Err(PatternError::IllegalChar(_))
303        ));
304    }
305
306    #[test]
307    fn from_str_is_loose() {
308        // Skips validation so back-compat callers keep working.
309        let p: ToolPattern = "anything goes".into();
310        assert_eq!(p.as_str(), "anything goes");
311    }
312
313    #[test]
314    fn display_matches_as_str() {
315        let p = ToolPattern::tool_with_args("Bash", "ls");
316        assert_eq!(format!("{p}"), p.as_str());
317    }
318}