1use std::{
14 collections::HashMap,
15 sync::{Arc, RwLock},
16};
17
18use crate::types::ToolDescriptor;
19
20pub trait ToolCapabilityIndex: Send + Sync {
22 fn upsert(&self, server: &str, tools: Vec<ToolDescriptor>);
25
26 fn remove(&self, server: &str) -> usize;
29
30 fn find(&self, verb_ns: &str, verb_action: &str) -> Vec<ToolDescriptor>;
33
34 fn snapshot(&self) -> Vec<ToolDescriptor>;
36
37 fn top_k(&self, query: &str, k: usize) -> Vec<ToolDescriptor> {
48 score_top_k(self.snapshot(), query, k)
49 }
50}
51
52pub fn score_top_k(mut tools: Vec<ToolDescriptor>, query: &str, k: usize) -> Vec<ToolDescriptor> {
61 if k == 0 {
62 return Vec::new();
63 }
64 let terms = tokenize(query);
65 tools.sort_by(|a, b| {
66 let sa = score_tool(a, &terms);
67 let sb = score_tool(b, &terms);
68 sb.cmp(&sa).then_with(|| a.name.cmp(&b.name))
69 });
70 tools.truncate(k);
71 tools
72}
73
74fn tokenize(text: &str) -> Vec<String> {
76 text.split(|c: char| !c.is_alphanumeric())
77 .filter(|t| !t.is_empty())
78 .map(|t| t.to_lowercase())
79 .collect()
80}
81
82fn score_tool(tool: &ToolDescriptor, terms: &[String]) -> usize {
85 if terms.is_empty() {
86 return 0;
87 }
88 let mut haystack = tool.name.to_lowercase();
89 if let Some(desc) = &tool.description {
90 haystack.push(' ');
91 haystack.push_str(&desc.to_lowercase());
92 }
93 terms
94 .iter()
95 .filter(|t| haystack.contains(t.as_str()))
96 .count()
97}
98
99#[derive(Default)]
103pub struct InMemoryToolCapabilityIndex {
104 by_server: RwLock<HashMap<String, Vec<ToolDescriptor>>>,
105}
106
107impl InMemoryToolCapabilityIndex {
108 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn shared() -> Arc<dyn ToolCapabilityIndex> {
113 Arc::new(Self::new())
114 }
115}
116
117impl ToolCapabilityIndex for InMemoryToolCapabilityIndex {
118 fn upsert(&self, server: &str, tools: Vec<ToolDescriptor>) {
119 let mut guard = self.by_server.write().expect("capability index poisoned");
120 guard.insert(server.to_string(), tools);
121 }
122
123 fn remove(&self, server: &str) -> usize {
124 let mut guard = self.by_server.write().expect("capability index poisoned");
125 guard.remove(server).map(|v| v.len()).unwrap_or(0)
126 }
127
128 fn find(&self, verb_ns: &str, verb_action: &str) -> Vec<ToolDescriptor> {
129 let guard = self.by_server.read().expect("capability index poisoned");
130 guard
131 .values()
132 .flat_map(|tools| tools.iter())
133 .filter(|t| {
134 let (ns, action) = parse_verb(&t.name);
135 ns == verb_ns && (verb_action == "*" || action == verb_action)
136 })
137 .cloned()
138 .collect()
139 }
140
141 fn snapshot(&self) -> Vec<ToolDescriptor> {
142 let guard = self.by_server.read().expect("capability index poisoned");
143 guard.values().flat_map(|t| t.iter().cloned()).collect()
144 }
145}
146
147fn parse_verb(name: &str) -> (&str, &str) {
150 match name.split_once('.') {
151 Some((ns, action)) => (ns, action),
152 None => ("", name),
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use serde_json::json;
160
161 fn td(server: &str, name: &str) -> ToolDescriptor {
162 ToolDescriptor {
163 server: server.into(),
164 name: name.into(),
165 description: None,
166 input_schema: json!({"type": "object"}),
167 }
168 }
169
170 fn td_desc(server: &str, name: &str, description: &str) -> ToolDescriptor {
171 ToolDescriptor {
172 server: server.into(),
173 name: name.into(),
174 description: Some(description.into()),
175 input_schema: json!({"type": "object"}),
176 }
177 }
178
179 #[test]
180 fn upsert_then_find_by_namespace_and_action() {
181 let idx = InMemoryToolCapabilityIndex::new();
182 idx.upsert(
183 "fs",
184 vec![td("fs", "fs.read_text_file"), td("fs", "fs.write_file")],
185 );
186 let hits = idx.find("fs", "read_text_file");
187 assert_eq!(hits.len(), 1);
188 assert_eq!(hits[0].name, "fs.read_text_file");
189 }
190
191 #[test]
192 fn wildcard_action_returns_whole_namespace() {
193 let idx = InMemoryToolCapabilityIndex::new();
194 idx.upsert(
195 "fs",
196 vec![td("fs", "fs.read_text_file"), td("fs", "fs.write_file")],
197 );
198 idx.upsert("git", vec![td("git", "git.commit")]);
199 let mut hits = idx.find("fs", "*");
200 hits.sort_by(|a, b| a.name.cmp(&b.name));
201 assert_eq!(hits.len(), 2);
202 assert_eq!(hits[0].name, "fs.read_text_file");
203 assert_eq!(hits[1].name, "fs.write_file");
204 }
205
206 #[test]
207 fn upsert_overwrites_previous_tools_for_server() {
208 let idx = InMemoryToolCapabilityIndex::new();
209 idx.upsert("fs", vec![td("fs", "fs.read_text_file")]);
210 idx.upsert("fs", vec![td("fs", "fs.write_file")]);
211 assert!(idx.find("fs", "read_text_file").is_empty());
212 assert_eq!(idx.find("fs", "write_file").len(), 1);
213 }
214
215 #[test]
216 fn remove_drops_servers_tools() {
217 let idx = InMemoryToolCapabilityIndex::new();
218 idx.upsert(
219 "fs",
220 vec![td("fs", "fs.read_text_file"), td("fs", "fs.write_file")],
221 );
222 assert_eq!(idx.remove("fs"), 2);
223 assert!(idx.find("fs", "*").is_empty());
224 }
225
226 #[test]
227 fn remove_unknown_server_is_noop() {
228 let idx = InMemoryToolCapabilityIndex::new();
229 assert_eq!(idx.remove("ghost"), 0);
230 }
231
232 #[test]
233 fn dotless_tool_names_match_empty_namespace() {
234 let idx = InMemoryToolCapabilityIndex::new();
235 idx.upsert("misc", vec![td("misc", "ping")]);
236 let hits = idx.find("", "ping");
237 assert_eq!(hits.len(), 1);
238 assert_eq!(hits[0].name, "ping");
239 }
240
241 #[test]
242 fn snapshot_returns_all_tools() {
243 let idx = InMemoryToolCapabilityIndex::new();
244 idx.upsert("fs", vec![td("fs", "fs.read_text_file")]);
245 idx.upsert("git", vec![td("git", "git.commit")]);
246 assert_eq!(idx.snapshot().len(), 2);
247 }
248
249 #[test]
250 fn top_k_ranks_keyword_matches_first() {
251 let idx = InMemoryToolCapabilityIndex::new();
252 idx.upsert(
253 "fs",
254 vec![
255 td_desc("fs", "fs.read_text_file", "Read the contents of a file"),
256 td_desc("fs", "fs.write_file", "Write data to a file"),
257 ],
258 );
259 idx.upsert(
260 "web",
261 vec![td_desc("web", "web.search", "Search the web for a query")],
262 );
263
264 let hits = idx.top_k("search the web", 2);
265 assert_eq!(hits.len(), 2);
266 assert_eq!(hits[0].name, "web.search");
268 }
269
270 #[test]
271 fn top_k_caps_result_count() {
272 let idx = InMemoryToolCapabilityIndex::new();
273 idx.upsert(
274 "fs",
275 vec![
276 td("fs", "fs.read_text_file"),
277 td("fs", "fs.write_file"),
278 td("fs", "fs.list_dir"),
279 ],
280 );
281 assert_eq!(idx.top_k("file", 2).len(), 2);
282 }
283
284 #[test]
285 fn top_k_zero_returns_empty() {
286 let idx = InMemoryToolCapabilityIndex::new();
287 idx.upsert("fs", vec![td("fs", "fs.read_text_file")]);
288 assert!(idx.top_k("anything", 0).is_empty());
289 }
290
291 #[test]
292 fn top_k_no_match_falls_back_to_filling_slots() {
293 let tools = vec![
296 td("git", "git.commit"),
297 td("fs", "fs.read_text_file"),
298 td("web", "web.search"),
299 ];
300 let hits = score_top_k(tools, "zzz_no_such_term", 2);
301 assert_eq!(hits.len(), 2);
302 assert_eq!(hits[0].name, "fs.read_text_file");
303 assert_eq!(hits[1].name, "git.commit");
304 }
305
306 #[test]
307 fn score_top_k_empty_query_returns_name_sorted_prefix() {
308 let tools = vec![
309 td("web", "web.search"),
310 td("fs", "fs.read_text_file"),
311 td("git", "git.commit"),
312 ];
313 let hits = score_top_k(tools, "", 2);
314 assert_eq!(hits.len(), 2);
315 assert_eq!(hits[0].name, "fs.read_text_file");
316 assert_eq!(hits[1].name, "git.commit");
317 }
318}