1pub mod agent;
10pub mod bash;
11pub mod edit;
12#[cfg(test)]
13mod edit_tests;
14pub mod file;
15pub mod git;
16pub mod native;
17pub mod search;
18
19#[cfg(feature = "ares")]
20pub mod ares_bridge;
21
22use async_trait::async_trait;
23use serde_json::Value;
24use std::collections::HashMap;
25use std::sync::Arc;
26
27#[derive(Debug, Clone)]
29pub struct ToolDefinition {
30 pub name: String,
32 pub description: String,
34 pub parameters: Value,
36}
37
38#[async_trait]
40pub trait Tool: Send + Sync {
41 fn name(&self) -> &str;
43
44 fn description(&self) -> &str;
46
47 fn parameters_schema(&self) -> Value;
49
50 async fn execute(&self, args: Value) -> crate::Result<Value>;
52
53 fn thulp_definition(&self) -> thulp_core::ToolDefinition {
57 let params = thulp_core::ToolDefinition::parse_mcp_input_schema(&self.parameters_schema())
58 .unwrap_or_default();
59 thulp_core::ToolDefinition::builder(self.name())
60 .description(self.description())
61 .parameters(params)
62 .build()
63 }
64
65 fn validate_args(&self, args: &Value) -> std::result::Result<(), String> {
68 self.thulp_definition()
69 .validate_args(args)
70 .map_err(|e| e.to_string())
71 }
72
73 fn to_definition(&self) -> ToolDefinition {
75 ToolDefinition {
76 name: self.name().to_string(),
77 description: self.description().to_string(),
78 parameters: self.parameters_schema(),
79 }
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq)]
87pub enum ToolTier {
88 Core,
90 Standard,
92 Extended,
94}
95
96pub struct ToolRegistry {
101 tools: HashMap<String, Arc<dyn Tool>>,
102 tiers: HashMap<String, ToolTier>,
103 activated: std::sync::Mutex<std::collections::HashSet<String>>,
105 tool_text_cache: HashMap<String, String>,
107}
108
109impl ToolRegistry {
110 pub fn new() -> Self {
112 Self {
113 tools: HashMap::new(),
114 tiers: HashMap::new(),
115 activated: std::sync::Mutex::new(std::collections::HashSet::new()),
116 tool_text_cache: HashMap::new(),
117 }
118 }
119
120 pub fn with_defaults(workspace_root: std::path::PathBuf) -> Self {
126 let mut registry = Self::new();
127 use ToolTier::*;
128
129 registry.register_with_tier(Arc::new(bash::BashTool::new(workspace_root.clone())), Core);
131 registry.register_with_tier(Arc::new(file::ReadFileTool::new(workspace_root.clone())), Core);
132 registry.register_with_tier(Arc::new(file::WriteFileTool::new(workspace_root.clone())), Core);
133 registry.register_with_tier(Arc::new(edit::EditFileTool::new(workspace_root.clone())), Core);
134 registry.register_with_tier(Arc::new(native::AstGrepTool::new(workspace_root.clone())), Core);
135 registry.register_with_tier(Arc::new(native::GlobSearchTool::new(workspace_root.clone())), Core);
136 registry.register_with_tier(Arc::new(native::GrepSearchTool::new(workspace_root.clone())), Core);
137
138 registry.register_with_tier(Arc::new(file::ListDirectoryTool::new(workspace_root.clone())), Standard);
140 registry.register_with_tier(Arc::new(edit::EditFileLinesTool::new(workspace_root.clone())), Standard);
141 registry.register_with_tier(Arc::new(edit::InsertAfterTool::new(workspace_root.clone())), Standard);
142 registry.register_with_tier(Arc::new(edit::AppendFileTool::new(workspace_root.clone())), Standard);
143 registry.register_with_tier(Arc::new(git::GitStatusTool::new(workspace_root.clone())), Standard);
144 registry.register_with_tier(Arc::new(git::GitDiffTool::new(workspace_root.clone())), Standard);
145 registry.register_with_tier(Arc::new(git::GitAddTool::new(workspace_root.clone())), Standard);
146 registry.register_with_tier(Arc::new(git::GitCommitTool::new(workspace_root.clone())), Standard);
147 registry.register_with_tier(Arc::new(git::GitLogTool::new(workspace_root.clone())), Standard);
148 registry.register_with_tier(Arc::new(git::GitBlameTool::new(workspace_root.clone())), Standard);
149 registry.register_with_tier(Arc::new(git::GitBranchTool::new(workspace_root.clone())), Standard);
150 registry.register_with_tier(Arc::new(git::GitCheckoutTool::new(workspace_root.clone())), Standard);
151 registry.register_with_tier(Arc::new(git::GitStashTool::new(workspace_root.clone())), Standard);
152 registry.register_with_tier(Arc::new(agent::SpawnAgentsTool::new(workspace_root.clone())), Standard);
153 registry.register_with_tier(Arc::new(agent::SpawnAgentTool::new(workspace_root.clone())), Standard);
154
155 registry.register_with_tier(Arc::new(native::RipgrepTool::new(workspace_root.clone())), Extended);
157 registry.register_with_tier(Arc::new(native::FdTool::new(workspace_root.clone())), Extended);
158 registry.register_with_tier(Arc::new(native::SdTool::new(workspace_root.clone())), Extended);
159 registry.register_with_tier(Arc::new(native::ErdTool::new(workspace_root.clone())), Extended);
160 registry.register_with_tier(Arc::new(native::MiseTool::new(workspace_root.clone())), Extended);
161 registry.register_with_tier(Arc::new(native::ZoxideTool::new(workspace_root.clone())), Extended);
162 registry.register_with_tier(Arc::new(native::LspTool::new(workspace_root)), Extended);
163
164 registry
165 }
166
167 pub fn register(&mut self, tool: Arc<dyn Tool>) {
169 self.register_with_tier(tool, ToolTier::Standard);
170 }
171
172 pub fn register_with_tier(&mut self, tool: Arc<dyn Tool>, tier: ToolTier) {
174 let name = tool.name().to_string();
175 let cached_text = format!("{} {}", name, tool.description()).to_lowercase();
176 self.tool_text_cache.insert(name.clone(), cached_text);
177 self.tiers.insert(name.clone(), tier);
178 self.tools.insert(name, tool);
179 }
180
181 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
183 self.tools.get(name)
184 }
185
186 pub fn has_tool(&self, name: &str) -> bool {
188 self.tools.contains_key(name)
189 }
190
191 pub async fn execute(&self, name: &str, args: Value) -> crate::Result<Value> {
193 match self.tools.get(name) {
194 Some(tool) => tool.execute(args).await,
195 None => Err(crate::PawanError::NotFound(format!(
196 "Tool not found: {}",
197 name
198 ))),
199 }
200 }
201
202 pub fn get_definitions(&self) -> Vec<ToolDefinition> {
205 let activated = self.activated.lock().unwrap_or_else(|e| e.into_inner());
206 self.tools.iter()
207 .filter(|(name, _)| {
208 match self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard) {
209 ToolTier::Core | ToolTier::Standard => true,
210 ToolTier::Extended => activated.contains(name.as_str()),
211 }
212 })
213 .map(|(_, tool)| tool.to_definition())
214 .collect()
215 }
216
217 pub fn select_for_query(&self, query: &str, max_tools: usize) -> Vec<ToolDefinition> {
223 let query_lower = query.to_lowercase();
224 let query_words: Vec<&str> = query_lower.split_whitespace().collect();
225
226 let mut scored: Vec<(i32, String)> = Vec::new();
227
228 for name in self.tools.keys() {
229 let tier = self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard);
230
231 if tier == ToolTier::Core { continue; }
233
234 let tool_text = self.tool_text_cache.get(name.as_str())
236 .map(|s| s.as_str())
237 .unwrap_or("");
238 let mut score: i32 = 0;
239
240 for word in &query_words {
241 if word.len() < 3 { continue; } if tool_text.contains(word) { score += 2; }
243 }
244
245 let search_words = ["search", "find", "web", "query", "look", "google", "bing", "wikipedia"];
247 let git_words = ["git", "commit", "branch", "diff", "status", "log", "stash", "checkout", "blame"];
248 let file_words = ["file", "read", "write", "edit", "append", "insert", "directory", "list"];
249 let code_words = ["refactor", "rename", "replace", "ast", "lsp", "symbol", "function", "struct"];
250 let tool_words = ["install", "mise", "tool", "runtime", "build", "test", "cargo"];
251
252 for word in &query_words {
253 if search_words.contains(word) && tool_text.contains("search") { score += 3; }
254 if git_words.contains(word) && tool_text.contains("git") { score += 3; }
255 if file_words.contains(word) && (tool_text.contains("file") || tool_text.contains("edit")) { score += 3; }
256 if code_words.contains(word) && (tool_text.contains("ast") || tool_text.contains("lsp")) { score += 3; }
257 if tool_words.contains(word) && tool_text.contains("mise") { score += 3; }
258 }
259
260 if name.starts_with("mcp_") {
262 score += 1;
263 if name.contains("search") || name.contains("web") {
264 let web_words = ["web", "search", "internet", "online", "find", "look up", "google"];
265 if web_words.iter().any(|w| query_lower.contains(w)) {
266 score += 10; }
268 }
269 }
270
271 let activated = self.activated.lock().unwrap_or_else(|e| e.into_inner());
273 if tier == ToolTier::Extended && activated.contains(name.as_str()) { score += 2; }
274
275 if score > 0 || tier == ToolTier::Standard {
276 scored.push((score, name.clone()));
277 }
278 }
279
280 scored.sort_by(|a, b| b.0.cmp(&a.0));
282
283 let mut result: Vec<ToolDefinition> = self.tools.iter()
285 .filter(|(name, _)| {
286 self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard) == ToolTier::Core
287 })
288 .map(|(_, tool)| tool.to_definition())
289 .collect();
290
291 let remaining_slots = max_tools.saturating_sub(result.len());
292 for (_, name) in scored.into_iter().take(remaining_slots) {
293 if let Some(tool) = self.tools.get(&name) {
294 result.push(tool.to_definition());
295 }
296 }
297
298 result
299 }
300
301 pub fn get_all_definitions(&self) -> Vec<ToolDefinition> {
303 self.tools.values().map(|t| t.to_definition()).collect()
304 }
305
306 pub fn activate(&self, name: &str) {
308 if self.tools.contains_key(name) {
309 self.activated.lock().unwrap_or_else(|e| e.into_inner()).insert(name.to_string());
310 }
311 }
312
313 pub fn tool_names(&self) -> Vec<&str> {
315 self.tools.keys().map(|s| s.as_str()).collect()
316 }
317}
318
319impl Default for ToolRegistry {
320 fn default() -> Self {
321 Self::new()
322 }
323}