1use aidale_core::error::AiError;
7use aidale_core::plugin::{Plugin, PluginPhase};
8use aidale_core::types::*;
9use async_trait::async_trait;
10use std::collections::HashMap;
11use std::sync::Arc;
12
13#[async_trait]
15pub trait ToolExecutor: Send + Sync {
16 async fn execute(
18 &self,
19 name: &str,
20 arguments: &serde_json::Value,
21 ) -> Result<serde_json::Value, AiError>;
22}
23
24type ToolExecutorFn = Arc<
27 dyn Fn(
28 serde_json::Value,
29 ) -> std::pin::Pin<
30 Box<dyn std::future::Future<Output = Result<serde_json::Value, AiError>> + Send>,
31 > + Send
32 + Sync,
33>;
34
35pub struct FunctionTool {
36 name: String,
37 description: String,
38 parameters: serde_json::Value,
39 executor: ToolExecutorFn,
40}
41
42impl FunctionTool {
43 pub fn new<F, Fut>(
45 name: impl Into<String>,
46 description: impl Into<String>,
47 parameters: serde_json::Value,
48 executor: F,
49 ) -> Self
50 where
51 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
52 Fut: std::future::Future<Output = Result<serde_json::Value, AiError>> + Send + 'static,
53 {
54 Self {
55 name: name.into(),
56 description: description.into(),
57 parameters,
58 executor: Arc::new(move |args| Box::pin(executor(args))),
59 }
60 }
61
62 pub fn definition(&self) -> Tool {
64 Tool {
65 name: self.name.clone(),
66 description: self.description.clone(),
67 parameters: self.parameters.clone(),
68 }
69 }
70}
71
72#[async_trait]
73impl ToolExecutor for FunctionTool {
74 async fn execute(
75 &self,
76 name: &str,
77 arguments: &serde_json::Value,
78 ) -> Result<serde_json::Value, AiError> {
79 if name != self.name {
80 return Err(AiError::plugin(
81 "ToolUsePlugin",
82 format!("Tool {} not found", name),
83 ));
84 }
85
86 (self.executor)(arguments.clone()).await
87 }
88}
89
90pub struct ToolRegistry {
92 tools: HashMap<String, Arc<dyn ToolExecutor>>,
93}
94
95impl ToolRegistry {
96 pub fn new() -> Self {
98 Self {
99 tools: HashMap::new(),
100 }
101 }
102
103 pub fn register(&mut self, name: impl Into<String>, tool: Arc<dyn ToolExecutor>) {
105 self.tools.insert(name.into(), tool);
106 }
107
108 pub fn definitions(&self) -> Vec<Tool> {
110 self.tools
111 .iter()
112 .map(|(name, tool)| {
113 if let Some(func_tool) = (tool as &dyn std::any::Any).downcast_ref::<FunctionTool>()
116 {
117 func_tool.definition()
118 } else {
119 Tool {
120 name: name.clone(),
121 description: format!("Tool: {}", name),
122 parameters: serde_json::json!({}),
123 }
124 }
125 })
126 .collect()
127 }
128
129 pub async fn execute(
131 &self,
132 name: &str,
133 arguments: &serde_json::Value,
134 ) -> Result<serde_json::Value, AiError> {
135 let tool = self
136 .tools
137 .get(name)
138 .ok_or_else(|| AiError::plugin("ToolUsePlugin", format!("Tool {} not found", name)))?;
139
140 tool.execute(name, arguments).await
141 }
142}
143
144impl Default for ToolRegistry {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct ToolUsePluginConfig {
153 pub auto_execute: bool,
155 pub max_rounds: usize,
157}
158
159impl Default for ToolUsePluginConfig {
160 fn default() -> Self {
161 Self {
162 auto_execute: true,
163 max_rounds: 3,
164 }
165 }
166}
167
168pub struct ToolUsePlugin {
170 registry: Arc<ToolRegistry>,
171 config: ToolUsePluginConfig,
172}
173
174impl std::fmt::Debug for ToolUsePlugin {
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 f.debug_struct("ToolUsePlugin")
177 .field("config", &self.config)
178 .field("tool_count", &self.registry.tools.len())
179 .finish()
180 }
181}
182
183impl ToolUsePlugin {
184 pub fn new(registry: Arc<ToolRegistry>) -> Self {
186 Self {
187 registry,
188 config: ToolUsePluginConfig::default(),
189 }
190 }
191
192 pub fn with_config(registry: Arc<ToolRegistry>, config: ToolUsePluginConfig) -> Self {
194 Self { registry, config }
195 }
196
197 fn add_tools_to_params(&self, mut params: TextParams) -> TextParams {
199 let tools = self.registry.definitions();
200 if !tools.is_empty() {
201 params.tools = Some(tools);
202 }
203 params
204 }
205
206 async fn process_tool_calls(&self, result: TextResult) -> Result<TextResult, AiError> {
208 if result.finish_reason != FinishReason::ToolCalls {
210 return Ok(result);
211 }
212
213 if !self.config.auto_execute {
214 return Ok(result);
215 }
216
217 let tool_calls = result.tool_calls.as_ref();
219 if tool_calls.is_none() || tool_calls.unwrap().is_empty() {
220 return Ok(result);
221 }
222
223 tracing::debug!("Processing tool calls (auto_execute=true)");
227
228 Ok(result)
229 }
230}
231
232#[async_trait]
233impl Plugin for ToolUsePlugin {
234 fn name(&self) -> &str {
235 "tool_use"
236 }
237
238 fn enforce(&self) -> PluginPhase {
239 PluginPhase::Pre
240 }
241
242 async fn transform_params(
243 &self,
244 params: TextParams,
245 _ctx: &RequestContext,
246 ) -> Result<TextParams, AiError> {
247 Ok(self.add_tools_to_params(params))
248 }
249
250 async fn transform_result(
251 &self,
252 result: TextResult,
253 _ctx: &RequestContext,
254 ) -> Result<TextResult, AiError> {
255 self.process_tool_calls(result).await
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[tokio::test]
264 async fn test_function_tool() {
265 let tool = FunctionTool::new(
266 "test",
267 "A test tool",
268 serde_json::json!({"type": "object"}),
269 |args: serde_json::Value| async move { Ok(args) },
270 );
271
272 let result = tool
273 .execute("test", &serde_json::json!({"key": "value"}))
274 .await
275 .unwrap();
276
277 assert_eq!(result, serde_json::json!({"key": "value"}));
278 }
279
280 #[tokio::test]
281 async fn test_tool_registry() {
282 let mut registry = ToolRegistry::new();
283
284 let tool = Arc::new(FunctionTool::new(
285 "add",
286 "Add two numbers",
287 serde_json::json!({"type": "object"}),
288 |args: serde_json::Value| async move { Ok(args) },
289 ));
290
291 registry.register("add", tool);
292
293 let definitions = registry.definitions();
294 assert_eq!(definitions.len(), 1);
295 assert_eq!(definitions[0].name, "add");
296 }
297}