Skip to main content

brainwires_storage/
embeddings.rs

1//! Embedding Provider
2//!
3//! Provides text embeddings using FastEmbed with LRU caching.
4//!
5//! This module is the canonical owner of embedding infrastructure in the framework:
6//!
7//! - **FastEmbedManager** - Low-level wrapper around the fastembed crate (ONNX model)
8//! - **CachedEmbeddingProvider** - LRU-cached wrapper that reduces latency for repeated queries
9//!
10//! Both implement the `brainwires_core::EmbeddingProvider` trait.
11
12use anyhow::{Context, Result};
13pub use brainwires_core::EmbeddingProvider as EmbeddingProviderTrait;
14use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
15use lru::LruCache;
16use std::collections::hash_map::DefaultHasher;
17use std::hash::{Hash, Hasher};
18use std::num::NonZeroUsize;
19use std::sync::{Arc, RwLock};
20
21/// Default cache size for embeddings (1000 entries)
22const DEFAULT_CACHE_SIZE: usize = 1000;
23const EMBEDDING_DIM_MINILM: usize = 384;
24const EMBEDDING_DIM_BGE_BASE: usize = 768;
25
26// ── FastEmbedManager ────────────────────────────────────────────────────────
27
28/// FastEmbed-based embedding provider using ONNX models.
29///
30/// Uses RwLock for safe interior mutability since fastembed's `embed()` requires `&mut self`.
31/// Default model is all-MiniLM-L6-v2 (384 dimensions).
32pub struct FastEmbedManager {
33    model: RwLock<TextEmbedding>,
34    dimension: usize,
35    model_name: String,
36}
37
38impl FastEmbedManager {
39    /// Create a new FastEmbedManager with the default model (all-MiniLM-L6-v2)
40    pub fn new() -> Result<Self> {
41        Self::with_model(EmbeddingModel::AllMiniLML6V2)
42    }
43
44    /// Create a new FastEmbedManager from a model name string
45    pub fn from_model_name(model_name: &str) -> Result<Self> {
46        let model = match model_name {
47            "all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
48            "all-MiniLM-L12-v2" => EmbeddingModel::AllMiniLML12V2,
49            "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
50            "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
51            _ => {
52                tracing::warn!(
53                    "Unknown model '{}', falling back to all-MiniLM-L6-v2",
54                    model_name
55                );
56                EmbeddingModel::AllMiniLML6V2
57            }
58        };
59        Self::with_model(model)
60    }
61
62    /// Create a new FastEmbedManager with a specific model
63    pub fn with_model(model: EmbeddingModel) -> Result<Self> {
64        tracing::info!("Initializing FastEmbed model: {:?}", model);
65
66        let (dimension, name) = match model {
67            EmbeddingModel::AllMiniLML6V2 => (EMBEDDING_DIM_MINILM, "all-MiniLM-L6-v2"),
68            EmbeddingModel::AllMiniLML12V2 => (EMBEDDING_DIM_MINILM, "all-MiniLM-L12-v2"),
69            EmbeddingModel::BGEBaseENV15 => (EMBEDDING_DIM_BGE_BASE, "BAAI/bge-base-en-v1.5"),
70            EmbeddingModel::BGESmallENV15 => (EMBEDDING_DIM_MINILM, "BAAI/bge-small-en-v1.5"),
71            _ => (EMBEDDING_DIM_MINILM, "all-MiniLM-L6-v2"),
72        };
73
74        let mut options = InitOptions::default();
75        options.model_name = model;
76        options.show_download_progress = true;
77
78        let embedding_model =
79            TextEmbedding::try_new(options).context("Failed to initialize FastEmbed model")?;
80
81        Ok(Self {
82            model: RwLock::new(embedding_model),
83            dimension,
84            model_name: name.to_string(),
85        })
86    }
87
88    /// Generate embeddings for a batch of texts (raw, no caching).
89    ///
90    /// This is the low-level batch method. Prefer using `CachedEmbeddingProvider`
91    /// for repeated queries.
92    pub fn embed_batch_vec(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
93        if texts.is_empty() {
94            return Ok(vec![]);
95        }
96
97        tracing::debug!("Generating embeddings for {} texts", texts.len());
98
99        let mut model = self.model.write().unwrap_or_else(|poisoned| {
100            tracing::warn!("FastEmbed model lock was poisoned, recovering...");
101            poisoned.into_inner()
102        });
103
104        let embeddings = model
105            .embed(texts, None)
106            .context("Failed to generate embeddings")?;
107
108        Ok(embeddings)
109    }
110
111    /// Generate an embedding for a single text (inherent method for convenience).
112    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
113        let embeddings = self.embed_batch_vec(vec![text.to_string()])?;
114        embeddings
115            .into_iter()
116            .next()
117            .ok_or_else(|| anyhow::anyhow!("No embedding generated"))
118    }
119
120    /// Generate embeddings for a batch of texts (inherent method for convenience).
121    pub fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
122        self.embed_batch_vec(texts.to_vec())
123    }
124
125    /// Get the dimensionality of the embedding vectors.
126    pub fn dimension(&self) -> usize {
127        self.dimension
128    }
129
130    /// Get the model name.
131    pub fn model_name(&self) -> &str {
132        &self.model_name
133    }
134}
135
136impl EmbeddingProviderTrait for FastEmbedManager {
137    fn embed(&self, text: &str) -> Result<Vec<f32>> {
138        let embeddings = self.embed_batch_vec(vec![text.to_string()])?;
139        embeddings
140            .into_iter()
141            .next()
142            .ok_or_else(|| anyhow::anyhow!("No embedding generated"))
143    }
144
145    fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
146        self.embed_batch_vec(texts.to_vec())
147    }
148
149    fn dimension(&self) -> usize {
150        self.dimension
151    }
152
153    fn model_name(&self) -> &str {
154        &self.model_name
155    }
156}
157
158impl Default for FastEmbedManager {
159    fn default() -> Self {
160        Self::new().expect("Failed to initialize default FastEmbed model")
161    }
162}
163
164// ── CachedEmbeddingProvider ─────────────────────────────────────────────────
165
166/// LRU-cached embedding provider for generating text embeddings.
167///
168/// Wraps `FastEmbedManager` and adds an LRU cache for memoizing query embeddings
169/// to reduce latency in agent loops that often repeat similar queries.
170pub struct CachedEmbeddingProvider {
171    inner: Arc<FastEmbedManager>,
172    cache: RwLock<LruCache<u64, Vec<f32>>>,
173}
174
175impl CachedEmbeddingProvider {
176    /// Create a new cached embedding provider with the default model
177    pub fn new() -> Result<Self> {
178        let inner = FastEmbedManager::new().context("Failed to create embedding provider")?;
179
180        Ok(Self {
181            inner: Arc::new(inner),
182            cache: RwLock::new(LruCache::new(
183                NonZeroUsize::new(DEFAULT_CACHE_SIZE).expect("DEFAULT_CACHE_SIZE is non-zero"),
184            )),
185        })
186    }
187
188    /// Create a cached wrapper around an existing FastEmbedManager
189    pub fn with_manager(manager: Arc<FastEmbedManager>) -> Self {
190        Self {
191            inner: manager,
192            cache: RwLock::new(LruCache::new(
193                NonZeroUsize::new(DEFAULT_CACHE_SIZE).expect("DEFAULT_CACHE_SIZE is non-zero"),
194            )),
195        }
196    }
197
198    /// Hash text to a cache key
199    fn hash_text(text: &str) -> u64 {
200        let mut hasher = DefaultHasher::new();
201        text.hash(&mut hasher);
202        hasher.finish()
203    }
204
205    /// Generate an embedding with caching
206    ///
207    /// Checks the LRU cache first; if not found, generates the embedding
208    /// and stores it in the cache.
209    pub fn embed_cached(&self, text: &str) -> Result<Vec<f32>> {
210        let cache_key = Self::hash_text(text);
211
212        // Check cache first (read lock)
213        if let Ok(cache) = self.cache.read()
214            && let Some(embedding) = cache.peek(&cache_key)
215        {
216            return Ok(embedding.clone());
217        }
218
219        // Generate embedding
220        let embedding = self.inner.embed(text)?;
221
222        // Store in cache (write lock)
223        if let Ok(mut cache) = self.cache.write() {
224            cache.put(cache_key, embedding.clone());
225        }
226
227        Ok(embedding)
228    }
229
230    /// Get the number of cached embeddings
231    pub fn cache_len(&self) -> usize {
232        self.cache.read().map(|c| c.len()).unwrap_or(0)
233    }
234
235    /// Clear the embedding cache
236    pub fn clear_cache(&self) {
237        if let Ok(mut cache) = self.cache.write() {
238            cache.clear();
239        }
240    }
241
242    /// Get a reference to the underlying FastEmbedManager
243    pub fn inner(&self) -> &Arc<FastEmbedManager> {
244        &self.inner
245    }
246
247    /// Generate an embedding for a single text (inherent method for convenience).
248    ///
249    /// This delegates to the `EmbeddingProvider` trait implementation, making
250    /// the method available without requiring the trait to be in scope.
251    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
252        self.embed_cached(text)
253    }
254
255    /// Generate embeddings for a batch of texts (inherent method for convenience).
256    pub fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
257        self.inner.embed_batch_vec(texts.to_vec())
258    }
259
260    /// Get the dimensionality of the embedding vectors (inherent method for convenience).
261    pub fn dimension(&self) -> usize {
262        self.inner.dimension
263    }
264
265    /// Get the model name (inherent method for convenience).
266    pub fn model_name(&self) -> &str {
267        &self.inner.model_name
268    }
269}
270
271impl EmbeddingProviderTrait for CachedEmbeddingProvider {
272    fn embed(&self, text: &str) -> Result<Vec<f32>> {
273        self.embed_cached(text)
274    }
275
276    fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
277        self.inner.embed_batch(texts)
278    }
279
280    fn dimension(&self) -> usize {
281        self.inner.dimension()
282    }
283
284    fn model_name(&self) -> &str {
285        self.inner.model_name()
286    }
287}
288
289impl Clone for CachedEmbeddingProvider {
290    fn clone(&self) -> Self {
291        Self {
292            inner: Arc::clone(&self.inner),
293            cache: RwLock::new(LruCache::new(
294                NonZeroUsize::new(DEFAULT_CACHE_SIZE).expect("DEFAULT_CACHE_SIZE is non-zero"),
295            )),
296        }
297    }
298}
299
300/// Type alias for backward compatibility
301pub type EmbeddingProvider = CachedEmbeddingProvider;
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    // ── FastEmbedManager tests ──────────────────────────────────────────
308
309    #[test]
310    fn test_fastembed_creation() {
311        let manager = FastEmbedManager::new().unwrap();
312        assert_eq!(manager.dimension(), 384);
313        assert_eq!(manager.model_name(), "all-MiniLM-L6-v2");
314    }
315
316    #[test]
317    fn test_fastembed_embed_single() {
318        let manager = FastEmbedManager::new().unwrap();
319        let embedding = manager.embed("Hello, world!").unwrap();
320        assert_eq!(embedding.len(), 384);
321    }
322
323    #[test]
324    fn test_fastembed_embed_batch() {
325        let manager = FastEmbedManager::new().unwrap();
326        let texts = vec![
327            "fn main() { println!(\"Hello, world!\"); }".to_string(),
328            "pub struct Vector { x: f32, y: f32 }".to_string(),
329        ];
330
331        let embeddings = manager.embed_batch(&texts).unwrap();
332        assert_eq!(embeddings.len(), 2);
333        assert_eq!(embeddings[0].len(), 384);
334        assert_eq!(embeddings[1].len(), 384);
335    }
336
337    #[test]
338    fn test_fastembed_empty_batch() {
339        let manager = FastEmbedManager::new().unwrap();
340        let embeddings = manager.embed_batch_vec(vec![]).unwrap();
341        assert_eq!(embeddings.len(), 0);
342    }
343
344    #[test]
345    fn test_fastembed_default() {
346        let manager = FastEmbedManager::default();
347        assert_eq!(manager.dimension(), 384);
348    }
349
350    #[test]
351    fn test_fastembed_from_model_name() {
352        let manager = FastEmbedManager::from_model_name("all-MiniLM-L6-v2").unwrap();
353        assert_eq!(manager.dimension(), 384);
354    }
355
356    #[test]
357    fn test_fastembed_unknown_model_fallback() {
358        let manager = FastEmbedManager::from_model_name("unknown-model").unwrap();
359        assert_eq!(manager.dimension(), 384);
360        assert_eq!(manager.model_name(), "all-MiniLM-L6-v2");
361    }
362
363    // ── CachedEmbeddingProvider tests ───────────────────────────────────
364
365    #[test]
366    fn test_cached_provider_creation() {
367        let provider = CachedEmbeddingProvider::new().unwrap();
368        assert_eq!(provider.dimension(), 384);
369    }
370
371    #[test]
372    fn test_cached_provider_embed_single() {
373        let provider = CachedEmbeddingProvider::new().unwrap();
374        let embedding = provider.embed("Hello, world!").unwrap();
375
376        assert_eq!(embedding.len(), 384);
377
378        // Verify it's normalized (approximately)
379        let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
380        assert!((magnitude - 1.0).abs() < 0.1);
381    }
382
383    #[test]
384    fn test_cached_provider_embed_batch() {
385        let provider = CachedEmbeddingProvider::new().unwrap();
386        let texts = vec![
387            "First message".to_string(),
388            "Second message".to_string(),
389            "Third message".to_string(),
390        ];
391
392        let embeddings = provider.embed_batch(&texts).unwrap();
393
394        assert_eq!(embeddings.len(), 3);
395        assert_eq!(embeddings[0].len(), 384);
396        assert_eq!(embeddings[1].len(), 384);
397        assert_eq!(embeddings[2].len(), 384);
398    }
399
400    #[test]
401    fn test_cached_provider_clone() {
402        let provider = CachedEmbeddingProvider::new().unwrap();
403        let cloned = provider.clone();
404
405        assert_eq!(provider.dimension(), cloned.dimension());
406    }
407
408    #[test]
409    fn test_cached_provider_caching() {
410        let provider = CachedEmbeddingProvider::new().unwrap();
411
412        // First call should compute and cache
413        let embedding1 = provider.embed_cached("test query").unwrap();
414        assert_eq!(provider.cache_len(), 1);
415
416        // Second call should return cached value
417        let embedding2 = provider.embed_cached("test query").unwrap();
418        assert_eq!(provider.cache_len(), 1); // Still 1, not 2
419
420        // Embeddings should be identical
421        assert_eq!(embedding1, embedding2);
422
423        // Different query should add to cache
424        let _embedding3 = provider.embed_cached("different query").unwrap();
425        assert_eq!(provider.cache_len(), 2);
426
427        // Clear cache
428        provider.clear_cache();
429        assert_eq!(provider.cache_len(), 0);
430    }
431}