Skip to main content

a3s_code_core/tools/
mod.rs

1//! Extensible Tool System
2//!
3//! Provides a trait-based abstraction for tools.
4//!
5//! ## Architecture
6//!
7//! ```text
8//! ToolRegistry
9//!   └── builtin tools (bash, read, write, edit, grep, glob, ls, patch, web_fetch, web_search)
10//! ```
11
12mod builtin;
13mod registry;
14mod process;
15pub mod task;
16mod types;
17
18pub use registry::ToolRegistry;
19pub use task::{
20    parallel_task_params_schema, task_params_schema, ParallelTaskParams, ParallelTaskTool,
21    TaskExecutor, TaskParams, TaskResult,
22};
23pub use types::{Tool, ToolContext, ToolEventSender, ToolOutput, ToolStreamEvent};
24
25use crate::file_history::{self, FileHistory};
26use crate::llm::ToolDefinition;
27use crate::permissions::{PermissionChecker, PermissionDecision};
28use anyhow::Result;
29use serde::{Deserialize, Serialize};
30use std::path::PathBuf;
31use std::sync::Arc;
32
33/// Maximum output size in bytes before truncation
34pub const MAX_OUTPUT_SIZE: usize = 100 * 1024; // 100KB
35
36/// Maximum lines to read from a file
37pub const MAX_READ_LINES: usize = 2000;
38
39/// Maximum line length before truncation
40pub const MAX_LINE_LENGTH: usize = 2000;
41
42/// Tool execution result (legacy format for backward compatibility)
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ToolResult {
45    pub name: String,
46    pub output: String,
47    pub exit_code: i32,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub metadata: Option<serde_json::Value>,
50}
51
52impl ToolResult {
53    pub fn success(name: &str, output: String) -> Self {
54        Self {
55            name: name.to_string(),
56            output,
57            exit_code: 0,
58            metadata: None,
59        }
60    }
61
62    pub fn error(name: &str, message: String) -> Self {
63        Self {
64            name: name.to_string(),
65            output: message,
66            exit_code: 1,
67            metadata: None,
68        }
69    }
70}
71
72impl From<ToolOutput> for ToolResult {
73    fn from(output: ToolOutput) -> Self {
74        Self {
75            name: String::new(),
76            output: output.content,
77            exit_code: if output.success { 0 } else { 1 },
78            metadata: output.metadata,
79        }
80    }
81}
82
83/// Tool executor with workspace sandboxing
84///
85/// This is the main entry point for tool execution. It wraps the ToolRegistry
86/// and provides backward-compatible API. Includes file version history tracking
87/// for write/edit/patch operations.
88///
89/// Defense-in-depth: An optional permission policy can be set to block
90/// denied tools even if the caller bypasses the agent loop's authorization.
91pub struct ToolExecutor {
92    workspace: PathBuf,
93    registry: ToolRegistry,
94    file_history: Arc<FileHistory>,
95    guard_policy: Option<Arc<dyn PermissionChecker>>,
96}
97
98impl ToolExecutor {
99    pub fn new(workspace: String) -> Self {
100        let workspace_path = PathBuf::from(&workspace);
101        let registry = ToolRegistry::new(workspace_path.clone());
102
103        // Register native Rust built-in tools
104        builtin::register_builtins(&registry);
105
106        Self {
107            workspace: workspace_path,
108            registry,
109            file_history: Arc::new(FileHistory::new(500)),
110            guard_policy: None,
111        }
112    }
113
114    pub fn set_guard_policy(&mut self, policy: Arc<dyn PermissionChecker>) {
115        self.guard_policy = Some(policy);
116    }
117
118    fn check_guard(&self, name: &str, args: &serde_json::Value) -> Result<()> {
119        if let Some(checker) = &self.guard_policy {
120            if checker.check(name, args) == PermissionDecision::Deny {
121                anyhow::bail!(
122                    "Defense-in-depth: Tool '{}' is blocked by guard permission policy",
123                    name
124                );
125            }
126        }
127        Ok(())
128    }
129
130    fn check_workspace_boundary(
131        name: &str,
132        args: &serde_json::Value,
133        ctx: &ToolContext,
134    ) -> Result<()> {
135        let path_field = match name {
136            "read" | "write" | "edit" | "patch" => Some("file_path"),
137            "ls" | "grep" | "glob" => Some("path"),
138            _ => None,
139        };
140
141        if let Some(field) = path_field {
142            if let Some(path_str) = args.get(field).and_then(|v| v.as_str()) {
143                let target = if std::path::Path::new(path_str).is_absolute() {
144                    std::path::PathBuf::from(path_str)
145                } else {
146                    ctx.workspace.join(path_str)
147                };
148
149                if let (Ok(canonical_target), Ok(canonical_workspace)) = (
150                    target.canonicalize().or_else(|_| {
151                        target
152                            .parent()
153                            .and_then(|p| p.canonicalize().ok())
154                            .ok_or_else(|| {
155                                std::io::Error::new(
156                                    std::io::ErrorKind::NotFound,
157                                    "parent not found",
158                                )
159                            })
160                    }),
161                    ctx.workspace.canonicalize(),
162                ) {
163                    if !canonical_target.starts_with(&canonical_workspace) {
164                        anyhow::bail!(
165                            "Workspace boundary violation: tool '{}' path '{}' escapes workspace '{}'",
166                            name,
167                            path_str,
168                            ctx.workspace.display()
169                        );
170                    }
171                }
172            }
173        }
174
175        Ok(())
176    }
177
178    pub fn workspace(&self) -> &PathBuf {
179        &self.workspace
180    }
181
182    pub fn registry(&self) -> &ToolRegistry {
183        &self.registry
184    }
185
186    pub fn register_dynamic_tool(&self, tool: Arc<dyn Tool>) {
187        self.registry.register(tool);
188    }
189
190    pub fn unregister_dynamic_tool(&self, name: &str) {
191        self.registry.unregister(name);
192    }
193
194    pub fn file_history(&self) -> &Arc<FileHistory> {
195        &self.file_history
196    }
197
198    fn capture_snapshot(&self, name: &str, args: &serde_json::Value) {
199        if let Some(file_path) = file_history::extract_file_path(name, args) {
200            let resolved = self.workspace.join(&file_path);
201            let path_to_read = if resolved.exists() {
202                resolved
203            } else if std::path::Path::new(&file_path).exists() {
204                std::path::PathBuf::from(&file_path)
205            } else {
206                self.file_history.save_snapshot(&file_path, "", name);
207                return;
208            };
209
210            match std::fs::read_to_string(&path_to_read) {
211                Ok(content) => {
212                    self.file_history.save_snapshot(&file_path, &content, name);
213                    tracing::debug!(
214                        "Captured file snapshot for {} before {} (version {})",
215                        file_path,
216                        name,
217                        self.file_history.list_versions(&file_path).len() - 1,
218                    );
219                }
220                Err(e) => {
221                    tracing::warn!("Failed to capture snapshot for {}: {}", file_path, e);
222                }
223            }
224        }
225    }
226
227    pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
228        self.check_guard(name, args)?;
229        tracing::info!("Executing tool: {} with args: {}", name, args);
230        self.capture_snapshot(name, args);
231        let result = self.registry.execute(name, args).await;
232        match &result {
233            Ok(r) => tracing::info!("Tool {} completed with exit_code={}", name, r.exit_code),
234            Err(e) => tracing::error!("Tool {} failed: {}", name, e),
235        }
236        result
237    }
238
239    pub async fn execute_with_context(
240        &self,
241        name: &str,
242        args: &serde_json::Value,
243        ctx: &ToolContext,
244    ) -> Result<ToolResult> {
245        self.check_guard(name, args)?;
246        Self::check_workspace_boundary(name, args, ctx)?;
247        tracing::info!("Executing tool: {} with args: {}", name, args);
248        self.capture_snapshot(name, args);
249        let result = self.registry.execute_with_context(name, args, ctx).await;
250        match &result {
251            Ok(r) => tracing::info!("Tool {} completed with exit_code={}", name, r.exit_code),
252            Err(e) => tracing::error!("Tool {} failed: {}", name, e),
253        }
254        result
255    }
256
257    pub fn definitions(&self) -> Vec<ToolDefinition> {
258        self.registry.definitions()
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[tokio::test]
267    async fn test_tool_executor_creation() {
268        let executor = ToolExecutor::new("/tmp".to_string());
269        assert_eq!(executor.registry.len(), 10);
270    }
271
272    #[tokio::test]
273    async fn test_unknown_tool() {
274        let executor = ToolExecutor::new("/tmp".to_string());
275        let result = executor
276            .execute("unknown", &serde_json::json!({}))
277            .await
278            .unwrap();
279        assert_eq!(result.exit_code, 1);
280        assert!(result.output.contains("Unknown tool"));
281    }
282
283    #[tokio::test]
284    async fn test_builtin_tools_registered() {
285        let executor = ToolExecutor::new("/tmp".to_string());
286        let definitions = executor.definitions();
287
288        assert!(definitions.iter().any(|t| t.name == "bash"));
289        assert!(definitions.iter().any(|t| t.name == "read"));
290        assert!(definitions.iter().any(|t| t.name == "write"));
291        assert!(definitions.iter().any(|t| t.name == "edit"));
292        assert!(definitions.iter().any(|t| t.name == "grep"));
293        assert!(definitions.iter().any(|t| t.name == "glob"));
294        assert!(definitions.iter().any(|t| t.name == "ls"));
295        assert!(definitions.iter().any(|t| t.name == "patch"));
296        assert!(definitions.iter().any(|t| t.name == "web_fetch"));
297        assert!(definitions.iter().any(|t| t.name == "web_search"));
298    }
299
300    #[test]
301    fn test_tool_result_success() {
302        let result = ToolResult::success("test_tool", "output text".to_string());
303        assert_eq!(result.name, "test_tool");
304        assert_eq!(result.output, "output text");
305        assert_eq!(result.exit_code, 0);
306        assert!(result.metadata.is_none());
307    }
308
309    #[test]
310    fn test_tool_result_error() {
311        let result = ToolResult::error("test_tool", "error message".to_string());
312        assert_eq!(result.name, "test_tool");
313        assert_eq!(result.output, "error message");
314        assert_eq!(result.exit_code, 1);
315        assert!(result.metadata.is_none());
316    }
317
318    #[test]
319    fn test_tool_result_from_tool_output_success() {
320        let output = ToolOutput {
321            content: "success content".to_string(),
322            success: true,
323            metadata: None,
324        };
325        let result: ToolResult = output.into();
326        assert_eq!(result.output, "success content");
327        assert_eq!(result.exit_code, 0);
328        assert!(result.metadata.is_none());
329    }
330
331    #[test]
332    fn test_tool_result_from_tool_output_failure() {
333        let output = ToolOutput {
334            content: "failure content".to_string(),
335            success: false,
336            metadata: Some(serde_json::json!({"error": "test"})),
337        };
338        let result: ToolResult = output.into();
339        assert_eq!(result.output, "failure content");
340        assert_eq!(result.exit_code, 1);
341        assert_eq!(result.metadata, Some(serde_json::json!({"error": "test"})));
342    }
343
344    #[test]
345    fn test_tool_result_metadata_propagation() {
346        let output = ToolOutput::success("content")
347            .with_metadata(serde_json::json!({"_load_skill": true, "skill_name": "test"}));
348        let result: ToolResult = output.into();
349        assert_eq!(result.exit_code, 0);
350        let meta = result.metadata.unwrap();
351        assert_eq!(meta["_load_skill"], true);
352        assert_eq!(meta["skill_name"], "test");
353    }
354
355    #[test]
356    fn test_tool_executor_workspace() {
357        let executor = ToolExecutor::new("/test/workspace".to_string());
358        assert_eq!(executor.workspace().to_str().unwrap(), "/test/workspace");
359    }
360
361    #[test]
362    fn test_tool_executor_registry() {
363        let executor = ToolExecutor::new("/tmp".to_string());
364        let registry = executor.registry();
365        assert_eq!(registry.len(), 10);
366    }
367
368    #[test]
369    fn test_tool_executor_file_history() {
370        let executor = ToolExecutor::new("/tmp".to_string());
371        let history = executor.file_history();
372        assert_eq!(history.list_versions("nonexistent.txt").len(), 0);
373    }
374
375    #[test]
376    fn test_max_output_size_constant() {
377        assert_eq!(MAX_OUTPUT_SIZE, 100 * 1024);
378    }
379
380    #[test]
381    fn test_max_read_lines_constant() {
382        assert_eq!(MAX_READ_LINES, 2000);
383    }
384
385    #[test]
386    fn test_max_line_length_constant() {
387        assert_eq!(MAX_LINE_LENGTH, 2000);
388    }
389
390    #[test]
391    fn test_tool_result_clone() {
392        let result = ToolResult::success("test", "output".to_string());
393        let cloned = result.clone();
394        assert_eq!(result.name, cloned.name);
395        assert_eq!(result.output, cloned.output);
396        assert_eq!(result.exit_code, cloned.exit_code);
397        assert_eq!(result.metadata, cloned.metadata);
398    }
399
400    #[test]
401    fn test_tool_result_debug() {
402        let result = ToolResult::success("test", "output".to_string());
403        let debug_str = format!("{:?}", result);
404        assert!(debug_str.contains("test"));
405        assert!(debug_str.contains("output"));
406    }
407}