1pub mod agent;
10pub mod ares_bridge;
11pub mod bash;
12pub mod batch;
13#[cfg(feature = "deagle")]
14pub mod deagle;
15pub mod edit;
16#[cfg(test)]
17mod edit_tests;
18pub mod file;
19pub mod git;
20pub mod lsp_tool;
21pub mod mise;
22pub mod native;
23pub mod native_search;
24pub mod search;
25pub mod task;
26
27use async_trait::async_trait;
28use serde_json::Value;
29use std::collections::HashMap;
30use std::sync::Arc;
31
32pub use thulp_core::ToolDefinition;
39
40#[async_trait]
42pub trait Tool: Send + Sync {
43 fn name(&self) -> &str;
45
46 fn description(&self) -> &str;
48
49 fn mutating(&self) -> bool {
54 false }
56
57 fn parameters_schema(&self) -> Value;
59
60 async fn execute(&self, args: Value) -> crate::Result<Value>;
62 fn thulp_definition(&self) -> thulp_core::ToolDefinition {
65 let params = thulp_core::ToolDefinition::parse_mcp_input_schema(&self.parameters_schema())
66 .unwrap_or_default();
67 thulp_core::ToolDefinition::builder(self.name())
68 .description(self.description())
69 .parameters(params)
70 .build()
71 }
72
73 fn validate_args(&self, args: &Value) -> std::result::Result<(), String> {
76 self.thulp_definition()
77 .validate_args(args)
78 .map_err(|e| e.to_string())
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq)]
86pub enum ToolTier {
87 Core,
89 Standard,
91 Extended,
93}
94
95pub struct ToolRegistry {
100 tools: HashMap<String, Arc<dyn Tool>>,
101 tiers: HashMap<String, ToolTier>,
102 activated: std::sync::Mutex<std::collections::HashSet<String>>,
104 tool_text_cache: HashMap<String, String>,
106}
107
108impl ToolRegistry {
109 pub fn new() -> Self {
111 Self {
112 tools: HashMap::new(),
113 tiers: HashMap::new(),
114 activated: std::sync::Mutex::new(std::collections::HashSet::new()),
115 tool_text_cache: HashMap::new(),
116 }
117 }
118
119 pub fn with_defaults(workspace_root: std::path::PathBuf) -> Self {
125 let mut registry = Self::new();
126 use ToolTier::*;
127
128 registry.register_with_tier(Arc::new(bash::BashTool::new(workspace_root.clone())), Core);
130 registry.register_with_tier(
131 Arc::new(file::ReadFileTool::new(workspace_root.clone())),
132 Core,
133 );
134 registry.register_with_tier(
135 Arc::new(file::WriteFileTool::new(workspace_root.clone())),
136 Core,
137 );
138 registry.register_with_tier(
139 Arc::new(edit::EditFileTool::new(workspace_root.clone())),
140 Core,
141 );
142 registry.register_with_tier(
143 Arc::new(native::AstGrepTool::new(workspace_root.clone())),
144 Core,
145 );
146 registry.register_with_tier(
147 Arc::new(native::GlobSearchTool::new(workspace_root.clone())),
148 Core,
149 );
150 registry.register_with_tier(
151 Arc::new(native::GrepSearchTool::new(workspace_root.clone())),
152 Core,
153 );
154
155 registry.register_with_tier(
157 Arc::new(file::ListDirectoryTool::new(workspace_root.clone())),
158 Standard,
159 );
160 registry.register_with_tier(
161 Arc::new(edit::EditFileLinesTool::new(workspace_root.clone())),
162 Standard,
163 );
164 registry.register_with_tier(
165 Arc::new(edit::InsertAfterTool::new(workspace_root.clone())),
166 Standard,
167 );
168 registry.register_with_tier(
169 Arc::new(edit::AppendFileTool::new(workspace_root.clone())),
170 Standard,
171 );
172 registry.register_with_tier(
173 Arc::new(git::GitStatusTool::new(workspace_root.clone())),
174 Standard,
175 );
176 registry.register_with_tier(
177 Arc::new(git::GitDiffTool::new(workspace_root.clone())),
178 Standard,
179 );
180 registry.register_with_tier(
181 Arc::new(git::GitAddTool::new(workspace_root.clone())),
182 Standard,
183 );
184 registry.register_with_tier(
185 Arc::new(git::GitCommitTool::new(workspace_root.clone())),
186 Standard,
187 );
188 registry.register_with_tier(
189 Arc::new(git::GitLogTool::new(workspace_root.clone())),
190 Standard,
191 );
192 registry.register_with_tier(
193 Arc::new(git::GitBlameTool::new(workspace_root.clone())),
194 Standard,
195 );
196 registry.register_with_tier(
197 Arc::new(git::GitBranchTool::new(workspace_root.clone())),
198 Standard,
199 );
200 registry.register_with_tier(
201 Arc::new(git::GitCheckoutTool::new(workspace_root.clone())),
202 Standard,
203 );
204 registry.register_with_tier(
205 Arc::new(git::GitStashTool::new(workspace_root.clone())),
206 Standard,
207 );
208 registry.register_with_tier(
209 Arc::new(agent::SpawnAgentsTool::new(workspace_root.clone())),
210 Standard,
211 );
212 registry.register_with_tier(
213 Arc::new(agent::SpawnAgentTool::new(workspace_root.clone())),
214 Standard,
215 );
216 registry.register_with_tier(
217 Arc::new(batch::BatchTool::new(workspace_root.clone())),
218 Standard,
219 );
220 registry.register_with_tier(
221 Arc::new(task::TaskTool::new(workspace_root.clone())),
222 Standard,
223 );
224
225 registry.register_with_tier(
227 Arc::new(native::RipgrepTool::new(workspace_root.clone())),
228 Extended,
229 );
230 registry.register_with_tier(
231 Arc::new(native::FdTool::new(workspace_root.clone())),
232 Extended,
233 );
234 registry.register_with_tier(
235 Arc::new(native::SdTool::new(workspace_root.clone())),
236 Extended,
237 );
238 registry.register_with_tier(
239 Arc::new(native::ErdTool::new(workspace_root.clone())),
240 Extended,
241 );
242 registry.register_with_tier(
243 Arc::new(native::MiseTool::new(workspace_root.clone())),
244 Extended,
245 );
246 registry.register_with_tier(
247 Arc::new(native::ZoxideTool::new(workspace_root.clone())),
248 Extended,
249 );
250 registry.register_with_tier(
251 Arc::new(native::LspTool::new(workspace_root.clone())),
252 Extended,
253 );
254
255 #[cfg(feature = "deagle")]
257 {
258 registry.register_with_tier(
259 Arc::new(deagle::DeagleSearchTool::new(workspace_root.clone())),
260 Extended,
261 );
262 registry.register_with_tier(
263 Arc::new(deagle::DeagleKeywordTool::new(workspace_root.clone())),
264 Extended,
265 );
266 registry.register_with_tier(
267 Arc::new(deagle::DeagleSgTool::new(workspace_root.clone())),
268 Extended,
269 );
270 registry.register_with_tier(
271 Arc::new(deagle::DeagleStatsTool::new(workspace_root.clone())),
272 Extended,
273 );
274 registry.register_with_tier(
275 Arc::new(deagle::DeagleMapTool::new(workspace_root)),
276 Extended,
277 );
278 }
279
280 registry
281 }
282
283 pub fn register(&mut self, tool: Arc<dyn Tool>) {
285 self.register_with_tier(tool, ToolTier::Standard);
286 }
287
288 pub fn register_with_tier(&mut self, tool: Arc<dyn Tool>, tier: ToolTier) {
290 let name = tool.name().to_string();
291 let cached_text = format!("{} {}", name, tool.description()).to_lowercase();
292 self.tool_text_cache.insert(name.clone(), cached_text);
293 self.tiers.insert(name.clone(), tier);
294 self.tools.insert(name, tool);
295 }
296
297 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
299 self.tools.get(name)
300 }
301
302 pub fn has_tool(&self, name: &str) -> bool {
304 self.tools.contains_key(name)
305 }
306
307 pub async fn execute(&self, name: &str, args: Value) -> crate::Result<Value> {
309 match self.tools.get(name) {
310 Some(tool) => tool.execute(args).await,
311 None => Err(crate::PawanError::NotFound(format!(
312 "Tool not found: {}",
313 name
314 ))),
315 }
316 }
317
318 pub fn get_definitions(&self) -> Vec<ToolDefinition> {
321 let activated = self.activated.lock().unwrap_or_else(|e| e.into_inner());
322 self.tools
323 .iter()
324 .filter(|(name, _)| {
325 match self
326 .tiers
327 .get(name.as_str())
328 .copied()
329 .unwrap_or(ToolTier::Standard)
330 {
331 ToolTier::Core | ToolTier::Standard => true,
332 ToolTier::Extended => activated.contains(name.as_str()),
333 }
334 })
335 .map(|(_, tool)| tool.thulp_definition())
336 .collect()
337 }
338
339 pub fn select_for_query(&self, query: &str, max_tools: usize) -> Vec<ToolDefinition> {
345 let query_lower = query.to_lowercase();
346 let query_words: Vec<&str> = query_lower.split_whitespace().collect();
347
348 let mut scored: Vec<(i32, String)> = Vec::new();
349
350 for name in self.tools.keys() {
351 let tier = self
352 .tiers
353 .get(name.as_str())
354 .copied()
355 .unwrap_or(ToolTier::Standard);
356
357 if tier == ToolTier::Core {
359 continue;
360 }
361
362 let tool_text = self
364 .tool_text_cache
365 .get(name.as_str())
366 .map(|s| s.as_str())
367 .unwrap_or("");
368 let mut score: i32 = 0;
369
370 for word in &query_words {
371 if word.len() < 3 {
372 continue;
373 } if tool_text.contains(word) {
375 score += 2;
376 }
377 }
378
379 let search_words = [
381 "search",
382 "find",
383 "web",
384 "query",
385 "look",
386 "google",
387 "bing",
388 "wikipedia",
389 ];
390 let git_words = [
391 "git", "commit", "branch", "diff", "status", "log", "stash", "checkout", "blame",
392 ];
393 let file_words = [
394 "file",
395 "read",
396 "write",
397 "edit",
398 "append",
399 "insert",
400 "directory",
401 "list",
402 ];
403 let code_words = [
404 "refactor", "rename", "replace", "ast", "lsp", "symbol", "function", "struct",
405 ];
406 let tool_words = [
407 "install", "mise", "tool", "runtime", "build", "test", "cargo",
408 ];
409
410 for word in &query_words {
411 if search_words.contains(word) && tool_text.contains("search") {
412 score += 3;
413 }
414 if git_words.contains(word) && tool_text.contains("git") {
415 score += 3;
416 }
417 if file_words.contains(word)
418 && (tool_text.contains("file") || tool_text.contains("edit"))
419 {
420 score += 3;
421 }
422 if code_words.contains(word)
423 && (tool_text.contains("ast") || tool_text.contains("lsp"))
424 {
425 score += 3;
426 }
427 if tool_words.contains(word) && tool_text.contains("mise") {
428 score += 3;
429 }
430 }
431
432 if name.starts_with("mcp_") {
434 score += 1;
435 if name.contains("search") || name.contains("web") {
436 let web_words = [
437 "web", "search", "internet", "online", "find", "look up", "google",
438 ];
439 if web_words.iter().any(|w| query_lower.contains(w)) {
440 score += 10; }
442 }
443 }
444
445 let activated = self.activated.lock().unwrap_or_else(|e| e.into_inner());
447 if tier == ToolTier::Extended && activated.contains(name.as_str()) {
448 score += 2;
449 }
450
451 if score > 0 || tier == ToolTier::Standard {
452 scored.push((score, name.clone()));
453 }
454 }
455
456 scored.sort_by_key(|&(score, _)| std::cmp::Reverse(score));
458
459 let mut result: Vec<ToolDefinition> = self
461 .tools
462 .iter()
463 .filter(|(name, _)| {
464 self.tiers
465 .get(name.as_str())
466 .copied()
467 .unwrap_or(ToolTier::Standard)
468 == ToolTier::Core
469 })
470 .map(|(_, tool)| tool.thulp_definition())
471 .collect();
472
473 let remaining_slots = max_tools.saturating_sub(result.len());
474 for (_, name) in scored.into_iter().take(remaining_slots) {
475 if let Some(tool) = self.tools.get(&name) {
476 result.push(tool.thulp_definition());
477 }
478 }
479
480 result
481 }
482
483 pub fn get_all_definitions(&self) -> Vec<ToolDefinition> {
485 self.tools.values().map(|t| t.thulp_definition()).collect()
486 }
487
488 pub fn activate(&self, name: &str) {
490 if self.tools.contains_key(name) {
491 self.activated
492 .lock()
493 .unwrap_or_else(|e| e.into_inner())
494 .insert(name.to_string());
495 }
496 }
497
498 pub fn tool_names(&self) -> Vec<&str> {
500 self.tools.keys().map(|s| s.as_str()).collect()
501 }
502
503 pub fn query_tools(&self, query: &str) -> Vec<thulp_core::ToolDefinition> {
520 let criteria = match thulp_query::parse_query(query) {
521 Ok(c) => c,
522 Err(e) => {
523 tracing::warn!(query = %query, error = %e, "failed to parse tool query");
524 return Vec::new();
525 }
526 };
527
528 self.tools
529 .values()
530 .map(|tool| tool.thulp_definition())
531 .filter(|def| criteria.matches(def))
532 .collect()
533 }
534}
535
536impl Default for ToolRegistry {
537 fn default() -> Self {
538 Self::new()
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545 use std::path::PathBuf;
546
547 #[test]
548 fn test_registry_new_is_empty() {
549 let registry = ToolRegistry::new();
550 assert!(registry.tool_names().is_empty());
551 assert!(!registry.has_tool("bash"));
552 assert!(registry.get("nonexistent").is_none());
553 }
554
555 #[test]
556 fn test_registry_with_defaults_contains_core_tools() {
557 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
558 for name in &[
560 "bash",
561 "read_file",
562 "write_file",
563 "edit_file",
564 "grep_search",
565 "glob_search",
566 ] {
567 assert!(
568 registry.has_tool(name),
569 "default registry missing core tool: {}",
570 name
571 );
572 }
573 assert!(registry.has_tool("git_status"));
575 assert!(registry.has_tool("git_commit"));
576 assert!(registry.has_tool("rg"));
578 assert!(registry.has_tool("fd"));
579 }
580
581 #[test]
582 fn test_registry_get_definitions_hides_extended_until_activated() {
583 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
584 let initial: Vec<String> = registry
585 .get_definitions()
586 .iter()
587 .map(|d| d.name.clone())
588 .collect();
589
590 assert!(
592 !initial.contains(&"rg".to_string()),
593 "rg should be hidden until activated"
594 );
595 assert!(
596 !initial.contains(&"fd".to_string()),
597 "fd should be hidden until activated"
598 );
599 assert!(initial.contains(&"bash".to_string()));
601 assert!(initial.contains(&"read_file".to_string()));
602
603 registry.activate("rg");
605 let after: Vec<String> = registry
606 .get_definitions()
607 .iter()
608 .map(|d| d.name.clone())
609 .collect();
610 assert!(
611 after.contains(&"rg".to_string()),
612 "rg should be visible after activate"
613 );
614 assert!(
615 after.len() > initial.len(),
616 "activation should grow visible set"
617 );
618 }
619
620 #[test]
621 fn test_registry_get_all_definitions_returns_everything() {
622 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
623 let all = registry.get_all_definitions();
624 let visible = registry.get_definitions();
625 assert!(
627 all.len() > visible.len(),
628 "get_all_definitions ({}) should include hidden extended tools beyond get_definitions ({})",
629 all.len(),
630 visible.len()
631 );
632 let all_names: Vec<String> = all.iter().map(|d| d.name.clone()).collect();
634 assert!(all_names.contains(&"rg".to_string()));
635 }
636
637 #[test]
638 fn test_registry_query_tools_filters_by_dsl() {
639 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
640 let bash_match = registry.query_tools("name:bash");
642 assert!(
643 !bash_match.is_empty(),
644 "query_tools('name:bash') should match the bash tool"
645 );
646 let names: Vec<String> = bash_match.iter().map(|d| d.name.clone()).collect();
647 assert!(names.contains(&"bash".to_string()));
648
649 let no_match = registry.query_tools("name:definitely_not_a_tool_xyz");
651 assert!(
652 no_match.is_empty(),
653 "query_tools for nonexistent name should return empty, got {:?}",
654 no_match.iter().map(|d| &d.name).collect::<Vec<_>>()
655 );
656 }
657
658 struct MockTool {
660 name: String,
661 description: String,
662 return_value: Value,
663 }
664
665 impl MockTool {
666 fn new(name: &str, description: &str, return_value: Value) -> Self {
667 Self {
668 name: name.to_string(),
669 description: description.to_string(),
670 return_value,
671 }
672 }
673 }
674
675 #[async_trait]
676 impl Tool for MockTool {
677 fn name(&self) -> &str {
678 &self.name
679 }
680 fn description(&self) -> &str {
681 &self.description
682 }
683 fn parameters_schema(&self) -> Value {
684 serde_json::json!({ "type": "object", "properties": {} })
685 }
686 async fn execute(&self, _args: Value) -> crate::Result<Value> {
687 Ok(self.return_value.clone())
688 }
689 }
690
691 #[test]
692 fn test_register_defaults_to_standard_tier() {
693 let mut registry = ToolRegistry::new();
694 registry.register(Arc::new(MockTool::new(
695 "mock_std",
696 "a test mock",
697 Value::Null,
698 )));
699 let visible: Vec<String> = registry
701 .get_definitions()
702 .iter()
703 .map(|d| d.name.clone())
704 .collect();
705 assert!(
706 visible.contains(&"mock_std".to_string()),
707 "register() should default to Standard tier (visible without activation), got {:?}",
708 visible
709 );
710 }
711
712 #[test]
713 fn test_register_with_tier_overwrites_same_name() {
714 let mut registry = ToolRegistry::new();
715 registry.register_with_tier(
716 Arc::new(MockTool::new("dup", "first registration", Value::Null)),
717 ToolTier::Standard,
718 );
719 registry.register_with_tier(
720 Arc::new(MockTool::new("dup", "second registration", Value::Null)),
721 ToolTier::Core,
722 );
723
724 let names = registry.tool_names();
727 assert_eq!(
728 names.iter().filter(|n| **n == "dup").count(),
729 1,
730 "register_with_tier of an existing name must replace, not duplicate"
731 );
732 let def = registry
733 .get("dup")
734 .expect("dup should exist after overwrite");
735 assert_eq!(def.description(), "second registration");
736 let visible: Vec<String> = registry
738 .get_definitions()
739 .iter()
740 .map(|d| d.name.clone())
741 .collect();
742 assert!(visible.contains(&"dup".to_string()));
743 }
744
745 #[tokio::test]
746 async fn test_execute_dispatches_to_registered_tool() {
747 let mut registry = ToolRegistry::new();
748 registry.register(Arc::new(MockTool::new(
749 "echo",
750 "returns a fixed value",
751 serde_json::json!({ "answer": 42 }),
752 )));
753
754 let out = registry
755 .execute("echo", Value::Null)
756 .await
757 .expect("execute on a registered tool should succeed");
758 assert_eq!(out, serde_json::json!({ "answer": 42 }));
759 }
760
761 #[tokio::test]
762 async fn test_execute_unknown_tool_returns_not_found() {
763 let registry = ToolRegistry::new();
764 let err = registry
765 .execute("nonexistent_tool", Value::Null)
766 .await
767 .expect_err("execute on missing tool should fail");
768 match err {
769 crate::PawanError::NotFound(msg) => {
770 assert!(
771 msg.contains("nonexistent_tool"),
772 "error should name the missing tool, got: {}",
773 msg
774 );
775 }
776 other => panic!("expected PawanError::NotFound, got: {:?}", other),
777 }
778 }
779
780 #[test]
781 fn test_select_for_query_always_includes_core_tools() {
782 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
783 let selected = registry.select_for_query("xyzzy plover", 5);
786 let names: Vec<String> = selected.iter().map(|d| d.name.clone()).collect();
787 for core in &[
788 "bash",
789 "read_file",
790 "write_file",
791 "edit_file",
792 "grep_search",
793 "glob_search",
794 "ast_grep",
795 ] {
796 assert!(
797 names.contains(&core.to_string()),
798 "select_for_query must include core tool {} regardless of query, got {:?}",
799 core,
800 names
801 );
802 }
803 }
804
805 #[test]
806 fn test_select_for_query_caps_at_max_tools_when_possible() {
807 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
808 let selected = registry.select_for_query("git commit my changes", 10);
812 assert!(
813 selected.len() <= 10,
814 "select_for_query(max=10) returned {} tools, must not exceed cap",
815 selected.len()
816 );
817 let names: Vec<String> = selected.iter().map(|d| d.name.clone()).collect();
819 assert!(
820 names.iter().any(|n| n.starts_with("git_")),
821 "git query should pull in at least one git_ tool, got {:?}",
822 names
823 );
824 }
825
826 #[test]
827 fn test_activate_no_op_for_unknown_tool_does_not_panic() {
828 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
829 registry.activate("not_a_real_tool_at_all");
832 let visible: Vec<String> = registry
833 .get_definitions()
834 .iter()
835 .map(|d| d.name.clone())
836 .collect();
837 assert!(
838 !visible.contains(&"not_a_real_tool_at_all".to_string()),
839 "activate of unknown tool must not make it visible"
840 );
841 }
842
843 #[test]
844 fn test_tool_names_lists_every_registered_tool() {
845 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
846 let names = registry.tool_names();
847 assert!(
851 names.len() >= 30,
852 "default registry should expose >=30 tools via tool_names(), got {}",
853 names.len()
854 );
855 for name in &names {
857 assert!(registry.has_tool(name));
858 assert!(registry.get(name).is_some());
859 }
860 }
861
862 #[test]
863 fn test_default_impl_returns_empty_registry() {
864 let registry = ToolRegistry::default();
865 assert!(registry.tool_names().is_empty());
866 assert!(registry.get_definitions().is_empty());
867 }
868}