lumosai_vector_fastembed/
provider.rs1use async_trait::async_trait;
4use std::sync::Arc;
5use tokio::sync::Mutex;
6use tracing::{debug, info, warn};
7
8use lumosai_vector_core::traits::EmbeddingModel;
9use lumosai_vector_core::types::Vector;
10
11use crate::error::{FastEmbedError, Result};
12use crate::models::FastEmbedModel;
13use crate::FastEmbedConfig;
14
15pub struct FastEmbedProvider {
21 model: Arc<Mutex<Option<fastembed::TextEmbedding>>>,
23
24 model_config: FastEmbedModel,
26
27 config: FastEmbedConfig,
29
30 model_name: String,
32
33 dimensions: usize,
35
36 max_sequence_length: usize,
38}
39
40impl FastEmbedProvider {
41 pub async fn new(model: FastEmbedModel, config: FastEmbedConfig) -> Result<Self> {
43 let model_name = model.model_name().to_string();
44 let dimensions = model.dimensions();
45 let max_sequence_length = model.max_sequence_length();
46
47 info!(
48 "Creating FastEmbed provider with model: {} ({}D)",
49 model_name, dimensions
50 );
51
52 let provider = Self {
53 model: Arc::new(Mutex::new(None)),
54 model_config: model,
55 config,
56 model_name,
57 dimensions,
58 max_sequence_length,
59 };
60
61 provider.ensure_model_loaded().await?;
63
64 Ok(provider)
65 }
66
67 pub async fn with_model(model: FastEmbedModel) -> Result<Self> {
69 Self::new(model, FastEmbedConfig::default()).await
70 }
71
72 async fn ensure_model_loaded(&self) -> Result<()> {
74 let mut model_guard = self.model.lock().await;
75
76 if model_guard.is_none() {
77 debug!("Initializing FastEmbed model: {}", self.model_name);
78
79 let mut init_options = fastembed::InitOptions::new(self.model_config.to_fastembed_model())
80 .with_show_download_progress(self.config.show_download_progress);
81
82 if let Some(cache_dir) = &self.config.cache_dir {
83 init_options = init_options.with_cache_dir(cache_dir.into());
84 debug!("Using cache directory: {}", cache_dir);
85 }
86
87 let embedding_model = fastembed::TextEmbedding::try_new(init_options)
94 .map_err(|e| FastEmbedError::ModelInitialization(format!(
95 "Failed to initialize FastEmbed model '{}': {}",
96 self.model_name, e
97 )))?;
98
99 *model_guard = Some(embedding_model);
100 info!("FastEmbed model '{}' initialized successfully", self.model_name);
101 }
102
103 Ok(())
104 }
105
106 pub fn model_config(&self) -> &FastEmbedModel {
108 &self.model_config
109 }
110
111 pub fn config(&self) -> &FastEmbedConfig {
113 &self.config
114 }
115
116 fn validate_text_length(&self, text: &str) -> Result<()> {
118 if text.len() > self.max_sequence_length * 4 { return Err(FastEmbedError::TextTooLong {
122 length: text.len(),
123 max_length: self.max_sequence_length * 4,
124 });
125 }
126 Ok(())
127 }
128
129 fn truncate_text(&self, text: &str) -> String {
131 let max_chars = self.max_sequence_length * 4; if text.len() > max_chars {
133 warn!(
134 "Text truncated from {} to {} characters for model '{}'",
135 text.len(), max_chars, self.model_name
136 );
137 text.chars().take(max_chars).collect()
138 } else {
139 text.to_string()
140 }
141 }
142
143 async fn process_in_batches(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
145 self.ensure_model_loaded().await?;
146
147 let model_guard = self.model.lock().await;
148 let model = model_guard.as_ref().ok_or_else(|| {
149 FastEmbedError::ModelNotInitialized("FastEmbed model not initialized".to_string())
150 })?;
151
152 let mut all_embeddings = Vec::new();
153
154 for chunk in texts.chunks(self.config.max_batch_size) {
155 debug!("Processing batch of {} texts", chunk.len());
156
157 let processed_texts: Vec<String> = chunk
159 .iter()
160 .map(|text| self.truncate_text(text))
161 .collect();
162
163 let embeddings = model.embed(processed_texts, None)
164 .map_err(|e| FastEmbedError::EmbeddingGeneration(format!(
165 "FastEmbed embedding failed: {}", e
166 )))?;
167
168 all_embeddings.extend(embeddings);
169 }
170
171 debug!("Generated {} embeddings", all_embeddings.len());
172 Ok(all_embeddings)
173 }
174}
175
176#[async_trait]
177impl EmbeddingModel for FastEmbedProvider {
178 type Config = FastEmbedConfig;
179
180 async fn embed_text(&self, text: &str) -> std::result::Result<Vector, lumosai_vector_core::error::VectorError> {
181 if let Err(e) = self.validate_text_length(text) {
183 warn!("Text validation failed: {}", e);
184 }
186
187 let texts = vec![text.to_string()];
188 let embeddings = self.process_in_batches(&texts).await
189 .map_err(|e| lumosai_vector_core::error::VectorError::EmbeddingError(e.to_string()))?;
190
191 embeddings.into_iter().next()
192 .ok_or_else(|| lumosai_vector_core::error::VectorError::EmbeddingError(
193 "No embedding returned from FastEmbed".to_string()
194 ))
195 }
196
197 async fn embed_batch(&self, texts: &[String]) -> std::result::Result<Vec<Vector>, lumosai_vector_core::error::VectorError> {
198 if texts.is_empty() {
199 return Ok(Vec::new());
200 }
201
202 debug!("Embedding batch of {} texts", texts.len());
203
204 self.process_in_batches(texts).await
205 .map_err(|e| lumosai_vector_core::error::VectorError::EmbeddingError(e.to_string()))
206 }
207
208 fn dimensions(&self) -> usize {
209 self.dimensions
210 }
211
212 fn model_name(&self) -> &str {
213 &self.model_name
214 }
215
216 fn max_input_length(&self) -> Option<usize> {
217 Some(self.max_sequence_length)
218 }
219
220 async fn health_check(&self) -> std::result::Result<(), lumosai_vector_core::error::VectorError> {
221 self.ensure_model_loaded().await
223 .map_err(|e| lumosai_vector_core::error::VectorError::EmbeddingError(e.to_string()))
224 }
225}
226
227pub struct FastEmbedProviderBuilder {
229 model: FastEmbedModel,
230 config: FastEmbedConfig,
231}
232
233impl FastEmbedProviderBuilder {
234 pub fn new(model: FastEmbedModel) -> Self {
236 Self {
237 model,
238 config: FastEmbedConfig::default(),
239 }
240 }
241
242 pub fn max_batch_size(mut self, size: usize) -> Self {
244 self.config.max_batch_size = size;
245 self
246 }
247
248 pub fn show_download_progress(mut self, show: bool) -> Self {
250 self.config.show_download_progress = show;
251 self
252 }
253
254 pub fn num_threads(mut self, threads: usize) -> Self {
256 self.config.num_threads = Some(threads);
257 self
258 }
259
260 pub fn cache_dir<S: Into<String>>(mut self, dir: S) -> Self {
262 self.config.cache_dir = Some(dir.into());
263 self
264 }
265
266 pub async fn build(self) -> Result<FastEmbedProvider> {
268 FastEmbedProvider::new(self.model, self.config).await
269 }
270}
271
272impl Default for FastEmbedProviderBuilder {
273 fn default() -> Self {
274 Self::new(FastEmbedModel::BGESmallENV15)
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use tokio_test;
282
283 #[tokio::test]
284 async fn test_provider_creation() {
285 let provider = FastEmbedProvider::with_model(FastEmbedModel::BGESmallENV15).await;
286
287 match provider {
290 Ok(p) => {
291 assert_eq!(p.dimensions(), 384);
292 assert_eq!(p.model_name(), "BAAI/bge-small-en-v1.5");
293 }
294 Err(e) => {
295 eprintln!("FastEmbed model not available (this is OK in CI): {}", e);
297 }
298 }
299 }
300
301 #[test]
302 fn test_text_validation() {
303 let provider = FastEmbedProviderBuilder::new(FastEmbedModel::BGESmallENV15)
304 .build();
305
306 let builder = FastEmbedProviderBuilder::new(FastEmbedModel::BGESmallENV15)
309 .max_batch_size(128)
310 .show_download_progress(false);
311
312 assert_eq!(builder.config.max_batch_size, 128);
313 assert!(!builder.config.show_download_progress);
314 }
315
316 #[test]
317 fn test_builder_pattern() {
318 let builder = FastEmbedProviderBuilder::new(FastEmbedModel::BGEBaseENV15)
319 .max_batch_size(64)
320 .num_threads(2)
321 .cache_dir("/tmp/test");
322
323 assert_eq!(builder.config.max_batch_size, 64);
324 assert_eq!(builder.config.num_threads, Some(2));
325 assert_eq!(builder.config.cache_dir, Some("/tmp/test".to_string()));
326 }
327}