use std::collections::HashMap;
use std::sync::Arc;
use crate::builtin::BuiltinTool;
use crate::error::ToolError;
#[derive(Clone)]
pub enum ToolBackend {
Builtin(Arc<dyn BuiltinTool>),
Mcp { server_name: String },
Wasm { module_name: String },
}
pub struct ToolExecutor {
tools: HashMap<String, ToolBackend>,
mcp_registry: Option<Arc<axocoatl_mcp::McpToolRegistry>>,
}
impl ToolExecutor {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
mcp_registry: None,
}
}
pub fn with_mcp_registry(mut self, registry: Arc<axocoatl_mcp::McpToolRegistry>) -> Self {
self.mcp_registry = Some(registry);
self
}
pub fn register_builtin(&mut self, name: impl Into<String>, tool: Arc<dyn BuiltinTool>) {
self.tools.insert(name.into(), ToolBackend::Builtin(tool));
}
pub fn register_mcp(&mut self, name: impl Into<String>, server_name: impl Into<String>) {
self.tools.insert(
name.into(),
ToolBackend::Mcp {
server_name: server_name.into(),
},
);
}
pub fn register_wasm(&mut self, name: impl Into<String>, module_name: impl Into<String>) {
self.tools.insert(
name.into(),
ToolBackend::Wasm {
module_name: module_name.into(),
},
);
}
pub async fn execute(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<serde_json::Value, ToolError> {
let backend = self
.tools
.get(tool_name)
.ok_or_else(|| ToolError::NotFound(tool_name.to_string()))?;
match backend {
ToolBackend::Builtin(tool) => tool.execute(arguments).await,
ToolBackend::Mcp { server_name } => {
Err(ToolError::ExecutionFailed {
tool: tool_name.to_string(),
reason: format!(
"MCP tool '{}' on server '{}': persistent connections not yet implemented. \
Tools are discovered but execution requires keeping the MCP client alive.",
tool_name, server_name
),
})
}
ToolBackend::Wasm { module_name } => {
Err(ToolError::ExecutionFailed {
tool: tool_name.to_string(),
reason: format!("WASM execution of '{module_name}' not yet wired"),
})
}
}
}
pub fn tool_names(&self) -> Vec<String> {
self.tools.keys().cloned().collect()
}
pub fn get_concurrency_policy(&self, tool_name: &str) -> Option<axocoatl_llm::ConcurrencyPolicy> {
match self.tools.get(tool_name) {
Some(ToolBackend::Builtin(_)) => Some(axocoatl_llm::ConcurrencyPolicy::Safe),
Some(ToolBackend::Mcp { .. }) => Some(axocoatl_llm::ConcurrencyPolicy::Safe),
Some(ToolBackend::Wasm { .. }) => Some(axocoatl_llm::ConcurrencyPolicy::Safe),
None => None,
}
}
pub fn as_llm_tools(&self) -> Vec<axocoatl_llm::ToolDefinition> {
self.tools
.iter()
.filter_map(|(name, backend)| match backend {
ToolBackend::Builtin(tool) => Some(axocoatl_llm::ToolDefinition {
name: name.clone(),
description: tool.description().to_string(),
parameters: tool.parameters_schema(),
concurrency: axocoatl_llm::ConcurrencyPolicy::Safe,
}),
_ => None, })
.collect()
}
}
impl Default for ToolExecutor {
fn default() -> Self {
Self::new()
}
}
impl ToolExecutor {
pub async fn execute_concurrent(
self: &Arc<Self>,
tool_calls: &[axocoatl_llm::ToolCall],
policy_lookup: impl Fn(&str) -> axocoatl_llm::ConcurrencyPolicy,
) -> Vec<crate::concurrent::ToolResult> {
crate::concurrent::ConcurrentToolDispatcher::dispatch(self, tool_calls, policy_lookup)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builtin::*;
#[tokio::test]
async fn register_and_execute_builtin() {
let mut executor = ToolExecutor::new();
executor.register_builtin("echo", Arc::new(EchoTool));
let result = executor
.execute("echo", serde_json::json!({"text": "hello"}))
.await
.unwrap();
assert_eq!(result["text"], "hello");
}
#[tokio::test]
async fn unknown_tool_returns_error() {
let executor = ToolExecutor::new();
let result = executor.execute("nonexistent", serde_json::json!({})).await;
assert!(matches!(result, Err(ToolError::NotFound(_))));
}
#[test]
fn as_llm_tools_includes_builtins() {
let mut executor = ToolExecutor::new();
executor.register_builtin("echo", Arc::new(EchoTool));
executor.register_builtin("json_keys", Arc::new(JsonKeysTool));
let tools = executor.as_llm_tools();
assert_eq!(tools.len(), 2);
}
}