coro_core/tools/
base.rs

1//! Base tool traits and structures
2
3use crate::error::{Result, ToolError};
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use uuid::Uuid;
8
9/// Trait for all tools
10#[async_trait]
11pub trait Tool: Send + Sync {
12    /// Get the name of the tool
13    fn name(&self) -> &str;
14
15    /// Get the description of the tool
16    fn description(&self) -> &str;
17
18    /// Get the JSON schema for the tool's parameters
19    fn parameters_schema(&self) -> serde_json::Value;
20
21    /// Execute the tool with the given parameters
22    async fn execute(&self, call: ToolCall) -> Result<ToolResult>;
23
24    /// Check if the tool requires special permissions
25    fn requires_confirmation(&self) -> bool {
26        false
27    }
28
29    /// Get examples of how to use this tool
30    fn examples(&self) -> Vec<ToolExample> {
31        Vec::new()
32    }
33}
34
35/// A call to a tool
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ToolCall {
38    /// Unique identifier for this tool call
39    pub id: String,
40
41    /// Name of the tool to call
42    pub name: String,
43
44    /// Parameters to pass to the tool
45    pub parameters: serde_json::Value,
46
47    /// Optional metadata
48    pub metadata: Option<HashMap<String, serde_json::Value>>,
49}
50
51/// Result of a tool execution
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ToolResult {
54    /// ID of the tool call this is a result for
55    pub tool_call_id: String,
56
57    /// Whether the execution was successful
58    pub success: bool,
59
60    /// Result content
61    pub content: String,
62
63    /// Optional structured data
64    pub data: Option<serde_json::Value>,
65
66    /// Execution duration in milliseconds
67    pub duration_ms: Option<u64>,
68
69    /// Optional metadata
70    pub metadata: Option<HashMap<String, serde_json::Value>>,
71}
72
73/// Example usage of a tool
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolExample {
76    /// Description of what this example does
77    pub description: String,
78
79    /// Example parameters
80    pub parameters: serde_json::Value,
81
82    /// Expected result description
83    pub expected_result: String,
84}
85
86/// Tool executor that manages tool execution
87pub struct ToolExecutor {
88    tools: HashMap<String, Box<dyn Tool>>,
89}
90
91impl ToolCall {
92    /// Create a new tool call
93    pub fn new<S: Into<String>>(name: S, parameters: serde_json::Value) -> Self {
94        Self {
95            id: Uuid::new_v4().to_string(),
96            name: name.into(),
97            parameters,
98            metadata: None,
99        }
100    }
101
102    /// Get a parameter value by key
103    pub fn get_parameter<T>(&self, key: &str) -> Result<T>
104    where
105        T: for<'de> Deserialize<'de>,
106    {
107        let value = self
108            .parameters
109            .get(key)
110            .ok_or_else(|| ToolError::InvalidParameters {
111                message: format!("Missing parameter: {}", key),
112            })?;
113
114        serde_json::from_value(value.clone()).map_err(|_| {
115            ToolError::InvalidParameters {
116                message: format!("Invalid parameter type for: {}", key),
117            }
118            .into()
119        })
120    }
121
122    /// Get a parameter value by key with a default
123    pub fn get_parameter_or<T>(&self, key: &str, default: T) -> T
124    where
125        T: for<'de> Deserialize<'de> + Clone,
126    {
127        self.get_parameter(key).unwrap_or(default)
128    }
129}
130
131impl ToolResult {
132    /// Create a successful result
133    pub fn success<S: Into<String>>(tool_call_id: S, content: S) -> Self {
134        Self {
135            tool_call_id: tool_call_id.into(),
136            success: true,
137            content: content.into(),
138            data: None,
139            duration_ms: None,
140            metadata: None,
141        }
142    }
143
144    /// Create an error result
145    pub fn error<S: Into<String>>(tool_call_id: S, error: S) -> Self {
146        Self {
147            tool_call_id: tool_call_id.into(),
148            success: false,
149            content: format!("Error: {}", error.into()),
150            data: None,
151            duration_ms: None,
152            metadata: None,
153        }
154    }
155
156    /// Set structured data
157    pub fn with_data(mut self, data: serde_json::Value) -> Self {
158        self.data = Some(data);
159        self
160    }
161
162    /// Set execution duration
163    pub fn with_duration(mut self, duration_ms: u64) -> Self {
164        self.duration_ms = Some(duration_ms);
165        self
166    }
167
168    /// Set metadata
169    pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
170        self.metadata = Some(metadata);
171        self
172    }
173}
174
175impl ToolExecutor {
176    /// Create a new tool executor
177    pub fn new() -> Self {
178        Self {
179            tools: HashMap::new(),
180        }
181    }
182
183    /// Register a tool
184    pub fn register_tool(&mut self, tool: Box<dyn Tool>) {
185        self.tools.insert(tool.name().to_string(), tool);
186    }
187
188    /// Get a tool by name
189    pub fn get_tool(&self, name: &str) -> Option<&dyn Tool> {
190        self.tools.get(name).map(|t| t.as_ref())
191    }
192
193    /// List all available tools
194    pub fn list_tools(&self) -> Vec<&str> {
195        self.tools.keys().map(|s| s.as_str()).collect()
196    }
197
198    /// Execute a tool call
199    pub async fn execute(&self, call: ToolCall) -> Result<ToolResult> {
200        let tool = self
201            .get_tool(&call.name)
202            .ok_or_else(|| ToolError::NotFound {
203                name: call.name.clone(),
204            })?;
205
206        let start_time = std::time::Instant::now();
207        let call_id = call.id.clone();
208        let result = tool.execute(call).await;
209        let duration = start_time.elapsed().as_millis() as u64;
210
211        match result {
212            Ok(mut result) => {
213                result.duration_ms = Some(duration);
214                Ok(result)
215            }
216            Err(e) => Ok(ToolResult::error(&call_id, &e.to_string()).with_duration(duration)),
217        }
218    }
219
220    /// Get tool definitions for LLM function calling
221    pub fn get_tool_definitions(&self) -> Vec<crate::llm::ToolDefinition> {
222        self.tools
223            .values()
224            .map(|tool| crate::llm::ToolDefinition {
225                tool_type: "function".to_string(),
226                function: crate::llm::FunctionDefinition {
227                    name: tool.name().to_string(),
228                    description: tool.description().to_string(),
229                    parameters: tool.parameters_schema(),
230                },
231            })
232            .collect()
233    }
234}
235
236impl Default for ToolExecutor {
237    fn default() -> Self {
238        Self::new()
239    }
240}