use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt::Write as _;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum DisclosureDepth {
Minimal,
Summary,
Parameters,
Full,
}
struct ToolEntry {
name: String,
namespace: String,
description: String,
tags: Vec<String>,
example_queries: Vec<String>,
schema_json: Option<String>,
call_count: u64,
#[allow(dead_code)] embedding: Vec<f32>,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ToolSearchResult {
pub name: String,
pub namespace: String,
pub score: f32,
pub rendered: String,
pub nearest_namespace: Option<String>,
pub alternative_keywords: Vec<String>,
pub confidence_level: String,
}
pub struct ToolSearchIndex {
entries: Vec<ToolEntry>,
loaded_schemas: HashSet<String>,
registry_hash: String,
}
impl Default for ToolSearchIndex {
fn default() -> Self {
Self::new()
}
}
impl ToolSearchIndex {
pub fn new() -> Self {
Self {
entries: Vec::new(),
loaded_schemas: HashSet::new(),
registry_hash: String::new(),
}
}
pub fn register(
&mut self,
name: &str,
namespace: &str,
description: &str,
tags: &[&str],
schema_json: Option<&str>,
) {
self.entries.retain(|e| e.name != name);
let embedding = compute_embedding(name, description);
self.entries.push(ToolEntry {
name: name.to_owned(),
namespace: namespace.to_owned(),
description: description.to_owned(),
tags: tags.iter().map(|&t| t.to_owned()).collect(),
example_queries: Vec::new(),
schema_json: schema_json.map(str::to_owned),
call_count: 0,
embedding,
});
self.recompute_hash();
}
pub fn search(&mut self, query: &str, top_k: usize) -> Vec<ToolSearchResult> {
let query_words: HashSet<&str> = query.split_whitespace().collect();
let mut scored: Vec<(usize, f32)> = self
.entries
.iter()
.enumerate()
.map(|(i, entry)| {
let mut score = keyword_score(entry, &query_words);
if self.loaded_schemas.contains(&entry.name) {
score *= 0.8;
}
(i, score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
let top_namespace = scored
.first()
.map(|(i, _)| self.entries[*i].namespace.clone());
let alternative_keywords: Vec<String> = scored
.iter()
.flat_map(|(i, _)| self.entries[*i].tags.iter().cloned())
.filter(|t| {
!query_words
.iter()
.any(|&w| w.eq_ignore_ascii_case(t.as_str()))
})
.collect::<HashSet<_>>()
.into_iter()
.take(5)
.collect();
let results: Vec<ToolSearchResult> = scored
.into_iter()
.map(|(i, score)| {
let entry = &self.entries[i];
let confidence_level = if score > 5.0 {
"high"
} else if score > 2.0 {
"medium"
} else {
"low"
};
let nearest_namespace = if confidence_level == "low" {
top_namespace.clone()
} else {
None
};
let rendered = render(entry, DisclosureDepth::Summary);
ToolSearchResult {
name: entry.name.clone(),
namespace: entry.namespace.clone(),
score,
rendered,
nearest_namespace,
alternative_keywords: alternative_keywords.clone(),
confidence_level: confidence_level.to_owned(),
}
})
.collect();
for r in &results {
let _ = self.loaded_schemas.insert(r.name.clone());
}
results
}
pub fn browse_namespace(&self, namespace: &str) -> Vec<ToolSearchResult> {
self.entries
.iter()
.filter(|e| e.namespace == namespace)
.map(|e| ToolSearchResult {
name: e.name.clone(),
namespace: e.namespace.clone(),
score: 1.0,
rendered: render(e, DisclosureDepth::Summary),
nearest_namespace: None,
alternative_keywords: Vec::new(),
confidence_level: "high".to_owned(),
})
.collect()
}
pub fn list_compact(&self) -> Vec<(String, String)> {
self.entries
.iter()
.map(|e| (e.name.clone(), e.description.clone()))
.collect()
}
pub fn record_success(&mut self, query: &str, tool_name: &str) {
if let Some(entry) = self.entries.iter_mut().find(|e| e.name == tool_name) {
entry.call_count = entry.call_count.saturating_add(1);
if !entry.example_queries.iter().any(|q| q == query) && entry.example_queries.len() < 10
{
entry.example_queries.push(query.to_owned());
}
}
}
pub fn search_progressive(
&mut self,
query: &str,
steps: usize,
per_step_k: usize,
) -> Vec<ToolSearchResult> {
let mut seen: HashSet<String> = HashSet::new();
let mut all_results: Vec<ToolSearchResult> = Vec::new();
let mut remaining_query = query.to_owned();
for _ in 0..steps {
let step_results = self.search(&remaining_query, per_step_k);
for r in step_results {
if seen.insert(r.name.clone()) {
all_results.push(r);
}
}
let found_names: Vec<&str> = all_results.iter().map(|r| r.name.as_str()).collect();
remaining_query = format!("{query} -{}", found_names.join(" -"));
}
all_results
}
pub fn registry_hash(&self) -> &str {
&self.registry_hash
}
fn recompute_hash(&mut self) {
use sha2::{Digest, Sha256};
let sorted: BTreeMap<&str, &str> = self
.entries
.iter()
.map(|e| (e.name.as_str(), e.description.as_str()))
.collect();
let data = serde_json::to_string(&sorted).unwrap_or_else(|_| format!("{sorted:?}"));
let mut hasher = Sha256::new();
hasher.update(data.as_bytes());
self.registry_hash = format!("{:x}", hasher.finalize());
}
}
const fn compute_embedding(_name: &str, _description: &str) -> Vec<f32> {
Vec::new()
}
fn keyword_score(entry: &ToolEntry, query_words: &HashSet<&str>) -> f32 {
let name_words: HashSet<&str> = entry.name.split(['-', '_', ' ']).collect();
let desc_words: HashSet<&str> = entry.description.split_whitespace().collect();
let ns_words: HashSet<&str> = entry.namespace.split(['-', '_']).collect();
#[allow(clippy::cast_precision_loss)]
let ns_score: f32 = ns_words.intersection(query_words).count() as f32 * 5.0;
#[allow(clippy::cast_precision_loss)]
let name_score: f32 = name_words.intersection(query_words).count() as f32 * 3.0;
#[allow(clippy::cast_precision_loss)]
let desc_score: f32 = desc_words.intersection(query_words).count() as f32 * 2.0;
#[allow(clippy::cast_precision_loss)]
let tag_score: f32 = entry
.tags
.iter()
.flat_map(|t| t.split(['-', '_', ' ']))
.filter(|w| query_words.contains(w))
.count() as f32
* 1.5;
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
let freq_boost = ((1.0 + entry.call_count as f64).log2() * 0.5).min(3.0) as f32;
ns_score + name_score + desc_score + tag_score + freq_boost
}
fn render(entry: &ToolEntry, depth: DisclosureDepth) -> String {
match depth {
DisclosureDepth::Minimal => entry.name.clone(),
DisclosureDepth::Summary => {
let desc = if entry.description.len() > 100 {
format!("{}…", &entry.description[..100])
} else {
entry.description.clone()
};
format!("{}: {desc}", entry.name)
}
DisclosureDepth::Parameters => {
let summary = if entry.description.len() > 100 {
format!("{}…", &entry.description[..100])
} else {
entry.description.clone()
};
let params = extract_parameter_names(entry.schema_json.as_deref());
if params.is_empty() {
format!("{}: {summary}", entry.name)
} else {
format!("{}: {summary} (params: {})", entry.name, params.join(", "))
}
}
DisclosureDepth::Full => {
let mut out = format!("name: {}\ndescription: {}\n", entry.name, entry.description);
if let Some(ref schema) = entry.schema_json {
out.push_str("schema: ");
out.push_str(schema);
}
out
}
}
}
fn extract_parameter_names(schema_json: Option<&str>) -> Vec<String> {
let Some(json) = schema_json else {
return Vec::new();
};
let Ok(value) = serde_json::from_str::<serde_json::Value>(json) else {
return Vec::new();
};
value
.get("properties")
.and_then(|p| p.as_object())
.map(|props| props.keys().cloned().collect())
.unwrap_or_default()
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ToolSearchArgs {
pub query: Option<String>,
pub namespace: Option<String>,
pub top_k: Option<usize>,
}
impl ToolSearchArgs {
pub const fn new() -> Self {
Self {
query: None,
namespace: None,
top_k: None,
}
}
#[must_use]
pub fn with_query(mut self, query: impl Into<String>) -> Self {
self.query = Some(query.into());
self
}
#[must_use]
pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
self.namespace = Some(namespace.into());
self
}
#[must_use]
pub const fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = Some(top_k);
self
}
}
impl Default for ToolSearchArgs {
fn default() -> Self {
Self::new()
}
}
pub fn run_tool_search(index: &mut ToolSearchIndex, args: &ToolSearchArgs) -> String {
let top_k = args.top_k.unwrap_or(5);
if let Some(ref ns) = args.namespace {
let results = index.browse_namespace(ns);
if results.is_empty() {
return format!("No tools found in namespace '{ns}'.");
}
let mut out = format!("Tools in namespace '{ns}':\n");
for r in &results {
let _ = writeln!(out, " - {}", r.rendered);
}
return out;
}
let query = match &args.query {
Some(q) => q.clone(),
None => return "Provide either 'query' or 'namespace'.".to_owned(),
};
let results = index.search(&query, top_k);
if results.is_empty() {
return format!("No tools matched '{query}'.");
}
let mut out = format!("Tool search results for '{query}':\n");
for r in &results {
let _ = writeln!(out, " [{}] {}", r.confidence_level, r.rendered);
}
out
}
pub fn run_tool_list(index: &ToolSearchIndex) -> String {
let mut grouped: BTreeMap<&str, Vec<(&str, &str)>> = BTreeMap::new();
for entry in &index.entries {
grouped
.entry(entry.namespace.as_str())
.or_default()
.push((entry.name.as_str(), entry.description.as_str()));
}
let mut out = String::new();
for (ns, tools) in &grouped {
out.push_str(ns);
out.push_str(":\n");
for (name, desc) in tools {
let short_desc = if desc.len() > 80 {
format!("{}…", &desc[..80])
} else {
(*desc).to_owned()
};
let _ = writeln!(out, " - {name}: {short_desc}");
}
}
out
}
pub fn allocate_budget(results: &[ToolSearchResult]) -> String {
const TOKEN_CAP: usize = 5_000;
const CHARS_PER_TOKEN: usize = 4;
let mut out = String::new();
let mut tokens_used: usize = 0;
for (i, r) in results.iter().enumerate() {
let depth_label = if i < 5 {
"full"
} else if i < 15 {
"summary"
} else {
"minimal"
};
let line = format!("[{depth_label}] {} (score={:.2})\n", r.rendered, r.score);
tokens_used += line.len() / CHARS_PER_TOKEN;
if tokens_used > TOKEN_CAP {
break;
}
out.push_str(&line);
}
out
}
pub fn verify_parameter_types(results: &mut [ToolSearchResult], query: &str) {
let query_lower = query.to_lowercase();
for r in results.iter_mut() {
let looks_like_path = query_lower.contains('/')
|| query_lower.contains(".rs")
|| query_lower.contains(".py")
|| query_lower.contains("file");
let is_file_tool = r.name.contains("read")
|| r.name.contains("write")
|| r.name.contains("file")
|| r.name.contains("glob")
|| r.namespace == "vfs";
if looks_like_path && !is_file_tool {
r.score *= 0.7;
}
let looks_like_symbol = query_lower.contains("function")
|| query_lower.contains("method")
|| query_lower.contains("struct")
|| query_lower.contains("class");
let is_lsp_tool =
r.namespace == "lsp" || r.name.contains("symbol") || r.name.contains("goto");
if looks_like_symbol && !is_lsp_tool {
r.score *= 0.85;
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
pub struct ToolTransitionGraph {
transitions: HashMap<String, HashMap<String, f64>>,
half_life: usize,
total_invocations: usize,
}
impl ToolTransitionGraph {
pub fn new(half_life: usize) -> Self {
Self {
transitions: HashMap::new(),
half_life,
total_invocations: 0,
}
}
pub fn record_transition(&mut self, from: &str, to: &str) {
let _ = self
.transitions
.entry(from.to_owned())
.or_default()
.entry(to.to_owned())
.and_modify(|c| *c += 1.0)
.or_insert(1.0);
self.total_invocations += 1;
}
pub fn successors(&self, current: &str) -> Vec<(String, f32)> {
let Some(counts) = self.transitions.get(current) else {
return Vec::new();
};
let total: f64 = counts.values().sum();
let exponent = self.total_invocations / self.half_life.max(1);
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let decay = 0.5_f64.powi(exponent as i32);
let mut results: Vec<(String, f32)> = counts
.iter()
.map(|(name, count)| {
#[allow(clippy::cast_possible_truncation)]
let score = ((count / total) * decay) as f32;
(name.clone(), score)
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
}
pub trait QueryPreprocessor: Send + Sync {
fn preprocess<'a>(&self, query: &'a str) -> Cow<'a, str>;
}
pub struct IntentExtractor;
impl QueryPreprocessor for IntentExtractor {
fn preprocess<'a>(&self, query: &'a str) -> Cow<'a, str> {
const STOP_WORDS: &[&str] = &[
"the", "a", "an", "in", "for", "to", "of", "that", "which", "with", "from",
];
let words: Vec<&str> = query.split_whitespace().collect();
if words.len() <= 5 {
return Cow::Borrowed(query);
}
let content_words: Vec<&str> = words
.iter()
.copied()
.filter(|w| !STOP_WORDS.contains(&w.to_lowercase().as_str()))
.take(5)
.collect();
Cow::Owned(content_words.join(" "))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::cast_precision_loss)]
mod tests {
use super::*;
fn sample_index() -> ToolSearchIndex {
let mut idx = ToolSearchIndex::new();
idx.register(
"read_file",
"vfs",
"Read the contents of a file",
&["file", "read"],
None,
);
idx.register(
"search_code",
"index",
"Semantic code search using embeddings",
&["search", "semantic"],
None,
);
idx.register(
"list_dir",
"vfs",
"List directory contents",
&["ls", "directory"],
None,
);
idx
}
#[test]
fn tool_search_finds_by_query() {
let mut idx = sample_index();
let results = idx.search("read file contents", 3);
assert!(!results.is_empty());
assert_eq!(results[0].name, "read_file");
}
#[test]
fn namespace_browse_returns_all() {
let mut idx = ToolSearchIndex::new();
idx.register("read_file", "vfs", "Read file", &[], None);
idx.register("write_file", "vfs", "Write file", &[], None);
idx.register("search_code", "index", "Search code", &[], None);
let vfs_tools = idx.browse_namespace("vfs");
assert_eq!(vfs_tools.len(), 2);
}
#[test]
fn exact_name_match_ranks_first() {
let mut idx = ToolSearchIndex::new();
idx.register(
"search_code",
"index",
"Semantic code search",
&["search"],
None,
);
idx.register("read_file", "vfs", "Read file", &["file"], None);
idx.register("list_dir", "vfs", "List directory contents", &[], None);
let results = idx.search("search code", 3);
assert!(!results.is_empty());
assert_eq!(results[0].name, "search_code");
}
#[test]
fn adaptive_scoring_penalises_loaded_schemas() {
let mut idx = ToolSearchIndex::new();
idx.register(
"find_func",
"lsp",
"Find function definition",
&["function", "find"],
None,
);
let first_results = idx.search("find function", 1);
assert!(!first_results.is_empty());
let score_before = first_results[0].score;
let second_results = idx.search("find function", 1);
assert!(!second_results.is_empty());
let score_after = second_results[0].score;
assert!(
score_after < score_before,
"expected penalised score {score_after} < {score_before}"
);
}
#[test]
fn registry_hash_changes_on_registration() {
let mut idx = ToolSearchIndex::new();
let h0 = idx.registry_hash().to_owned();
idx.register("read_file", "vfs", "Read file", &[], None);
let h1 = idx.registry_hash().to_owned();
idx.register("write_file", "vfs", "Write file", &[], None);
let h2 = idx.registry_hash().to_owned();
assert_ne!(h0, h1);
assert_ne!(h1, h2);
}
#[test]
fn record_success_capped() {
let mut idx = ToolSearchIndex::new();
idx.register("read_file", "vfs", "Read file", &[], None);
for i in 0..15 {
idx.record_success(&format!("query {i}"), "read_file");
}
let entry = idx.entries.iter().find(|e| e.name == "read_file").unwrap();
assert_eq!(entry.example_queries.len(), 10);
}
#[test]
fn record_success_no_duplicates() {
let mut idx = ToolSearchIndex::new();
idx.register("read_file", "vfs", "Read file", &[], None);
for _ in 0..5 {
idx.record_success("read a file", "read_file");
}
let entry = idx.entries.iter().find(|e| e.name == "read_file").unwrap();
assert_eq!(entry.example_queries.len(), 1);
}
#[test]
fn progressive_retrieval_deduplicates() {
let mut idx = ToolSearchIndex::new();
idx.register("read_file", "vfs", "Read file contents", &["file"], None);
idx.register("write_file", "vfs", "Write file contents", &["file"], None);
idx.register(
"search_code",
"index",
"Search code semantically",
&["search"],
None,
);
let results = idx.search_progressive("file operations", 2, 2);
let names: Vec<&String> = results.iter().map(|r| &r.name).collect();
let unique_names: HashSet<&String> = names.iter().copied().collect();
assert_eq!(names.len(), unique_names.len());
}
#[test]
fn transition_graph_boosts_successors() {
let mut g = ToolTransitionGraph::new(100);
g.record_transition("read_file", "write_file");
g.record_transition("read_file", "write_file");
g.record_transition("read_file", "search_code");
let successors = g.successors("read_file");
assert!(!successors.is_empty());
assert_eq!(successors[0].0, "write_file");
}
#[test]
fn intent_extractor_shortens_long_query() {
let extractor = IntentExtractor;
let long = "I need to find the function that handles authentication in the codebase";
let short = extractor.preprocess(long);
assert!(short.split_whitespace().count() <= 5);
}
#[test]
fn intent_extractor_passthrough_short_query() {
let extractor = IntentExtractor;
let q = "read file";
let result = extractor.preprocess(q);
assert_eq!(result, q);
}
#[test]
fn parameter_verification_demotes_non_file_tools() {
let mut results = vec![
ToolSearchResult {
name: "go_to_definition".to_owned(),
namespace: "lsp".to_owned(),
score: 6.0,
rendered: String::new(),
nearest_namespace: None,
alternative_keywords: Vec::new(),
confidence_level: "high".to_owned(),
},
ToolSearchResult {
name: "read_file".to_owned(),
namespace: "vfs".to_owned(),
score: 5.0,
rendered: String::new(),
nearest_namespace: None,
alternative_keywords: Vec::new(),
confidence_level: "high".to_owned(),
},
];
verify_parameter_types(&mut results, "read /src/main.rs file");
assert_eq!(results[0].name, "read_file");
}
#[test]
fn run_tool_list_grouped_output() {
let mut idx = ToolSearchIndex::new();
idx.register("read_file", "vfs", "Read file", &[], None);
idx.register("write_file", "vfs", "Write file", &[], None);
idx.register("search_code", "index", "Search code", &[], None);
let output = run_tool_list(&idx);
assert!(output.contains("vfs:"));
assert!(output.contains("index:"));
}
#[test]
fn allocate_budget_produces_output() {
let results: Vec<ToolSearchResult> = (0..20)
.map(|i| ToolSearchResult {
name: format!("tool_{i}"),
namespace: "vfs".to_owned(),
score: 10.0 - i as f32,
rendered: format!("tool_{i}: does something useful"),
nearest_namespace: None,
alternative_keywords: Vec::new(),
confidence_level: "high".to_owned(),
})
.collect();
let output = allocate_budget(&results);
assert!(!output.is_empty());
assert!(output.contains("[full]"));
assert!(output.contains("[summary]"));
}
}