1use crate::agent_tool::Tool;
11use crate::registry::ToolRegistry;
12
13pub struct ToolFilter {
15 pub max_visible: usize,
17}
18
19impl ToolFilter {
20 pub fn new(max_visible: usize) -> Self {
21 Self { max_visible }
22 }
23
24 pub fn select<'a>(&self, query: &str, registry: &'a ToolRegistry) -> Vec<&'a dyn Tool> {
26 let query_lower = query.to_lowercase();
27 let query_words: Vec<&str> = query_lower.split_whitespace().collect();
28
29 let mut system_tools = Vec::new();
30 let mut scored: Vec<(&dyn Tool, f64)> = Vec::new();
31
32 for tool in registry.list() {
33 if tool.is_system() {
34 system_tools.push(tool);
35 continue;
36 }
37
38 let score = score_tool(tool, &query_lower, &query_words);
39 scored.push((tool, score));
40 }
41
42 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
44
45 let mut result = system_tools;
47 for (tool, _score) in scored.into_iter().take(self.max_visible) {
48 result.push(tool);
49 }
50
51 result
52 }
53}
54
55impl Default for ToolFilter {
56 fn default() -> Self {
57 Self { max_visible: 10 }
58 }
59}
60
61fn score_tool(tool: &dyn Tool, query_lower: &str, query_words: &[&str]) -> f64 {
63 let name = tool.name().to_lowercase();
64 let desc = tool.description().to_lowercase();
65 let combined = format!("{} {}", name, desc);
66 let tool_words: Vec<&str> = combined.split_whitespace().collect();
67
68 let mut score = 0.0;
69
70 if query_lower.contains(&name) {
72 score += 5.0;
73 }
74
75 for qw in query_words {
77 for tw in &tool_words {
78 if qw == tw {
79 score += 2.0;
80 } else {
81 let sim = strsim::normalized_levenshtein(qw, tw);
82 if sim > 0.7 {
83 score += sim;
84 }
85 }
86 }
87 }
88
89 for qw in query_words {
91 if name.contains(qw) {
92 score += 1.5;
93 }
94 }
95
96 score
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use crate::agent_tool::{ToolError, ToolOutput};
103 use crate::context::AgentContext;
104 use serde_json::Value;
105
106 struct TestTool {
107 tool_name: &'static str,
108 desc: &'static str,
109 system: bool,
110 }
111
112 #[async_trait::async_trait]
113 impl Tool for TestTool {
114 fn name(&self) -> &str {
115 self.tool_name
116 }
117 fn description(&self) -> &str {
118 self.desc
119 }
120 fn is_system(&self) -> bool {
121 self.system
122 }
123 fn parameters_schema(&self) -> Value {
124 serde_json::json!({"type": "object"})
125 }
126 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
127 Ok(ToolOutput::text("ok"))
128 }
129 }
130
131 #[test]
132 fn system_tools_always_included() {
133 let reg = ToolRegistry::new()
134 .register(TestTool {
135 tool_name: "finish_task",
136 desc: "finish",
137 system: true,
138 })
139 .register(TestTool {
140 tool_name: "read_file",
141 desc: "read a file from disk",
142 system: false,
143 })
144 .register(TestTool {
145 tool_name: "bash",
146 desc: "run shell command",
147 system: false,
148 });
149
150 let filter = ToolFilter::new(1);
151 let selected = filter.select("read the file", ®);
152
153 assert!(selected.iter().any(|t| t.name() == "finish_task"));
155 let non_sys: Vec<_> = selected.iter().filter(|t| !t.is_system()).collect();
157 assert_eq!(non_sys.len(), 1);
158 }
159
160 #[test]
161 fn relevant_tool_ranked_higher() {
162 let reg = ToolRegistry::new()
163 .register(TestTool {
164 tool_name: "read_file",
165 desc: "read a file from disk",
166 system: false,
167 })
168 .register(TestTool {
169 tool_name: "bash",
170 desc: "run shell command",
171 system: false,
172 })
173 .register(TestTool {
174 tool_name: "write_file",
175 desc: "write content to a file",
176 system: false,
177 });
178
179 let filter = ToolFilter::new(2);
180 let selected = filter.select("read the file main.rs", ®);
181 assert_eq!(selected[0].name(), "read_file");
183 }
184
185 #[test]
186 fn empty_query_returns_all_up_to_max() {
187 let reg = ToolRegistry::new()
188 .register(TestTool {
189 tool_name: "a",
190 desc: "tool a",
191 system: false,
192 })
193 .register(TestTool {
194 tool_name: "b",
195 desc: "tool b",
196 system: false,
197 })
198 .register(TestTool {
199 tool_name: "c",
200 desc: "tool c",
201 system: false,
202 });
203
204 let filter = ToolFilter::new(2);
205 let selected = filter.select("", ®);
206 assert_eq!(selected.len(), 2);
207 }
208}