use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::session::ToolRegistry;
use crate::tool::{Tool, ToolSchema};
pub struct StaticToolRegistry {
schemas: Vec<ToolSchema>,
by_name: HashMap<String, Arc<dyn Tool>>,
}
impl StaticToolRegistry {
pub fn builder() -> StaticToolRegistryBuilder {
StaticToolRegistryBuilder::default()
}
pub fn empty() -> Self {
Self {
schemas: Vec::new(),
by_name: HashMap::new(),
}
}
}
impl ToolRegistry for StaticToolRegistry {
fn schemas(&self) -> Vec<ToolSchema> {
self.schemas.clone()
}
fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.by_name.get(name).cloned()
}
}
#[derive(Default)]
pub struct StaticToolRegistryBuilder {
schemas: Vec<ToolSchema>,
by_name: HashMap<String, Arc<dyn Tool>>,
}
impl StaticToolRegistryBuilder {
pub fn insert(mut self, tool: Arc<dyn Tool>) -> Self {
let schema = tool.schema().clone();
if let Some(pos) = self.schemas.iter().position(|s| s.name == schema.name) {
if let Some(slot) = self.schemas.get_mut(pos) {
*slot = schema.clone();
}
} else {
self.schemas.push(schema.clone());
}
self.by_name.insert(schema.name, tool);
self
}
pub fn build(self) -> StaticToolRegistry {
StaticToolRegistry {
schemas: self.schemas,
by_name: self.by_name,
}
}
}
pub struct CompositeRegistry {
session: Arc<dyn ToolRegistry>,
process: Arc<dyn ToolRegistry>,
}
impl CompositeRegistry {
pub fn new(session: Arc<dyn ToolRegistry>, process: Arc<dyn ToolRegistry>) -> Self {
Self { session, process }
}
}
impl ToolRegistry for CompositeRegistry {
fn schemas(&self) -> Vec<ToolSchema> {
let mut session_schemas = self.session.schemas();
let mut process_schemas = self.process.schemas();
let session_names: HashSet<&str> =
session_schemas.iter().map(|s| s.name.as_str()).collect();
process_schemas.retain(|s| !session_names.contains(s.name.as_str()));
session_schemas.append(&mut process_schemas);
session_schemas
}
fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.session.get(name).or_else(|| self.process.get(name))
}
}
#[derive(Debug)]
pub struct AllowlistMatch {
pub tools: Vec<String>,
pub spawn_agent: bool,
}
pub fn match_tool_allowlist(
base: &Arc<dyn ToolRegistry>,
allow: &[String],
) -> Result<AllowlistMatch, String> {
let schemas = base.schemas();
let pool_names: Vec<&str> = schemas
.iter()
.map(|s| s.name.as_str())
.filter(|n| *n != crate::tool::SPAWN_AGENT_TOOL_NAME)
.collect();
let mut tools: Vec<String> = Vec::new();
let mut seen: HashSet<String> = HashSet::new();
let mut spawn_agent = false;
for pattern in allow {
let matcher = globset::Glob::new(pattern)
.map_err(|e| format!("invalid tool pattern `{pattern}`: {e}"))?
.compile_matcher();
let mut hit = false;
for name in &pool_names {
if matcher.is_match(name) {
hit = true;
if seen.insert((*name).to_string()) {
tools.push((*name).to_string());
}
}
}
if matcher.is_match(crate::tool::SPAWN_AGENT_TOOL_NAME) {
hit = true;
spawn_agent = true;
}
if !hit {
return Err(pattern.clone());
}
}
Ok(AllowlistMatch { tools, spawn_agent })
}
pub fn filter_registry_by_allowlist(
base: &Arc<dyn ToolRegistry>,
allow: &[String],
) -> Result<Arc<dyn ToolRegistry>, String> {
let matched = match_tool_allowlist(base, allow)?;
let mut builder = StaticToolRegistry::builder();
for name in &matched.tools {
if let Some(tool) = base.get(name) {
builder = builder.insert(tool);
}
}
Ok(Arc::new(builder.build()))
}
#[cfg(test)]
mod tests;