1pub mod agent;
10pub mod bash;
11#[cfg(feature = "deagle")]
12pub mod deagle;
13pub mod edit;
14#[cfg(test)]
15mod edit_tests;
16pub mod file;
17pub mod git;
18pub mod lsp_tool;
19pub mod mise;
20pub mod native_search;
21pub mod native;
22pub mod search;
23pub mod ares_bridge;
24
25use async_trait::async_trait;
26use serde_json::Value;
27use std::collections::HashMap;
28use std::sync::Arc;
29
30pub use thulp_core::ToolDefinition;
37
38#[async_trait]
40pub trait Tool: Send + Sync {
41 fn name(&self) -> &str;
43
44 fn description(&self) -> &str;
46
47 fn mutating(&self) -> bool {
52 false }
54
55 fn parameters_schema(&self) -> Value;
57
58 async fn execute(&self, args: Value) -> crate::Result<Value>;
60 fn thulp_definition(&self) -> thulp_core::ToolDefinition {
63 let params = thulp_core::ToolDefinition::parse_mcp_input_schema(&self.parameters_schema())
64 .unwrap_or_default();
65 thulp_core::ToolDefinition::builder(self.name())
66 .description(self.description())
67 .parameters(params)
68 .build()
69 }
70
71 fn validate_args(&self, args: &Value) -> std::result::Result<(), String> {
74 self.thulp_definition()
75 .validate_args(args)
76 .map_err(|e| e.to_string())
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq)]
84pub enum ToolTier {
85 Core,
87 Standard,
89 Extended,
91}
92
93pub struct ToolRegistry {
98 tools: HashMap<String, Arc<dyn Tool>>,
99 tiers: HashMap<String, ToolTier>,
100 activated: std::sync::Mutex<std::collections::HashSet<String>>,
102 tool_text_cache: HashMap<String, String>,
104}
105
106impl ToolRegistry {
107 pub fn new() -> Self {
109 Self {
110 tools: HashMap::new(),
111 tiers: HashMap::new(),
112 activated: std::sync::Mutex::new(std::collections::HashSet::new()),
113 tool_text_cache: HashMap::new(),
114 }
115 }
116
117 pub fn with_defaults(workspace_root: std::path::PathBuf) -> Self {
123 let mut registry = Self::new();
124 use ToolTier::*;
125
126 registry.register_with_tier(Arc::new(bash::BashTool::new(workspace_root.clone())), Core);
128 registry.register_with_tier(Arc::new(file::ReadFileTool::new(workspace_root.clone())), Core);
129 registry.register_with_tier(Arc::new(file::WriteFileTool::new(workspace_root.clone())), Core);
130 registry.register_with_tier(Arc::new(edit::EditFileTool::new(workspace_root.clone())), Core);
131 registry.register_with_tier(Arc::new(native::AstGrepTool::new(workspace_root.clone())), Core);
132 registry.register_with_tier(Arc::new(native::GlobSearchTool::new(workspace_root.clone())), Core);
133 registry.register_with_tier(Arc::new(native::GrepSearchTool::new(workspace_root.clone())), Core);
134
135 registry.register_with_tier(Arc::new(file::ListDirectoryTool::new(workspace_root.clone())), Standard);
137 registry.register_with_tier(Arc::new(edit::EditFileLinesTool::new(workspace_root.clone())), Standard);
138 registry.register_with_tier(Arc::new(edit::InsertAfterTool::new(workspace_root.clone())), Standard);
139 registry.register_with_tier(Arc::new(edit::AppendFileTool::new(workspace_root.clone())), Standard);
140 registry.register_with_tier(Arc::new(git::GitStatusTool::new(workspace_root.clone())), Standard);
141 registry.register_with_tier(Arc::new(git::GitDiffTool::new(workspace_root.clone())), Standard);
142 registry.register_with_tier(Arc::new(git::GitAddTool::new(workspace_root.clone())), Standard);
143 registry.register_with_tier(Arc::new(git::GitCommitTool::new(workspace_root.clone())), Standard);
144 registry.register_with_tier(Arc::new(git::GitLogTool::new(workspace_root.clone())), Standard);
145 registry.register_with_tier(Arc::new(git::GitBlameTool::new(workspace_root.clone())), Standard);
146 registry.register_with_tier(Arc::new(git::GitBranchTool::new(workspace_root.clone())), Standard);
147 registry.register_with_tier(Arc::new(git::GitCheckoutTool::new(workspace_root.clone())), Standard);
148 registry.register_with_tier(Arc::new(git::GitStashTool::new(workspace_root.clone())), Standard);
149 registry.register_with_tier(Arc::new(agent::SpawnAgentsTool::new(workspace_root.clone())), Standard);
150 registry.register_with_tier(Arc::new(agent::SpawnAgentTool::new(workspace_root.clone())), Standard);
151
152 registry.register_with_tier(Arc::new(native::RipgrepTool::new(workspace_root.clone())), Extended);
154 registry.register_with_tier(Arc::new(native::FdTool::new(workspace_root.clone())), Extended);
155 registry.register_with_tier(Arc::new(native::SdTool::new(workspace_root.clone())), Extended);
156 registry.register_with_tier(Arc::new(native::ErdTool::new(workspace_root.clone())), Extended);
157 registry.register_with_tier(Arc::new(native::MiseTool::new(workspace_root.clone())), Extended);
158 registry.register_with_tier(Arc::new(native::ZoxideTool::new(workspace_root.clone())), Extended);
159 registry.register_with_tier(Arc::new(native::LspTool::new(workspace_root.clone())), Extended);
160
161 #[cfg(feature = "deagle")]
163 {
164 registry.register_with_tier(Arc::new(deagle::DeagleSearchTool::new(workspace_root.clone())), Extended);
165 registry.register_with_tier(Arc::new(deagle::DeagleKeywordTool::new(workspace_root.clone())), Extended);
166 registry.register_with_tier(Arc::new(deagle::DeagleSgTool::new(workspace_root.clone())), Extended);
167 registry.register_with_tier(Arc::new(deagle::DeagleStatsTool::new(workspace_root.clone())), Extended);
168 registry.register_with_tier(Arc::new(deagle::DeagleMapTool::new(workspace_root)), Extended);
169 }
170
171 registry
172 }
173
174 pub fn register(&mut self, tool: Arc<dyn Tool>) {
176 self.register_with_tier(tool, ToolTier::Standard);
177 }
178
179 pub fn register_with_tier(&mut self, tool: Arc<dyn Tool>, tier: ToolTier) {
181 let name = tool.name().to_string();
182 let cached_text = format!("{} {}", name, tool.description()).to_lowercase();
183 self.tool_text_cache.insert(name.clone(), cached_text);
184 self.tiers.insert(name.clone(), tier);
185 self.tools.insert(name, tool);
186 }
187
188 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
190 self.tools.get(name)
191 }
192
193 pub fn has_tool(&self, name: &str) -> bool {
195 self.tools.contains_key(name)
196 }
197
198 pub async fn execute(&self, name: &str, args: Value) -> crate::Result<Value> {
200 match self.tools.get(name) {
201 Some(tool) => tool.execute(args).await,
202 None => Err(crate::PawanError::NotFound(format!(
203 "Tool not found: {}",
204 name
205 ))),
206 }
207 }
208
209 pub fn get_definitions(&self) -> Vec<ToolDefinition> {
212 let activated = self.activated.lock().unwrap_or_else(|e| e.into_inner());
213 self.tools.iter()
214 .filter(|(name, _)| {
215 match self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard) {
216 ToolTier::Core | ToolTier::Standard => true,
217 ToolTier::Extended => activated.contains(name.as_str()),
218 }
219 })
220 .map(|(_, tool)| tool.thulp_definition())
221 .collect()
222 }
223
224 pub fn select_for_query(&self, query: &str, max_tools: usize) -> Vec<ToolDefinition> {
230 let query_lower = query.to_lowercase();
231 let query_words: Vec<&str> = query_lower.split_whitespace().collect();
232
233 let mut scored: Vec<(i32, String)> = Vec::new();
234
235 for name in self.tools.keys() {
236 let tier = self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard);
237
238 if tier == ToolTier::Core { continue; }
240
241 let tool_text = self.tool_text_cache.get(name.as_str())
243 .map(|s| s.as_str())
244 .unwrap_or("");
245 let mut score: i32 = 0;
246
247 for word in &query_words {
248 if word.len() < 3 { continue; } if tool_text.contains(word) { score += 2; }
250 }
251
252 let search_words = ["search", "find", "web", "query", "look", "google", "bing", "wikipedia"];
254 let git_words = ["git", "commit", "branch", "diff", "status", "log", "stash", "checkout", "blame"];
255 let file_words = ["file", "read", "write", "edit", "append", "insert", "directory", "list"];
256 let code_words = ["refactor", "rename", "replace", "ast", "lsp", "symbol", "function", "struct"];
257 let tool_words = ["install", "mise", "tool", "runtime", "build", "test", "cargo"];
258
259 for word in &query_words {
260 if search_words.contains(word) && tool_text.contains("search") { score += 3; }
261 if git_words.contains(word) && tool_text.contains("git") { score += 3; }
262 if file_words.contains(word) && (tool_text.contains("file") || tool_text.contains("edit")) { score += 3; }
263 if code_words.contains(word) && (tool_text.contains("ast") || tool_text.contains("lsp")) { score += 3; }
264 if tool_words.contains(word) && tool_text.contains("mise") { score += 3; }
265 }
266
267 if name.starts_with("mcp_") {
269 score += 1;
270 if name.contains("search") || name.contains("web") {
271 let web_words = ["web", "search", "internet", "online", "find", "look up", "google"];
272 if web_words.iter().any(|w| query_lower.contains(w)) {
273 score += 10; }
275 }
276 }
277
278 let activated = self.activated.lock().unwrap_or_else(|e| e.into_inner());
280 if tier == ToolTier::Extended && activated.contains(name.as_str()) { score += 2; }
281
282 if score > 0 || tier == ToolTier::Standard {
283 scored.push((score, name.clone()));
284 }
285 }
286
287 scored.sort_by(|a, b| b.0.cmp(&a.0));
289
290 let mut result: Vec<ToolDefinition> = self.tools.iter()
292 .filter(|(name, _)| {
293 self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard) == ToolTier::Core
294 })
295 .map(|(_, tool)| tool.thulp_definition())
296 .collect();
297
298 let remaining_slots = max_tools.saturating_sub(result.len());
299 for (_, name) in scored.into_iter().take(remaining_slots) {
300 if let Some(tool) = self.tools.get(&name) {
301 result.push(tool.thulp_definition());
302 }
303 }
304
305 result
306 }
307
308 pub fn get_all_definitions(&self) -> Vec<ToolDefinition> {
310 self.tools.values().map(|t| t.thulp_definition()).collect()
311 }
312
313 pub fn activate(&self, name: &str) {
315 if self.tools.contains_key(name) {
316 self.activated.lock().unwrap_or_else(|e| e.into_inner()).insert(name.to_string());
317 }
318 }
319
320 pub fn tool_names(&self) -> Vec<&str> {
322 self.tools.keys().map(|s| s.as_str()).collect()
323 }
324
325 pub fn query_tools(&self, query: &str) -> Vec<thulp_core::ToolDefinition> {
342 let criteria = match thulp_query::parse_query(query) {
343 Ok(c) => c,
344 Err(e) => {
345 tracing::warn!(query = %query, error = %e, "failed to parse tool query");
346 return Vec::new();
347 }
348 };
349
350 self.tools
351 .values()
352 .map(|tool| tool.thulp_definition())
353 .filter(|def| criteria.matches(def))
354 .collect()
355 }
356}
357
358impl Default for ToolRegistry {
359 fn default() -> Self {
360 Self::new()
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use std::path::PathBuf;
368
369 #[test]
370 fn test_registry_new_is_empty() {
371 let registry = ToolRegistry::new();
372 assert!(registry.tool_names().is_empty());
373 assert!(!registry.has_tool("bash"));
374 assert!(registry.get("nonexistent").is_none());
375 }
376
377 #[test]
378 fn test_registry_with_defaults_contains_core_tools() {
379 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
380 for name in &["bash", "read_file", "write_file", "edit_file", "grep_search", "glob_search"] {
382 assert!(
383 registry.has_tool(name),
384 "default registry missing core tool: {}",
385 name
386 );
387 }
388 assert!(registry.has_tool("git_status"));
390 assert!(registry.has_tool("git_commit"));
391 assert!(registry.has_tool("rg"));
393 assert!(registry.has_tool("fd"));
394 }
395
396 #[test]
397 fn test_registry_get_definitions_hides_extended_until_activated() {
398 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
399 let initial: Vec<String> = registry
400 .get_definitions()
401 .iter()
402 .map(|d| d.name.clone())
403 .collect();
404
405 assert!(!initial.contains(&"rg".to_string()), "rg should be hidden until activated");
407 assert!(!initial.contains(&"fd".to_string()), "fd should be hidden until activated");
408 assert!(initial.contains(&"bash".to_string()));
410 assert!(initial.contains(&"read_file".to_string()));
411
412 registry.activate("rg");
414 let after: Vec<String> = registry
415 .get_definitions()
416 .iter()
417 .map(|d| d.name.clone())
418 .collect();
419 assert!(after.contains(&"rg".to_string()), "rg should be visible after activate");
420 assert!(after.len() > initial.len(), "activation should grow visible set");
421 }
422
423 #[test]
424 fn test_registry_get_all_definitions_returns_everything() {
425 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
426 let all = registry.get_all_definitions();
427 let visible = registry.get_definitions();
428 assert!(
430 all.len() > visible.len(),
431 "get_all_definitions ({}) should include hidden extended tools beyond get_definitions ({})",
432 all.len(),
433 visible.len()
434 );
435 let all_names: Vec<String> = all.iter().map(|d| d.name.clone()).collect();
437 assert!(all_names.contains(&"rg".to_string()));
438 }
439
440 #[test]
441 fn test_registry_query_tools_filters_by_dsl() {
442 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
443 let bash_match = registry.query_tools("name:bash");
445 assert!(
446 !bash_match.is_empty(),
447 "query_tools('name:bash') should match the bash tool"
448 );
449 let names: Vec<String> = bash_match.iter().map(|d| d.name.clone()).collect();
450 assert!(names.contains(&"bash".to_string()));
451
452 let no_match = registry.query_tools("name:definitely_not_a_tool_xyz");
454 assert!(
455 no_match.is_empty(),
456 "query_tools for nonexistent name should return empty, got {:?}",
457 no_match.iter().map(|d| &d.name).collect::<Vec<_>>()
458 );
459 }
460
461 struct MockTool {
463 name: String,
464 description: String,
465 return_value: Value,
466 }
467
468 impl MockTool {
469 fn new(name: &str, description: &str, return_value: Value) -> Self {
470 Self {
471 name: name.to_string(),
472 description: description.to_string(),
473 return_value,
474 }
475 }
476 }
477
478 #[async_trait]
479 impl Tool for MockTool {
480 fn name(&self) -> &str {
481 &self.name
482 }
483 fn description(&self) -> &str {
484 &self.description
485 }
486 fn parameters_schema(&self) -> Value {
487 serde_json::json!({ "type": "object", "properties": {} })
488 }
489 async fn execute(&self, _args: Value) -> crate::Result<Value> {
490 Ok(self.return_value.clone())
491 }
492 }
493
494 #[test]
495 fn test_register_defaults_to_standard_tier() {
496 let mut registry = ToolRegistry::new();
497 registry.register(Arc::new(MockTool::new(
498 "mock_std",
499 "a test mock",
500 Value::Null,
501 )));
502 let visible: Vec<String> = registry
504 .get_definitions()
505 .iter()
506 .map(|d| d.name.clone())
507 .collect();
508 assert!(
509 visible.contains(&"mock_std".to_string()),
510 "register() should default to Standard tier (visible without activation), got {:?}",
511 visible
512 );
513 }
514
515 #[test]
516 fn test_register_with_tier_overwrites_same_name() {
517 let mut registry = ToolRegistry::new();
518 registry.register_with_tier(
519 Arc::new(MockTool::new("dup", "first registration", Value::Null)),
520 ToolTier::Standard,
521 );
522 registry.register_with_tier(
523 Arc::new(MockTool::new("dup", "second registration", Value::Null)),
524 ToolTier::Core,
525 );
526
527 let names = registry.tool_names();
530 assert_eq!(
531 names.iter().filter(|n| **n == "dup").count(),
532 1,
533 "register_with_tier of an existing name must replace, not duplicate"
534 );
535 let def = registry.get("dup").expect("dup should exist after overwrite");
536 assert_eq!(def.description(), "second registration");
537 let visible: Vec<String> = registry
539 .get_definitions()
540 .iter()
541 .map(|d| d.name.clone())
542 .collect();
543 assert!(visible.contains(&"dup".to_string()));
544 }
545
546 #[tokio::test]
547 async fn test_execute_dispatches_to_registered_tool() {
548 let mut registry = ToolRegistry::new();
549 registry.register(Arc::new(MockTool::new(
550 "echo",
551 "returns a fixed value",
552 serde_json::json!({ "answer": 42 }),
553 )));
554
555 let out = registry
556 .execute("echo", Value::Null)
557 .await
558 .expect("execute on a registered tool should succeed");
559 assert_eq!(out, serde_json::json!({ "answer": 42 }));
560 }
561
562 #[tokio::test]
563 async fn test_execute_unknown_tool_returns_not_found() {
564 let registry = ToolRegistry::new();
565 let err = registry
566 .execute("nonexistent_tool", Value::Null)
567 .await
568 .expect_err("execute on missing tool should fail");
569 match err {
570 crate::PawanError::NotFound(msg) => {
571 assert!(msg.contains("nonexistent_tool"), "error should name the missing tool, got: {}", msg);
572 }
573 other => panic!("expected PawanError::NotFound, got: {:?}", other),
574 }
575 }
576
577 #[test]
578 fn test_select_for_query_always_includes_core_tools() {
579 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
580 let selected = registry.select_for_query("xyzzy plover", 5);
583 let names: Vec<String> = selected.iter().map(|d| d.name.clone()).collect();
584 for core in &["bash", "read_file", "write_file", "edit_file", "grep_search", "glob_search", "ast_grep"] {
585 assert!(
586 names.contains(&core.to_string()),
587 "select_for_query must include core tool {} regardless of query, got {:?}",
588 core,
589 names
590 );
591 }
592 }
593
594 #[test]
595 fn test_select_for_query_caps_at_max_tools_when_possible() {
596 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
597 let selected = registry.select_for_query("git commit my changes", 10);
601 assert!(
602 selected.len() <= 10,
603 "select_for_query(max=10) returned {} tools, must not exceed cap",
604 selected.len()
605 );
606 let names: Vec<String> = selected.iter().map(|d| d.name.clone()).collect();
608 assert!(
609 names.iter().any(|n| n.starts_with("git_")),
610 "git query should pull in at least one git_ tool, got {:?}",
611 names
612 );
613 }
614
615 #[test]
616 fn test_activate_no_op_for_unknown_tool_does_not_panic() {
617 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
618 registry.activate("not_a_real_tool_at_all");
621 let visible: Vec<String> = registry
622 .get_definitions()
623 .iter()
624 .map(|d| d.name.clone())
625 .collect();
626 assert!(
627 !visible.contains(&"not_a_real_tool_at_all".to_string()),
628 "activate of unknown tool must not make it visible"
629 );
630 }
631
632 #[test]
633 fn test_tool_names_lists_every_registered_tool() {
634 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
635 let names = registry.tool_names();
636 assert!(
640 names.len() >= 30,
641 "default registry should expose >=30 tools via tool_names(), got {}",
642 names.len()
643 );
644 for name in &names {
646 assert!(registry.has_tool(name));
647 assert!(registry.get(name).is_some());
648 }
649 }
650
651 #[test]
652 fn test_default_impl_returns_empty_registry() {
653 let registry = ToolRegistry::default();
654 assert!(registry.tool_names().is_empty());
655 assert!(registry.get_definitions().is_empty());
656 }
657}