Skip to main content

brainwires_tool_runtime/
tool_embedding.rs

1//! ToolEmbedding - Semantic tool discovery via embedding similarity
2//!
3//! Embeds tool names and descriptions into vectors, then uses cosine
4//! similarity to find semantically relevant tools for a given query.
5
6use anyhow::{Context, Result};
7use brainwires_rag::rag::embedding::FastEmbedManager;
8use std::sync::{Arc, OnceLock};
9
10static EMBED_MANAGER: OnceLock<Arc<FastEmbedManager>> = OnceLock::new();
11
12/// Get or initialize the shared FastEmbedManager.
13fn get_embed_manager() -> Result<&'static Arc<FastEmbedManager>> {
14    EMBED_MANAGER.get().ok_or(()).or_else(|_| {
15        let manager = Arc::new(FastEmbedManager::new()?);
16        // Another thread may have initialized it between our check and here;
17        // that's fine — just use whichever won the race.
18        let _ = EMBED_MANAGER.set(manager.clone());
19        Ok::<_, anyhow::Error>(EMBED_MANAGER.get().unwrap())
20    })
21}
22
23/// Pre-computed embedding index for semantic tool discovery.
24///
25/// Stores embeddings of tool `"{name}: {description}"` strings and supports
26/// cosine-similarity search against user queries.
27pub struct ToolEmbeddingIndex {
28    /// (tool_name, embedding_vector) pairs
29    entries: Vec<ToolEmbeddingEntry>,
30    /// Number of tools when the index was built (for staleness detection)
31    tool_count: usize,
32}
33
34struct ToolEmbeddingEntry {
35    name: String,
36    embedding: Vec<f32>,
37}
38
39impl ToolEmbeddingIndex {
40    /// Build an index from tool (name, description) pairs.
41    ///
42    /// Each tool is embedded as `"{name}: {description}"`.
43    /// Returns an empty index if no tools are provided.
44    pub fn build(tools: &[(String, String)]) -> Result<Self> {
45        if tools.is_empty() {
46            return Ok(Self {
47                entries: vec![],
48                tool_count: 0,
49            });
50        }
51
52        let manager = get_embed_manager().context("Failed to initialize embedding model")?;
53
54        // Build text representations for embedding
55        let texts: Vec<String> = tools
56            .iter()
57            .map(|(name, desc)| format!("{}: {}", name, desc))
58            .collect();
59
60        let embeddings = manager
61            .embed_batch(&texts)
62            .context("Failed to generate tool embeddings")?;
63
64        let entries = tools
65            .iter()
66            .zip(embeddings)
67            .map(|((name, _), embedding)| ToolEmbeddingEntry {
68                name: name.clone(),
69                embedding,
70            })
71            .collect();
72
73        Ok(Self {
74            entries,
75            tool_count: tools.len(),
76        })
77    }
78
79    /// Search for tools semantically similar to the query.
80    ///
81    /// Returns `(tool_name, similarity_score)` pairs sorted by score descending,
82    /// filtered by `min_score` and capped at `limit`.
83    pub fn search(&self, query: &str, limit: usize, min_score: f32) -> Result<Vec<(String, f32)>> {
84        if self.entries.is_empty() {
85            return Ok(vec![]);
86        }
87
88        let manager = get_embed_manager().context("Failed to get embedding model")?;
89        let query_vec = manager.embed(query).context("Failed to embed query")?;
90        let query_vec = &query_vec;
91
92        let mut scored: Vec<(String, f32)> = self
93            .entries
94            .iter()
95            .map(|entry| {
96                let score = cosine_similarity(query_vec, &entry.embedding);
97                (entry.name.clone(), score)
98            })
99            .filter(|(_, score)| *score >= min_score)
100            .collect();
101
102        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
103        scored.truncate(limit);
104
105        Ok(scored)
106    }
107
108    /// Number of tools in the index (for staleness detection).
109    pub fn tool_count(&self) -> usize {
110        self.tool_count
111    }
112}
113
114/// Cosine similarity between two vectors.
115fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
116    debug_assert_eq!(a.len(), b.len(), "vectors must have equal dimensions");
117
118    let mut dot = 0.0f32;
119    let mut norm_a = 0.0f32;
120    let mut norm_b = 0.0f32;
121
122    for (ai, bi) in a.iter().zip(b.iter()) {
123        dot += ai * bi;
124        norm_a += ai * ai;
125        norm_b += bi * bi;
126    }
127
128    let denom = norm_a.sqrt() * norm_b.sqrt();
129    if denom == 0.0 { 0.0 } else { dot / denom }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    fn sample_tools() -> Vec<(String, String)> {
137        vec![
138            (
139                "read_file".to_string(),
140                "Read the contents of a file from disk".to_string(),
141            ),
142            (
143                "write_file".to_string(),
144                "Write content to a file on disk".to_string(),
145            ),
146            (
147                "execute_command".to_string(),
148                "Execute a shell command in bash".to_string(),
149            ),
150            (
151                "git_commit".to_string(),
152                "Create a git commit with a message".to_string(),
153            ),
154            (
155                "optimize_png".to_string(),
156                "Optimize and compress PNG image files".to_string(),
157            ),
158        ]
159    }
160
161    #[test]
162    fn test_cosine_similarity_identical() {
163        let a = vec![1.0, 2.0, 3.0];
164        assert!((cosine_similarity(&a, &a) - 1.0).abs() < 1e-6);
165    }
166
167    #[test]
168    fn test_cosine_similarity_orthogonal() {
169        let a = vec![1.0, 0.0];
170        let b = vec![0.0, 1.0];
171        assert!(cosine_similarity(&a, &b).abs() < 1e-6);
172    }
173
174    #[test]
175    fn test_cosine_similarity_zero_vector() {
176        let a = vec![1.0, 2.0];
177        let b = vec![0.0, 0.0];
178        assert_eq!(cosine_similarity(&a, &b), 0.0);
179    }
180
181    #[test]
182    fn test_build_empty() {
183        let index = ToolEmbeddingIndex::build(&[]).unwrap();
184        assert_eq!(index.tool_count(), 0);
185        let results = index.search("anything", 10, 0.0).unwrap();
186        assert!(results.is_empty());
187    }
188
189    #[test]
190    fn test_build_and_search() {
191        let tools = sample_tools();
192        let index = ToolEmbeddingIndex::build(&tools).unwrap();
193        assert_eq!(index.tool_count(), 5);
194
195        // "compress image" should find "optimize_png" (described as "Optimize and compress PNG image files")
196        let results = index.search("compress image", 5, 0.0).unwrap();
197        assert!(!results.is_empty());
198        // The top result should be optimize_png
199        assert_eq!(results[0].0, "optimize_png");
200    }
201
202    #[test]
203    fn test_search_file_operations() {
204        let tools = sample_tools();
205        let index = ToolEmbeddingIndex::build(&tools).unwrap();
206
207        // "load a document" should find file reading tools
208        let results = index.search("load a document", 3, 0.0).unwrap();
209        assert!(!results.is_empty());
210        // read_file or write_file should be in top results
211        let top_names: Vec<&str> = results.iter().map(|(n, _)| n.as_str()).collect();
212        assert!(
213            top_names.contains(&"read_file") || top_names.contains(&"write_file"),
214            "Expected file tools in results, got: {:?}",
215            top_names
216        );
217    }
218
219    #[test]
220    fn test_min_score_filtering() {
221        let tools = sample_tools();
222        let index = ToolEmbeddingIndex::build(&tools).unwrap();
223
224        // With a very high min_score, most results should be filtered out
225        let results = index
226            .search("random unrelated query xyz", 10, 0.95)
227            .unwrap();
228        // Very unlikely anything scores above 0.95 for an unrelated query
229        assert!(
230            results.len() <= 1,
231            "Expected few/no results with high min_score, got {}",
232            results.len()
233        );
234    }
235
236    #[test]
237    fn test_limit_respected() {
238        let tools = sample_tools();
239        let index = ToolEmbeddingIndex::build(&tools).unwrap();
240
241        let results = index.search("file", 2, 0.0).unwrap();
242        assert!(results.len() <= 2);
243    }
244}