1use crate::ranking::{CandidateKind, RankedCandidate, top_k_with_pinned};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub enum ToolSourceInfo {
13 BuiltIn,
14 Plugin(String),
15 Mcp { server: String },
16}
17
18#[derive(Debug, Clone)]
20pub struct ToolDescriptor {
21 pub name: String,
22 pub description: String,
23 pub token_cost: usize,
24 pub source: ToolSourceInfo,
25 pub embedding: Option<Vec<f32>>,
26}
27
28#[derive(Debug, Clone)]
30pub struct SearchConfig {
31 pub top_k: usize,
32 pub token_budget: usize,
33 pub mcp_latency_penalty: f64,
34 pub always_include: Vec<String>,
35}
36
37impl Default for SearchConfig {
38 fn default() -> Self {
39 Self {
40 top_k: 15,
41 token_budget: 4000,
42 mcp_latency_penalty: 0.05,
43 always_include: vec!["memory_store".into(), "delegate".into()],
44 }
45 }
46}
47
48pub fn rank_tools(
50 tools: &[ToolDescriptor],
51 query_embedding: &[f32],
52 config: &SearchConfig,
53) -> Vec<RankedCandidate> {
54 let mut candidates: Vec<RankedCandidate> = tools
55 .iter()
56 .map(|tool| {
57 let raw_score = tool
58 .embedding
59 .as_ref()
60 .map(|emb| cosine_similarity(emb, query_embedding))
61 .unwrap_or(0.0);
62
63 let penalty = match &tool.source {
64 ToolSourceInfo::Mcp { .. } => config.mcp_latency_penalty,
65 _ => 0.0,
66 };
67
68 RankedCandidate {
69 source_id: tool.name.clone(),
70 source_kind: CandidateKind::Tool,
71 raw_score,
72 adjusted_score: (raw_score - penalty).max(0.0),
73 token_cost: tool.token_cost,
74 }
75 })
76 .collect();
77
78 candidates.sort_by(|a, b| {
79 b.adjusted_score
80 .partial_cmp(&a.adjusted_score)
81 .unwrap_or(std::cmp::Ordering::Equal)
82 });
83 candidates
84}
85
86pub fn search_and_prune(
88 tools: &[ToolDescriptor],
89 query_embedding: &[f32],
90 config: &SearchConfig,
91) -> (Vec<RankedCandidate>, ToolSearchStats) {
92 let ranked = rank_tools(tools, query_embedding, config);
93 let total_before = tools.len();
94 let always_refs: Vec<&str> = config.always_include.iter().map(|s| s.as_str()).collect();
95 let pruned = top_k_with_pinned(&ranked, config.top_k, config.token_budget, &always_refs);
96 let total_after = pruned.len();
97
98 let top_scores: Vec<(String, f64)> = pruned
99 .iter()
100 .take(10)
101 .map(|c| (c.source_id.clone(), c.adjusted_score))
102 .collect();
103
104 let stats = ToolSearchStats {
105 candidates_considered: total_before,
106 candidates_selected: total_after,
107 candidates_pruned: total_before - total_after,
108 token_savings: ranked.iter().map(|c| c.token_cost).sum::<usize>()
109 - pruned.iter().map(|c| c.token_cost).sum::<usize>(),
110 top_scores,
111 embedding_status: "ok".to_string(),
112 };
113
114 (pruned, stats)
115}
116
117#[derive(Debug, Clone, Serialize)]
119pub struct ToolSearchStats {
120 pub candidates_considered: usize,
121 pub candidates_selected: usize,
122 pub candidates_pruned: usize,
123 pub token_savings: usize,
124 #[serde(default, skip_serializing_if = "Vec::is_empty")]
127 pub top_scores: Vec<(String, f64)>,
128 #[serde(default = "default_embedding_status")]
131 pub embedding_status: String,
132}
133
134#[allow(dead_code)] fn default_embedding_status() -> String {
136 "ok".to_string()
137}
138
139fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
140 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
141 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
142 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
143 if norm_a == 0.0 || norm_b == 0.0 {
144 return 0.0;
145 }
146 (dot / (norm_a * norm_b)) as f64
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn rank_tools_returns_sorted_by_adjusted_score() {
155 let tools = vec![
156 ToolDescriptor {
157 name: "web_search".into(),
158 description: "Search the web".into(),
159 token_cost: 50,
160 source: ToolSourceInfo::BuiltIn,
161 embedding: Some(vec![0.9, 0.1, 0.0]),
162 },
163 ToolDescriptor {
164 name: "memory_store".into(),
165 description: "Store a memory".into(),
166 token_cost: 30,
167 source: ToolSourceInfo::BuiltIn,
168 embedding: Some(vec![0.1, 0.9, 0.0]),
169 },
170 ];
171 let query_embedding = vec![0.85, 0.15, 0.0];
172
173 let ranked = rank_tools(&tools, &query_embedding, &SearchConfig::default());
174 assert_eq!(ranked[0].source_id, "web_search");
175 }
176
177 #[test]
178 fn mcp_tools_receive_latency_penalty() {
179 let tools = vec![
180 ToolDescriptor {
181 name: "local_tool".into(),
182 description: "A local tool".into(),
183 token_cost: 50,
184 source: ToolSourceInfo::BuiltIn,
185 embedding: Some(vec![0.9, 0.1]),
186 },
187 ToolDescriptor {
188 name: "server::remote_tool".into(),
189 description: "A remote tool".into(),
190 token_cost: 50,
191 source: ToolSourceInfo::Mcp {
192 server: "server".into(),
193 },
194 embedding: Some(vec![0.9, 0.1]),
195 },
196 ];
197 let query_embedding = vec![0.9, 0.1];
198
199 let config = SearchConfig {
200 mcp_latency_penalty: 0.1,
201 ..Default::default()
202 };
203 let ranked = rank_tools(&tools, &query_embedding, &config);
204 assert_eq!(ranked[0].source_id, "local_tool");
205 }
206
207 #[test]
208 fn tools_without_embeddings_are_included_unranked() {
209 let tools = vec![ToolDescriptor {
210 name: "no_embedding".into(),
211 description: "No embedding yet".into(),
212 token_cost: 50,
213 source: ToolSourceInfo::BuiltIn,
214 embedding: None,
215 }];
216 let query_embedding = vec![0.9, 0.1];
217
218 let ranked = rank_tools(&tools, &query_embedding, &SearchConfig::default());
219 assert_eq!(ranked.len(), 1);
220 assert_eq!(ranked[0].adjusted_score, 0.0);
221 }
222}