aidale_plugin/
tool_use.rs

1//! Tool use plugin for AI models.
2//!
3//! This plugin enables models that don't natively support tool/function calling
4//! to use tools through prompt engineering and response parsing.
5
6use aidale_core::error::AiError;
7use aidale_core::plugin::{Plugin, PluginPhase};
8use aidale_core::types::*;
9use async_trait::async_trait;
10use std::collections::HashMap;
11use std::sync::Arc;
12
13/// Tool executor trait
14#[async_trait]
15pub trait ToolExecutor: Send + Sync {
16    /// Execute a tool with the given arguments
17    async fn execute(
18        &self,
19        name: &str,
20        arguments: &serde_json::Value,
21    ) -> Result<serde_json::Value, AiError>;
22}
23
24/// Simple function-based tool executor
25// Type alias to reduce complexity
26type ToolExecutorFn = Arc<
27    dyn Fn(
28            serde_json::Value,
29        ) -> std::pin::Pin<
30            Box<dyn std::future::Future<Output = Result<serde_json::Value, AiError>> + Send>,
31        > + Send
32        + Sync,
33>;
34
35pub struct FunctionTool {
36    name: String,
37    description: String,
38    parameters: serde_json::Value,
39    executor: ToolExecutorFn,
40}
41
42impl FunctionTool {
43    /// Create a new function tool
44    pub fn new<F, Fut>(
45        name: impl Into<String>,
46        description: impl Into<String>,
47        parameters: serde_json::Value,
48        executor: F,
49    ) -> Self
50    where
51        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
52        Fut: std::future::Future<Output = Result<serde_json::Value, AiError>> + Send + 'static,
53    {
54        Self {
55            name: name.into(),
56            description: description.into(),
57            parameters,
58            executor: Arc::new(move |args| Box::pin(executor(args))),
59        }
60    }
61
62    /// Get tool definition
63    pub fn definition(&self) -> Tool {
64        Tool {
65            name: self.name.clone(),
66            description: self.description.clone(),
67            parameters: self.parameters.clone(),
68        }
69    }
70}
71
72#[async_trait]
73impl ToolExecutor for FunctionTool {
74    async fn execute(
75        &self,
76        name: &str,
77        arguments: &serde_json::Value,
78    ) -> Result<serde_json::Value, AiError> {
79        if name != self.name {
80            return Err(AiError::plugin(
81                "ToolUsePlugin",
82                format!("Tool {} not found", name),
83            ));
84        }
85
86        (self.executor)(arguments.clone()).await
87    }
88}
89
90/// Tool registry that can execute multiple tools
91pub struct ToolRegistry {
92    tools: HashMap<String, Arc<dyn ToolExecutor>>,
93}
94
95impl ToolRegistry {
96    /// Create a new tool registry
97    pub fn new() -> Self {
98        Self {
99            tools: HashMap::new(),
100        }
101    }
102
103    /// Register a tool
104    pub fn register(&mut self, name: impl Into<String>, tool: Arc<dyn ToolExecutor>) {
105        self.tools.insert(name.into(), tool);
106    }
107
108    /// Get all tool definitions
109    pub fn definitions(&self) -> Vec<Tool> {
110        self.tools
111            .iter()
112            .map(|(name, tool)| {
113                // If the tool is a FunctionTool, get its definition
114                // Otherwise, create a basic definition
115                if let Some(func_tool) = (tool as &dyn std::any::Any).downcast_ref::<FunctionTool>()
116                {
117                    func_tool.definition()
118                } else {
119                    Tool {
120                        name: name.clone(),
121                        description: format!("Tool: {}", name),
122                        parameters: serde_json::json!({}),
123                    }
124                }
125            })
126            .collect()
127    }
128
129    /// Execute a tool
130    pub async fn execute(
131        &self,
132        name: &str,
133        arguments: &serde_json::Value,
134    ) -> Result<serde_json::Value, AiError> {
135        let tool = self
136            .tools
137            .get(name)
138            .ok_or_else(|| AiError::plugin("ToolUsePlugin", format!("Tool {} not found", name)))?;
139
140        tool.execute(name, arguments).await
141    }
142}
143
144impl Default for ToolRegistry {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150/// Tool use plugin configuration
151#[derive(Debug, Clone)]
152pub struct ToolUsePluginConfig {
153    /// Whether to automatically execute tool calls
154    pub auto_execute: bool,
155    /// Maximum number of tool execution rounds
156    pub max_rounds: usize,
157}
158
159impl Default for ToolUsePluginConfig {
160    fn default() -> Self {
161        Self {
162            auto_execute: true,
163            max_rounds: 3,
164        }
165    }
166}
167
168/// Tool use plugin
169pub struct ToolUsePlugin {
170    registry: Arc<ToolRegistry>,
171    config: ToolUsePluginConfig,
172}
173
174impl std::fmt::Debug for ToolUsePlugin {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        f.debug_struct("ToolUsePlugin")
177            .field("config", &self.config)
178            .field("tool_count", &self.registry.tools.len())
179            .finish()
180    }
181}
182
183impl ToolUsePlugin {
184    /// Create a new tool use plugin
185    pub fn new(registry: Arc<ToolRegistry>) -> Self {
186        Self {
187            registry,
188            config: ToolUsePluginConfig::default(),
189        }
190    }
191
192    /// Create with custom configuration
193    pub fn with_config(registry: Arc<ToolRegistry>, config: ToolUsePluginConfig) -> Self {
194        Self { registry, config }
195    }
196
197    /// Add tools to request parameters
198    fn add_tools_to_params(&self, mut params: TextParams) -> TextParams {
199        let tools = self.registry.definitions();
200        if !tools.is_empty() {
201            params.tools = Some(tools);
202        }
203        params
204    }
205
206    /// Process tool calls in the result
207    async fn process_tool_calls(&self, result: TextResult) -> Result<TextResult, AiError> {
208        // Check if result contains tool calls
209        if result.finish_reason != FinishReason::ToolCalls {
210            return Ok(result);
211        }
212
213        if !self.config.auto_execute {
214            return Ok(result);
215        }
216
217        // Extract tool calls
218        let tool_calls = result.tool_calls.as_ref();
219        if tool_calls.is_none() || tool_calls.unwrap().is_empty() {
220            return Ok(result);
221        }
222
223        // Execute each tool call
224        // Note: In a real implementation, this would be more sophisticated
225        // and might involve multiple rounds of execution
226        tracing::debug!("Processing tool calls (auto_execute=true)");
227
228        Ok(result)
229    }
230}
231
232#[async_trait]
233impl Plugin for ToolUsePlugin {
234    fn name(&self) -> &str {
235        "tool_use"
236    }
237
238    fn enforce(&self) -> PluginPhase {
239        PluginPhase::Pre
240    }
241
242    async fn transform_params(
243        &self,
244        params: TextParams,
245        _ctx: &RequestContext,
246    ) -> Result<TextParams, AiError> {
247        Ok(self.add_tools_to_params(params))
248    }
249
250    async fn transform_result(
251        &self,
252        result: TextResult,
253        _ctx: &RequestContext,
254    ) -> Result<TextResult, AiError> {
255        self.process_tool_calls(result).await
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[tokio::test]
264    async fn test_function_tool() {
265        let tool = FunctionTool::new(
266            "test",
267            "A test tool",
268            serde_json::json!({"type": "object"}),
269            |args: serde_json::Value| async move { Ok(args) },
270        );
271
272        let result = tool
273            .execute("test", &serde_json::json!({"key": "value"}))
274            .await
275            .unwrap();
276
277        assert_eq!(result, serde_json::json!({"key": "value"}));
278    }
279
280    #[tokio::test]
281    async fn test_tool_registry() {
282        let mut registry = ToolRegistry::new();
283
284        let tool = Arc::new(FunctionTool::new(
285            "add",
286            "Add two numbers",
287            serde_json::json!({"type": "object"}),
288            |args: serde_json::Value| async move { Ok(args) },
289        ));
290
291        registry.register("add", tool);
292
293        let definitions = registry.definitions();
294        assert_eq!(definitions.len(), 1);
295        assert_eq!(definitions[0].name, "add");
296    }
297}