lumosai_vector_fastembed/
lib.rs1use std::collections::HashMap;
33use std::sync::Arc;
34use tokio::sync::Mutex;
35
36pub mod models;
37pub mod provider;
38pub mod error;
39
40pub use models::{FastEmbedModel, ModelInfo};
41pub use provider::FastEmbedProvider;
42pub use error::{FastEmbedError, Result};
43
44pub use lumosai_vector_core::types::{Vector, Metadata};
46pub use lumosai_vector_core::traits::EmbeddingModel;
47
48#[derive(Clone)]
50pub struct FastEmbedClient {
51 #[allow(dead_code)]
53 models: Arc<Mutex<HashMap<String, Arc<fastembed::TextEmbedding>>>>,
54
55 #[allow(dead_code)]
57 cache_dir: Option<String>,
58
59 config: FastEmbedConfig,
61}
62
63#[derive(Debug, Clone)]
65pub struct FastEmbedConfig {
66 pub max_batch_size: usize,
68
69 pub show_download_progress: bool,
71
72 pub num_threads: Option<usize>,
74
75 pub cache_dir: Option<String>,
77}
78
79impl Default for FastEmbedConfig {
80 fn default() -> Self {
81 Self {
82 max_batch_size: 256,
83 show_download_progress: true,
84 num_threads: None,
85 cache_dir: None,
86 }
87 }
88}
89
90impl FastEmbedClient {
91 pub fn new() -> Self {
93 Self {
94 models: Arc::new(Mutex::new(HashMap::new())),
95 cache_dir: None,
96 config: FastEmbedConfig::default(),
97 }
98 }
99
100 pub fn with_config(config: FastEmbedConfig) -> Self {
102 Self {
103 models: Arc::new(Mutex::new(HashMap::new())),
104 cache_dir: config.cache_dir.clone(),
105 config,
106 }
107 }
108
109 pub async fn embedding_provider(&self, model: FastEmbedModel) -> Result<FastEmbedProvider> {
111 FastEmbedProvider::new(model, self.config.clone()).await
112 }
113
114 #[allow(dead_code)]
116 async fn get_or_create_model(
117 &self,
118 model: &FastEmbedModel,
119 ) -> Result<Arc<fastembed::TextEmbedding>> {
120 let model_key = model.model_name().to_string();
121 let mut models = self.models.lock().await;
122
123 if let Some(existing_model) = models.get(&model_key) {
124 return Ok(existing_model.clone());
125 }
126
127 let mut init_options = fastembed::InitOptions::new(model.to_fastembed_model())
129 .with_show_download_progress(self.config.show_download_progress);
130
131 if let Some(cache_dir) = &self.config.cache_dir {
132 init_options = init_options.with_cache_dir(cache_dir.into());
133 }
134
135 let embedding_model = fastembed::TextEmbedding::try_new(init_options)
141 .map_err(|e| FastEmbedError::ModelInitialization(e.to_string()))?;
142
143 let model_arc = Arc::new(embedding_model);
144 models.insert(model_key, model_arc.clone());
145
146 Ok(model_arc)
147 }
148
149 pub fn available_models() -> Vec<FastEmbedModel> {
151 vec![
152 FastEmbedModel::BGESmallENV15,
153 FastEmbedModel::BGEBaseENV15,
154 FastEmbedModel::BGELargeENV15,
155 FastEmbedModel::AllMiniLML6V2,
156 FastEmbedModel::AllMiniLML12V2,
157 FastEmbedModel::MultilingualE5Small,
158 FastEmbedModel::MultilingualE5Base,
159 FastEmbedModel::MultilingualE5Large,
160 ]
161 }
162
163 pub fn model_info(model: &FastEmbedModel) -> ModelInfo {
165 ModelInfo {
166 name: model.model_name().to_string(),
167 dimensions: model.dimensions(),
168 max_sequence_length: model.max_sequence_length(),
169 description: model.description().to_string(),
170 language_support: model.language_support().iter().map(|s| s.to_string()).collect(),
171 }
172 }
173}
174
175impl Default for FastEmbedClient {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181pub struct FastEmbedConfigBuilder {
183 config: FastEmbedConfig,
184}
185
186impl FastEmbedConfigBuilder {
187 pub fn new() -> Self {
189 Self {
190 config: FastEmbedConfig::default(),
191 }
192 }
193
194 pub fn max_batch_size(mut self, size: usize) -> Self {
196 self.config.max_batch_size = size;
197 self
198 }
199
200 pub fn show_download_progress(mut self, show: bool) -> Self {
202 self.config.show_download_progress = show;
203 self
204 }
205
206 pub fn num_threads(mut self, threads: usize) -> Self {
208 self.config.num_threads = Some(threads);
209 self
210 }
211
212 pub fn cache_dir<S: Into<String>>(mut self, dir: S) -> Self {
214 self.config.cache_dir = Some(dir.into());
215 self
216 }
217
218 pub fn build(self) -> FastEmbedConfig {
220 self.config
221 }
222}
223
224impl Default for FastEmbedConfigBuilder {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn test_client_creation() {
236 let client = FastEmbedClient::new();
237 assert_eq!(client.config.max_batch_size, 256);
238 assert!(client.config.show_download_progress);
239 }
240
241 #[test]
242 fn test_config_builder() {
243 let config = FastEmbedConfigBuilder::new()
244 .max_batch_size(128)
245 .show_download_progress(false)
246 .num_threads(4)
247 .cache_dir("/tmp/fastembed")
248 .build();
249
250 assert_eq!(config.max_batch_size, 128);
251 assert!(!config.show_download_progress);
252 assert_eq!(config.num_threads, Some(4));
253 assert_eq!(config.cache_dir, Some("/tmp/fastembed".to_string()));
254 }
255
256 #[test]
257 fn test_available_models() {
258 let models = FastEmbedClient::available_models();
259 assert!(!models.is_empty());
260 assert!(models.contains(&FastEmbedModel::BGESmallENV15));
261 }
262
263 #[test]
264 fn test_model_info() {
265 let info = FastEmbedClient::model_info(&FastEmbedModel::BGESmallENV15);
266 assert_eq!(info.name, "BAAI/bge-small-en-v1.5");
267 assert_eq!(info.dimensions, 384);
268 }
269}