1#![deny(unsafe_code)]
7#![deny(missing_docs)]
8#![deny(clippy::all)]
9#![warn(unreachable_pub)]
10#![deny(clippy::unwrap_used)]
11#![cfg_attr(test, allow(clippy::unwrap_used))]
12
13mod bash;
14mod edit_file;
15mod glob;
16mod grep;
17mod instructions;
18mod list_directory;
19mod read_file;
20pub mod spark;
21mod spark_tool;
22mod subagent_spawner;
23mod system_prompt;
24mod task;
25mod truncate;
26mod write_file;
27
28pub use bash::BashTool;
29pub use edit_file::EditFileTool;
30pub use glob::GlobTool;
31pub use grep::GrepTool;
32pub use instructions::load_project_instructions;
33pub use list_directory::ListDirectoryTool;
34pub use read_file::ReadFileTool;
35pub use spark::SparkConfig;
36pub use spark_tool::SparkTool;
37pub use subagent_spawner::{SubAgentRequest, SubAgentResult, SubAgentSpawner};
38pub use system_prompt::build_system_prompt;
39pub use task::TaskTool;
40pub use truncate::truncate_at_char_boundary;
41pub use write_file::WriteFileTool;
42
43use astrid_llm::LlmToolDefinition;
44use serde_json::Value;
45use std::collections::HashMap;
46use std::path::PathBuf;
47use std::sync::Arc;
48use tokio::sync::RwLock;
49
50const MAX_OUTPUT_CHARS: usize = 30_000;
52
53#[async_trait::async_trait]
55pub trait BuiltinTool: Send + Sync {
56 fn name(&self) -> &'static str;
58
59 fn description(&self) -> &'static str;
61
62 fn input_schema(&self) -> Value;
64
65 async fn execute(&self, args: Value, ctx: &ToolContext) -> ToolResult;
67}
68
69pub struct ToolContext {
71 pub workspace_root: PathBuf,
73 pub cwd: Arc<RwLock<PathBuf>>,
75 pub spark_file: Option<PathBuf>,
77 subagent_spawner: RwLock<Option<Arc<dyn SubAgentSpawner>>>,
79}
80
81impl ToolContext {
82 #[must_use]
84 pub fn new(workspace_root: PathBuf, spark_file: Option<PathBuf>) -> Self {
85 let cwd = Arc::new(RwLock::new(workspace_root.clone()));
86 Self {
87 workspace_root,
88 cwd,
89 spark_file,
90 subagent_spawner: RwLock::new(None),
91 }
92 }
93
94 #[must_use]
100 pub fn with_shared_cwd(
101 workspace_root: PathBuf,
102 cwd: Arc<RwLock<PathBuf>>,
103 spark_file: Option<PathBuf>,
104 ) -> Self {
105 Self {
106 workspace_root,
107 cwd,
108 spark_file,
109 subagent_spawner: RwLock::new(None),
110 }
111 }
112
113 pub async fn set_subagent_spawner(&self, spawner: Option<Arc<dyn SubAgentSpawner>>) {
115 *self.subagent_spawner.write().await = spawner;
116 }
117
118 pub async fn subagent_spawner(&self) -> Option<Arc<dyn SubAgentSpawner>> {
120 self.subagent_spawner.read().await.clone()
121 }
122}
123
124#[derive(Debug, thiserror::Error)]
126pub enum ToolError {
127 #[error("I/O error: {0}")]
129 Io(#[from] std::io::Error),
130
131 #[error("Invalid arguments: {0}")]
133 InvalidArguments(String),
134
135 #[error("Execution failed: {0}")]
137 ExecutionFailed(String),
138
139 #[error("Path not found: {0}")]
141 PathNotFound(String),
142
143 #[error("Timeout after {0}ms")]
145 Timeout(u64),
146
147 #[error("{0}")]
149 Other(String),
150}
151
152pub type ToolResult = Result<String, ToolError>;
154
155pub struct ToolRegistry {
157 tools: HashMap<String, Box<dyn BuiltinTool>>,
158}
159
160impl ToolRegistry {
161 #[must_use]
163 pub fn new() -> Self {
164 Self {
165 tools: HashMap::new(),
166 }
167 }
168
169 #[must_use]
171 pub fn with_defaults() -> Self {
172 let mut registry = Self::new();
173 registry.register(Box::new(ReadFileTool));
174 registry.register(Box::new(WriteFileTool));
175 registry.register(Box::new(EditFileTool));
176 registry.register(Box::new(GlobTool));
177 registry.register(Box::new(GrepTool));
178 registry.register(Box::new(BashTool));
179 registry.register(Box::new(ListDirectoryTool));
180 registry.register(Box::new(TaskTool));
181 registry.register(Box::new(SparkTool));
182 registry
183 }
184
185 pub fn register(&mut self, tool: Box<dyn BuiltinTool>) {
187 self.tools.insert(tool.name().to_string(), tool);
188 }
189
190 #[must_use]
192 pub fn get(&self, name: &str) -> Option<&dyn BuiltinTool> {
193 self.tools.get(name).map(AsRef::as_ref)
194 }
195
196 #[must_use]
198 pub fn is_builtin(name: &str) -> bool {
199 !name.contains(':')
200 }
201
202 #[must_use]
204 pub fn all_definitions(&self) -> Vec<LlmToolDefinition> {
205 self.tools
206 .values()
207 .map(|t| {
208 LlmToolDefinition::new(t.name())
209 .with_description(t.description())
210 .with_schema(t.input_schema())
211 })
212 .collect()
213 }
214}
215
216impl Default for ToolRegistry {
217 fn default() -> Self {
218 Self::new()
219 }
220}
221
222#[must_use]
226pub fn truncate_output(mut output: String) -> String {
227 if output.len() <= MAX_OUTPUT_CHARS {
228 return output;
229 }
230 let mut end = MAX_OUTPUT_CHARS;
231 while end > 0 && !output.is_char_boundary(end) {
232 end = end.saturating_sub(1);
233 }
234 output.truncate(end);
235 output.push_str("\n\n... (output truncated — exceeded 30000 character limit)");
236 output
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_is_builtin() {
245 assert!(ToolRegistry::is_builtin("read_file"));
246 assert!(ToolRegistry::is_builtin("bash"));
247 assert!(!ToolRegistry::is_builtin("filesystem:read_file"));
248 }
249
250 #[test]
251 fn test_registry_with_defaults() {
252 let registry = ToolRegistry::with_defaults();
253 assert!(registry.get("read_file").is_some());
254 assert!(registry.get("write_file").is_some());
255 assert!(registry.get("edit_file").is_some());
256 assert!(registry.get("glob").is_some());
257 assert!(registry.get("grep").is_some());
258 assert!(registry.get("bash").is_some());
259 assert!(registry.get("list_directory").is_some());
260 assert!(registry.get("task").is_some());
261 assert!(registry.get("spark").is_some());
262 assert!(registry.get("nonexistent").is_none());
263 }
264
265 #[test]
266 fn test_all_definitions() {
267 let registry = ToolRegistry::with_defaults();
268 let defs = registry.all_definitions();
269 assert_eq!(defs.len(), 9);
270 for def in &defs {
271 assert!(!def.name.contains(':'));
272 assert!(def.description.is_some());
273 }
274 }
275
276 #[test]
277 fn test_truncate_output_small() {
278 let small = "hello".to_string();
279 assert_eq!(truncate_output(small.clone()), small);
280 }
281
282 #[test]
283 fn test_truncate_output_large() {
284 let large = "x".repeat(40_000);
285 let result = truncate_output(large);
286 assert!(result.len() < 40_000);
287 assert!(result.contains("output truncated"));
288 }
289
290 #[test]
291 fn test_truncate_output_multibyte_boundary() {
292 let mut s = "x".repeat(MAX_OUTPUT_CHARS - 2);
293 s.push('🦀');
294 s.push_str("y".repeat(100).as_str());
295 let result = truncate_output(s);
298 assert!(result.len() < MAX_OUTPUT_CHARS + 100);
300 assert!(result.starts_with(&"x".repeat(MAX_OUTPUT_CHARS - 2)));
301 assert!(result.contains("output truncated"));
302 }
303}