lumosai_vector_fastembed/
provider.rs

1//! FastEmbed embedding provider implementation
2
3use async_trait::async_trait;
4use std::sync::Arc;
5use tokio::sync::Mutex;
6use tracing::{debug, info, warn};
7
8use lumosai_vector_core::traits::EmbeddingModel;
9use lumosai_vector_core::types::Vector;
10
11use crate::error::{FastEmbedError, Result};
12use crate::models::FastEmbedModel;
13use crate::FastEmbedConfig;
14
15/// FastEmbed embedding provider
16/// 
17/// This provider uses FastEmbed for local embedding generation,
18/// eliminating the need for external API calls and providing
19/// fast, reliable embedding generation.
20pub struct FastEmbedProvider {
21    /// The embedding model instance
22    model: Arc<Mutex<Option<fastembed::TextEmbedding>>>,
23    
24    /// Model configuration
25    model_config: FastEmbedModel,
26    
27    /// Provider configuration
28    config: FastEmbedConfig,
29    
30    /// Model name for identification
31    model_name: String,
32    
33    /// Embedding dimensions
34    dimensions: usize,
35    
36    /// Maximum sequence length
37    max_sequence_length: usize,
38}
39
40impl FastEmbedProvider {
41    /// Create a new FastEmbed provider with the specified model
42    pub async fn new(model: FastEmbedModel, config: FastEmbedConfig) -> Result<Self> {
43        let model_name = model.model_name().to_string();
44        let dimensions = model.dimensions();
45        let max_sequence_length = model.max_sequence_length();
46        
47        info!(
48            "Creating FastEmbed provider with model: {} ({}D)",
49            model_name, dimensions
50        );
51        
52        let provider = Self {
53            model: Arc::new(Mutex::new(None)),
54            model_config: model,
55            config,
56            model_name,
57            dimensions,
58            max_sequence_length,
59        };
60        
61        // Initialize the model
62        provider.ensure_model_loaded().await?;
63        
64        Ok(provider)
65    }
66    
67    /// Create a new FastEmbed provider with default configuration
68    pub async fn with_model(model: FastEmbedModel) -> Result<Self> {
69        Self::new(model, FastEmbedConfig::default()).await
70    }
71    
72    /// Ensure the embedding model is loaded (lazy loading)
73    async fn ensure_model_loaded(&self) -> Result<()> {
74        let mut model_guard = self.model.lock().await;
75        
76        if model_guard.is_none() {
77            debug!("Initializing FastEmbed model: {}", self.model_name);
78            
79            let mut init_options = fastembed::InitOptions::new(self.model_config.to_fastembed_model())
80                .with_show_download_progress(self.config.show_download_progress);
81            
82            if let Some(cache_dir) = &self.config.cache_dir {
83                init_options = init_options.with_cache_dir(cache_dir.into());
84                debug!("Using cache directory: {}", cache_dir);
85            }
86            
87            // Note: with_num_threads is not available in current fastembed version
88            // if let Some(num_threads) = self.config.num_threads {
89            //     init_options = init_options.with_num_threads(num_threads);
90            //     debug!("Using {} threads", num_threads);
91            // }
92            
93            let embedding_model = fastembed::TextEmbedding::try_new(init_options)
94                .map_err(|e| FastEmbedError::ModelInitialization(format!(
95                    "Failed to initialize FastEmbed model '{}': {}", 
96                    self.model_name, e
97                )))?;
98            
99            *model_guard = Some(embedding_model);
100            info!("FastEmbed model '{}' initialized successfully", self.model_name);
101        }
102        
103        Ok(())
104    }
105    
106    /// Get the model configuration
107    pub fn model_config(&self) -> &FastEmbedModel {
108        &self.model_config
109    }
110    
111    /// Get the provider configuration
112    pub fn config(&self) -> &FastEmbedConfig {
113        &self.config
114    }
115    
116    /// Check if text length is within model limits
117    fn validate_text_length(&self, text: &str) -> Result<()> {
118        // Simple character-based check (not token-based)
119        // In practice, you might want to use a tokenizer for more accurate checking
120        if text.len() > self.max_sequence_length * 4 { // Rough estimate: 4 chars per token
121            return Err(FastEmbedError::TextTooLong {
122                length: text.len(),
123                max_length: self.max_sequence_length * 4,
124            });
125        }
126        Ok(())
127    }
128    
129    /// Truncate text if it's too long
130    fn truncate_text(&self, text: &str) -> String {
131        let max_chars = self.max_sequence_length * 4; // Rough estimate
132        if text.len() > max_chars {
133            warn!(
134                "Text truncated from {} to {} characters for model '{}'",
135                text.len(), max_chars, self.model_name
136            );
137            text.chars().take(max_chars).collect()
138        } else {
139            text.to_string()
140        }
141    }
142    
143    /// Process texts in batches to respect model limits
144    async fn process_in_batches(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
145        self.ensure_model_loaded().await?;
146        
147        let model_guard = self.model.lock().await;
148        let model = model_guard.as_ref().ok_or_else(|| {
149            FastEmbedError::ModelNotInitialized("FastEmbed model not initialized".to_string())
150        })?;
151        
152        let mut all_embeddings = Vec::new();
153        
154        for chunk in texts.chunks(self.config.max_batch_size) {
155            debug!("Processing batch of {} texts", chunk.len());
156            
157            // Validate and truncate texts
158            let processed_texts: Vec<String> = chunk
159                .iter()
160                .map(|text| self.truncate_text(text))
161                .collect();
162            
163            let embeddings = model.embed(processed_texts, None)
164                .map_err(|e| FastEmbedError::EmbeddingGeneration(format!(
165                    "FastEmbed embedding failed: {}", e
166                )))?;
167            
168            all_embeddings.extend(embeddings);
169        }
170        
171        debug!("Generated {} embeddings", all_embeddings.len());
172        Ok(all_embeddings)
173    }
174}
175
176#[async_trait]
177impl EmbeddingModel for FastEmbedProvider {
178    type Config = FastEmbedConfig;
179    
180    async fn embed_text(&self, text: &str) -> std::result::Result<Vector, lumosai_vector_core::error::VectorError> {
181        // Validate text length
182        if let Err(e) = self.validate_text_length(text) {
183            warn!("Text validation failed: {}", e);
184            // Continue with truncation instead of failing
185        }
186        
187        let texts = vec![text.to_string()];
188        let embeddings = self.process_in_batches(&texts).await
189            .map_err(|e| lumosai_vector_core::error::VectorError::EmbeddingError(e.to_string()))?;
190        
191        embeddings.into_iter().next()
192            .ok_or_else(|| lumosai_vector_core::error::VectorError::EmbeddingError(
193                "No embedding returned from FastEmbed".to_string()
194            ))
195    }
196    
197    async fn embed_batch(&self, texts: &[String]) -> std::result::Result<Vec<Vector>, lumosai_vector_core::error::VectorError> {
198        if texts.is_empty() {
199            return Ok(Vec::new());
200        }
201        
202        debug!("Embedding batch of {} texts", texts.len());
203        
204        self.process_in_batches(texts).await
205            .map_err(|e| lumosai_vector_core::error::VectorError::EmbeddingError(e.to_string()))
206    }
207    
208    fn dimensions(&self) -> usize {
209        self.dimensions
210    }
211    
212    fn model_name(&self) -> &str {
213        &self.model_name
214    }
215    
216    fn max_input_length(&self) -> Option<usize> {
217        Some(self.max_sequence_length)
218    }
219
220    async fn health_check(&self) -> std::result::Result<(), lumosai_vector_core::error::VectorError> {
221        // Try to ensure the model is loaded as a health check
222        self.ensure_model_loaded().await
223            .map_err(|e| lumosai_vector_core::error::VectorError::EmbeddingError(e.to_string()))
224    }
225}
226
227/// Builder for FastEmbed provider
228pub struct FastEmbedProviderBuilder {
229    model: FastEmbedModel,
230    config: FastEmbedConfig,
231}
232
233impl FastEmbedProviderBuilder {
234    /// Create a new builder with the specified model
235    pub fn new(model: FastEmbedModel) -> Self {
236        Self {
237            model,
238            config: FastEmbedConfig::default(),
239        }
240    }
241    
242    /// Set the maximum batch size for processing
243    pub fn max_batch_size(mut self, size: usize) -> Self {
244        self.config.max_batch_size = size;
245        self
246    }
247    
248    /// Set whether to show download progress
249    pub fn show_download_progress(mut self, show: bool) -> Self {
250        self.config.show_download_progress = show;
251        self
252    }
253    
254    /// Set the number of threads for processing
255    pub fn num_threads(mut self, threads: usize) -> Self {
256        self.config.num_threads = Some(threads);
257        self
258    }
259    
260    /// Set the cache directory for model files
261    pub fn cache_dir<S: Into<String>>(mut self, dir: S) -> Self {
262        self.config.cache_dir = Some(dir.into());
263        self
264    }
265    
266    /// Build the FastEmbed provider
267    pub async fn build(self) -> Result<FastEmbedProvider> {
268        FastEmbedProvider::new(self.model, self.config).await
269    }
270}
271
272impl Default for FastEmbedProviderBuilder {
273    fn default() -> Self {
274        Self::new(FastEmbedModel::BGESmallENV15)
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use tokio_test;
282    
283    #[tokio::test]
284    async fn test_provider_creation() {
285        let provider = FastEmbedProvider::with_model(FastEmbedModel::BGESmallENV15).await;
286        
287        // Note: This test might fail if FastEmbed models are not available
288        // In CI/CD, you might want to skip this test or use mocks
289        match provider {
290            Ok(p) => {
291                assert_eq!(p.dimensions(), 384);
292                assert_eq!(p.model_name(), "BAAI/bge-small-en-v1.5");
293            }
294            Err(e) => {
295                // Log the error but don't fail the test in case models aren't available
296                eprintln!("FastEmbed model not available (this is OK in CI): {}", e);
297            }
298        }
299    }
300    
301    #[test]
302    fn test_text_validation() {
303        let provider = FastEmbedProviderBuilder::new(FastEmbedModel::BGESmallENV15)
304            .build();
305        
306        // This is a sync test, so we can't actually create the provider
307        // But we can test the builder
308        let builder = FastEmbedProviderBuilder::new(FastEmbedModel::BGESmallENV15)
309            .max_batch_size(128)
310            .show_download_progress(false);
311        
312        assert_eq!(builder.config.max_batch_size, 128);
313        assert!(!builder.config.show_download_progress);
314    }
315    
316    #[test]
317    fn test_builder_pattern() {
318        let builder = FastEmbedProviderBuilder::new(FastEmbedModel::BGEBaseENV15)
319            .max_batch_size(64)
320            .num_threads(2)
321            .cache_dir("/tmp/test");
322        
323        assert_eq!(builder.config.max_batch_size, 64);
324        assert_eq!(builder.config.num_threads, Some(2));
325        assert_eq!(builder.config.cache_dir, Some("/tmp/test".to_string()));
326    }
327}