axocoatl_tools/
executor.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::builtin::BuiltinTool;
7use crate::error::ToolError;
8
9#[derive(Clone)]
11pub enum ToolBackend {
12 Builtin(Arc<dyn BuiltinTool>),
14 Mcp { server_name: String },
16 Wasm { module_name: String },
18}
19
20pub struct ToolExecutor {
22 tools: HashMap<String, ToolBackend>,
23 mcp_registry: Option<Arc<tokio::sync::RwLock<axocoatl_mcp::McpToolRegistry>>>,
24}
25
26impl ToolExecutor {
27 pub fn new() -> Self {
28 Self {
29 tools: HashMap::new(),
30 mcp_registry: None,
31 }
32 }
33
34 pub fn with_mcp_registry(
36 mut self,
37 registry: Arc<tokio::sync::RwLock<axocoatl_mcp::McpToolRegistry>>,
38 ) -> Self {
39 self.mcp_registry = Some(registry);
40 self
41 }
42
43 pub fn register_builtin(&mut self, name: impl Into<String>, tool: Arc<dyn BuiltinTool>) {
45 self.tools.insert(name.into(), ToolBackend::Builtin(tool));
46 }
47
48 pub fn register_mcp(&mut self, name: impl Into<String>, server_name: impl Into<String>) {
50 self.tools.insert(
51 name.into(),
52 ToolBackend::Mcp {
53 server_name: server_name.into(),
54 },
55 );
56 }
57
58 pub fn register_wasm(&mut self, name: impl Into<String>, module_name: impl Into<String>) {
60 self.tools.insert(
61 name.into(),
62 ToolBackend::Wasm {
63 module_name: module_name.into(),
64 },
65 );
66 }
67
68 pub async fn execute(
70 &self,
71 tool_name: &str,
72 arguments: serde_json::Value,
73 ) -> Result<serde_json::Value, ToolError> {
74 let backend = self
75 .tools
76 .get(tool_name)
77 .ok_or_else(|| ToolError::NotFound(tool_name.to_string()))?;
78
79 match backend {
80 ToolBackend::Builtin(tool) => tool.execute(arguments).await,
81 ToolBackend::Mcp { server_name } => {
82 Err(ToolError::ExecutionFailed {
86 tool: tool_name.to_string(),
87 reason: format!(
88 "MCP tool '{}' on server '{}': persistent connections not yet implemented. \
89 Tools are discovered but execution requires keeping the MCP client alive.",
90 tool_name, server_name
91 ),
92 })
93 }
94 ToolBackend::Wasm { module_name } => {
95 Err(ToolError::ExecutionFailed {
97 tool: tool_name.to_string(),
98 reason: format!("WASM execution of '{module_name}' not yet wired"),
99 })
100 }
101 }
102 }
103
104 pub fn tool_names(&self) -> Vec<String> {
106 self.tools.keys().cloned().collect()
107 }
108
109 pub fn get_concurrency_policy(
111 &self,
112 tool_name: &str,
113 ) -> Option<axocoatl_llm::ConcurrencyPolicy> {
114 match self.tools.get(tool_name) {
115 Some(ToolBackend::Builtin(_)) => Some(axocoatl_llm::ConcurrencyPolicy::Safe),
116 Some(ToolBackend::Mcp { .. }) => Some(axocoatl_llm::ConcurrencyPolicy::Safe),
117 Some(ToolBackend::Wasm { .. }) => Some(axocoatl_llm::ConcurrencyPolicy::Safe),
118 None => None,
119 }
120 }
121
122 pub fn as_llm_tools(&self) -> Vec<axocoatl_llm::ToolDefinition> {
124 self.tools
125 .iter()
126 .filter_map(|(name, backend)| match backend {
127 ToolBackend::Builtin(tool) => Some(axocoatl_llm::ToolDefinition {
128 name: name.clone(),
129 description: tool.description().to_string(),
130 parameters: tool.parameters_schema(),
131 concurrency: axocoatl_llm::ConcurrencyPolicy::Safe,
132 }),
133 _ => None, })
135 .collect()
136 }
137}
138
139impl Default for ToolExecutor {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl ToolExecutor {
148 pub async fn execute_concurrent(
149 self: &Arc<Self>,
150 tool_calls: &[axocoatl_llm::ToolCall],
151 policy_lookup: impl Fn(&str) -> axocoatl_llm::ConcurrencyPolicy,
152 ) -> Vec<crate::concurrent::ToolResult> {
153 crate::concurrent::ConcurrentToolDispatcher::dispatch(self, tool_calls, policy_lookup).await
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::builtin::*;
161
162 #[tokio::test]
163 async fn register_and_execute_builtin() {
164 let mut executor = ToolExecutor::new();
165 executor.register_builtin("echo", Arc::new(EchoTool));
166
167 let result = executor
168 .execute("echo", serde_json::json!({"text": "hello"}))
169 .await
170 .unwrap();
171
172 assert_eq!(result["text"], "hello");
173 }
174
175 #[tokio::test]
176 async fn unknown_tool_returns_error() {
177 let executor = ToolExecutor::new();
178 let result = executor.execute("nonexistent", serde_json::json!({})).await;
179 assert!(matches!(result, Err(ToolError::NotFound(_))));
180 }
181
182 #[test]
183 fn as_llm_tools_includes_builtins() {
184 let mut executor = ToolExecutor::new();
185 executor.register_builtin("echo", Arc::new(EchoTool));
186 executor.register_builtin("json_keys", Arc::new(JsonKeysTool));
187
188 let tools = executor.as_llm_tools();
189 assert_eq!(tools.len(), 2);
190 }
191}