1pub mod agent;
10pub mod bash;
11pub mod edit;
12#[cfg(test)]
13mod edit_tests;
14pub mod file;
15pub mod git;
16pub mod native;
17pub mod search;
18
19#[cfg(feature = "ares")]
20pub mod ares_bridge;
21
22use async_trait::async_trait;
23use serde_json::Value;
24use std::collections::HashMap;
25use std::sync::Arc;
26
27#[derive(Debug, Clone)]
29pub struct ToolDefinition {
30 pub name: String,
32 pub description: String,
34 pub parameters: Value,
36}
37
38#[async_trait]
40pub trait Tool: Send + Sync {
41 fn name(&self) -> &str;
43
44 fn description(&self) -> &str;
46
47 fn parameters_schema(&self) -> Value;
49
50 async fn execute(&self, args: Value) -> crate::Result<Value>;
52
53 fn to_definition(&self) -> ToolDefinition {
55 ToolDefinition {
56 name: self.name().to_string(),
57 description: self.description().to_string(),
58 parameters: self.parameters_schema(),
59 }
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq)]
67pub enum ToolTier {
68 Core,
70 Standard,
72 Extended,
74}
75
76pub struct ToolRegistry {
81 tools: HashMap<String, Arc<dyn Tool>>,
82 tiers: HashMap<String, ToolTier>,
83 activated: std::sync::Mutex<std::collections::HashSet<String>>,
85}
86
87impl ToolRegistry {
88 pub fn new() -> Self {
90 Self {
91 tools: HashMap::new(),
92 tiers: HashMap::new(),
93 activated: std::sync::Mutex::new(std::collections::HashSet::new()),
94 }
95 }
96
97 pub fn with_defaults(workspace_root: std::path::PathBuf) -> Self {
103 let mut registry = Self::new();
104 use ToolTier::*;
105
106 registry.register_with_tier(Arc::new(bash::BashTool::new(workspace_root.clone())), Core);
108 registry.register_with_tier(Arc::new(file::ReadFileTool::new(workspace_root.clone())), Core);
109 registry.register_with_tier(Arc::new(file::WriteFileTool::new(workspace_root.clone())), Core);
110 registry.register_with_tier(Arc::new(edit::EditFileTool::new(workspace_root.clone())), Core);
111 registry.register_with_tier(Arc::new(native::AstGrepTool::new(workspace_root.clone())), Core);
112 registry.register_with_tier(Arc::new(native::GlobSearchTool::new(workspace_root.clone())), Core);
113 registry.register_with_tier(Arc::new(native::GrepSearchTool::new(workspace_root.clone())), Core);
114
115 registry.register_with_tier(Arc::new(file::ListDirectoryTool::new(workspace_root.clone())), Standard);
117 registry.register_with_tier(Arc::new(edit::EditFileLinesTool::new(workspace_root.clone())), Standard);
118 registry.register_with_tier(Arc::new(edit::InsertAfterTool::new(workspace_root.clone())), Standard);
119 registry.register_with_tier(Arc::new(edit::AppendFileTool::new(workspace_root.clone())), Standard);
120 registry.register_with_tier(Arc::new(git::GitStatusTool::new(workspace_root.clone())), Standard);
121 registry.register_with_tier(Arc::new(git::GitDiffTool::new(workspace_root.clone())), Standard);
122 registry.register_with_tier(Arc::new(git::GitAddTool::new(workspace_root.clone())), Standard);
123 registry.register_with_tier(Arc::new(git::GitCommitTool::new(workspace_root.clone())), Standard);
124 registry.register_with_tier(Arc::new(git::GitLogTool::new(workspace_root.clone())), Standard);
125 registry.register_with_tier(Arc::new(git::GitBlameTool::new(workspace_root.clone())), Standard);
126 registry.register_with_tier(Arc::new(git::GitBranchTool::new(workspace_root.clone())), Standard);
127 registry.register_with_tier(Arc::new(git::GitCheckoutTool::new(workspace_root.clone())), Standard);
128 registry.register_with_tier(Arc::new(git::GitStashTool::new(workspace_root.clone())), Standard);
129 registry.register_with_tier(Arc::new(agent::SpawnAgentsTool::new(workspace_root.clone())), Standard);
130 registry.register_with_tier(Arc::new(agent::SpawnAgentTool::new(workspace_root.clone())), Standard);
131
132 registry.register_with_tier(Arc::new(native::RipgrepTool::new(workspace_root.clone())), Extended);
134 registry.register_with_tier(Arc::new(native::FdTool::new(workspace_root.clone())), Extended);
135 registry.register_with_tier(Arc::new(native::SdTool::new(workspace_root.clone())), Extended);
136 registry.register_with_tier(Arc::new(native::ErdTool::new(workspace_root.clone())), Extended);
137 registry.register_with_tier(Arc::new(native::MiseTool::new(workspace_root.clone())), Extended);
138 registry.register_with_tier(Arc::new(native::ZoxideTool::new(workspace_root.clone())), Extended);
139 registry.register_with_tier(Arc::new(native::LspTool::new(workspace_root)), Extended);
140
141 registry
142 }
143
144 pub fn register(&mut self, tool: Arc<dyn Tool>) {
146 self.register_with_tier(tool, ToolTier::Standard);
147 }
148
149 pub fn register_with_tier(&mut self, tool: Arc<dyn Tool>, tier: ToolTier) {
151 let name = tool.name().to_string();
152 self.tiers.insert(name.clone(), tier);
153 self.tools.insert(name, tool);
154 }
155
156 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
158 self.tools.get(name)
159 }
160
161 pub fn has_tool(&self, name: &str) -> bool {
163 self.tools.contains_key(name)
164 }
165
166 pub async fn execute(&self, name: &str, args: Value) -> crate::Result<Value> {
168 match self.tools.get(name) {
169 Some(tool) => tool.execute(args).await,
170 None => Err(crate::PawanError::NotFound(format!(
171 "Tool not found: {}",
172 name
173 ))),
174 }
175 }
176
177 pub fn get_definitions(&self) -> Vec<ToolDefinition> {
180 let activated = self.activated.lock().unwrap();
181 self.tools.iter()
182 .filter(|(name, _)| {
183 match self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard) {
184 ToolTier::Core | ToolTier::Standard => true,
185 ToolTier::Extended => activated.contains(name.as_str()),
186 }
187 })
188 .map(|(_, tool)| tool.to_definition())
189 .collect()
190 }
191
192 pub fn select_for_query(&self, query: &str, max_tools: usize) -> Vec<ToolDefinition> {
198 let query_lower = query.to_lowercase();
199 let query_words: Vec<&str> = query_lower.split_whitespace().collect();
200
201 let mut scored: Vec<(i32, String)> = Vec::new();
202
203 for (name, tool) in &self.tools {
204 let tier = self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard);
205
206 if tier == ToolTier::Core { continue; }
208
209 let tool_text = format!("{} {}", name, tool.description()).to_lowercase();
211 let mut score: i32 = 0;
212
213 for word in &query_words {
214 if word.len() < 3 { continue; } if tool_text.contains(word) { score += 2; }
216 }
217
218 let search_words = ["search", "find", "web", "query", "look", "google", "bing", "wikipedia"];
220 let git_words = ["git", "commit", "branch", "diff", "status", "log", "stash", "checkout", "blame"];
221 let file_words = ["file", "read", "write", "edit", "append", "insert", "directory", "list"];
222 let code_words = ["refactor", "rename", "replace", "ast", "lsp", "symbol", "function", "struct"];
223 let tool_words = ["install", "mise", "tool", "runtime", "build", "test", "cargo"];
224
225 for word in &query_words {
226 if search_words.contains(word) && tool_text.contains("search") { score += 3; }
227 if git_words.contains(word) && tool_text.contains("git") { score += 3; }
228 if file_words.contains(word) && (tool_text.contains("file") || tool_text.contains("edit")) { score += 3; }
229 if code_words.contains(word) && (tool_text.contains("ast") || tool_text.contains("lsp")) { score += 3; }
230 if tool_words.contains(word) && tool_text.contains("mise") { score += 3; }
231 }
232
233 if name.starts_with("mcp_") {
235 score += 1;
236 if name.contains("search") || name.contains("web") {
237 let web_words = ["web", "search", "internet", "online", "find", "look up", "google"];
238 if web_words.iter().any(|w| query_lower.contains(w)) {
239 score += 10; }
241 }
242 }
243
244 let activated = self.activated.lock().unwrap();
246 if tier == ToolTier::Extended && activated.contains(name.as_str()) { score += 2; }
247
248 if score > 0 || tier == ToolTier::Standard {
249 scored.push((score, name.clone()));
250 }
251 }
252
253 scored.sort_by(|a, b| b.0.cmp(&a.0));
255
256 let mut result: Vec<ToolDefinition> = self.tools.iter()
258 .filter(|(name, _)| {
259 self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard) == ToolTier::Core
260 })
261 .map(|(_, tool)| tool.to_definition())
262 .collect();
263
264 let remaining_slots = max_tools.saturating_sub(result.len());
265 for (_, name) in scored.into_iter().take(remaining_slots) {
266 if let Some(tool) = self.tools.get(&name) {
267 result.push(tool.to_definition());
268 }
269 }
270
271 result
272 }
273
274 pub fn get_all_definitions(&self) -> Vec<ToolDefinition> {
276 self.tools.values().map(|t| t.to_definition()).collect()
277 }
278
279 pub fn activate(&self, name: &str) {
281 if self.tools.contains_key(name) {
282 self.activated.lock().unwrap().insert(name.to_string());
283 }
284 }
285
286 pub fn tool_names(&self) -> Vec<&str> {
288 self.tools.keys().map(|s| s.as_str()).collect()
289 }
290}
291
292impl Default for ToolRegistry {
293 fn default() -> Self {
294 Self::new()
295 }
296}