use std::collections::HashSet;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ToolRestrictionError {
ToolNotAllowed {
tool: String,
allowed: Vec<String>,
},
InvalidSpec(String),
}
impl fmt::Display for ToolRestrictionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ToolRestrictionError::ToolNotAllowed { tool, allowed } => {
write!(
f,
"Tool '{}' is not allowed. Allowed tools: {:?}",
tool, allowed
)
}
ToolRestrictionError::InvalidSpec(spec) => {
write!(f, "Invalid tool specification: '{}'", spec)
}
}
}
}
impl std::error::Error for ToolRestrictionError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ToolRestriction {
allowed_tools: Option<HashSet<String>>,
}
impl ToolRestriction {
pub fn new(allowed_tools: Option<Vec<String>>) -> Self {
Self {
allowed_tools: allowed_tools.map(|tools| tools.into_iter().collect()),
}
}
pub fn unrestricted() -> Self {
Self { allowed_tools: None }
}
pub fn is_tool_allowed(&self, tool_name: &str) -> bool {
match &self.allowed_tools {
None => true, Some(allowed) => {
if allowed.contains("*") {
return true;
}
let (base_tool, params) = Self::parse_tool_spec(tool_name);
if allowed.contains(tool_name) {
return true;
}
if allowed.contains(&base_tool) {
return true;
}
for allowed_spec in allowed {
if let Some((allowed_base, allowed_pattern)) =
Self::parse_tool_spec_with_pattern(allowed_spec)
{
if allowed_base == base_tool {
if let Some(params) = ¶ms {
if Self::pattern_matches(&allowed_pattern, params) {
return true;
}
}
}
}
}
false
}
}
}
pub fn validate_tool(&self, tool_name: &str) -> Result<(), ToolRestrictionError> {
if self.is_tool_allowed(tool_name) {
Ok(())
} else {
Err(ToolRestrictionError::ToolNotAllowed {
tool: tool_name.to_string(),
allowed: self
.allowed_tools
.as_ref()
.map(|s| s.iter().cloned().collect())
.unwrap_or_default(),
})
}
}
pub fn get_allowed_tools(&self) -> Option<Vec<String>> {
self.allowed_tools
.as_ref()
.map(|s| s.iter().cloned().collect())
}
pub fn is_unrestricted(&self) -> bool {
self.allowed_tools.is_none()
}
fn parse_tool_spec(tool_spec: &str) -> (String, Option<String>) {
if let Some(params) = tool_spec.strip_suffix(')') {
if let Some((base, args)) = params.split_once('(') {
return (base.to_string(), Some(args.to_string()));
}
}
(tool_spec.to_string(), None)
}
fn parse_tool_spec_with_pattern(spec: &str) -> Option<(String, String)> {
if let Some(params) = spec.strip_suffix(')') {
if let Some((base, pattern)) = params.split_once('(') {
if pattern.contains('*') {
return Some((base.to_string(), pattern.to_string()));
}
}
}
None
}
fn pattern_matches(pattern: &str, params: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix('*') {
return params.starts_with(prefix);
}
pattern == params
}
pub fn add_tool(&mut self, tool: String) {
self.allowed_tools
.get_or_insert_with(HashSet::new)
.insert(tool);
}
pub fn remove_tool(&mut self, tool: &str) {
if let Some(allowed) = &mut self.allowed_tools {
allowed.remove(tool);
}
}
}
impl Default for ToolRestriction {
fn default() -> Self {
Self::unrestricted()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unrestricted() {
let restriction = ToolRestriction::unrestricted();
assert!(restriction.is_tool_allowed("AnyTool"));
assert!(restriction.is_tool_allowed("AnotherTool"));
assert!(restriction.is_unrestricted());
}
#[test]
fn test_specific_tools() {
let restriction = ToolRestriction::new(Some(vec![
"Read".to_string(),
"Grep".to_string(),
]));
assert!(restriction.is_tool_allowed("Read"));
assert!(restriction.is_tool_allowed("Grep"));
assert!(!restriction.is_tool_allowed("Write"));
assert!(!restriction.is_tool_allowed("Bash"));
}
#[test]
fn test_tool_with_parameters() {
let restriction = ToolRestriction::new(Some(vec![
"Bash(python:*)".to_string(),
"Read".to_string(),
]));
assert!(restriction.is_tool_allowed("Bash(python:script.py)"));
assert!(restriction.is_tool_allowed("Bash(python:-m pytest)"));
assert!(!restriction.is_tool_allowed("Bash(node:script.js)"));
assert!(!restriction.is_tool_allowed("Bash(ls -la)"));
assert!(restriction.is_tool_allowed("Read"));
}
#[test]
fn test_wildcard() {
let restriction = ToolRestriction::new(Some(vec!["*".to_string()]));
assert!(restriction.is_tool_allowed("AnyTool"));
assert!(restriction.is_tool_allowed("Bash(anything)"));
assert!(restriction.is_tool_allowed("Read"));
}
#[test]
fn test_validate_tool() {
let restriction = ToolRestriction::new(Some(vec!["Read".to_string()]));
assert!(restriction.validate_tool("Read").is_ok());
let result = restriction.validate_tool("Write");
assert!(result.is_err());
if let Err(ToolRestrictionError::ToolNotAllowed { tool, .. }) = result {
assert_eq!(tool, "Write");
} else {
panic!("Expected ToolNotAllowed error");
}
}
#[test]
fn test_parse_tool_spec() {
assert_eq!(
ToolRestriction::parse_tool_spec("Bash(python:script.py)"),
("Bash".to_string(), Some("python:script.py".to_string()))
);
assert_eq!(
ToolRestriction::parse_tool_spec("Read"),
("Read".to_string(), None)
);
}
#[test]
fn test_pattern_matches() {
assert!(ToolRestriction::pattern_matches("python:*", "python:script.py"));
assert!(ToolRestriction::pattern_matches("python:*", "python:-m pytest"));
assert!(!ToolRestriction::pattern_matches("python:*", "node:script.js"));
assert!(ToolRestriction::pattern_matches("*", "anything"));
assert!(ToolRestriction::pattern_matches("*", ""));
}
#[test]
fn test_add_tool() {
let mut restriction = ToolRestriction::new(Some(vec!["Read".to_string()]));
assert!(!restriction.is_tool_allowed("Grep"));
restriction.add_tool("Grep".to_string());
assert!(restriction.is_tool_allowed("Grep"));
}
#[test]
fn test_remove_tool() {
let mut restriction = ToolRestriction::new(Some(vec![
"Read".to_string(),
"Grep".to_string(),
]));
assert!(restriction.is_tool_allowed("Grep"));
restriction.remove_tool("Grep");
assert!(!restriction.is_tool_allowed("Grep"));
assert!(restriction.is_tool_allowed("Read"));
}
#[test]
fn test_get_allowed_tools() {
let unrestricted = ToolRestriction::unrestricted();
assert!(unrestricted.get_allowed_tools().is_none());
let restricted = ToolRestriction::new(Some(vec![
"Read".to_string(),
"Grep".to_string(),
]));
let allowed = restricted.get_allowed_tools();
assert!(allowed.is_some());
assert_eq!(allowed.unwrap().len(), 2);
}
#[test]
fn test_default() {
let restriction = ToolRestriction::default();
assert!(restriction.is_unrestricted());
assert!(restriction.is_tool_allowed("AnyTool"));
}
#[test]
fn test_empty_allowed_list() {
let restriction = ToolRestriction::new(Some(vec![]));
assert!(!restriction.is_tool_allowed("Read"));
assert!(!restriction.is_tool_allowed("AnyTool"));
}
}