use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ToolPattern {
repr: String,
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum PatternError {
#[error("tool pattern must not be empty")]
Empty,
#[error("tool pattern is missing a name before '('")]
MissingName,
#[error("tool pattern has unbalanced parentheses: {0:?}")]
UnbalancedParens(String),
#[error("tool pattern contains an illegal character: {0:?}")]
IllegalChar(String),
}
impl ToolPattern {
pub fn tool(name: impl Into<String>) -> Self {
Self {
repr: name.into().trim().to_string(),
}
}
pub fn tool_with_args(name: impl Into<String>, args: impl Into<String>) -> Self {
Self {
repr: format!("{}({})", name.into().trim(), args.into()),
}
}
pub fn all(name: impl Into<String>) -> Self {
Self::tool_with_args(name, "*")
}
pub fn mcp(server: impl Into<String>, tool: impl Into<String>) -> Self {
Self {
repr: format!("mcp__{}__{}", server.into(), tool.into()),
}
}
pub fn parse(s: impl AsRef<str>) -> Result<Self, PatternError> {
let trimmed = s.as_ref().trim();
if trimmed.is_empty() {
return Err(PatternError::Empty);
}
for ch in trimmed.chars() {
if ch == ',' || ch.is_control() {
return Err(PatternError::IllegalChar(trimmed.to_string()));
}
}
if let Some(open) = trimmed.find('(') {
if !trimmed.ends_with(')') {
return Err(PatternError::UnbalancedParens(trimmed.to_string()));
}
if trimmed.matches('(').count() != 1 || trimmed.matches(')').count() != 1 {
return Err(PatternError::UnbalancedParens(trimmed.to_string()));
}
if open == 0 {
return Err(PatternError::MissingName);
}
} else if trimmed.contains(')') {
return Err(PatternError::UnbalancedParens(trimmed.to_string()));
}
Ok(Self {
repr: trimmed.to_string(),
})
}
pub fn as_str(&self) -> &str {
&self.repr
}
}
impl fmt::Display for ToolPattern {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.repr)
}
}
impl AsRef<str> for ToolPattern {
fn as_ref(&self) -> &str {
&self.repr
}
}
impl From<&str> for ToolPattern {
fn from(s: &str) -> Self {
Self {
repr: s.trim().to_string(),
}
}
}
impl From<String> for ToolPattern {
fn from(s: String) -> Self {
let trimmed = s.trim();
if trimmed.len() == s.len() {
Self { repr: s }
} else {
Self {
repr: trimmed.to_string(),
}
}
}
}
impl From<&String> for ToolPattern {
fn from(s: &String) -> Self {
Self::from(s.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tool_strips_whitespace() {
assert_eq!(ToolPattern::tool(" Bash ").as_str(), "Bash");
}
#[test]
fn tool_with_args_renders_parens() {
let p = ToolPattern::tool_with_args("Bash", "git log:*");
assert_eq!(p.as_str(), "Bash(git log:*)");
}
#[test]
fn all_wildcards_args() {
assert_eq!(ToolPattern::all("Write").as_str(), "Write(*)");
}
#[test]
fn mcp_patterns() {
assert_eq!(ToolPattern::mcp("srv", "do_it").as_str(), "mcp__srv__do_it");
assert_eq!(ToolPattern::mcp("srv", "*").as_str(), "mcp__srv__*");
}
#[test]
fn parse_accepts_bare_name() {
assert_eq!(ToolPattern::parse("Bash").unwrap().as_str(), "Bash");
}
#[test]
fn parse_accepts_name_with_args() {
assert_eq!(
ToolPattern::parse("Bash(git log:*)").unwrap().as_str(),
"Bash(git log:*)"
);
}
#[test]
fn parse_accepts_mcp() {
assert_eq!(
ToolPattern::parse("mcp__srv__*").unwrap().as_str(),
"mcp__srv__*"
);
}
#[test]
fn parse_trims_whitespace() {
assert_eq!(ToolPattern::parse(" Read ").unwrap().as_str(), "Read");
}
#[test]
fn parse_rejects_empty() {
assert_eq!(ToolPattern::parse("").unwrap_err(), PatternError::Empty);
assert_eq!(ToolPattern::parse(" ").unwrap_err(), PatternError::Empty);
}
#[test]
fn parse_rejects_unbalanced_parens() {
assert!(matches!(
ToolPattern::parse("Bash(git log"),
Err(PatternError::UnbalancedParens(_))
));
assert!(matches!(
ToolPattern::parse("Bashgit log)"),
Err(PatternError::UnbalancedParens(_))
));
assert!(matches!(
ToolPattern::parse("Bash((nested))"),
Err(PatternError::UnbalancedParens(_))
));
}
#[test]
fn parse_rejects_missing_name() {
assert_eq!(
ToolPattern::parse("(args)").unwrap_err(),
PatternError::MissingName
);
}
#[test]
fn parse_rejects_comma() {
assert!(matches!(
ToolPattern::parse("Bash,Read"),
Err(PatternError::IllegalChar(_))
));
}
#[test]
fn parse_rejects_control_chars() {
assert!(matches!(
ToolPattern::parse("Ba\nsh"),
Err(PatternError::IllegalChar(_))
));
}
#[test]
fn from_str_is_loose() {
let p: ToolPattern = "anything goes".into();
assert_eq!(p.as_str(), "anything goes");
}
#[test]
fn display_matches_as_str() {
let p = ToolPattern::tool_with_args("Bash", "ls");
assert_eq!(format!("{p}"), p.as_str());
}
}