1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use crate::errors::ToolError;
7
8#[async_trait]
10pub trait Tool: Send + Sync {
11 fn name(&self) -> &str;
12 fn description(&self) -> &str;
13 fn parameters(&self) -> Vec<Parameter>;
14 async fn execute(&self, args: &ToolArgs) -> Result<ToolResult, ToolError>;
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Parameter {
20 pub name: String,
21 pub description: String,
22 pub required: bool,
23 pub parameter_type: ParameterType,
24 pub default_value: Option<serde_json::Value>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum ParameterType {
30 String,
31 Number,
32 Boolean,
33 Array,
34 Object,
35}
36
37impl Parameter {
38 pub fn required(name: &str, description: &str) -> Self {
39 Self {
40 name: name.to_string(),
41 description: description.to_string(),
42 required: true,
43 parameter_type: ParameterType::String,
44 default_value: None,
45 }
46 }
47
48 pub fn optional(name: &str, description: &str) -> Self {
49 Self {
50 name: name.to_string(),
51 description: description.to_string(),
52 required: false,
53 parameter_type: ParameterType::String,
54 default_value: None,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct ToolArgs {
62 args: HashMap<String, serde_json::Value>,
63}
64
65impl ToolArgs {
66 pub fn from_map(args: HashMap<String, serde_json::Value>) -> Self {
67 Self { args }
68 }
69
70 pub fn get_string(&self, key: &str) -> Result<String, ToolError> {
71 self.args.get(key)
72 .and_then(|v| v.as_str())
73 .map(|s| s.to_string())
74 .ok_or_else(|| ToolError::InvalidParameters(format!("Missing or invalid parameter: {}", key)))
75 }
76
77 pub fn get_string_or(&self, key: &str, default: &str) -> String {
78 self.args.get(key)
79 .and_then(|v| v.as_str())
80 .unwrap_or(default)
81 .to_string()
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct ToolResult {
88 pub success: bool,
89 pub content: String,
90 pub summary: String,
91 pub data: Option<serde_json::Value>,
92 pub error: Option<String>,
93}
94
95impl ToolResult {
96 pub fn text(content: String) -> Self {
97 Self {
98 success: true,
99 summary: content.clone(),
100 content,
101 data: None,
102 error: None,
103 }
104 }
105
106 pub fn json(data: serde_json::Value) -> Self {
107 Self {
108 success: true,
109 summary: "Operation completed successfully".to_string(),
110 content: "Operation completed successfully".to_string(),
111 data: Some(data),
112 error: None,
113 }
114 }
115
116 pub fn error(error: String) -> Self {
117 Self {
118 success: false,
119 summary: error.clone(),
120 content: String::new(),
121 data: None,
122 error: Some(error),
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct ToolCall {
130 pub name: String,
131 pub args: ToolArgs,
132}
133
134pub struct ToolRegistry {
136 tools: HashMap<String, Box<dyn Tool>>,
137}
138
139impl ToolRegistry {
140 pub fn new() -> Self {
141 Self {
142 tools: HashMap::new(),
143 }
144 }
145
146 pub fn register<T: Tool + 'static>(&mut self, tool: T) {
147 self.tools.insert(tool.name().to_string(), Box::new(tool));
148 }
149
150 pub async fn execute(&self, tool_call: &ToolCall) -> Result<ToolResult, ToolError> {
151 let tool = self.tools.get(&tool_call.name)
152 .ok_or_else(|| ToolError::ToolNotFound(tool_call.name.clone()))?;
153
154 tool.execute(&tool_call.args).await
155 }
156
157 pub fn get_tool_names(&self) -> Vec<String> {
158 self.tools.keys().cloned().collect()
159 }
160
161 pub fn get_tool(&self, name: &str) -> Option<&dyn Tool> {
162 self.tools.get(name).map(|tool| tool.as_ref())
163 }
164
165 pub fn get_all_tools(&self) -> Vec<&dyn Tool> {
166 self.tools.values().map(|tool| tool.as_ref()).collect()
167 }
168}
169
170pub struct ReadFileTool;
174
175#[async_trait]
176impl Tool for ReadFileTool {
177 fn name(&self) -> &str {
178 "read_file"
179 }
180
181 fn description(&self) -> &str {
182 "Read the contents of a file"
183 }
184
185 fn parameters(&self) -> Vec<Parameter> {
186 vec![
187 Parameter::required("path", "File path to read")
188 ]
189 }
190
191 async fn execute(&self, args: &ToolArgs) -> Result<ToolResult, ToolError> {
192 let path = args.get_string("path")?;
193
194 if path.contains("..") || path.starts_with("/") {
196 return Err(ToolError::PermissionDenied("Access to this path is not allowed".to_string()));
197 }
198
199 let content = tokio::fs::read_to_string(path)
200 .await
201 .map_err(|e| ToolError::ExecutionError(e.to_string()))?;
202
203 Ok(ToolResult::text(content))
204 }
205}
206
207pub struct WriteFileTool;
209
210#[async_trait]
211impl Tool for WriteFileTool {
212 fn name(&self) -> &str {
213 "write_file"
214 }
215
216 fn description(&self) -> &str {
217 "Write content to a file"
218 }
219
220 fn parameters(&self) -> Vec<Parameter> {
221 vec![
222 Parameter::required("path", "File path to write"),
223 Parameter::required("content", "Content to write"),
224 ]
225 }
226
227 async fn execute(&self, args: &ToolArgs) -> Result<ToolResult, ToolError> {
228 let path = args.get_string("path")?;
229 let content = args.get_string("content")?;
230
231 if path.contains("..") || path.starts_with("/") {
233 return Err(ToolError::PermissionDenied("Access to this path is not allowed".to_string()));
234 }
235
236 tokio::fs::write(path, content)
237 .await
238 .map_err(|e| ToolError::ExecutionError(e.to_string()))?;
239
240 Ok(ToolResult::text("File written successfully".to_string()))
241 }
242}
243
244pub struct ListFilesTool;
246
247#[async_trait]
248impl Tool for ListFilesTool {
249 fn name(&self) -> &str {
250 "list_files"
251 }
252
253 fn description(&self) -> &str {
254 "List files and directories in a given path"
255 }
256
257 fn parameters(&self) -> Vec<Parameter> {
258 vec![
259 Parameter::required("path", "Directory path to list")
260 ]
261 }
262
263 async fn execute(&self, args: &ToolArgs) -> Result<ToolResult, ToolError> {
264 let path = args.get_string("path")?;
265
266 if path.contains("..") || path.starts_with("/") {
268 return Err(ToolError::PermissionDenied("Access to this path is not allowed".to_string()));
269 }
270
271 let mut entries = Vec::new();
272 let mut dir = tokio::fs::read_dir(path)
273 .await
274 .map_err(|e| ToolError::ExecutionError(e.to_string()))?;
275
276 while let Some(entry) = dir.next_entry().await.map_err(|e| ToolError::ExecutionError(e.to_string()))? {
277 let metadata = std::fs::metadata(entry.path()).ok();
278 entries.push((
279 entry.file_name().to_string_lossy().to_string(),
280 metadata.map(|m| m.is_dir()).unwrap_or(false)
281 ));
282 }
283
284 entries.sort_by(|a, b| {
285 match (a.1, b.1) {
287 (true, false) => std::cmp::Ordering::Less,
288 (false, true) => std::cmp::Ordering::Greater,
289 _ => a.0.cmp(&b.0),
290 }
291 });
292
293 let list_text = entries.iter()
294 .map(|(name, is_dir)| {
295 let prefix = if *is_dir { "DIR " } else { "FILE " };
296 format!("{}{}", prefix, name)
297 })
298 .collect::<Vec<_>>()
299 .join("\n");
300
301 Ok(ToolResult::text(list_text))
302 }
303}
304
305pub struct RunCommandTool;
307
308#[async_trait]
309impl Tool for RunCommandTool {
310 fn name(&self) -> &str {
311 "run_command"
312 }
313
314 fn description(&self) -> &str {
315 "Execute a shell command"
316 }
317
318 fn parameters(&self) -> Vec<Parameter> {
319 vec![
320 Parameter::required("command", "Command to execute"),
321 Parameter::optional("working_dir", "Working directory"),
322 ]
323 }
324
325 async fn execute(&self, args: &ToolArgs) -> Result<ToolResult, ToolError> {
326 let command = args.get_string("command")?;
327 let working_dir = args.get_string_or("working_dir", ".");
328
329 let dangerous_commands = vec![
331 "rm -rf /", "format", "fdisk", "dd if=", "shutdown", "reboot",
332 ];
333
334 for dangerous in &dangerous_commands {
335 if command.contains(dangerous) {
336 return Err(ToolError::PermissionDenied(format!("Command '{}' is not allowed", dangerous)));
337 }
338 }
339
340 let output = tokio::process::Command::new("sh")
341 .arg("-c")
342 .arg(&command)
343 .current_dir(working_dir)
344 .output()
345 .await
346 .map_err(|e| ToolError::ExecutionError(e.to_string()))?;
347
348 if output.status.success() {
349 let stdout = String::from_utf8_lossy(&output.stdout);
350 Ok(ToolResult::text(stdout.to_string()))
351 } else {
352 let stderr = String::from_utf8_lossy(&output.stderr);
353 Ok(ToolResult::error(stderr.to_string()))
354 }
355 }
356}