1pub mod agent;
10pub mod bash;
11pub mod deagle;
12pub mod edit;
13#[cfg(test)]
14mod edit_tests;
15pub mod file;
16pub mod git;
17pub mod native;
18pub mod search;
19pub mod ares_bridge;
20
21use async_trait::async_trait;
22use serde_json::Value;
23use std::collections::HashMap;
24use std::sync::Arc;
25
26pub use thulp_core::ToolDefinition;
33
34#[async_trait]
36pub trait Tool: Send + Sync {
37 fn name(&self) -> &str;
39
40 fn description(&self) -> &str;
42
43 fn mutating(&self) -> bool {
48 false }
50
51 fn parameters_schema(&self) -> Value;
53
54 async fn execute(&self, args: Value) -> crate::Result<Value>;
56 fn thulp_definition(&self) -> thulp_core::ToolDefinition {
59 let params = thulp_core::ToolDefinition::parse_mcp_input_schema(&self.parameters_schema())
60 .unwrap_or_default();
61 thulp_core::ToolDefinition::builder(self.name())
62 .description(self.description())
63 .parameters(params)
64 .build()
65 }
66
67 fn validate_args(&self, args: &Value) -> std::result::Result<(), String> {
70 self.thulp_definition()
71 .validate_args(args)
72 .map_err(|e| e.to_string())
73 }
74
75 fn to_definition(&self) -> ToolDefinition {
81 self.thulp_definition()
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq)]
89pub enum ToolTier {
90 Core,
92 Standard,
94 Extended,
96}
97
98pub struct ToolRegistry {
103 tools: HashMap<String, Arc<dyn Tool>>,
104 tiers: HashMap<String, ToolTier>,
105 activated: std::sync::Mutex<std::collections::HashSet<String>>,
107 tool_text_cache: HashMap<String, String>,
109}
110
111impl ToolRegistry {
112 pub fn new() -> Self {
114 Self {
115 tools: HashMap::new(),
116 tiers: HashMap::new(),
117 activated: std::sync::Mutex::new(std::collections::HashSet::new()),
118 tool_text_cache: HashMap::new(),
119 }
120 }
121
122 pub fn with_defaults(workspace_root: std::path::PathBuf) -> Self {
128 let mut registry = Self::new();
129 use ToolTier::*;
130
131 registry.register_with_tier(Arc::new(bash::BashTool::new(workspace_root.clone())), Core);
133 registry.register_with_tier(Arc::new(file::ReadFileTool::new(workspace_root.clone())), Core);
134 registry.register_with_tier(Arc::new(file::WriteFileTool::new(workspace_root.clone())), Core);
135 registry.register_with_tier(Arc::new(edit::EditFileTool::new(workspace_root.clone())), Core);
136 registry.register_with_tier(Arc::new(native::AstGrepTool::new(workspace_root.clone())), Core);
137 registry.register_with_tier(Arc::new(native::GlobSearchTool::new(workspace_root.clone())), Core);
138 registry.register_with_tier(Arc::new(native::GrepSearchTool::new(workspace_root.clone())), Core);
139
140 registry.register_with_tier(Arc::new(file::ListDirectoryTool::new(workspace_root.clone())), Standard);
142 registry.register_with_tier(Arc::new(edit::EditFileLinesTool::new(workspace_root.clone())), Standard);
143 registry.register_with_tier(Arc::new(edit::InsertAfterTool::new(workspace_root.clone())), Standard);
144 registry.register_with_tier(Arc::new(edit::AppendFileTool::new(workspace_root.clone())), Standard);
145 registry.register_with_tier(Arc::new(git::GitStatusTool::new(workspace_root.clone())), Standard);
146 registry.register_with_tier(Arc::new(git::GitDiffTool::new(workspace_root.clone())), Standard);
147 registry.register_with_tier(Arc::new(git::GitAddTool::new(workspace_root.clone())), Standard);
148 registry.register_with_tier(Arc::new(git::GitCommitTool::new(workspace_root.clone())), Standard);
149 registry.register_with_tier(Arc::new(git::GitLogTool::new(workspace_root.clone())), Standard);
150 registry.register_with_tier(Arc::new(git::GitBlameTool::new(workspace_root.clone())), Standard);
151 registry.register_with_tier(Arc::new(git::GitBranchTool::new(workspace_root.clone())), Standard);
152 registry.register_with_tier(Arc::new(git::GitCheckoutTool::new(workspace_root.clone())), Standard);
153 registry.register_with_tier(Arc::new(git::GitStashTool::new(workspace_root.clone())), Standard);
154 registry.register_with_tier(Arc::new(agent::SpawnAgentsTool::new(workspace_root.clone())), Standard);
155 registry.register_with_tier(Arc::new(agent::SpawnAgentTool::new(workspace_root.clone())), Standard);
156
157 registry.register_with_tier(Arc::new(native::RipgrepTool::new(workspace_root.clone())), Extended);
159 registry.register_with_tier(Arc::new(native::FdTool::new(workspace_root.clone())), Extended);
160 registry.register_with_tier(Arc::new(native::SdTool::new(workspace_root.clone())), Extended);
161 registry.register_with_tier(Arc::new(native::ErdTool::new(workspace_root.clone())), Extended);
162 registry.register_with_tier(Arc::new(native::MiseTool::new(workspace_root.clone())), Extended);
163 registry.register_with_tier(Arc::new(native::ZoxideTool::new(workspace_root.clone())), Extended);
164 registry.register_with_tier(Arc::new(native::LspTool::new(workspace_root.clone())), Extended);
165
166 registry.register_with_tier(Arc::new(deagle::DeagleSearchTool::new(workspace_root.clone())), Extended);
168 registry.register_with_tier(Arc::new(deagle::DeagleKeywordTool::new(workspace_root.clone())), Extended);
169 registry.register_with_tier(Arc::new(deagle::DeagleSgTool::new(workspace_root.clone())), Extended);
170 registry.register_with_tier(Arc::new(deagle::DeagleStatsTool::new(workspace_root.clone())), Extended);
171 registry.register_with_tier(Arc::new(deagle::DeagleMapTool::new(workspace_root)), Extended);
172
173 registry
174 }
175
176 pub fn register(&mut self, tool: Arc<dyn Tool>) {
178 self.register_with_tier(tool, ToolTier::Standard);
179 }
180
181 pub fn register_with_tier(&mut self, tool: Arc<dyn Tool>, tier: ToolTier) {
183 let name = tool.name().to_string();
184 let cached_text = format!("{} {}", name, tool.description()).to_lowercase();
185 self.tool_text_cache.insert(name.clone(), cached_text);
186 self.tiers.insert(name.clone(), tier);
187 self.tools.insert(name, tool);
188 }
189
190 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
192 self.tools.get(name)
193 }
194
195 pub fn has_tool(&self, name: &str) -> bool {
197 self.tools.contains_key(name)
198 }
199
200 pub async fn execute(&self, name: &str, args: Value) -> crate::Result<Value> {
202 match self.tools.get(name) {
203 Some(tool) => tool.execute(args).await,
204 None => Err(crate::PawanError::NotFound(format!(
205 "Tool not found: {}",
206 name
207 ))),
208 }
209 }
210
211 pub fn get_definitions(&self) -> Vec<ToolDefinition> {
214 let activated = self.activated.lock().unwrap_or_else(|e| e.into_inner());
215 self.tools.iter()
216 .filter(|(name, _)| {
217 match self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard) {
218 ToolTier::Core | ToolTier::Standard => true,
219 ToolTier::Extended => activated.contains(name.as_str()),
220 }
221 })
222 .map(|(_, tool)| tool.to_definition())
223 .collect()
224 }
225
226 pub fn select_for_query(&self, query: &str, max_tools: usize) -> Vec<ToolDefinition> {
232 let query_lower = query.to_lowercase();
233 let query_words: Vec<&str> = query_lower.split_whitespace().collect();
234
235 let mut scored: Vec<(i32, String)> = Vec::new();
236
237 for name in self.tools.keys() {
238 let tier = self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard);
239
240 if tier == ToolTier::Core { continue; }
242
243 let tool_text = self.tool_text_cache.get(name.as_str())
245 .map(|s| s.as_str())
246 .unwrap_or("");
247 let mut score: i32 = 0;
248
249 for word in &query_words {
250 if word.len() < 3 { continue; } if tool_text.contains(word) { score += 2; }
252 }
253
254 let search_words = ["search", "find", "web", "query", "look", "google", "bing", "wikipedia"];
256 let git_words = ["git", "commit", "branch", "diff", "status", "log", "stash", "checkout", "blame"];
257 let file_words = ["file", "read", "write", "edit", "append", "insert", "directory", "list"];
258 let code_words = ["refactor", "rename", "replace", "ast", "lsp", "symbol", "function", "struct"];
259 let tool_words = ["install", "mise", "tool", "runtime", "build", "test", "cargo"];
260
261 for word in &query_words {
262 if search_words.contains(word) && tool_text.contains("search") { score += 3; }
263 if git_words.contains(word) && tool_text.contains("git") { score += 3; }
264 if file_words.contains(word) && (tool_text.contains("file") || tool_text.contains("edit")) { score += 3; }
265 if code_words.contains(word) && (tool_text.contains("ast") || tool_text.contains("lsp")) { score += 3; }
266 if tool_words.contains(word) && tool_text.contains("mise") { score += 3; }
267 }
268
269 if name.starts_with("mcp_") {
271 score += 1;
272 if name.contains("search") || name.contains("web") {
273 let web_words = ["web", "search", "internet", "online", "find", "look up", "google"];
274 if web_words.iter().any(|w| query_lower.contains(w)) {
275 score += 10; }
277 }
278 }
279
280 let activated = self.activated.lock().unwrap_or_else(|e| e.into_inner());
282 if tier == ToolTier::Extended && activated.contains(name.as_str()) { score += 2; }
283
284 if score > 0 || tier == ToolTier::Standard {
285 scored.push((score, name.clone()));
286 }
287 }
288
289 scored.sort_by(|a, b| b.0.cmp(&a.0));
291
292 let mut result: Vec<ToolDefinition> = self.tools.iter()
294 .filter(|(name, _)| {
295 self.tiers.get(name.as_str()).copied().unwrap_or(ToolTier::Standard) == ToolTier::Core
296 })
297 .map(|(_, tool)| tool.to_definition())
298 .collect();
299
300 let remaining_slots = max_tools.saturating_sub(result.len());
301 for (_, name) in scored.into_iter().take(remaining_slots) {
302 if let Some(tool) = self.tools.get(&name) {
303 result.push(tool.to_definition());
304 }
305 }
306
307 result
308 }
309
310 pub fn get_all_definitions(&self) -> Vec<ToolDefinition> {
312 self.tools.values().map(|t| t.to_definition()).collect()
313 }
314
315 pub fn activate(&self, name: &str) {
317 if self.tools.contains_key(name) {
318 self.activated.lock().unwrap_or_else(|e| e.into_inner()).insert(name.to_string());
319 }
320 }
321
322 pub fn tool_names(&self) -> Vec<&str> {
324 self.tools.keys().map(|s| s.as_str()).collect()
325 }
326
327 pub fn query_tools(&self, query: &str) -> Vec<thulp_core::ToolDefinition> {
344 let criteria = match thulp_query::parse_query(query) {
345 Ok(c) => c,
346 Err(e) => {
347 tracing::warn!(query = %query, error = %e, "failed to parse tool query");
348 return Vec::new();
349 }
350 };
351
352 self.tools
353 .values()
354 .map(|tool| tool.thulp_definition())
355 .filter(|def| criteria.matches(def))
356 .collect()
357 }
358}
359
360impl Default for ToolRegistry {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use std::path::PathBuf;
370
371 #[test]
372 fn test_registry_new_is_empty() {
373 let registry = ToolRegistry::new();
374 assert!(registry.tool_names().is_empty());
375 assert!(!registry.has_tool("bash"));
376 assert!(registry.get("nonexistent").is_none());
377 }
378
379 #[test]
380 fn test_registry_with_defaults_contains_core_tools() {
381 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
382 for name in &["bash", "read_file", "write_file", "edit_file", "grep_search", "glob_search"] {
384 assert!(
385 registry.has_tool(name),
386 "default registry missing core tool: {}",
387 name
388 );
389 }
390 assert!(registry.has_tool("git_status"));
392 assert!(registry.has_tool("git_commit"));
393 assert!(registry.has_tool("rg"));
395 assert!(registry.has_tool("fd"));
396 }
397
398 #[test]
399 fn test_registry_get_definitions_hides_extended_until_activated() {
400 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
401 let initial: Vec<String> = registry
402 .get_definitions()
403 .iter()
404 .map(|d| d.name.clone())
405 .collect();
406
407 assert!(!initial.contains(&"rg".to_string()), "rg should be hidden until activated");
409 assert!(!initial.contains(&"fd".to_string()), "fd should be hidden until activated");
410 assert!(initial.contains(&"bash".to_string()));
412 assert!(initial.contains(&"read_file".to_string()));
413
414 registry.activate("rg");
416 let after: Vec<String> = registry
417 .get_definitions()
418 .iter()
419 .map(|d| d.name.clone())
420 .collect();
421 assert!(after.contains(&"rg".to_string()), "rg should be visible after activate");
422 assert!(after.len() > initial.len(), "activation should grow visible set");
423 }
424
425 #[test]
426 fn test_registry_get_all_definitions_returns_everything() {
427 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
428 let all = registry.get_all_definitions();
429 let visible = registry.get_definitions();
430 assert!(
432 all.len() > visible.len(),
433 "get_all_definitions ({}) should include hidden extended tools beyond get_definitions ({})",
434 all.len(),
435 visible.len()
436 );
437 let all_names: Vec<String> = all.iter().map(|d| d.name.clone()).collect();
439 assert!(all_names.contains(&"rg".to_string()));
440 }
441
442 #[test]
443 fn test_registry_query_tools_filters_by_dsl() {
444 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
445 let bash_match = registry.query_tools("name:bash");
447 assert!(
448 !bash_match.is_empty(),
449 "query_tools('name:bash') should match the bash tool"
450 );
451 let names: Vec<String> = bash_match.iter().map(|d| d.name.clone()).collect();
452 assert!(names.contains(&"bash".to_string()));
453
454 let no_match = registry.query_tools("name:definitely_not_a_tool_xyz");
456 assert!(
457 no_match.is_empty(),
458 "query_tools for nonexistent name should return empty, got {:?}",
459 no_match.iter().map(|d| &d.name).collect::<Vec<_>>()
460 );
461 }
462
463 struct MockTool {
465 name: String,
466 description: String,
467 return_value: Value,
468 }
469
470 impl MockTool {
471 fn new(name: &str, description: &str, return_value: Value) -> Self {
472 Self {
473 name: name.to_string(),
474 description: description.to_string(),
475 return_value,
476 }
477 }
478 }
479
480 #[async_trait]
481 impl Tool for MockTool {
482 fn name(&self) -> &str {
483 &self.name
484 }
485 fn description(&self) -> &str {
486 &self.description
487 }
488 fn parameters_schema(&self) -> Value {
489 serde_json::json!({ "type": "object", "properties": {} })
490 }
491 async fn execute(&self, _args: Value) -> crate::Result<Value> {
492 Ok(self.return_value.clone())
493 }
494 }
495
496 #[test]
497 fn test_register_defaults_to_standard_tier() {
498 let mut registry = ToolRegistry::new();
499 registry.register(Arc::new(MockTool::new(
500 "mock_std",
501 "a test mock",
502 Value::Null,
503 )));
504 let visible: Vec<String> = registry
506 .get_definitions()
507 .iter()
508 .map(|d| d.name.clone())
509 .collect();
510 assert!(
511 visible.contains(&"mock_std".to_string()),
512 "register() should default to Standard tier (visible without activation), got {:?}",
513 visible
514 );
515 }
516
517 #[test]
518 fn test_register_with_tier_overwrites_same_name() {
519 let mut registry = ToolRegistry::new();
520 registry.register_with_tier(
521 Arc::new(MockTool::new("dup", "first registration", Value::Null)),
522 ToolTier::Standard,
523 );
524 registry.register_with_tier(
525 Arc::new(MockTool::new("dup", "second registration", Value::Null)),
526 ToolTier::Core,
527 );
528
529 let names = registry.tool_names();
532 assert_eq!(
533 names.iter().filter(|n| **n == "dup").count(),
534 1,
535 "register_with_tier of an existing name must replace, not duplicate"
536 );
537 let def = registry.get("dup").expect("dup should exist after overwrite");
538 assert_eq!(def.description(), "second registration");
539 let visible: Vec<String> = registry
541 .get_definitions()
542 .iter()
543 .map(|d| d.name.clone())
544 .collect();
545 assert!(visible.contains(&"dup".to_string()));
546 }
547
548 #[tokio::test]
549 async fn test_execute_dispatches_to_registered_tool() {
550 let mut registry = ToolRegistry::new();
551 registry.register(Arc::new(MockTool::new(
552 "echo",
553 "returns a fixed value",
554 serde_json::json!({ "answer": 42 }),
555 )));
556
557 let out = registry
558 .execute("echo", Value::Null)
559 .await
560 .expect("execute on a registered tool should succeed");
561 assert_eq!(out, serde_json::json!({ "answer": 42 }));
562 }
563
564 #[tokio::test]
565 async fn test_execute_unknown_tool_returns_not_found() {
566 let registry = ToolRegistry::new();
567 let err = registry
568 .execute("nonexistent_tool", Value::Null)
569 .await
570 .expect_err("execute on missing tool should fail");
571 match err {
572 crate::PawanError::NotFound(msg) => {
573 assert!(msg.contains("nonexistent_tool"), "error should name the missing tool, got: {}", msg);
574 }
575 other => panic!("expected PawanError::NotFound, got: {:?}", other),
576 }
577 }
578
579 #[test]
580 fn test_select_for_query_always_includes_core_tools() {
581 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
582 let selected = registry.select_for_query("xyzzy plover", 5);
585 let names: Vec<String> = selected.iter().map(|d| d.name.clone()).collect();
586 for core in &["bash", "read_file", "write_file", "edit_file", "grep_search", "glob_search", "ast_grep"] {
587 assert!(
588 names.contains(&core.to_string()),
589 "select_for_query must include core tool {} regardless of query, got {:?}",
590 core,
591 names
592 );
593 }
594 }
595
596 #[test]
597 fn test_select_for_query_caps_at_max_tools_when_possible() {
598 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
599 let selected = registry.select_for_query("git commit my changes", 10);
603 assert!(
604 selected.len() <= 10,
605 "select_for_query(max=10) returned {} tools, must not exceed cap",
606 selected.len()
607 );
608 let names: Vec<String> = selected.iter().map(|d| d.name.clone()).collect();
610 assert!(
611 names.iter().any(|n| n.starts_with("git_")),
612 "git query should pull in at least one git_ tool, got {:?}",
613 names
614 );
615 }
616
617 #[test]
618 fn test_activate_no_op_for_unknown_tool_does_not_panic() {
619 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
620 registry.activate("not_a_real_tool_at_all");
623 let visible: Vec<String> = registry
624 .get_definitions()
625 .iter()
626 .map(|d| d.name.clone())
627 .collect();
628 assert!(
629 !visible.contains(&"not_a_real_tool_at_all".to_string()),
630 "activate of unknown tool must not make it visible"
631 );
632 }
633
634 #[test]
635 fn test_tool_names_lists_every_registered_tool() {
636 let registry = ToolRegistry::with_defaults(PathBuf::from("/tmp/test"));
637 let names = registry.tool_names();
638 assert!(
642 names.len() >= 30,
643 "default registry should expose >=30 tools via tool_names(), got {}",
644 names.len()
645 );
646 for name in &names {
648 assert!(registry.has_tool(name));
649 assert!(registry.get(name).is_some());
650 }
651 }
652
653 #[test]
654 fn test_default_impl_returns_empty_registry() {
655 let registry = ToolRegistry::default();
656 assert!(registry.tool_names().is_empty());
657 assert!(registry.get_definitions().is_empty());
658 }
659}