ai_agent/
tools.rs

1//! Tool system for the AI-Native Code Agent
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use crate::errors::ToolError;
7
8/// Tool trait
9#[async_trait]
10pub trait Tool: Send + Sync {
11    fn name(&self) -> &str;
12    fn description(&self) -> &str;
13    fn parameters(&self) -> Vec<Parameter>;
14    async fn execute(&self, args: &ToolArgs) -> Result<ToolResult, ToolError>;
15}
16
17/// Tool parameter
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Parameter {
20    pub name: String,
21    pub description: String,
22    pub required: bool,
23    pub parameter_type: ParameterType,
24    pub default_value: Option<serde_json::Value>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum ParameterType {
30    String,
31    Number,
32    Boolean,
33    Array,
34    Object,
35}
36
37impl Parameter {
38    pub fn required(name: &str, description: &str) -> Self {
39        Self {
40            name: name.to_string(),
41            description: description.to_string(),
42            required: true,
43            parameter_type: ParameterType::String,
44            default_value: None,
45        }
46    }
47
48    pub fn optional(name: &str, description: &str) -> Self {
49        Self {
50            name: name.to_string(),
51            description: description.to_string(),
52            required: false,
53            parameter_type: ParameterType::String,
54            default_value: None,
55        }
56    }
57}
58
59/// Tool arguments
60#[derive(Debug, Clone)]
61pub struct ToolArgs {
62    args: HashMap<String, serde_json::Value>,
63}
64
65impl ToolArgs {
66    pub fn from_map(args: HashMap<String, serde_json::Value>) -> Self {
67        Self { args }
68    }
69
70    pub fn get_string(&self, key: &str) -> Result<String, ToolError> {
71        self.args.get(key)
72            .and_then(|v| v.as_str())
73            .map(|s| s.to_string())
74            .ok_or_else(|| ToolError::InvalidParameters(format!("Missing or invalid parameter: {}", key)))
75    }
76
77    pub fn get_string_or(&self, key: &str, default: &str) -> String {
78        self.args.get(key)
79            .and_then(|v| v.as_str())
80            .unwrap_or(default)
81            .to_string()
82    }
83}
84
85/// Tool result
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct ToolResult {
88    pub success: bool,
89    pub content: String,
90    pub summary: String,
91    pub data: Option<serde_json::Value>,
92    pub error: Option<String>,
93}
94
95impl ToolResult {
96    pub fn text(content: String) -> Self {
97        Self {
98            success: true,
99            summary: content.clone(),
100            content,
101            data: None,
102            error: None,
103        }
104    }
105
106    pub fn json(data: serde_json::Value) -> Self {
107        Self {
108            success: true,
109            summary: "Operation completed successfully".to_string(),
110            content: "Operation completed successfully".to_string(),
111            data: Some(data),
112            error: None,
113        }
114    }
115
116    pub fn error(error: String) -> Self {
117        Self {
118            success: false,
119            summary: error.clone(),
120            content: String::new(),
121            data: None,
122            error: Some(error),
123        }
124    }
125}
126
127/// Tool call
128#[derive(Debug, Clone)]
129pub struct ToolCall {
130    pub name: String,
131    pub args: ToolArgs,
132}
133
134/// Tool registry
135pub struct ToolRegistry {
136    tools: HashMap<String, Box<dyn Tool>>,
137}
138
139impl ToolRegistry {
140    pub fn new() -> Self {
141        Self {
142            tools: HashMap::new(),
143        }
144    }
145
146    pub fn register<T: Tool + 'static>(&mut self, tool: T) {
147        self.tools.insert(tool.name().to_string(), Box::new(tool));
148    }
149
150    pub async fn execute(&self, tool_call: &ToolCall) -> Result<ToolResult, ToolError> {
151        let tool = self.tools.get(&tool_call.name)
152            .ok_or_else(|| ToolError::ToolNotFound(tool_call.name.clone()))?;
153
154        tool.execute(&tool_call.args).await
155    }
156
157    pub fn get_tool_names(&self) -> Vec<String> {
158        self.tools.keys().cloned().collect()
159    }
160
161    pub fn get_tool(&self, name: &str) -> Option<&dyn Tool> {
162        self.tools.get(name).map(|tool| tool.as_ref())
163    }
164
165    pub fn get_all_tools(&self) -> Vec<&dyn Tool> {
166        self.tools.values().map(|tool| tool.as_ref()).collect()
167    }
168}
169
170// Basic tool implementations
171
172/// Read file tool
173pub struct ReadFileTool;
174
175#[async_trait]
176impl Tool for ReadFileTool {
177    fn name(&self) -> &str {
178        "read_file"
179    }
180
181    fn description(&self) -> &str {
182        "Read the contents of a file"
183    }
184
185    fn parameters(&self) -> Vec<Parameter> {
186        vec![
187            Parameter::required("path", "File path to read")
188        ]
189    }
190
191    async fn execute(&self, args: &ToolArgs) -> Result<ToolResult, ToolError> {
192        let path = args.get_string("path")?;
193
194        // Safety check
195        if path.contains("..") || path.starts_with("/") {
196            return Err(ToolError::PermissionDenied("Access to this path is not allowed".to_string()));
197        }
198
199        let content = tokio::fs::read_to_string(path)
200            .await
201            .map_err(|e| ToolError::ExecutionError(e.to_string()))?;
202
203        Ok(ToolResult::text(content))
204    }
205}
206
207/// Write file tool
208pub struct WriteFileTool;
209
210#[async_trait]
211impl Tool for WriteFileTool {
212    fn name(&self) -> &str {
213        "write_file"
214    }
215
216    fn description(&self) -> &str {
217        "Write content to a file"
218    }
219
220    fn parameters(&self) -> Vec<Parameter> {
221        vec![
222            Parameter::required("path", "File path to write"),
223            Parameter::required("content", "Content to write"),
224        ]
225    }
226
227    async fn execute(&self, args: &ToolArgs) -> Result<ToolResult, ToolError> {
228        let path = args.get_string("path")?;
229        let content = args.get_string("content")?;
230
231        // Safety check
232        if path.contains("..") || path.starts_with("/") {
233            return Err(ToolError::PermissionDenied("Access to this path is not allowed".to_string()));
234        }
235
236        tokio::fs::write(path, content)
237            .await
238            .map_err(|e| ToolError::ExecutionError(e.to_string()))?;
239
240        Ok(ToolResult::text("File written successfully".to_string()))
241    }
242}
243
244/// List files tool
245pub struct ListFilesTool;
246
247#[async_trait]
248impl Tool for ListFilesTool {
249    fn name(&self) -> &str {
250        "list_files"
251    }
252
253    fn description(&self) -> &str {
254        "List files and directories in a given path"
255    }
256
257    fn parameters(&self) -> Vec<Parameter> {
258        vec![
259            Parameter::required("path", "Directory path to list")
260        ]
261    }
262
263    async fn execute(&self, args: &ToolArgs) -> Result<ToolResult, ToolError> {
264        let path = args.get_string("path")?;
265
266        // Safety check
267        if path.contains("..") || path.starts_with("/") {
268            return Err(ToolError::PermissionDenied("Access to this path is not allowed".to_string()));
269        }
270
271        let mut entries = Vec::new();
272        let mut dir = tokio::fs::read_dir(path)
273            .await
274            .map_err(|e| ToolError::ExecutionError(e.to_string()))?;
275
276        while let Some(entry) = dir.next_entry().await.map_err(|e| ToolError::ExecutionError(e.to_string()))? {
277            let metadata = std::fs::metadata(entry.path()).ok();
278            entries.push((
279                entry.file_name().to_string_lossy().to_string(),
280                metadata.map(|m| m.is_dir()).unwrap_or(false)
281            ));
282        }
283
284        entries.sort_by(|a, b| {
285            // Directories first, then files
286            match (a.1, b.1) {
287                (true, false) => std::cmp::Ordering::Less,
288                (false, true) => std::cmp::Ordering::Greater,
289                _ => a.0.cmp(&b.0),
290            }
291        });
292
293        let list_text = entries.iter()
294            .map(|(name, is_dir)| {
295                let prefix = if *is_dir { "DIR  " } else { "FILE " };
296                format!("{}{}", prefix, name)
297            })
298            .collect::<Vec<_>>()
299            .join("\n");
300
301        Ok(ToolResult::text(list_text))
302    }
303}
304
305/// Run command tool
306pub struct RunCommandTool;
307
308#[async_trait]
309impl Tool for RunCommandTool {
310    fn name(&self) -> &str {
311        "run_command"
312    }
313
314    fn description(&self) -> &str {
315        "Execute a shell command"
316    }
317
318    fn parameters(&self) -> Vec<Parameter> {
319        vec![
320            Parameter::required("command", "Command to execute"),
321            Parameter::optional("working_dir", "Working directory"),
322        ]
323    }
324
325    async fn execute(&self, args: &ToolArgs) -> Result<ToolResult, ToolError> {
326        let command = args.get_string("command")?;
327        let working_dir = args.get_string_or("working_dir", ".");
328
329        // Safety checks for dangerous commands
330        let dangerous_commands = vec![
331            "rm -rf /", "format", "fdisk", "dd if=", "shutdown", "reboot",
332        ];
333
334        for dangerous in &dangerous_commands {
335            if command.contains(dangerous) {
336                return Err(ToolError::PermissionDenied(format!("Command '{}' is not allowed", dangerous)));
337            }
338        }
339
340        let output = tokio::process::Command::new("sh")
341            .arg("-c")
342            .arg(&command)
343            .current_dir(working_dir)
344            .output()
345            .await
346            .map_err(|e| ToolError::ExecutionError(e.to_string()))?;
347
348        if output.status.success() {
349            let stdout = String::from_utf8_lossy(&output.stdout);
350            Ok(ToolResult::text(stdout.to_string()))
351        } else {
352            let stderr = String::from_utf8_lossy(&output.stderr);
353            Ok(ToolResult::error(stderr.to_string()))
354        }
355    }
356}