1use anyhow::{Context, Result};
13pub use brainwires_core::EmbeddingProvider as EmbeddingProviderTrait;
14use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
15use lru::LruCache;
16use std::collections::hash_map::DefaultHasher;
17use std::hash::{Hash, Hasher};
18use std::num::NonZeroUsize;
19use std::sync::{Arc, RwLock};
20
21const DEFAULT_CACHE_SIZE: usize = 1000;
23const EMBEDDING_DIM_MINILM: usize = 384;
24const EMBEDDING_DIM_BGE_BASE: usize = 768;
25
26pub struct FastEmbedManager {
33 model: RwLock<TextEmbedding>,
34 dimension: usize,
35 model_name: String,
36}
37
38impl FastEmbedManager {
39 pub fn new() -> Result<Self> {
41 Self::with_model(EmbeddingModel::AllMiniLML6V2)
42 }
43
44 pub fn from_model_name(model_name: &str) -> Result<Self> {
46 let model = match model_name {
47 "all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
48 "all-MiniLM-L12-v2" => EmbeddingModel::AllMiniLML12V2,
49 "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
50 "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
51 _ => {
52 tracing::warn!(
53 "Unknown model '{}', falling back to all-MiniLM-L6-v2",
54 model_name
55 );
56 EmbeddingModel::AllMiniLML6V2
57 }
58 };
59 Self::with_model(model)
60 }
61
62 pub fn with_model(model: EmbeddingModel) -> Result<Self> {
64 tracing::info!("Initializing FastEmbed model: {:?}", model);
65
66 let (dimension, name) = match model {
67 EmbeddingModel::AllMiniLML6V2 => (EMBEDDING_DIM_MINILM, "all-MiniLM-L6-v2"),
68 EmbeddingModel::AllMiniLML12V2 => (EMBEDDING_DIM_MINILM, "all-MiniLM-L12-v2"),
69 EmbeddingModel::BGEBaseENV15 => (EMBEDDING_DIM_BGE_BASE, "BAAI/bge-base-en-v1.5"),
70 EmbeddingModel::BGESmallENV15 => (EMBEDDING_DIM_MINILM, "BAAI/bge-small-en-v1.5"),
71 _ => (EMBEDDING_DIM_MINILM, "all-MiniLM-L6-v2"),
72 };
73
74 let mut options = InitOptions::default();
75 options.model_name = model;
76 options.show_download_progress = true;
77
78 let embedding_model =
79 TextEmbedding::try_new(options).context("Failed to initialize FastEmbed model")?;
80
81 Ok(Self {
82 model: RwLock::new(embedding_model),
83 dimension,
84 model_name: name.to_string(),
85 })
86 }
87
88 pub fn embed_batch_vec(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
93 if texts.is_empty() {
94 return Ok(vec![]);
95 }
96
97 tracing::debug!("Generating embeddings for {} texts", texts.len());
98
99 let mut model = self.model.write().unwrap_or_else(|poisoned| {
100 tracing::warn!("FastEmbed model lock was poisoned, recovering...");
101 poisoned.into_inner()
102 });
103
104 let embeddings = model
105 .embed(texts, None)
106 .context("Failed to generate embeddings")?;
107
108 Ok(embeddings)
109 }
110
111 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
113 let embeddings = self.embed_batch_vec(vec![text.to_string()])?;
114 embeddings
115 .into_iter()
116 .next()
117 .ok_or_else(|| anyhow::anyhow!("No embedding generated"))
118 }
119
120 pub fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
122 self.embed_batch_vec(texts.to_vec())
123 }
124
125 pub fn dimension(&self) -> usize {
127 self.dimension
128 }
129
130 pub fn model_name(&self) -> &str {
132 &self.model_name
133 }
134}
135
136impl EmbeddingProviderTrait for FastEmbedManager {
137 fn embed(&self, text: &str) -> Result<Vec<f32>> {
138 let embeddings = self.embed_batch_vec(vec![text.to_string()])?;
139 embeddings
140 .into_iter()
141 .next()
142 .ok_or_else(|| anyhow::anyhow!("No embedding generated"))
143 }
144
145 fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
146 self.embed_batch_vec(texts.to_vec())
147 }
148
149 fn dimension(&self) -> usize {
150 self.dimension
151 }
152
153 fn model_name(&self) -> &str {
154 &self.model_name
155 }
156}
157
158impl Default for FastEmbedManager {
159 fn default() -> Self {
160 Self::new().expect("Failed to initialize default FastEmbed model")
161 }
162}
163
164pub struct CachedEmbeddingProvider {
171 inner: Arc<FastEmbedManager>,
172 cache: RwLock<LruCache<u64, Vec<f32>>>,
173}
174
175impl CachedEmbeddingProvider {
176 pub fn new() -> Result<Self> {
178 let inner = FastEmbedManager::new().context("Failed to create embedding provider")?;
179
180 Ok(Self {
181 inner: Arc::new(inner),
182 cache: RwLock::new(LruCache::new(
183 NonZeroUsize::new(DEFAULT_CACHE_SIZE).expect("DEFAULT_CACHE_SIZE is non-zero"),
184 )),
185 })
186 }
187
188 pub fn with_manager(manager: Arc<FastEmbedManager>) -> Self {
190 Self {
191 inner: manager,
192 cache: RwLock::new(LruCache::new(
193 NonZeroUsize::new(DEFAULT_CACHE_SIZE).expect("DEFAULT_CACHE_SIZE is non-zero"),
194 )),
195 }
196 }
197
198 fn hash_text(text: &str) -> u64 {
200 let mut hasher = DefaultHasher::new();
201 text.hash(&mut hasher);
202 hasher.finish()
203 }
204
205 pub fn embed_cached(&self, text: &str) -> Result<Vec<f32>> {
210 let cache_key = Self::hash_text(text);
211
212 if let Ok(cache) = self.cache.read()
214 && let Some(embedding) = cache.peek(&cache_key)
215 {
216 return Ok(embedding.clone());
217 }
218
219 let embedding = self.inner.embed(text)?;
221
222 if let Ok(mut cache) = self.cache.write() {
224 cache.put(cache_key, embedding.clone());
225 }
226
227 Ok(embedding)
228 }
229
230 pub fn cache_len(&self) -> usize {
232 self.cache.read().map(|c| c.len()).unwrap_or(0)
233 }
234
235 pub fn clear_cache(&self) {
237 if let Ok(mut cache) = self.cache.write() {
238 cache.clear();
239 }
240 }
241
242 pub fn inner(&self) -> &Arc<FastEmbedManager> {
244 &self.inner
245 }
246
247 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
252 self.embed_cached(text)
253 }
254
255 pub fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
257 self.inner.embed_batch_vec(texts.to_vec())
258 }
259
260 pub fn dimension(&self) -> usize {
262 self.inner.dimension
263 }
264
265 pub fn model_name(&self) -> &str {
267 &self.inner.model_name
268 }
269}
270
271impl EmbeddingProviderTrait for CachedEmbeddingProvider {
272 fn embed(&self, text: &str) -> Result<Vec<f32>> {
273 self.embed_cached(text)
274 }
275
276 fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
277 self.inner.embed_batch(texts)
278 }
279
280 fn dimension(&self) -> usize {
281 self.inner.dimension()
282 }
283
284 fn model_name(&self) -> &str {
285 self.inner.model_name()
286 }
287}
288
289impl Clone for CachedEmbeddingProvider {
290 fn clone(&self) -> Self {
291 Self {
292 inner: Arc::clone(&self.inner),
293 cache: RwLock::new(LruCache::new(
294 NonZeroUsize::new(DEFAULT_CACHE_SIZE).expect("DEFAULT_CACHE_SIZE is non-zero"),
295 )),
296 }
297 }
298}
299
300pub type EmbeddingProvider = CachedEmbeddingProvider;
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
310 fn test_fastembed_creation() {
311 let manager = FastEmbedManager::new().unwrap();
312 assert_eq!(manager.dimension(), 384);
313 assert_eq!(manager.model_name(), "all-MiniLM-L6-v2");
314 }
315
316 #[test]
317 fn test_fastembed_embed_single() {
318 let manager = FastEmbedManager::new().unwrap();
319 let embedding = manager.embed("Hello, world!").unwrap();
320 assert_eq!(embedding.len(), 384);
321 }
322
323 #[test]
324 fn test_fastembed_embed_batch() {
325 let manager = FastEmbedManager::new().unwrap();
326 let texts = vec![
327 "fn main() { println!(\"Hello, world!\"); }".to_string(),
328 "pub struct Vector { x: f32, y: f32 }".to_string(),
329 ];
330
331 let embeddings = manager.embed_batch(&texts).unwrap();
332 assert_eq!(embeddings.len(), 2);
333 assert_eq!(embeddings[0].len(), 384);
334 assert_eq!(embeddings[1].len(), 384);
335 }
336
337 #[test]
338 fn test_fastembed_empty_batch() {
339 let manager = FastEmbedManager::new().unwrap();
340 let embeddings = manager.embed_batch_vec(vec![]).unwrap();
341 assert_eq!(embeddings.len(), 0);
342 }
343
344 #[test]
345 fn test_fastembed_default() {
346 let manager = FastEmbedManager::default();
347 assert_eq!(manager.dimension(), 384);
348 }
349
350 #[test]
351 fn test_fastembed_from_model_name() {
352 let manager = FastEmbedManager::from_model_name("all-MiniLM-L6-v2").unwrap();
353 assert_eq!(manager.dimension(), 384);
354 }
355
356 #[test]
357 fn test_fastembed_unknown_model_fallback() {
358 let manager = FastEmbedManager::from_model_name("unknown-model").unwrap();
359 assert_eq!(manager.dimension(), 384);
360 assert_eq!(manager.model_name(), "all-MiniLM-L6-v2");
361 }
362
363 #[test]
366 fn test_cached_provider_creation() {
367 let provider = CachedEmbeddingProvider::new().unwrap();
368 assert_eq!(provider.dimension(), 384);
369 }
370
371 #[test]
372 fn test_cached_provider_embed_single() {
373 let provider = CachedEmbeddingProvider::new().unwrap();
374 let embedding = provider.embed("Hello, world!").unwrap();
375
376 assert_eq!(embedding.len(), 384);
377
378 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
380 assert!((magnitude - 1.0).abs() < 0.1);
381 }
382
383 #[test]
384 fn test_cached_provider_embed_batch() {
385 let provider = CachedEmbeddingProvider::new().unwrap();
386 let texts = vec![
387 "First message".to_string(),
388 "Second message".to_string(),
389 "Third message".to_string(),
390 ];
391
392 let embeddings = provider.embed_batch(&texts).unwrap();
393
394 assert_eq!(embeddings.len(), 3);
395 assert_eq!(embeddings[0].len(), 384);
396 assert_eq!(embeddings[1].len(), 384);
397 assert_eq!(embeddings[2].len(), 384);
398 }
399
400 #[test]
401 fn test_cached_provider_clone() {
402 let provider = CachedEmbeddingProvider::new().unwrap();
403 let cloned = provider.clone();
404
405 assert_eq!(provider.dimension(), cloned.dimension());
406 }
407
408 #[test]
409 fn test_cached_provider_caching() {
410 let provider = CachedEmbeddingProvider::new().unwrap();
411
412 let embedding1 = provider.embed_cached("test query").unwrap();
414 assert_eq!(provider.cache_len(), 1);
415
416 let embedding2 = provider.embed_cached("test query").unwrap();
418 assert_eq!(provider.cache_len(), 1); assert_eq!(embedding1, embedding2);
422
423 let _embedding3 = provider.embed_cached("different query").unwrap();
425 assert_eq!(provider.cache_len(), 2);
426
427 provider.clear_cache();
429 assert_eq!(provider.cache_len(), 0);
430 }
431}