Skip to main content

axocoatl_tools/
executor.rs

1//! Unified tool executor — routes calls to built-in tools, MCP servers, or WASM sandboxes.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::builtin::BuiltinTool;
7use crate::error::ToolError;
8
9/// A registered tool with its execution backend.
10#[derive(Clone)]
11pub enum ToolBackend {
12    /// Built-in Rust tool (runs in-process).
13    Builtin(Arc<dyn BuiltinTool>),
14    /// MCP tool on a named server.
15    Mcp { server_name: String },
16    /// WASM tool in sandbox.
17    Wasm { module_name: String },
18}
19
20/// Routes tool calls to the appropriate backend.
21pub struct ToolExecutor {
22    tools: HashMap<String, ToolBackend>,
23    mcp_registry: Option<Arc<tokio::sync::RwLock<axocoatl_mcp::McpToolRegistry>>>,
24}
25
26impl ToolExecutor {
27    pub fn new() -> Self {
28        Self {
29            tools: HashMap::new(),
30            mcp_registry: None,
31        }
32    }
33
34    /// Set the MCP tool registry for routing MCP tool calls.
35    pub fn with_mcp_registry(
36        mut self,
37        registry: Arc<tokio::sync::RwLock<axocoatl_mcp::McpToolRegistry>>,
38    ) -> Self {
39        self.mcp_registry = Some(registry);
40        self
41    }
42
43    /// Register a built-in tool.
44    pub fn register_builtin(&mut self, name: impl Into<String>, tool: Arc<dyn BuiltinTool>) {
45        self.tools.insert(name.into(), ToolBackend::Builtin(tool));
46    }
47
48    /// Register an MCP tool (from a connected server).
49    pub fn register_mcp(&mut self, name: impl Into<String>, server_name: impl Into<String>) {
50        self.tools.insert(
51            name.into(),
52            ToolBackend::Mcp {
53                server_name: server_name.into(),
54            },
55        );
56    }
57
58    /// Register a WASM tool.
59    pub fn register_wasm(&mut self, name: impl Into<String>, module_name: impl Into<String>) {
60        self.tools.insert(
61            name.into(),
62            ToolBackend::Wasm {
63                module_name: module_name.into(),
64            },
65        );
66    }
67
68    /// Execute a tool by name.
69    pub async fn execute(
70        &self,
71        tool_name: &str,
72        arguments: serde_json::Value,
73    ) -> Result<serde_json::Value, ToolError> {
74        let backend = self
75            .tools
76            .get(tool_name)
77            .ok_or_else(|| ToolError::NotFound(tool_name.to_string()))?;
78
79        match backend {
80            ToolBackend::Builtin(tool) => tool.execute(arguments).await,
81            ToolBackend::Mcp { server_name } => {
82                // MCP tool execution requires a persistent connection (not yet implemented).
83                // The registry currently disconnects after discovery.
84                // For now, return a descriptive error.
85                Err(ToolError::ExecutionFailed {
86                    tool: tool_name.to_string(),
87                    reason: format!(
88                        "MCP tool '{}' on server '{}': persistent connections not yet implemented. \
89                         Tools are discovered but execution requires keeping the MCP client alive.",
90                        tool_name, server_name
91                    ),
92                })
93            }
94            ToolBackend::Wasm { module_name } => {
95                // TODO: Route to WasmtimeSandbox for execution
96                Err(ToolError::ExecutionFailed {
97                    tool: tool_name.to_string(),
98                    reason: format!("WASM execution of '{module_name}' not yet wired"),
99                })
100            }
101        }
102    }
103
104    /// List all registered tool names.
105    pub fn tool_names(&self) -> Vec<String> {
106        self.tools.keys().cloned().collect()
107    }
108
109    /// Get the concurrency policy for a tool by name.
110    pub fn get_concurrency_policy(
111        &self,
112        tool_name: &str,
113    ) -> Option<axocoatl_llm::ConcurrencyPolicy> {
114        match self.tools.get(tool_name) {
115            Some(ToolBackend::Builtin(_)) => Some(axocoatl_llm::ConcurrencyPolicy::Safe),
116            Some(ToolBackend::Mcp { .. }) => Some(axocoatl_llm::ConcurrencyPolicy::Safe),
117            Some(ToolBackend::Wasm { .. }) => Some(axocoatl_llm::ConcurrencyPolicy::Safe),
118            None => None,
119        }
120    }
121
122    /// Convert registered tools to LLM-compatible tool definitions.
123    pub fn as_llm_tools(&self) -> Vec<axocoatl_llm::ToolDefinition> {
124        self.tools
125            .iter()
126            .filter_map(|(name, backend)| match backend {
127                ToolBackend::Builtin(tool) => Some(axocoatl_llm::ToolDefinition {
128                    name: name.clone(),
129                    description: tool.description().to_string(),
130                    parameters: tool.parameters_schema(),
131                    concurrency: axocoatl_llm::ConcurrencyPolicy::Safe,
132                }),
133                _ => None, // MCP/WASM tools get their schemas from their registries
134            })
135            .collect()
136    }
137}
138
139impl Default for ToolExecutor {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145/// Convenience: execute a batch of tool calls concurrently.
146/// This is a thin wrapper around ConcurrentToolDispatcher::dispatch.
147impl ToolExecutor {
148    pub async fn execute_concurrent(
149        self: &Arc<Self>,
150        tool_calls: &[axocoatl_llm::ToolCall],
151        policy_lookup: impl Fn(&str) -> axocoatl_llm::ConcurrencyPolicy,
152    ) -> Vec<crate::concurrent::ToolResult> {
153        crate::concurrent::ConcurrentToolDispatcher::dispatch(self, tool_calls, policy_lookup).await
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::builtin::*;
161
162    #[tokio::test]
163    async fn register_and_execute_builtin() {
164        let mut executor = ToolExecutor::new();
165        executor.register_builtin("echo", Arc::new(EchoTool));
166
167        let result = executor
168            .execute("echo", serde_json::json!({"text": "hello"}))
169            .await
170            .unwrap();
171
172        assert_eq!(result["text"], "hello");
173    }
174
175    #[tokio::test]
176    async fn unknown_tool_returns_error() {
177        let executor = ToolExecutor::new();
178        let result = executor.execute("nonexistent", serde_json::json!({})).await;
179        assert!(matches!(result, Err(ToolError::NotFound(_))));
180    }
181
182    #[test]
183    fn as_llm_tools_includes_builtins() {
184        let mut executor = ToolExecutor::new();
185        executor.register_builtin("echo", Arc::new(EchoTool));
186        executor.register_builtin("json_keys", Arc::new(JsonKeysTool));
187
188        let tools = executor.as_llm_tools();
189        assert_eq!(tools.len(), 2);
190    }
191}