project_rag/embedding/
fastembed_manager.rs1use super::EmbeddingProvider;
2use anyhow::{Context, Result};
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4use std::sync::RwLock;
5
6pub struct FastEmbedManager {
10 model: RwLock<TextEmbedding>,
11 dimension: usize,
12}
13
14impl FastEmbedManager {
15 pub fn new() -> Result<Self> {
17 Self::with_model(EmbeddingModel::AllMiniLML6V2)
18 }
19
20 pub fn from_model_name(model_name: &str) -> Result<Self> {
22 let model = match model_name {
23 "all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
24 "all-MiniLM-L12-v2" => EmbeddingModel::AllMiniLML12V2,
25 "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
26 "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
27 _ => {
28 tracing::warn!(
29 "Unknown model '{}', falling back to all-MiniLM-L6-v2",
30 model_name
31 );
32 EmbeddingModel::AllMiniLML6V2
33 }
34 };
35 Self::with_model(model)
36 }
37
38 pub fn with_model(model: EmbeddingModel) -> Result<Self> {
40 tracing::info!("Initializing FastEmbed model: {:?}", model);
41
42 let dimension = match model {
44 EmbeddingModel::AllMiniLML6V2 => 384,
45 EmbeddingModel::AllMiniLML12V2 => 384,
46 EmbeddingModel::BGEBaseENV15 => 768,
47 EmbeddingModel::BGESmallENV15 => 384,
48 _ => 384, };
50
51 let mut options = InitOptions::default();
52 options.model_name = model;
53 options.show_download_progress = true;
54
55 let embedding_model =
56 TextEmbedding::try_new(options).context("Failed to initialize FastEmbed model")?;
57
58 Ok(Self {
59 model: RwLock::new(embedding_model),
60 dimension,
61 })
62 }
63}
64
65impl EmbeddingProvider for FastEmbedManager {
66 fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
67 if texts.is_empty() {
68 return Ok(vec![]);
69 }
70
71 tracing::debug!("Generating embeddings for {} texts", texts.len());
72
73 let mut model = self.model.write().unwrap_or_else(|poisoned| {
76 tracing::warn!("FastEmbed model lock was poisoned, recovering...");
77 poisoned.into_inner()
78 });
79
80 let embeddings = model
84 .embed(texts, None)
85 .context("Failed to generate embeddings")?;
86
87 Ok(embeddings)
88 }
89
90 fn dimension(&self) -> usize {
91 self.dimension
92 }
93
94 fn model_name(&self) -> &str {
95 "all-MiniLM-L6-v2"
96 }
97}
98
99impl Default for FastEmbedManager {
100 fn default() -> Self {
101 Self::new().expect("Failed to initialize default FastEmbed model")
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn test_embedding_generation() {
111 let manager = FastEmbedManager::new().unwrap();
112 let texts = vec![
113 "fn main() { println!(\"Hello, world!\"); }".to_string(),
114 "pub struct Vector { x: f32, y: f32 }".to_string(),
115 ];
116
117 let embeddings = manager.embed_batch(texts).unwrap();
118 assert_eq!(embeddings.len(), 2);
119 assert_eq!(embeddings[0].len(), 384);
120 assert_eq!(embeddings[1].len(), 384);
121 }
122
123 #[test]
124 fn test_empty_batch() {
125 let manager = FastEmbedManager::new().unwrap();
126 let embeddings = manager.embed_batch(vec![]).unwrap();
127 assert_eq!(embeddings.len(), 0);
128 }
129
130 #[test]
131 fn test_dimension() {
132 let manager = FastEmbedManager::new().unwrap();
133 assert_eq!(manager.dimension(), 384);
134 }
135
136 #[test]
137 fn test_model_name() {
138 let manager = FastEmbedManager::new().unwrap();
139 assert_eq!(manager.model_name(), "all-MiniLM-L6-v2");
140 }
141
142 #[test]
143 fn test_default() {
144 let manager = FastEmbedManager::default();
145 assert_eq!(manager.dimension(), 384);
146 assert_eq!(manager.model_name(), "all-MiniLM-L6-v2");
147 }
148
149 #[test]
150 fn test_single_text() {
151 let manager = FastEmbedManager::new().unwrap();
152 let texts = vec!["Hello world".to_string()];
153 let embeddings = manager.embed_batch(texts).unwrap();
154 assert_eq!(embeddings.len(), 1);
155 assert_eq!(embeddings[0].len(), 384);
156 }
157
158 #[test]
159 fn test_large_batch() {
160 let manager = FastEmbedManager::new().unwrap();
161 let texts: Vec<String> = (0..10).map(|i| format!("Test text {}", i)).collect();
162 let embeddings = manager.embed_batch(texts).unwrap();
163 assert_eq!(embeddings.len(), 10);
164 for embedding in embeddings {
165 assert_eq!(embedding.len(), 384);
166 }
167 }
168
169 #[test]
170 fn test_with_model_allminilm_l12() {
171 let manager = FastEmbedManager::with_model(EmbeddingModel::AllMiniLML12V2).unwrap();
172 assert_eq!(manager.dimension(), 384);
173 }
174
175 #[test]
176 fn test_with_model_bge_base() {
177 let manager = FastEmbedManager::with_model(EmbeddingModel::BGEBaseENV15).unwrap();
178 assert_eq!(manager.dimension(), 768);
179 }
180
181 #[test]
182 fn test_with_model_bge_small() {
183 let manager = FastEmbedManager::with_model(EmbeddingModel::BGESmallENV15).unwrap();
184 assert_eq!(manager.dimension(), 384);
185 }
186}