use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use crate::serde_types::Value;
use crate::{ToolDefinition, ToolError, ToolErrorKind};
pub type DynamicHandler = Arc<dyn Fn(&Value) -> Result<Value, ToolError> + Send + Sync>;
pub trait DynamicToolProvider {
fn add_tool(&mut self, name: &str, definition: ToolDefinition, handler: DynamicHandler);
fn remove_tool(&mut self, name: &str) -> bool;
fn enable_for(&mut self, name: &str, tenant: &str);
fn disable_for(&mut self, name: &str, tenant: &str);
fn visible_tools(&self, tenant: Option<&str>) -> Vec<ToolDefinition>;
}
#[derive(Default, Clone)]
pub struct DynamicToolRegistry {
inner: Arc<RwLock<Inner>>,
}
#[derive(Default)]
struct Inner {
tools: HashMap<String, RegisteredTool>,
per_tenant: HashMap<String, HashSet<String>>,
}
struct RegisteredTool {
definition: ToolDefinition,
handler: DynamicHandler,
}
impl DynamicToolRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn list_global(&self) -> Vec<String> {
self.inner
.read()
.map(|guard| guard.tools.keys().cloned().collect())
.unwrap_or_default()
}
pub fn contains(&self, name: &str) -> bool {
self.inner
.read()
.map(|guard| guard.tools.contains_key(name))
.unwrap_or(false)
}
pub fn definition(&self, name: &str) -> Option<ToolDefinition> {
self.inner
.read()
.ok()
.and_then(|guard| guard.tools.get(name).map(|t| t.definition.clone()))
}
pub fn clear(&self) {
if let Ok(mut guard) = self.inner.write() {
guard.tools.clear();
guard.per_tenant.clear();
}
}
pub fn call_for_tenant(
&self,
name: &str,
tenant: Option<&str>,
args: &Value,
) -> Result<Value, ToolError> {
let guard = self
.inner
.read()
.map_err(|e| ToolError::internal_error(format!("registry poisoned: {}", e)))?;
let tool = guard
.tools
.get(name)
.ok_or_else(|| ToolError::not_found(format!("tool `{}` is not registered", name)))?;
if let Some(t) = tenant {
if let Some(set) = guard.per_tenant.get(t) {
if !set.contains(name) {
return Err(ToolError::not_found(format!(
"tool `{}` is not enabled for tenant `{}`",
name, t
)));
}
}
}
(tool.handler)(args)
}
}
impl DynamicToolProvider for DynamicToolRegistry {
fn add_tool(&mut self, name: &str, definition: ToolDefinition, handler: DynamicHandler) {
if let Ok(mut guard) = self.inner.write() {
guard.tools.insert(
name.to_string(),
RegisteredTool {
definition,
handler,
},
);
}
}
fn remove_tool(&mut self, name: &str) -> bool {
if let Ok(mut guard) = self.inner.write() {
for (tenant, set) in guard.per_tenant.iter_mut() {
set.remove(name);
let _ = tenant; }
guard.per_tenant.retain(|_tenant, set| !set.is_empty());
guard.tools.remove(name).is_some()
} else {
false
}
}
fn enable_for(&mut self, name: &str, tenant: &str) {
let Ok(mut guard) = self.inner.write() else {
return;
};
if !guard.tools.contains_key(name) {
return;
}
let entry = guard
.per_tenant
.entry(tenant.to_string())
.or_insert_with(HashSet::new);
entry.insert(name.to_string());
}
fn disable_for(&mut self, name: &str, tenant: &str) {
let Ok(mut guard) = self.inner.write() else {
return;
};
let entry = guard
.per_tenant
.entry(tenant.to_string())
.or_insert_with(HashSet::new);
entry.remove(name);
}
fn visible_tools(&self, tenant: Option<&str>) -> Vec<ToolDefinition> {
let Ok(guard) = self.inner.read() else {
return Vec::new();
};
match tenant {
None => guard.tools.values().map(|t| t.definition.clone()).collect(),
Some(t) => {
if let Some(set) = guard.per_tenant.get(t) {
guard
.tools
.values()
.filter(|tool| set.contains(&tool.definition.name))
.map(|t| t.definition.clone())
.collect()
} else {
guard.tools.values().map(|t| t.definition.clone()).collect()
}
}
}
}
}
impl crate::ToolProvider for DynamicToolRegistry {
fn tool_definitions() -> &'static [ToolDefinition] {
const EMPTY: &[ToolDefinition] = &[];
EMPTY
}
}
impl crate::ToolCaller for DynamicToolRegistry {
fn call_tool(&self, name: &str, args: &Value) -> Result<Value, ToolError> {
self.call_for_tenant(name, None, args)
}
}
impl crate::CapabilityManifestProvider for DynamicToolRegistry {}
pub const TENANT_DENIED_KIND_HINT: &str = "tenant-denied";
pub fn is_tenant_denied(err: &ToolError) -> bool {
err.kind == ToolErrorKind::NotFound && err.message.contains("is not enabled for tenant")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ToolCaller;
use crate::ToolProvider;
use serde_json::json;
fn add_handler(a: i64, b: i64) -> DynamicHandler {
Arc::new(move |_args| Ok(json!(a + b)))
}
#[test]
fn add_and_call() {
let mut reg = DynamicToolRegistry::new();
reg.add_tool(
"add",
ToolDefinition::new("add", "Add", r#"{"type":"object"}"#),
add_handler(1, 2),
);
let v = reg.call_tool("add", &json!({})).unwrap();
assert_eq!(v, json!(3));
assert!(reg.contains("add"));
assert_eq!(reg.list_global(), vec!["add".to_string()]);
}
#[test]
fn remove_tool_returns_true_then_false() {
let mut reg = DynamicToolRegistry::new();
reg.add_tool(
"x",
ToolDefinition::new("x", "x", "{}"),
Arc::new(|_| Ok(json!(null))),
);
assert!(reg.remove_tool("x"));
assert!(!reg.remove_tool("x"));
}
#[test]
fn enable_disable_for_tenant() {
let mut reg = DynamicToolRegistry::new();
reg.add_tool(
"t",
ToolDefinition::new("t", "t", "{}"),
Arc::new(|_| Ok(json!("ok"))),
);
assert!(reg.call_for_tenant("t", Some("a"), &json!({})).is_ok());
reg.disable_for("t", "a");
let err = reg.call_for_tenant("t", Some("a"), &json!({})).unwrap_err();
assert!(is_tenant_denied(&err));
assert!(reg.call_for_tenant("t", Some("b"), &json!({})).is_ok());
reg.enable_for("t", "a");
assert!(reg.call_for_tenant("t", Some("a"), &json!({})).is_ok());
}
#[test]
fn enable_for_unknown_tool_is_noop() {
let mut reg = DynamicToolRegistry::new();
reg.enable_for("nope", "a");
assert!(!reg.contains("nope"));
}
#[test]
fn remove_tool_clears_per_tenant_set() {
let mut reg = DynamicToolRegistry::new();
reg.add_tool(
"t",
ToolDefinition::new("t", "t", "{}"),
Arc::new(|_| Ok(json!(null))),
);
reg.enable_for("t", "a");
assert!(reg.remove_tool("t"));
reg.add_tool(
"t",
ToolDefinition::new("t", "t", "{}"),
Arc::new(|_| Ok(json!(null))),
);
assert!(reg.call_for_tenant("t", Some("a"), &json!({})).is_ok());
}
#[test]
fn visible_tools_filters_by_tenant() {
let mut reg = DynamicToolRegistry::new();
reg.add_tool(
"a",
ToolDefinition::new("a", "a", "{}"),
Arc::new(|_| Ok(json!(null))),
);
reg.add_tool(
"b",
ToolDefinition::new("b", "b", "{}"),
Arc::new(|_| Ok(json!(null))),
);
let all = reg.visible_tools(None);
assert_eq!(all.len(), 2);
reg.disable_for("a", "alice");
reg.disable_for("b", "alice");
let alice = reg.visible_tools(Some("alice"));
assert!(alice.is_empty());
reg.enable_for("b", "alice");
let alice = reg.visible_tools(Some("alice"));
assert_eq!(alice.len(), 1);
assert_eq!(alice[0].name, "b");
}
#[test]
fn tool_provider_static_slice_is_empty() {
assert!(DynamicToolRegistry::tool_definitions().is_empty());
}
#[test]
fn clear_resets_state() {
let mut reg = DynamicToolRegistry::new();
reg.add_tool(
"x",
ToolDefinition::new("x", "x", "{}"),
Arc::new(|_| Ok(json!(null))),
);
reg.enable_for("x", "a");
reg.clear();
assert!(reg.list_global().is_empty());
assert!(reg.visible_tools(Some("a")).is_empty());
}
#[test]
fn missing_tool_returns_not_found() {
let reg = DynamicToolRegistry::new();
let err = reg.call_tool("ghost", &json!({})).unwrap_err();
assert_eq!(err.kind, ToolErrorKind::NotFound);
}
#[test]
fn handler_error_propagates() {
let mut reg = DynamicToolRegistry::new();
reg.add_tool(
"boom",
ToolDefinition::new("boom", "boom", "{}"),
Arc::new(|_| Err(ToolError::internal_error("kaboom"))),
);
let err = reg.call_tool("boom", &json!({})).unwrap_err();
assert_eq!(err.kind, ToolErrorKind::InternalError);
assert!(err.message.contains("kaboom"));
}
#[test]
fn registry_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<DynamicToolRegistry>();
}
}