lumosai_vector_fastembed/
lib.rs

1//! # LumosAI FastEmbed Integration
2//!
3//! This crate provides FastEmbed integration for LumosAI vector storage,
4//! enabling local embedding generation without external API dependencies.
5//!
6//! ## Features
7//!
8//! - **Local Processing**: Generate embeddings locally without API calls
9//! - **Multiple Models**: Support for various pre-trained models
10//! - **High Performance**: Optimized for batch processing
11//! - **Easy Integration**: Seamless integration with LumosAI vector storage
12//!
13//! ## Quick Start
14//!
15//! ```rust
16//! use lumosai_vector_fastembed::{FastEmbedProvider, FastEmbedModel};
17//! use lumosai_vector_core::traits::EmbeddingModel;
18//!
19//! #[tokio::main]
20//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
21//!     // Create FastEmbed provider
22//!     let provider = FastEmbedProvider::new(FastEmbedModel::BGESmallENV15).await?;
23//!     
24//!     // Generate embedding
25//!     let embedding = provider.embed_text("Hello, world!").await?;
26//!     println!("Embedding dimensions: {}", embedding.len());
27//!     
28//!     Ok(())
29//! }
30//! ```
31
32use std::collections::HashMap;
33use std::sync::Arc;
34use tokio::sync::Mutex;
35
36pub mod models;
37pub mod provider;
38pub mod error;
39
40pub use models::{FastEmbedModel, ModelInfo};
41pub use provider::FastEmbedProvider;
42pub use error::{FastEmbedError, Result};
43
44// Re-export core types for convenience
45pub use lumosai_vector_core::types::{Vector, Metadata};
46pub use lumosai_vector_core::traits::EmbeddingModel;
47
48/// FastEmbed client for managing embedding models
49#[derive(Clone)]
50pub struct FastEmbedClient {
51    /// Cache of initialized models
52    #[allow(dead_code)]
53    models: Arc<Mutex<HashMap<String, Arc<fastembed::TextEmbedding>>>>,
54
55    /// Default cache directory for models
56    #[allow(dead_code)]
57    cache_dir: Option<String>,
58    
59    /// Default configuration
60    config: FastEmbedConfig,
61}
62
63/// Configuration for FastEmbed client
64#[derive(Debug, Clone)]
65pub struct FastEmbedConfig {
66    /// Maximum batch size for processing
67    pub max_batch_size: usize,
68    
69    /// Show download progress
70    pub show_download_progress: bool,
71    
72    /// Number of threads for processing
73    pub num_threads: Option<usize>,
74    
75    /// Cache directory for model files
76    pub cache_dir: Option<String>,
77}
78
79impl Default for FastEmbedConfig {
80    fn default() -> Self {
81        Self {
82            max_batch_size: 256,
83            show_download_progress: true,
84            num_threads: None,
85            cache_dir: None,
86        }
87    }
88}
89
90impl FastEmbedClient {
91    /// Create a new FastEmbed client with default configuration
92    pub fn new() -> Self {
93        Self {
94            models: Arc::new(Mutex::new(HashMap::new())),
95            cache_dir: None,
96            config: FastEmbedConfig::default(),
97        }
98    }
99    
100    /// Create a new FastEmbed client with custom configuration
101    pub fn with_config(config: FastEmbedConfig) -> Self {
102        Self {
103            models: Arc::new(Mutex::new(HashMap::new())),
104            cache_dir: config.cache_dir.clone(),
105            config,
106        }
107    }
108    
109    /// Create an embedding provider for the specified model
110    pub async fn embedding_provider(&self, model: FastEmbedModel) -> Result<FastEmbedProvider> {
111        FastEmbedProvider::new(model, self.config.clone()).await
112    }
113    
114    /// Get or create a model instance
115    #[allow(dead_code)]
116    async fn get_or_create_model(
117        &self,
118        model: &FastEmbedModel,
119    ) -> Result<Arc<fastembed::TextEmbedding>> {
120        let model_key = model.model_name().to_string();
121        let mut models = self.models.lock().await;
122        
123        if let Some(existing_model) = models.get(&model_key) {
124            return Ok(existing_model.clone());
125        }
126        
127        // Create new model instance
128        let mut init_options = fastembed::InitOptions::new(model.to_fastembed_model())
129            .with_show_download_progress(self.config.show_download_progress);
130        
131        if let Some(cache_dir) = &self.config.cache_dir {
132            init_options = init_options.with_cache_dir(cache_dir.into());
133        }
134        
135        // Note: with_num_threads is not available in current fastembed version
136        // if let Some(num_threads) = self.config.num_threads {
137        //     init_options = init_options.with_num_threads(num_threads);
138        // }
139        
140        let embedding_model = fastembed::TextEmbedding::try_new(init_options)
141            .map_err(|e| FastEmbedError::ModelInitialization(e.to_string()))?;
142        
143        let model_arc = Arc::new(embedding_model);
144        models.insert(model_key, model_arc.clone());
145        
146        Ok(model_arc)
147    }
148    
149    /// List available models
150    pub fn available_models() -> Vec<FastEmbedModel> {
151        vec![
152            FastEmbedModel::BGESmallENV15,
153            FastEmbedModel::BGEBaseENV15,
154            FastEmbedModel::BGELargeENV15,
155            FastEmbedModel::AllMiniLML6V2,
156            FastEmbedModel::AllMiniLML12V2,
157            FastEmbedModel::MultilingualE5Small,
158            FastEmbedModel::MultilingualE5Base,
159            FastEmbedModel::MultilingualE5Large,
160        ]
161    }
162    
163    /// Get model information
164    pub fn model_info(model: &FastEmbedModel) -> ModelInfo {
165        ModelInfo {
166            name: model.model_name().to_string(),
167            dimensions: model.dimensions(),
168            max_sequence_length: model.max_sequence_length(),
169            description: model.description().to_string(),
170            language_support: model.language_support().iter().map(|s| s.to_string()).collect(),
171        }
172    }
173}
174
175impl Default for FastEmbedClient {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181/// Builder for FastEmbed configuration
182pub struct FastEmbedConfigBuilder {
183    config: FastEmbedConfig,
184}
185
186impl FastEmbedConfigBuilder {
187    /// Create a new configuration builder
188    pub fn new() -> Self {
189        Self {
190            config: FastEmbedConfig::default(),
191        }
192    }
193    
194    /// Set the maximum batch size
195    pub fn max_batch_size(mut self, size: usize) -> Self {
196        self.config.max_batch_size = size;
197        self
198    }
199    
200    /// Set whether to show download progress
201    pub fn show_download_progress(mut self, show: bool) -> Self {
202        self.config.show_download_progress = show;
203        self
204    }
205    
206    /// Set the number of threads for processing
207    pub fn num_threads(mut self, threads: usize) -> Self {
208        self.config.num_threads = Some(threads);
209        self
210    }
211    
212    /// Set the cache directory for model files
213    pub fn cache_dir<S: Into<String>>(mut self, dir: S) -> Self {
214        self.config.cache_dir = Some(dir.into());
215        self
216    }
217    
218    /// Build the configuration
219    pub fn build(self) -> FastEmbedConfig {
220        self.config
221    }
222}
223
224impl Default for FastEmbedConfigBuilder {
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    
234    #[test]
235    fn test_client_creation() {
236        let client = FastEmbedClient::new();
237        assert_eq!(client.config.max_batch_size, 256);
238        assert!(client.config.show_download_progress);
239    }
240    
241    #[test]
242    fn test_config_builder() {
243        let config = FastEmbedConfigBuilder::new()
244            .max_batch_size(128)
245            .show_download_progress(false)
246            .num_threads(4)
247            .cache_dir("/tmp/fastembed")
248            .build();
249        
250        assert_eq!(config.max_batch_size, 128);
251        assert!(!config.show_download_progress);
252        assert_eq!(config.num_threads, Some(4));
253        assert_eq!(config.cache_dir, Some("/tmp/fastembed".to_string()));
254    }
255    
256    #[test]
257    fn test_available_models() {
258        let models = FastEmbedClient::available_models();
259        assert!(!models.is_empty());
260        assert!(models.contains(&FastEmbedModel::BGESmallENV15));
261    }
262    
263    #[test]
264    fn test_model_info() {
265        let info = FastEmbedClient::model_info(&FastEmbedModel::BGESmallENV15);
266        assert_eq!(info.name, "BAAI/bge-small-en-v1.5");
267        assert_eq!(info.dimensions, 384);
268    }
269}