use std::collections::HashMap;
use serde_json::Value;
use crate::search::scoring::{Bm25Params, bm25_score};
use crate::search::tokenizer;
const PARAMS: Bm25Params = Bm25Params { k1: 1.5, b: 0.2 };
#[derive(Debug, Clone)]
struct ToolEntry {
pub id: String,
pub source: ToolSource,
pub tf: HashMap<String, u32>,
pub token_count: u32,
pub schema: Option<Value>,
pub description: String,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ToolSource {
Mcp,
Skill,
Agent,
}
#[derive(Debug, Clone)]
pub struct ToolSearchResult {
pub id: String,
pub source: ToolSource,
pub score: f64,
pub description: String,
pub schema: Option<Value>,
}
pub struct ToolIndex {
entries: Vec<ToolEntry>,
doc_freq: HashMap<String, u32>,
total_tokens: u64,
}
impl Default for ToolIndex {
fn default() -> Self {
Self::new()
}
}
impl ToolIndex {
pub fn new() -> Self {
Self {
entries: Vec::new(),
doc_freq: HashMap::new(),
total_tokens: 0,
}
}
fn add_entry(
&mut self,
id: String,
source: ToolSource,
name: &str,
description: &str,
instructions: Option<&str>,
schema: Option<Value>,
) {
let mut combined = format!("{} {}", name, description);
if let Some(inst) = instructions {
let truncated = crate::util::truncate_bytes(inst, 1500);
combined.push(' ');
combined.push_str(truncated);
}
let tf = tokenizer::tokenize_text(&combined);
let token_count: u32 = tf.values().sum();
for key in tf.keys() {
*self.doc_freq.entry(key.clone()).or_default() += 1;
}
self.total_tokens += token_count as u64;
self.entries.push(ToolEntry {
id,
source,
tf,
token_count,
schema,
description: description.to_string(),
});
}
pub fn remove_entry(&mut self, id: &str) -> bool {
if let Some(pos) = self.entries.iter().position(|e| e.id == id) {
let entry = self.entries.remove(pos);
self.total_tokens = self.total_tokens.saturating_sub(entry.token_count as u64);
for key in entry.tf.keys() {
if let Some(count) = self.doc_freq.get_mut(key) {
*count = count.saturating_sub(1);
if *count == 0 {
self.doc_freq.remove(key);
}
}
}
true
} else {
false
}
}
pub fn upsert_entry(
&mut self,
id: String,
source: ToolSource,
name: &str,
description: &str,
instructions: Option<&str>,
schema: Option<Value>,
) {
self.remove_entry(&id);
self.add_entry(id, source, name, description, instructions, schema);
}
pub fn remove_source(&mut self, source: ToolSource) {
let ids: Vec<String> = self
.entries
.iter()
.filter(|e| e.source == source)
.map(|e| e.id.clone())
.collect();
for id in ids {
self.remove_entry(&id);
}
}
pub fn reindex_mcp_tools(&mut self, manager: &crate::mcp::manager::McpManager) {
self.remove_source(ToolSource::Mcp);
for (server_name, meta) in manager.server_meta() {
for tool_name in &meta.tool_names {
let (desc, schema) = manager
.tool_definitions()
.iter()
.find(|def| def["function"]["name"].as_str() == Some(tool_name.as_str()))
.map(|def| {
let d = def["function"]["description"]
.as_str()
.unwrap_or("")
.to_string();
(d, Some(def.clone()))
})
.unwrap_or_default();
self.upsert_entry(
tool_name.clone(),
ToolSource::Mcp,
tool_name,
&desc,
meta.instructions.as_deref(),
schema,
);
let _ = server_name;
}
}
}
pub fn reindex_skills(&mut self, registry: &crate::skills::SkillRegistry) {
self.remove_source(ToolSource::Skill);
self.index_skills(registry);
}
pub fn reindex_agents(&mut self, agents: &[crate::config::AgentDef]) {
self.remove_source(ToolSource::Agent);
self.index_agents(agents);
}
pub fn index_mcp_tools(&mut self, manager: &crate::mcp::manager::McpManager) {
for (server_name, meta) in manager.server_meta() {
for tool_name in &meta.tool_names {
let (desc, schema) = manager
.tool_definitions()
.iter()
.find(|def| def["function"]["name"].as_str() == Some(tool_name.as_str()))
.map(|def| {
let d = def["function"]["description"]
.as_str()
.unwrap_or("")
.to_string();
(d, Some(def.clone()))
})
.unwrap_or_default();
self.add_entry(
tool_name.clone(),
ToolSource::Mcp,
tool_name,
&desc,
meta.instructions.as_deref(),
schema,
);
let _ = server_name; }
}
}
pub fn index_skills(&mut self, registry: &crate::skills::SkillRegistry) {
for skill in registry.all() {
let extra = format!("{} {}", skill.tags.join(" "), skill.body_excerpt);
self.add_entry(
skill.name.clone(),
ToolSource::Skill,
&skill.name,
&skill.description,
Some(extra.trim()),
None,
);
}
}
pub fn index_agents(&mut self, agents: &[crate::config::AgentDef]) {
for agent in agents {
let desc = agent
.description
.as_deref()
.unwrap_or("general-purpose coding agent");
let tags_str = agent.tags.join(" ");
let prompt_len = agent.system_prompt.len().min(1000);
let mut prompt_end = prompt_len;
while prompt_end > 0 && !agent.system_prompt.is_char_boundary(prompt_end) {
prompt_end -= 1;
}
let prompt_excerpt = &agent.system_prompt[..prompt_end];
let extra = format!("{} {}", tags_str, prompt_excerpt);
self.add_entry(
agent.name.clone(),
ToolSource::Agent,
&agent.name,
desc,
Some(extra.trim()),
None,
);
}
}
pub fn search(&self, query: &str, max_results: usize) -> Vec<ToolSearchResult> {
if self.entries.is_empty() {
return Vec::new();
}
let query_tokens = tokenizer::tokenize_query(query);
if query_tokens.is_empty() {
return Vec::new();
}
let avg_dl = if self.entries.is_empty() {
1.0
} else {
self.total_tokens as f64 / self.entries.len() as f64
};
let n = self.entries.len();
let mut scored: Vec<(&ToolEntry, f64)> = self
.entries
.iter()
.filter_map(|entry| {
let score = bm25_score(
&entry.tf,
entry.token_count,
&query_tokens,
&self.doc_freq,
n,
avg_dl,
&PARAMS,
);
if score > 0.0 {
Some((entry, score))
} else {
None
}
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(max_results);
scored
.into_iter()
.map(|(entry, score)| ToolSearchResult {
id: entry.id.clone(),
source: entry.source,
score,
description: entry.description.clone(),
schema: entry.schema.clone(),
})
.collect()
}
pub fn search_by_source(
&self,
query: &str,
source: ToolSource,
max_results: usize,
) -> Vec<ToolSearchResult> {
let all = self.search(query, max_results * 4);
all.into_iter()
.filter(|r| r.source == source)
.take(max_results)
.collect()
}
pub fn auto_select(&self, user_prompt: &str, max_k: usize) -> Vec<Value> {
let results = self.search(user_prompt, max_k * 2);
if results.is_empty() {
return Vec::new();
}
let mut selected = Vec::new();
let mut prev_score = results[0].score;
for result in &results {
if selected.len() >= max_k {
break;
}
if !selected.is_empty() && result.score < prev_score * 0.4 {
break;
}
if let Some(ref schema) = result.schema {
selected.push(schema.clone());
}
prev_score = result.score;
}
selected
}
pub fn entry_count(&self) -> usize {
self.entries.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_index_with_tools() -> ToolIndex {
let mut idx = ToolIndex::new();
idx.add_entry(
"mcp__ctx7__query-docs".into(),
ToolSource::Mcp,
"mcp__ctx7__query-docs",
"Query library documentation for up-to-date API references",
Some("Use this for looking up framework documentation"),
Some(json!({"type": "function", "function": {"name": "mcp__ctx7__query-docs"}})),
);
idx.add_entry(
"mcp__playwright__navigate".into(),
ToolSource::Mcp,
"mcp__playwright__navigate",
"Navigate browser to a URL for testing",
Some("Browser automation and E2E testing"),
Some(json!({"type": "function", "function": {"name": "mcp__playwright__navigate"}})),
);
idx.add_entry(
"mcp__sequential__think".into(),
ToolSource::Mcp,
"mcp__sequential__think",
"Create a structured thinking step for complex analysis",
Some("Multi-step reasoning for debugging and architecture"),
Some(json!({"type": "function", "function": {"name": "mcp__sequential__think"}})),
);
idx.add_entry(
"git-commit".into(),
ToolSource::Skill,
"git-commit",
"Create conventional commits with proper formatting",
None,
None,
);
idx.add_entry(
"deploy-app".into(),
ToolSource::Skill,
"deploy-app",
"Deploy application to staging or production environment",
None,
None,
);
idx
}
#[test]
fn test_search_finds_relevant_tools() {
let idx = make_index_with_tools();
let results = idx.search("documentation library API", 5);
assert!(!results.is_empty());
assert_eq!(results[0].id, "mcp__ctx7__query-docs");
}
#[test]
fn test_search_browser_testing() {
let idx = make_index_with_tools();
let results = idx.search("browser testing navigate", 5);
assert!(!results.is_empty());
assert_eq!(results[0].id, "mcp__playwright__navigate");
}
#[test]
fn test_search_skill() {
let idx = make_index_with_tools();
let results = idx.search("commit git conventional", 5);
assert!(!results.is_empty());
assert_eq!(results[0].id, "git-commit");
assert_eq!(results[0].source, ToolSource::Skill);
}
#[test]
fn test_search_complex_analysis() {
let idx = make_index_with_tools();
let results = idx.search("debug complex analysis reasoning", 5);
assert!(!results.is_empty());
assert_eq!(results[0].id, "mcp__sequential__think");
}
#[test]
fn test_auto_select_returns_schemas_only() {
let idx = make_index_with_tools();
let schemas = idx.auto_select("documentation API lookup", 5);
assert!(!schemas.is_empty());
for schema in &schemas {
assert!(schema["type"].as_str() == Some("function"));
}
}
#[test]
fn test_auto_select_elbow_detection() {
let idx = make_index_with_tools();
let schemas = idx.auto_select("playwright browser navigate URL", 5);
assert!(!schemas.is_empty());
assert!(schemas.len() <= 3);
}
#[test]
fn test_search_empty_query() {
let idx = make_index_with_tools();
let results = idx.search("", 5);
assert!(results.is_empty());
}
#[test]
fn test_search_no_match() {
let idx = make_index_with_tools();
let results = idx.search("quantum entanglement physics", 5);
assert!(results.is_empty());
}
#[test]
fn test_entry_count() {
let idx = make_index_with_tools();
assert_eq!(idx.entry_count(), 5);
}
#[test]
fn test_remove_entry() {
let mut idx = make_index_with_tools();
assert_eq!(idx.entry_count(), 5);
assert!(idx.remove_entry("git-commit"));
assert_eq!(idx.entry_count(), 4);
let results = idx.search("commit git conventional", 5);
assert!(results.is_empty() || results[0].id != "git-commit");
}
#[test]
fn test_remove_nonexistent() {
let mut idx = make_index_with_tools();
assert!(!idx.remove_entry("nonexistent-tool"));
assert_eq!(idx.entry_count(), 5);
}
#[test]
fn test_upsert_entry() {
let mut idx = make_index_with_tools();
idx.upsert_entry(
"git-commit".into(),
ToolSource::Skill,
"git-commit",
"Create semantic versioning commits with changelog",
None,
None,
);
assert_eq!(idx.entry_count(), 5); let results = idx.search("changelog versioning", 5);
assert!(!results.is_empty());
assert_eq!(results[0].id, "git-commit");
}
#[test]
fn test_remove_source() {
let mut idx = make_index_with_tools();
idx.remove_source(ToolSource::Skill);
assert_eq!(idx.entry_count(), 3); let results = idx.search("commit", 5);
assert!(results.is_empty() || results[0].source != ToolSource::Skill);
}
#[test]
fn test_reindex_skills() {
let mut idx = make_index_with_tools();
let before = idx.entry_count();
idx.remove_source(ToolSource::Skill);
let after_remove = idx.entry_count();
idx.add_entry(
"new-skill".into(),
ToolSource::Skill,
"new-skill",
"A fresh skill added after reindex",
None,
None,
);
assert!(idx.entry_count() > after_remove);
assert!(idx.entry_count() <= before);
}
#[test]
fn test_reindex_agents() {
let mut idx = make_index_with_tools();
idx.add_entry(
"agent:planner".into(),
ToolSource::Agent,
"planner",
"Plans tasks and coordinates work",
None,
None,
);
let with_agent = idx.entry_count();
idx.remove_source(ToolSource::Agent);
assert_eq!(idx.entry_count(), with_agent - 1);
}
}