1use crate::error::{Result, ToolError};
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use uuid::Uuid;
8
9#[async_trait]
11pub trait Tool: Send + Sync {
12 fn name(&self) -> &str;
14
15 fn description(&self) -> &str;
17
18 fn parameters_schema(&self) -> serde_json::Value;
20
21 async fn execute(&self, call: ToolCall) -> Result<ToolResult>;
23
24 fn requires_confirmation(&self) -> bool {
26 false
27 }
28
29 fn examples(&self) -> Vec<ToolExample> {
31 Vec::new()
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ToolCall {
38 pub id: String,
40
41 pub name: String,
43
44 pub parameters: serde_json::Value,
46
47 pub metadata: Option<HashMap<String, serde_json::Value>>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ToolResult {
54 pub tool_call_id: String,
56
57 pub success: bool,
59
60 pub content: String,
62
63 pub data: Option<serde_json::Value>,
65
66 pub duration_ms: Option<u64>,
68
69 pub metadata: Option<HashMap<String, serde_json::Value>>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolExample {
76 pub description: String,
78
79 pub parameters: serde_json::Value,
81
82 pub expected_result: String,
84}
85
86pub struct ToolExecutor {
88 tools: HashMap<String, Box<dyn Tool>>,
89}
90
91impl ToolCall {
92 pub fn new<S: Into<String>>(name: S, parameters: serde_json::Value) -> Self {
94 Self {
95 id: Uuid::new_v4().to_string(),
96 name: name.into(),
97 parameters,
98 metadata: None,
99 }
100 }
101
102 pub fn get_parameter<T>(&self, key: &str) -> Result<T>
104 where
105 T: for<'de> Deserialize<'de>,
106 {
107 let value = self
108 .parameters
109 .get(key)
110 .ok_or_else(|| ToolError::InvalidParameters {
111 message: format!("Missing parameter: {}", key),
112 })?;
113
114 serde_json::from_value(value.clone()).map_err(|_| {
115 ToolError::InvalidParameters {
116 message: format!("Invalid parameter type for: {}", key),
117 }
118 .into()
119 })
120 }
121
122 pub fn get_parameter_or<T>(&self, key: &str, default: T) -> T
124 where
125 T: for<'de> Deserialize<'de> + Clone,
126 {
127 self.get_parameter(key).unwrap_or(default)
128 }
129}
130
131impl ToolResult {
132 pub fn success<S: Into<String>>(tool_call_id: S, content: S) -> Self {
134 Self {
135 tool_call_id: tool_call_id.into(),
136 success: true,
137 content: content.into(),
138 data: None,
139 duration_ms: None,
140 metadata: None,
141 }
142 }
143
144 pub fn error<S: Into<String>>(tool_call_id: S, error: S) -> Self {
146 Self {
147 tool_call_id: tool_call_id.into(),
148 success: false,
149 content: format!("Error: {}", error.into()),
150 data: None,
151 duration_ms: None,
152 metadata: None,
153 }
154 }
155
156 pub fn with_data(mut self, data: serde_json::Value) -> Self {
158 self.data = Some(data);
159 self
160 }
161
162 pub fn with_duration(mut self, duration_ms: u64) -> Self {
164 self.duration_ms = Some(duration_ms);
165 self
166 }
167
168 pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
170 self.metadata = Some(metadata);
171 self
172 }
173}
174
175impl ToolExecutor {
176 pub fn new() -> Self {
178 Self {
179 tools: HashMap::new(),
180 }
181 }
182
183 pub fn register_tool(&mut self, tool: Box<dyn Tool>) {
185 self.tools.insert(tool.name().to_string(), tool);
186 }
187
188 pub fn get_tool(&self, name: &str) -> Option<&dyn Tool> {
190 self.tools.get(name).map(|t| t.as_ref())
191 }
192
193 pub fn list_tools(&self) -> Vec<&str> {
195 self.tools.keys().map(|s| s.as_str()).collect()
196 }
197
198 pub async fn execute(&self, call: ToolCall) -> Result<ToolResult> {
200 let tool = self
201 .get_tool(&call.name)
202 .ok_or_else(|| ToolError::NotFound {
203 name: call.name.clone(),
204 })?;
205
206 let start_time = std::time::Instant::now();
207 let call_id = call.id.clone();
208 let result = tool.execute(call).await;
209 let duration = start_time.elapsed().as_millis() as u64;
210
211 match result {
212 Ok(mut result) => {
213 result.duration_ms = Some(duration);
214 Ok(result)
215 }
216 Err(e) => Ok(ToolResult::error(&call_id, &e.to_string()).with_duration(duration)),
217 }
218 }
219
220 pub fn get_tool_definitions(&self) -> Vec<crate::llm::ToolDefinition> {
222 self.tools
223 .values()
224 .map(|tool| crate::llm::ToolDefinition {
225 tool_type: "function".to_string(),
226 function: crate::llm::FunctionDefinition {
227 name: tool.name().to_string(),
228 description: tool.description().to_string(),
229 parameters: tool.parameters_schema(),
230 },
231 })
232 .collect()
233 }
234}
235
236impl Default for ToolExecutor {
237 fn default() -> Self {
238 Self::new()
239 }
240}