use crate::ranking::{CandidateKind, RankedCandidate, top_k_with_pinned};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolSourceInfo {
BuiltIn,
Plugin(String),
Mcp { server: String },
}
#[derive(Debug, Clone)]
pub struct ToolDescriptor {
pub name: String,
pub description: String,
pub token_cost: usize,
pub source: ToolSourceInfo,
pub embedding: Option<Vec<f32>>,
}
#[derive(Debug, Clone)]
pub struct SearchConfig {
pub top_k: usize,
pub token_budget: usize,
pub mcp_latency_penalty: f64,
pub always_include: Vec<String>,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
top_k: 15,
token_budget: 4000,
mcp_latency_penalty: 0.05,
always_include: vec!["memory_store".into(), "delegate".into()],
}
}
}
pub fn rank_tools(
tools: &[ToolDescriptor],
query_embedding: &[f32],
config: &SearchConfig,
) -> Vec<RankedCandidate> {
let mut candidates: Vec<RankedCandidate> = tools
.iter()
.map(|tool| {
let raw_score = tool
.embedding
.as_ref()
.map(|emb| cosine_similarity(emb, query_embedding))
.unwrap_or(0.0);
let penalty = match &tool.source {
ToolSourceInfo::Mcp { .. } => config.mcp_latency_penalty,
_ => 0.0,
};
RankedCandidate {
source_id: tool.name.clone(),
source_kind: CandidateKind::Tool,
raw_score,
adjusted_score: (raw_score - penalty).max(0.0),
token_cost: tool.token_cost,
}
})
.collect();
candidates.sort_by(|a, b| {
b.adjusted_score
.partial_cmp(&a.adjusted_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates
}
pub fn search_and_prune(
tools: &[ToolDescriptor],
query_embedding: &[f32],
config: &SearchConfig,
) -> (Vec<RankedCandidate>, ToolSearchStats) {
let ranked = rank_tools(tools, query_embedding, config);
let total_before = tools.len();
let always_refs: Vec<&str> = config.always_include.iter().map(|s| s.as_str()).collect();
let pruned = top_k_with_pinned(&ranked, config.top_k, config.token_budget, &always_refs);
let total_after = pruned.len();
let top_scores: Vec<(String, f64)> = pruned
.iter()
.take(10)
.map(|c| (c.source_id.clone(), c.adjusted_score))
.collect();
let stats = ToolSearchStats {
candidates_considered: total_before,
candidates_selected: total_after,
candidates_pruned: total_before - total_after,
token_savings: ranked.iter().map(|c| c.token_cost).sum::<usize>()
- pruned.iter().map(|c| c.token_cost).sum::<usize>(),
top_scores,
embedding_status: "ok".to_string(),
};
(pruned, stats)
}
#[derive(Debug, Clone, Serialize)]
pub struct ToolSearchStats {
pub candidates_considered: usize,
pub candidates_selected: usize,
pub candidates_pruned: usize,
pub token_savings: usize,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub top_scores: Vec<(String, f64)>,
#[serde(default = "default_embedding_status")]
pub embedding_status: String,
}
#[allow(dead_code)] fn default_embedding_status() -> String {
"ok".to_string()
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
(dot / (norm_a * norm_b)) as f64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rank_tools_returns_sorted_by_adjusted_score() {
let tools = vec![
ToolDescriptor {
name: "web_search".into(),
description: "Search the web".into(),
token_cost: 50,
source: ToolSourceInfo::BuiltIn,
embedding: Some(vec![0.9, 0.1, 0.0]),
},
ToolDescriptor {
name: "memory_store".into(),
description: "Store a memory".into(),
token_cost: 30,
source: ToolSourceInfo::BuiltIn,
embedding: Some(vec![0.1, 0.9, 0.0]),
},
];
let query_embedding = vec![0.85, 0.15, 0.0];
let ranked = rank_tools(&tools, &query_embedding, &SearchConfig::default());
assert_eq!(ranked[0].source_id, "web_search");
}
#[test]
fn mcp_tools_receive_latency_penalty() {
let tools = vec![
ToolDescriptor {
name: "local_tool".into(),
description: "A local tool".into(),
token_cost: 50,
source: ToolSourceInfo::BuiltIn,
embedding: Some(vec![0.9, 0.1]),
},
ToolDescriptor {
name: "server::remote_tool".into(),
description: "A remote tool".into(),
token_cost: 50,
source: ToolSourceInfo::Mcp {
server: "server".into(),
},
embedding: Some(vec![0.9, 0.1]),
},
];
let query_embedding = vec![0.9, 0.1];
let config = SearchConfig {
mcp_latency_penalty: 0.1,
..Default::default()
};
let ranked = rank_tools(&tools, &query_embedding, &config);
assert_eq!(ranked[0].source_id, "local_tool");
}
#[test]
fn tools_without_embeddings_are_included_unranked() {
let tools = vec![ToolDescriptor {
name: "no_embedding".into(),
description: "No embedding yet".into(),
token_cost: 50,
source: ToolSourceInfo::BuiltIn,
embedding: None,
}];
let query_embedding = vec![0.9, 0.1];
let ranked = rank_tools(&tools, &query_embedding, &SearchConfig::default());
assert_eq!(ranked.len(), 1);
assert_eq!(ranked[0].adjusted_score, 0.0);
}
}