Skip to main content

astrid_tools/
lib.rs

1#![deny(unsafe_code)]
2#![warn(missing_docs)]
3#![deny(clippy::all)]
4#![warn(unreachable_pub)]
5//! Built-in coding tools for the Astrid agent runtime.
6//!
7//! Provides 8 tools as direct Rust function calls (not MCP) for the hot-path
8//! coding operations: read, write, edit, search, and execute.
9
10mod bash;
11mod edit_file;
12mod glob;
13mod grep;
14mod instructions;
15mod list_directory;
16mod read_file;
17mod subagent_spawner;
18mod system_prompt;
19mod task;
20mod write_file;
21
22pub use bash::BashTool;
23pub use edit_file::EditFileTool;
24pub use glob::GlobTool;
25pub use grep::GrepTool;
26pub use instructions::load_project_instructions;
27pub use list_directory::ListDirectoryTool;
28pub use read_file::ReadFileTool;
29pub use subagent_spawner::{SubAgentRequest, SubAgentResult, SubAgentSpawner};
30pub use system_prompt::build_system_prompt;
31pub use task::TaskTool;
32pub use write_file::WriteFileTool;
33
34use astrid_llm::LlmToolDefinition;
35use serde_json::Value;
36use std::collections::HashMap;
37use std::path::PathBuf;
38use std::sync::Arc;
39use tokio::sync::RwLock;
40
41/// Maximum output size in characters before truncation.
42const MAX_OUTPUT_CHARS: usize = 30_000;
43
44/// A built-in tool that executes directly in-process.
45#[async_trait::async_trait]
46pub trait BuiltinTool: Send + Sync {
47    /// Tool name (no colons — distinguishes from MCP "server:tool" format).
48    fn name(&self) -> &'static str;
49
50    /// Human-readable description for the LLM.
51    fn description(&self) -> &'static str;
52
53    /// JSON schema for tool input parameters.
54    fn input_schema(&self) -> Value;
55
56    /// Execute the tool with the given arguments.
57    async fn execute(&self, args: Value, ctx: &ToolContext) -> ToolResult;
58}
59
60/// Shared context available to all built-in tools.
61pub struct ToolContext {
62    /// Workspace root directory.
63    pub workspace_root: PathBuf,
64    /// Current working directory (persists across bash invocations).
65    pub cwd: Arc<RwLock<PathBuf>>,
66    /// Sub-agent spawner (set by runtime before each turn, cleared after).
67    subagent_spawner: RwLock<Option<Arc<dyn SubAgentSpawner>>>,
68}
69
70impl ToolContext {
71    /// Create a new tool context.
72    #[must_use]
73    pub fn new(workspace_root: PathBuf) -> Self {
74        let cwd = Arc::new(RwLock::new(workspace_root.clone()));
75        Self {
76            workspace_root,
77            cwd,
78            subagent_spawner: RwLock::new(None),
79        }
80    }
81
82    /// Create a per-turn tool context that shares the `cwd` with other turns
83    /// but has its own independent spawner slot.
84    ///
85    /// This prevents concurrent sessions from racing on the spawner field
86    /// while still sharing the working directory state.
87    #[must_use]
88    pub fn with_shared_cwd(workspace_root: PathBuf, cwd: Arc<RwLock<PathBuf>>) -> Self {
89        Self {
90            workspace_root,
91            cwd,
92            subagent_spawner: RwLock::new(None),
93        }
94    }
95
96    /// Set the sub-agent spawner (called by runtime at turn start).
97    pub async fn set_subagent_spawner(&self, spawner: Option<Arc<dyn SubAgentSpawner>>) {
98        *self.subagent_spawner.write().await = spawner;
99    }
100
101    /// Get the sub-agent spawner (called by `TaskTool`).
102    pub async fn subagent_spawner(&self) -> Option<Arc<dyn SubAgentSpawner>> {
103        self.subagent_spawner.read().await.clone()
104    }
105}
106
107/// Tool execution errors.
108#[derive(Debug, thiserror::Error)]
109pub enum ToolError {
110    /// I/O error.
111    #[error("I/O error: {0}")]
112    Io(#[from] std::io::Error),
113
114    /// Invalid arguments.
115    #[error("Invalid arguments: {0}")]
116    InvalidArguments(String),
117
118    /// Execution failed.
119    #[error("Execution failed: {0}")]
120    ExecutionFailed(String),
121
122    /// Path not found.
123    #[error("Path not found: {0}")]
124    PathNotFound(String),
125
126    /// Timeout.
127    #[error("Timeout after {0}ms")]
128    Timeout(u64),
129
130    /// Other error.
131    #[error("{0}")]
132    Other(String),
133}
134
135/// Result type for tool execution.
136pub type ToolResult = Result<String, ToolError>;
137
138/// Registry of built-in tools for lookup and LLM definition export.
139pub struct ToolRegistry {
140    tools: HashMap<String, Box<dyn BuiltinTool>>,
141}
142
143impl ToolRegistry {
144    /// Create an empty registry.
145    #[must_use]
146    pub fn new() -> Self {
147        Self {
148            tools: HashMap::new(),
149        }
150    }
151
152    /// Create a registry with all default tools registered.
153    #[must_use]
154    pub fn with_defaults() -> Self {
155        let mut registry = Self::new();
156        registry.register(Box::new(ReadFileTool));
157        registry.register(Box::new(WriteFileTool));
158        registry.register(Box::new(EditFileTool));
159        registry.register(Box::new(GlobTool));
160        registry.register(Box::new(GrepTool));
161        registry.register(Box::new(BashTool));
162        registry.register(Box::new(ListDirectoryTool));
163        registry.register(Box::new(TaskTool));
164        registry
165    }
166
167    /// Register a tool.
168    pub fn register(&mut self, tool: Box<dyn BuiltinTool>) {
169        self.tools.insert(tool.name().to_string(), tool);
170    }
171
172    /// Get a tool by name.
173    #[must_use]
174    pub fn get(&self, name: &str) -> Option<&dyn BuiltinTool> {
175        self.tools.get(name).map(AsRef::as_ref)
176    }
177
178    /// Check if a name refers to a built-in tool (no colon = built-in).
179    #[must_use]
180    pub fn is_builtin(name: &str) -> bool {
181        !name.contains(':')
182    }
183
184    /// Export all tool definitions for the LLM.
185    #[must_use]
186    pub fn all_definitions(&self) -> Vec<LlmToolDefinition> {
187        self.tools
188            .values()
189            .map(|t| {
190                LlmToolDefinition::new(t.name())
191                    .with_description(t.description())
192                    .with_schema(t.input_schema())
193            })
194            .collect()
195    }
196}
197
198impl Default for ToolRegistry {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204/// Truncate output to stay within LLM context limits.
205///
206/// If `output` exceeds [`MAX_OUTPUT_CHARS`], it is truncated and a notice is appended.
207#[must_use]
208pub fn truncate_output(output: String) -> String {
209    if output.len() <= MAX_OUTPUT_CHARS {
210        return output;
211    }
212    let mut truncated = output[..MAX_OUTPUT_CHARS].to_string();
213    truncated.push_str("\n\n... (output truncated — exceeded 30000 character limit)");
214    truncated
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_is_builtin() {
223        assert!(ToolRegistry::is_builtin("read_file"));
224        assert!(ToolRegistry::is_builtin("bash"));
225        assert!(!ToolRegistry::is_builtin("filesystem:read_file"));
226    }
227
228    #[test]
229    fn test_registry_with_defaults() {
230        let registry = ToolRegistry::with_defaults();
231        assert!(registry.get("read_file").is_some());
232        assert!(registry.get("write_file").is_some());
233        assert!(registry.get("edit_file").is_some());
234        assert!(registry.get("glob").is_some());
235        assert!(registry.get("grep").is_some());
236        assert!(registry.get("bash").is_some());
237        assert!(registry.get("list_directory").is_some());
238        assert!(registry.get("task").is_some());
239        assert!(registry.get("nonexistent").is_none());
240    }
241
242    #[test]
243    fn test_all_definitions() {
244        let registry = ToolRegistry::with_defaults();
245        let defs = registry.all_definitions();
246        assert_eq!(defs.len(), 8);
247        for def in &defs {
248            assert!(!def.name.contains(':'));
249            assert!(def.description.is_some());
250        }
251    }
252
253    #[test]
254    fn test_truncate_output_small() {
255        let small = "hello".to_string();
256        assert_eq!(truncate_output(small.clone()), small);
257    }
258
259    #[test]
260    fn test_truncate_output_large() {
261        let large = "x".repeat(40_000);
262        let result = truncate_output(large);
263        assert!(result.len() < 40_000);
264        assert!(result.contains("output truncated"));
265    }
266}