1use anyhow::{anyhow, Result};
2use fastembed::{EmbeddingModel as FastEmbedModel, InitOptions, TextEmbedding};
3use ort::execution_providers::CPUExecutionProvider;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
7pub enum ModelType {
8 AllMiniLML6V2,
11 #[default]
13 AllMiniLML6V2Q,
14 AllMiniLML12V2,
16 AllMiniLML12V2Q,
18 ParaphraseMLMiniLML12V2,
20
21 BGESmallENV15,
24 BGESmallENV15Q,
26 BGEBaseENV15,
28 BGELargeENV15,
30
31 NomicEmbedTextV1,
34 NomicEmbedTextV15,
36 NomicEmbedTextV15Q,
38
39 JinaEmbeddingsV2BaseCode,
42 MultilingualE5Small,
44 MxbaiEmbedLargeV1,
46 ModernBertEmbedLarge,
48}
49
50impl ModelType {
51 pub fn to_fastembed_model(self) -> FastEmbedModel {
52 match self {
53 Self::AllMiniLML6V2 => FastEmbedModel::AllMiniLML6V2,
55 Self::AllMiniLML6V2Q => FastEmbedModel::AllMiniLML6V2Q,
56 Self::AllMiniLML12V2 => FastEmbedModel::AllMiniLML12V2,
57 Self::AllMiniLML12V2Q => FastEmbedModel::AllMiniLML12V2Q,
58 Self::ParaphraseMLMiniLML12V2 => FastEmbedModel::ParaphraseMLMiniLML12V2,
59 Self::BGESmallENV15 => FastEmbedModel::BGESmallENV15,
61 Self::BGESmallENV15Q => FastEmbedModel::BGESmallENV15Q,
62 Self::BGEBaseENV15 => FastEmbedModel::BGEBaseENV15,
63 Self::BGELargeENV15 => FastEmbedModel::BGELargeENV15,
64 Self::NomicEmbedTextV1 => FastEmbedModel::NomicEmbedTextV1,
66 Self::NomicEmbedTextV15 => FastEmbedModel::NomicEmbedTextV15,
67 Self::NomicEmbedTextV15Q => FastEmbedModel::NomicEmbedTextV15Q,
68 Self::JinaEmbeddingsV2BaseCode => FastEmbedModel::JinaEmbeddingsV2BaseCode,
70 Self::MultilingualE5Small => FastEmbedModel::MultilingualE5Small,
71 Self::MxbaiEmbedLargeV1 => FastEmbedModel::MxbaiEmbedLargeV1,
72 Self::ModernBertEmbedLarge => FastEmbedModel::ModernBertEmbedLarge,
73 }
74 }
75
76 pub fn dimensions(&self) -> usize {
77 match self {
78 Self::AllMiniLML6V2
80 | Self::AllMiniLML6V2Q
81 | Self::AllMiniLML12V2
82 | Self::AllMiniLML12V2Q
83 | Self::ParaphraseMLMiniLML12V2
84 | Self::BGESmallENV15
85 | Self::BGESmallENV15Q
86 | Self::MultilingualE5Small => 384,
87 Self::BGEBaseENV15
89 | Self::NomicEmbedTextV1
90 | Self::NomicEmbedTextV15
91 | Self::NomicEmbedTextV15Q
92 | Self::JinaEmbeddingsV2BaseCode => 768,
93 Self::BGELargeENV15 | Self::MxbaiEmbedLargeV1 | Self::ModernBertEmbedLarge => 1024,
95 }
96 }
97
98 pub fn name(&self) -> &'static str {
99 match self {
100 Self::AllMiniLML6V2 => "sentence-transformers/all-MiniLM-L6-v2",
101 Self::AllMiniLML6V2Q => "sentence-transformers/all-MiniLM-L6-v2 (quantized)",
102 Self::AllMiniLML12V2 => "sentence-transformers/all-MiniLM-L12-v2",
103 Self::AllMiniLML12V2Q => "sentence-transformers/all-MiniLM-L12-v2 (quantized)",
104 Self::ParaphraseMLMiniLML12V2 => "sentence-transformers/paraphrase-MiniLM-L6-v2",
105 Self::BGESmallENV15 => "BAAI/bge-small-en-v1.5",
106 Self::BGESmallENV15Q => "BAAI/bge-small-en-v1.5 (quantized)",
107 Self::BGEBaseENV15 => "BAAI/bge-base-en-v1.5",
108 Self::BGELargeENV15 => "BAAI/bge-large-en-v1.5",
109 Self::NomicEmbedTextV1 => "nomic-ai/nomic-embed-text-v1",
110 Self::NomicEmbedTextV15 => "nomic-ai/nomic-embed-text-v1.5",
111 Self::NomicEmbedTextV15Q => "nomic-ai/nomic-embed-text-v1.5 (quantized)",
112 Self::JinaEmbeddingsV2BaseCode => "jinaai/jina-embeddings-v2-base-code",
113 Self::MultilingualE5Small => "intfloat/multilingual-e5-small",
114 Self::MxbaiEmbedLargeV1 => "mixedbread-ai/mxbai-embed-large-v1",
115 Self::ModernBertEmbedLarge => "lightonai/modernbert-embed-large",
116 }
117 }
118
119 #[allow(dead_code)] pub fn is_quantized(&self) -> bool {
122 matches!(
123 self,
124 Self::AllMiniLML6V2Q
125 | Self::AllMiniLML12V2Q
126 | Self::BGESmallENV15Q
127 | Self::NomicEmbedTextV15Q
128 )
129 }
130
131 pub fn short_name(&self) -> &'static str {
133 match self {
134 Self::AllMiniLML6V2 => "minilm-l6",
135 Self::AllMiniLML6V2Q => "minilm-l6-q",
136 Self::AllMiniLML12V2 => "minilm-l12",
137 Self::AllMiniLML12V2Q => "minilm-l12-q",
138 Self::ParaphraseMLMiniLML12V2 => "paraphrase-minilm",
139 Self::BGESmallENV15 => "bge-small",
140 Self::BGESmallENV15Q => "bge-small-q",
141 Self::BGEBaseENV15 => "bge-base",
142 Self::BGELargeENV15 => "bge-large",
143 Self::NomicEmbedTextV1 => "nomic-v1",
144 Self::NomicEmbedTextV15 => "nomic-v1.5",
145 Self::NomicEmbedTextV15Q => "nomic-v1.5-q",
146 Self::JinaEmbeddingsV2BaseCode => "jina-code",
147 Self::MultilingualE5Small => "e5-multilingual",
148 Self::MxbaiEmbedLargeV1 => "mxbai-large",
149 Self::ModernBertEmbedLarge => "modernbert-large",
150 }
151 }
152
153 #[allow(dead_code)] pub fn all() -> &'static [ModelType] {
156 &[
157 Self::AllMiniLML6V2,
158 Self::AllMiniLML6V2Q,
159 Self::AllMiniLML12V2,
160 Self::AllMiniLML12V2Q,
161 Self::ParaphraseMLMiniLML12V2,
162 Self::BGESmallENV15,
163 Self::BGESmallENV15Q,
164 Self::BGEBaseENV15,
165 Self::BGELargeENV15,
166 Self::NomicEmbedTextV1,
167 Self::NomicEmbedTextV15,
168 Self::NomicEmbedTextV15Q,
169 Self::JinaEmbeddingsV2BaseCode,
170 Self::MultilingualE5Small,
171 Self::MxbaiEmbedLargeV1,
172 Self::ModernBertEmbedLarge,
173 ]
174 }
175
176 pub fn parse(s: &str) -> Option<Self> {
178 match s.to_lowercase().as_str() {
179 "minilm-l6" | "allminiml6v2" => Some(Self::AllMiniLML6V2),
180 "minilm-l6-q" | "allminiml6v2q" => Some(Self::AllMiniLML6V2Q),
181 "minilm-l12" | "allminiml12v2" => Some(Self::AllMiniLML12V2),
182 "minilm-l12-q" | "allminiml12v2q" => Some(Self::AllMiniLML12V2Q),
183 "paraphrase-minilm" => Some(Self::ParaphraseMLMiniLML12V2),
184 "bge-small" | "bgesmallenv15" => Some(Self::BGESmallENV15),
185 "bge-small-q" | "bgesmallenv15q" => Some(Self::BGESmallENV15Q),
186 "bge-base" | "bgebaseenv15" => Some(Self::BGEBaseENV15),
187 "bge-large" | "bgelargeenv15" => Some(Self::BGELargeENV15),
188 "nomic-v1" | "nomicembedtextv1" => Some(Self::NomicEmbedTextV1),
189 "nomic-v1.5" | "nomicembedtextv15" => Some(Self::NomicEmbedTextV15),
190 "nomic-v1.5-q" | "nomicembedtextv15q" => Some(Self::NomicEmbedTextV15Q),
191 "jina-code" | "jinaembeddingsv2basecode" => Some(Self::JinaEmbeddingsV2BaseCode),
192 "e5-multilingual" | "multilinguale5small" => Some(Self::MultilingualE5Small),
193 "mxbai-large" | "mxbaiembedlargev1" => Some(Self::MxbaiEmbedLargeV1),
194 "modernbert-large" | "modernbertembedlarge" => Some(Self::ModernBertEmbedLarge),
195 _ => None,
196 }
197 }
198}
199
200pub struct FastEmbedder {
202 model: TextEmbedding,
203 model_type: ModelType,
204}
205
206impl FastEmbedder {
207 pub fn new() -> Result<Self> {
209 Self::with_model(ModelType::default())
210 }
211
212 pub fn with_model(model_type: ModelType) -> Result<Self> {
214 Self::with_cache_dir(model_type, None)
215 }
216
217 pub fn with_cache_dir(
219 model_type: ModelType,
220 cache_dir: Option<&std::path::Path>,
221 ) -> Result<Self> {
222 if let Some(cache_dir) = cache_dir {
225 std::env::set_var(
226 "FASTEMBED_CACHE_DIR",
227 cache_dir.to_string_lossy().to_string(),
228 );
229 }
230
231 let cpu_ep = CPUExecutionProvider::default()
234 .with_arena_allocator(true)
235 .build();
236
237 let model = TextEmbedding::try_new(
238 InitOptions::new(model_type.to_fastembed_model())
239 .with_show_download_progress(false)
240 .with_execution_providers(vec![cpu_ep]),
241 )
242 .map_err(|e| anyhow!("Failed to initialize embedding model: {}", e))?;
243
244 Ok(Self { model, model_type })
245 }
246 pub fn embed_batch(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
250 let batch_size = if let Ok(env_size) = std::env::var("CODESEARCH_BATCH_SIZE") {
252 env_size.parse().unwrap_or(256)
253 } else {
254 match self.model_type.dimensions() {
257 d if d <= 384 => 256, d if d <= 768 => 128, _ => 64, }
261 };
262 self.embed_batch_chunked(texts, batch_size)
263 }
264
265 pub fn embed_batch_chunked(
267 &mut self,
268 texts: Vec<String>,
269 batch_size: usize,
270 ) -> Result<Vec<Vec<f32>>> {
271 if texts.is_empty() {
272 return Ok(Vec::new());
273 }
274
275 let mut all_embeddings = Vec::with_capacity(texts.len());
276
277 for chunk in texts.chunks(batch_size) {
279 if crate::constants::is_shutdown_requested() {
281 return Err(anyhow!("Embedding interrupted by shutdown request"));
282 }
283
284 let text_refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect();
285
286 let embeddings = self
287 .model
288 .embed(text_refs, None)
289 .map_err(|e| anyhow!("Failed to generate embeddings: {}", e))?;
290
291 all_embeddings.extend(embeddings);
292 }
293
294 Ok(all_embeddings)
295 }
296
297 pub fn embed_one(&mut self, text: &str) -> Result<Vec<f32>> {
299 let embeddings = self.embed_batch(vec![text.to_string()])?;
300 embeddings
301 .into_iter()
302 .next()
303 .ok_or_else(|| anyhow!("No embedding generated"))
304 }
305
306 pub fn dimensions(&self) -> usize {
308 self.model_type.dimensions()
309 }
310
311 #[allow(dead_code)] pub fn model_name(&self) -> &str {
314 self.model_type.name()
315 }
316
317 #[allow(dead_code)] pub fn model_type(&self) -> ModelType {
320 self.model_type
321 }
322}
323
324impl Default for FastEmbedder {
325 fn default() -> Self {
326 Self::new().expect("Failed to create default embedder")
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_model_type_dimensions() {
336 assert_eq!(ModelType::BGESmallENV15.dimensions(), 384);
338 assert_eq!(ModelType::BGESmallENV15Q.dimensions(), 384);
339 assert_eq!(ModelType::AllMiniLML6V2.dimensions(), 384);
340 assert_eq!(ModelType::AllMiniLML6V2Q.dimensions(), 384);
341 assert_eq!(ModelType::AllMiniLML12V2.dimensions(), 384);
342 assert_eq!(ModelType::MultilingualE5Small.dimensions(), 384);
343 assert_eq!(ModelType::BGEBaseENV15.dimensions(), 768);
345 assert_eq!(ModelType::NomicEmbedTextV1.dimensions(), 768);
346 assert_eq!(ModelType::NomicEmbedTextV15.dimensions(), 768);
347 assert_eq!(ModelType::JinaEmbeddingsV2BaseCode.dimensions(), 768);
348 assert_eq!(ModelType::BGELargeENV15.dimensions(), 1024);
350 assert_eq!(ModelType::MxbaiEmbedLargeV1.dimensions(), 1024);
351 assert_eq!(ModelType::ModernBertEmbedLarge.dimensions(), 1024);
352 }
353
354 #[test]
355 fn test_model_type_names() {
356 assert_eq!(ModelType::BGESmallENV15.name(), "BAAI/bge-small-en-v1.5");
357 assert_eq!(
358 ModelType::AllMiniLML6V2.name(),
359 "sentence-transformers/all-MiniLM-L6-v2"
360 );
361 assert_eq!(
362 ModelType::JinaEmbeddingsV2BaseCode.name(),
363 "jinaai/jina-embeddings-v2-base-code"
364 );
365 }
366
367 #[test]
368 fn test_default_model() {
369 let model = ModelType::default();
370 assert_eq!(model, ModelType::AllMiniLML6V2Q);
371 assert_eq!(model.dimensions(), 384);
372 }
373
374 #[test]
375 fn test_all_models() {
376 let all = ModelType::all();
377 assert_eq!(all.len(), 16);
378 }
379
380 #[test]
381 fn test_parse() {
382 assert_eq!(
383 ModelType::parse("minilm-l6"),
384 Some(ModelType::AllMiniLML6V2)
385 );
386 assert_eq!(
387 ModelType::parse("minilm-l6-q"),
388 Some(ModelType::AllMiniLML6V2Q)
389 );
390 assert_eq!(
391 ModelType::parse("minilm-l12"),
392 Some(ModelType::AllMiniLML12V2)
393 );
394 assert_eq!(
395 ModelType::parse("minilm-l12-q"),
396 Some(ModelType::AllMiniLML12V2Q)
397 );
398 assert_eq!(
399 ModelType::parse("paraphrase-minilm"),
400 Some(ModelType::ParaphraseMLMiniLML12V2)
401 );
402 assert_eq!(
403 ModelType::parse("bge-small"),
404 Some(ModelType::BGESmallENV15)
405 );
406 assert_eq!(
407 ModelType::parse("bge-small-q"),
408 Some(ModelType::BGESmallENV15Q)
409 );
410 assert_eq!(ModelType::parse("bge-base"), Some(ModelType::BGEBaseENV15));
411 assert_eq!(
412 ModelType::parse("nomic-v1"),
413 Some(ModelType::NomicEmbedTextV1)
414 );
415 assert_eq!(
416 ModelType::parse("nomic-v1.5"),
417 Some(ModelType::NomicEmbedTextV15)
418 );
419 assert_eq!(
420 ModelType::parse("nomic-v1.5-q"),
421 Some(ModelType::NomicEmbedTextV15Q)
422 );
423 assert_eq!(
424 ModelType::parse("jina-code"),
425 Some(ModelType::JinaEmbeddingsV2BaseCode)
426 );
427 assert_eq!(ModelType::parse("invalid"), None);
428 }
429
430 #[test]
431 fn test_is_quantized() {
432 assert!(ModelType::AllMiniLML6V2Q.is_quantized());
433 assert!(ModelType::BGESmallENV15Q.is_quantized());
434 assert!(!ModelType::BGESmallENV15.is_quantized());
435 assert!(!ModelType::JinaEmbeddingsV2BaseCode.is_quantized());
436 }
437
438 #[test]
439 #[ignore] fn test_embedder_creation() {
441 let embedder = FastEmbedder::new();
442 assert!(embedder.is_ok());
443
444 let embedder = embedder.unwrap();
445 assert_eq!(embedder.dimensions(), 384);
446 }
447
448 #[test]
449 #[ignore] fn test_embed_single_text() {
451 let mut embedder = FastEmbedder::new().unwrap();
452 let embedding = embedder.embed_one("Hello, world!").unwrap();
453
454 assert_eq!(embedding.len(), 384);
455 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
457 assert!((magnitude - 1.0).abs() < 0.1);
458 }
459
460 #[test]
461 #[ignore] fn test_embed_batch() {
463 let mut embedder = FastEmbedder::new().unwrap();
464 let texts = vec![
465 "Hello, world!".to_string(),
466 "Rust is awesome".to_string(),
467 "Code search with AI".to_string(),
468 ];
469
470 let embeddings = embedder.embed_batch(texts).unwrap();
471
472 assert_eq!(embeddings.len(), 3);
473 for embedding in embeddings {
474 assert_eq!(embedding.len(), 384);
475 }
476 }
477
478 #[test]
479 #[ignore] fn test_semantic_similarity() {
481 let mut embedder = FastEmbedder::new().unwrap();
482
483 let text1 = "The quick brown fox jumps over the lazy dog";
484 let text2 = "A fast auburn fox leaps over a sleepy canine";
485 let text3 = "Python is a programming language";
486
487 let emb1 = embedder.embed_one(text1).unwrap();
488 let emb2 = embedder.embed_one(text2).unwrap();
489 let emb3 = embedder.embed_one(text3).unwrap();
490
491 let sim_1_2 = cosine_similarity(&emb1, &emb2);
493 let sim_1_3 = cosine_similarity(&emb1, &emb3);
494
495 assert!(sim_1_2 > sim_1_3);
497 assert!(sim_1_2 > 0.7); }
499
500 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
501 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
502 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
503 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
504 dot / (mag_a * mag_b)
505 }
506}