#[cfg(not(target_arch = "wasm32"))]
use crate::error::DispatchError;
use async_trait::async_trait;
use meerkat_core::AgentToolDispatcher;
use meerkat_core::error::ToolError;
use meerkat_core::ops::ToolAccessPolicy;
use meerkat_core::types::{ToolCallView, ToolDef, ToolResult};
use std::collections::HashSet;
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use crate::registry::ToolRegistry;
#[cfg(not(target_arch = "wasm32"))]
use meerkat_core::error::ToolValidationError;
#[cfg(not(target_arch = "wasm32"))]
use serde_json::Value;
#[cfg(not(target_arch = "wasm32"))]
use std::time::Duration;
#[derive(Debug, Default, Clone, Copy)]
pub struct EmptyToolDispatcher;
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentToolDispatcher for EmptyToolDispatcher {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
Arc::from([])
}
async fn dispatch(&self, call: ToolCallView<'_>) -> Result<ToolResult, ToolError> {
Err(ToolError::NotFound {
name: call.name.to_string(),
})
}
}
#[cfg(not(target_arch = "wasm32"))]
pub struct ToolDispatcher {
registry: ToolRegistry,
router: Arc<dyn AgentToolDispatcher>,
default_timeout: Duration,
}
#[cfg(not(target_arch = "wasm32"))]
impl ToolDispatcher {
pub fn new(registry: ToolRegistry, router: Arc<dyn AgentToolDispatcher>) -> Self {
Self {
registry,
router,
default_timeout: Duration::from_secs(30),
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.default_timeout = timeout;
self
}
pub async fn dispatch_call(&self, call: ToolCallView<'_>) -> Result<ToolResult, DispatchError> {
let args: Value = serde_json::from_str(call.args.get())
.map_err(|e| ToolValidationError::invalid_arguments(call.name, e.to_string()))?;
self.registry.validate(call.name, &args)?;
let result = tokio::time::timeout(self.default_timeout, self.router.dispatch(call))
.await
.map_err(|_| DispatchError::Timeout {
timeout_ms: self.default_timeout.as_millis() as u64,
})??;
Ok(result)
}
}
#[cfg(not(target_arch = "wasm32"))]
#[async_trait]
impl AgentToolDispatcher for ToolDispatcher {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
Arc::from(self.registry.list())
}
async fn dispatch(&self, call: ToolCallView<'_>) -> Result<ToolResult, ToolError> {
let args: Value =
serde_json::from_str(call.args.get()).map_err(|e| ToolError::InvalidArguments {
name: call.name.to_string(),
reason: e.to_string(),
})?;
self.registry
.validate(call.name, &args)
.map_err(|e| match e {
ToolValidationError::NotFound { name } => ToolError::NotFound { name },
ToolValidationError::InvalidArguments { name, reason } => {
ToolError::InvalidArguments { name, reason }
}
})?;
tokio::time::timeout(self.default_timeout, self.router.dispatch(call))
.await
.map_err(|_| ToolError::timeout(call.name, self.default_timeout.as_millis() as u64))?
}
}
pub struct FilteredDispatcher {
inner: Arc<dyn AgentToolDispatcher>,
allowed_names: HashSet<String>,
}
impl FilteredDispatcher {
pub fn new(inner: Arc<dyn AgentToolDispatcher>, policy: &ToolAccessPolicy) -> Self {
let all_names: HashSet<String> = inner.tools().iter().map(|t| t.name.clone()).collect();
let allowed_names = match policy {
ToolAccessPolicy::Inherit => all_names,
ToolAccessPolicy::AllowList(allow) => {
let allow_set: HashSet<&str> =
allow.iter().map(std::string::String::as_str).collect();
all_names
.into_iter()
.filter(|n| allow_set.contains(n.as_str()))
.collect()
}
ToolAccessPolicy::DenyList(deny) => {
let deny_set: HashSet<&str> =
deny.iter().map(std::string::String::as_str).collect();
all_names
.into_iter()
.filter(|n| !deny_set.contains(n.as_str()))
.collect()
}
};
Self {
inner,
allowed_names,
}
}
pub fn is_allowed(&self, name: &str) -> bool {
self.allowed_names.contains(name)
}
pub fn allowed_names(&self) -> &HashSet<String> {
&self.allowed_names
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentToolDispatcher for FilteredDispatcher {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
self.inner
.tools()
.iter()
.filter(|t| self.allowed_names.contains(&t.name))
.cloned()
.collect::<Vec<_>>()
.into()
}
async fn dispatch(&self, call: ToolCallView<'_>) -> Result<ToolResult, ToolError> {
if !self.allowed_names.contains(call.name) {
return Err(ToolError::NotFound {
name: call.name.to_string(),
});
}
self.inner.dispatch(call).await
}
}
#[cfg(test)]
#[allow(clippy::panic)]
mod tests {
use super::*;
use serde_json::json;
struct MockDispatcher {
tool_names: Vec<&'static str>,
}
impl MockDispatcher {
fn new(names: Vec<&'static str>) -> Self {
Self { tool_names: names }
}
}
fn make_call<'a>(name: &'a str, args_raw: &'a serde_json::value::RawValue) -> ToolCallView<'a> {
ToolCallView {
id: "test-1",
name,
args: args_raw,
}
}
#[async_trait]
impl AgentToolDispatcher for MockDispatcher {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
self.tool_names
.iter()
.map(|name| {
Arc::new(ToolDef {
name: (*name).to_string(),
description: format!("{name} tool"),
input_schema: json!({"type": "object"}),
})
})
.collect::<Vec<_>>()
.into()
}
async fn dispatch(&self, call: ToolCallView<'_>) -> Result<ToolResult, ToolError> {
if self.tool_names.contains(&call.name) {
Ok(ToolResult {
tool_use_id: call.id.to_string(),
content: json!({"called": call.name}).to_string(),
is_error: false,
})
} else {
Err(ToolError::NotFound {
name: call.name.to_string(),
})
}
}
}
#[test]
fn test_filtered_dispatcher_inherit_passes_all_tools() {
let inner = Arc::new(MockDispatcher::new(vec!["shell", "task_list", "wait"]));
let filtered = FilteredDispatcher::new(inner, &ToolAccessPolicy::Inherit);
let tool_names: Vec<_> = filtered.tools().iter().map(|t| t.name.clone()).collect();
assert_eq!(tool_names.len(), 3);
assert!(filtered.is_allowed("shell"));
assert!(filtered.is_allowed("task_list"));
assert!(filtered.is_allowed("wait"));
}
#[test]
fn test_filtered_dispatcher_allow_list_only_includes_specified_tools() {
let inner = Arc::new(MockDispatcher::new(vec!["shell", "task_list", "wait"]));
let policy = ToolAccessPolicy::AllowList(vec!["task_list".to_string()]);
let filtered = FilteredDispatcher::new(inner, &policy);
let tool_names: Vec<_> = filtered.tools().iter().map(|t| t.name.clone()).collect();
assert_eq!(tool_names.len(), 1);
assert_eq!(tool_names[0], "task_list");
assert!(!filtered.is_allowed("shell"));
assert!(filtered.is_allowed("task_list"));
assert!(!filtered.is_allowed("wait"));
}
#[test]
fn test_filtered_dispatcher_deny_list_excludes_specified_tools() {
let inner = Arc::new(MockDispatcher::new(vec!["shell", "task_list", "wait"]));
let policy = ToolAccessPolicy::DenyList(vec!["shell".to_string()]);
let filtered = FilteredDispatcher::new(inner, &policy);
let tool_names: Vec<_> = filtered.tools().iter().map(|t| t.name.clone()).collect();
assert_eq!(tool_names.len(), 2);
assert!(!filtered.is_allowed("shell"));
assert!(filtered.is_allowed("task_list"));
assert!(filtered.is_allowed("wait"));
}
#[tokio::test]
async fn test_filtered_dispatcher_dispatch_blocked_tool_returns_not_found() {
let inner = Arc::new(MockDispatcher::new(vec!["shell", "task_list"]));
let policy = ToolAccessPolicy::DenyList(vec!["shell".to_string()]);
let filtered = FilteredDispatcher::new(inner, &policy);
let args_raw = serde_json::value::RawValue::from_string(json!({}).to_string()).unwrap();
let result = filtered.dispatch(make_call("task_list", &args_raw)).await;
assert!(result.is_ok());
let result = filtered.dispatch(make_call("shell", &args_raw)).await;
match result {
Err(ToolError::NotFound { name }) => assert_eq!(name, "shell"),
other => panic!("Expected NotFound error, got: {other:?}"),
}
}
#[tokio::test]
async fn test_regression_tool_access_policy_must_be_enforced() {
let inner = Arc::new(MockDispatcher::new(vec![
"shell",
"agent_spawn",
"task_list",
"wait",
]));
let policy =
ToolAccessPolicy::DenyList(vec!["shell".to_string(), "agent_spawn".to_string()]);
let filtered = FilteredDispatcher::new(inner, &policy);
let visible_tools: Vec<_> = filtered.tools().iter().map(|t| t.name.clone()).collect();
assert_eq!(visible_tools.len(), 2);
assert!(visible_tools.contains(&"task_list".to_string()));
assert!(visible_tools.contains(&"wait".to_string()));
assert!(
!visible_tools.contains(&"shell".to_string()),
"shell should not be visible in tools list"
);
assert!(
!visible_tools.contains(&"agent_spawn".to_string()),
"agent_spawn should not be visible in tools list"
);
let args_raw = serde_json::value::RawValue::from_string(json!({}).to_string()).unwrap();
let shell_result = filtered.dispatch(make_call("shell", &args_raw)).await;
assert!(
matches!(shell_result, Err(ToolError::NotFound { .. })),
"shell dispatch should fail with NotFound"
);
}
}