brainwires_tool_runtime/
tool_embedding.rs1use anyhow::{Context, Result};
7use brainwires_rag::rag::embedding::FastEmbedManager;
8use std::sync::{Arc, OnceLock};
9
10static EMBED_MANAGER: OnceLock<Arc<FastEmbedManager>> = OnceLock::new();
11
12fn get_embed_manager() -> Result<&'static Arc<FastEmbedManager>> {
14 EMBED_MANAGER.get().ok_or(()).or_else(|_| {
15 let manager = Arc::new(FastEmbedManager::new()?);
16 let _ = EMBED_MANAGER.set(manager.clone());
19 Ok::<_, anyhow::Error>(EMBED_MANAGER.get().unwrap())
20 })
21}
22
23pub struct ToolEmbeddingIndex {
28 entries: Vec<ToolEmbeddingEntry>,
30 tool_count: usize,
32}
33
34struct ToolEmbeddingEntry {
35 name: String,
36 embedding: Vec<f32>,
37}
38
39impl ToolEmbeddingIndex {
40 pub fn build(tools: &[(String, String)]) -> Result<Self> {
45 if tools.is_empty() {
46 return Ok(Self {
47 entries: vec![],
48 tool_count: 0,
49 });
50 }
51
52 let manager = get_embed_manager().context("Failed to initialize embedding model")?;
53
54 let texts: Vec<String> = tools
56 .iter()
57 .map(|(name, desc)| format!("{}: {}", name, desc))
58 .collect();
59
60 let embeddings = manager
61 .embed_batch(&texts)
62 .context("Failed to generate tool embeddings")?;
63
64 let entries = tools
65 .iter()
66 .zip(embeddings)
67 .map(|((name, _), embedding)| ToolEmbeddingEntry {
68 name: name.clone(),
69 embedding,
70 })
71 .collect();
72
73 Ok(Self {
74 entries,
75 tool_count: tools.len(),
76 })
77 }
78
79 pub fn search(&self, query: &str, limit: usize, min_score: f32) -> Result<Vec<(String, f32)>> {
84 if self.entries.is_empty() {
85 return Ok(vec![]);
86 }
87
88 let manager = get_embed_manager().context("Failed to get embedding model")?;
89 let query_vec = manager.embed(query).context("Failed to embed query")?;
90 let query_vec = &query_vec;
91
92 let mut scored: Vec<(String, f32)> = self
93 .entries
94 .iter()
95 .map(|entry| {
96 let score = cosine_similarity(query_vec, &entry.embedding);
97 (entry.name.clone(), score)
98 })
99 .filter(|(_, score)| *score >= min_score)
100 .collect();
101
102 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
103 scored.truncate(limit);
104
105 Ok(scored)
106 }
107
108 pub fn tool_count(&self) -> usize {
110 self.tool_count
111 }
112}
113
114fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
116 debug_assert_eq!(a.len(), b.len(), "vectors must have equal dimensions");
117
118 let mut dot = 0.0f32;
119 let mut norm_a = 0.0f32;
120 let mut norm_b = 0.0f32;
121
122 for (ai, bi) in a.iter().zip(b.iter()) {
123 dot += ai * bi;
124 norm_a += ai * ai;
125 norm_b += bi * bi;
126 }
127
128 let denom = norm_a.sqrt() * norm_b.sqrt();
129 if denom == 0.0 { 0.0 } else { dot / denom }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 fn sample_tools() -> Vec<(String, String)> {
137 vec![
138 (
139 "read_file".to_string(),
140 "Read the contents of a file from disk".to_string(),
141 ),
142 (
143 "write_file".to_string(),
144 "Write content to a file on disk".to_string(),
145 ),
146 (
147 "execute_command".to_string(),
148 "Execute a shell command in bash".to_string(),
149 ),
150 (
151 "git_commit".to_string(),
152 "Create a git commit with a message".to_string(),
153 ),
154 (
155 "optimize_png".to_string(),
156 "Optimize and compress PNG image files".to_string(),
157 ),
158 ]
159 }
160
161 #[test]
162 fn test_cosine_similarity_identical() {
163 let a = vec![1.0, 2.0, 3.0];
164 assert!((cosine_similarity(&a, &a) - 1.0).abs() < 1e-6);
165 }
166
167 #[test]
168 fn test_cosine_similarity_orthogonal() {
169 let a = vec![1.0, 0.0];
170 let b = vec![0.0, 1.0];
171 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
172 }
173
174 #[test]
175 fn test_cosine_similarity_zero_vector() {
176 let a = vec![1.0, 2.0];
177 let b = vec![0.0, 0.0];
178 assert_eq!(cosine_similarity(&a, &b), 0.0);
179 }
180
181 #[test]
182 fn test_build_empty() {
183 let index = ToolEmbeddingIndex::build(&[]).unwrap();
184 assert_eq!(index.tool_count(), 0);
185 let results = index.search("anything", 10, 0.0).unwrap();
186 assert!(results.is_empty());
187 }
188
189 #[test]
190 fn test_build_and_search() {
191 let tools = sample_tools();
192 let index = ToolEmbeddingIndex::build(&tools).unwrap();
193 assert_eq!(index.tool_count(), 5);
194
195 let results = index.search("compress image", 5, 0.0).unwrap();
197 assert!(!results.is_empty());
198 assert_eq!(results[0].0, "optimize_png");
200 }
201
202 #[test]
203 fn test_search_file_operations() {
204 let tools = sample_tools();
205 let index = ToolEmbeddingIndex::build(&tools).unwrap();
206
207 let results = index.search("load a document", 3, 0.0).unwrap();
209 assert!(!results.is_empty());
210 let top_names: Vec<&str> = results.iter().map(|(n, _)| n.as_str()).collect();
212 assert!(
213 top_names.contains(&"read_file") || top_names.contains(&"write_file"),
214 "Expected file tools in results, got: {:?}",
215 top_names
216 );
217 }
218
219 #[test]
220 fn test_min_score_filtering() {
221 let tools = sample_tools();
222 let index = ToolEmbeddingIndex::build(&tools).unwrap();
223
224 let results = index
226 .search("random unrelated query xyz", 10, 0.95)
227 .unwrap();
228 assert!(
230 results.len() <= 1,
231 "Expected few/no results with high min_score, got {}",
232 results.len()
233 );
234 }
235
236 #[test]
237 fn test_limit_respected() {
238 let tools = sample_tools();
239 let index = ToolEmbeddingIndex::build(&tools).unwrap();
240
241 let results = index.search("file", 2, 0.0).unwrap();
242 assert!(results.len() <= 2);
243 }
244}