Skip to main content

roboticus_agent/
tool_search.rs

1//! Semantic tool search — rank and prune tools before presenting to the LLM.
2//!
3//! Embeds tool descriptions at registration time, ranks against user query
4//! at inference time, and prunes to top-K within a token budget. Uses the
5//! `RankedCandidate` type from `ranking.rs` for convergence with memory retrieval.
6
7use crate::ranking::{CandidateKind, RankedCandidate, top_k_with_pinned};
8use serde::{Deserialize, Serialize};
9
10/// Where a tool came from (simplified for search ranking).
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub enum ToolSourceInfo {
13    BuiltIn,
14    Plugin(String),
15    Mcp { server: String },
16}
17
18/// A tool descriptor with its cached embedding.
19#[derive(Debug, Clone)]
20pub struct ToolDescriptor {
21    pub name: String,
22    pub description: String,
23    pub token_cost: usize,
24    pub source: ToolSourceInfo,
25    pub embedding: Option<Vec<f32>>,
26}
27
28/// Configuration for tool search ranking.
29#[derive(Debug, Clone)]
30pub struct SearchConfig {
31    pub top_k: usize,
32    pub token_budget: usize,
33    pub mcp_latency_penalty: f64,
34    pub always_include: Vec<String>,
35}
36
37impl Default for SearchConfig {
38    fn default() -> Self {
39        Self {
40            top_k: 15,
41            token_budget: 4000,
42            mcp_latency_penalty: 0.05,
43            always_include: vec!["memory_store".into(), "delegate".into()],
44        }
45    }
46}
47
48/// Rank tools by cosine similarity to the query embedding.
49pub fn rank_tools(
50    tools: &[ToolDescriptor],
51    query_embedding: &[f32],
52    config: &SearchConfig,
53) -> Vec<RankedCandidate> {
54    let mut candidates: Vec<RankedCandidate> = tools
55        .iter()
56        .map(|tool| {
57            let raw_score = tool
58                .embedding
59                .as_ref()
60                .map(|emb| cosine_similarity(emb, query_embedding))
61                .unwrap_or(0.0);
62
63            let penalty = match &tool.source {
64                ToolSourceInfo::Mcp { .. } => config.mcp_latency_penalty,
65                _ => 0.0,
66            };
67
68            RankedCandidate {
69                source_id: tool.name.clone(),
70                source_kind: CandidateKind::Tool,
71                raw_score,
72                adjusted_score: (raw_score - penalty).max(0.0),
73                token_cost: tool.token_cost,
74            }
75        })
76        .collect();
77
78    candidates.sort_by(|a, b| {
79        b.adjusted_score
80            .partial_cmp(&a.adjusted_score)
81            .unwrap_or(std::cmp::Ordering::Equal)
82    });
83    candidates
84}
85
86/// Rank and prune to top-K within budget, preserving pinned tools.
87pub fn search_and_prune(
88    tools: &[ToolDescriptor],
89    query_embedding: &[f32],
90    config: &SearchConfig,
91) -> (Vec<RankedCandidate>, ToolSearchStats) {
92    let ranked = rank_tools(tools, query_embedding, config);
93    let total_before = tools.len();
94    let always_refs: Vec<&str> = config.always_include.iter().map(|s| s.as_str()).collect();
95    let pruned = top_k_with_pinned(&ranked, config.top_k, config.token_budget, &always_refs);
96    let total_after = pruned.len();
97
98    let top_scores: Vec<(String, f64)> = pruned
99        .iter()
100        .take(10)
101        .map(|c| (c.source_id.clone(), c.adjusted_score))
102        .collect();
103
104    let stats = ToolSearchStats {
105        candidates_considered: total_before,
106        candidates_selected: total_after,
107        candidates_pruned: total_before - total_after,
108        token_savings: ranked.iter().map(|c| c.token_cost).sum::<usize>()
109            - pruned.iter().map(|c| c.token_cost).sum::<usize>(),
110        top_scores,
111        embedding_status: "ok".to_string(),
112    };
113
114    (pruned, stats)
115}
116
117/// Stats for trace annotation.
118#[derive(Debug, Clone, Serialize)]
119pub struct ToolSearchStats {
120    pub candidates_considered: usize,
121    pub candidates_selected: usize,
122    pub candidates_pruned: usize,
123    pub token_savings: usize,
124    /// Top-10 selected tools with their adjusted ranking scores.
125    /// Empty when embedding failed (graceful degradation).
126    #[serde(default, skip_serializing_if = "Vec::is_empty")]
127    pub top_scores: Vec<(String, f64)>,
128    /// Whether the embedding provider was used successfully.
129    /// `"ok"` = normal ranking, `"failed"` = fallback to static ordering.
130    #[serde(default = "default_embedding_status")]
131    pub embedding_status: String,
132}
133
134#[allow(dead_code)] // Used by serde(default) attribute
135fn default_embedding_status() -> String {
136    "ok".to_string()
137}
138
139fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
140    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
141    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
142    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
143    if norm_a == 0.0 || norm_b == 0.0 {
144        return 0.0;
145    }
146    (dot / (norm_a * norm_b)) as f64
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn rank_tools_returns_sorted_by_adjusted_score() {
155        let tools = vec![
156            ToolDescriptor {
157                name: "web_search".into(),
158                description: "Search the web".into(),
159                token_cost: 50,
160                source: ToolSourceInfo::BuiltIn,
161                embedding: Some(vec![0.9, 0.1, 0.0]),
162            },
163            ToolDescriptor {
164                name: "memory_store".into(),
165                description: "Store a memory".into(),
166                token_cost: 30,
167                source: ToolSourceInfo::BuiltIn,
168                embedding: Some(vec![0.1, 0.9, 0.0]),
169            },
170        ];
171        let query_embedding = vec![0.85, 0.15, 0.0];
172
173        let ranked = rank_tools(&tools, &query_embedding, &SearchConfig::default());
174        assert_eq!(ranked[0].source_id, "web_search");
175    }
176
177    #[test]
178    fn mcp_tools_receive_latency_penalty() {
179        let tools = vec![
180            ToolDescriptor {
181                name: "local_tool".into(),
182                description: "A local tool".into(),
183                token_cost: 50,
184                source: ToolSourceInfo::BuiltIn,
185                embedding: Some(vec![0.9, 0.1]),
186            },
187            ToolDescriptor {
188                name: "server::remote_tool".into(),
189                description: "A remote tool".into(),
190                token_cost: 50,
191                source: ToolSourceInfo::Mcp {
192                    server: "server".into(),
193                },
194                embedding: Some(vec![0.9, 0.1]),
195            },
196        ];
197        let query_embedding = vec![0.9, 0.1];
198
199        let config = SearchConfig {
200            mcp_latency_penalty: 0.1,
201            ..Default::default()
202        };
203        let ranked = rank_tools(&tools, &query_embedding, &config);
204        assert_eq!(ranked[0].source_id, "local_tool");
205    }
206
207    #[test]
208    fn tools_without_embeddings_are_included_unranked() {
209        let tools = vec![ToolDescriptor {
210            name: "no_embedding".into(),
211            description: "No embedding yet".into(),
212            token_cost: 50,
213            source: ToolSourceInfo::BuiltIn,
214            embedding: None,
215        }];
216        let query_embedding = vec![0.9, 0.1];
217
218        let ranked = rank_tools(&tools, &query_embedding, &SearchConfig::default());
219        assert_eq!(ranked.len(), 1);
220        assert_eq!(ranked[0].adjusted_score, 0.0);
221    }
222}