Skip to main content

codesearch/embed/
embedder.rs

1use anyhow::{anyhow, Result};
2use fastembed::{EmbeddingModel as FastEmbedModel, InitOptions, TextEmbedding};
3use ort::execution_providers::CPUExecutionProvider;
4
5/// Available embedding models
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
7pub enum ModelType {
8    // === MiniLM Family ===
9    /// All-MiniLM-L6-v2 - 384 dimensions, fast and efficient
10    AllMiniLML6V2,
11    /// Quantized All-MiniLM-L6-v2 - 384 dimensions, faster
12    #[default]
13    AllMiniLML6V2Q,
14    /// All-MiniLM-L12-v2 - 384 dimensions, better quality than L6
15    AllMiniLML12V2,
16    /// Quantized All-MiniLM-L12-v2 - 384 dimensions
17    AllMiniLML12V2Q,
18    /// Paraphrase-MiniLM-L6-v2 - 384 dimensions
19    ParaphraseMLMiniLML12V2,
20
21    // === BGE Family ===
22    /// BGE Small EN v1.5 - 384 dimensions, good balance (DEFAULT)
23    BGESmallENV15,
24    /// Quantized BGE Small EN v1.5 - 384 dimensions, faster
25    BGESmallENV15Q,
26    /// BGE Base EN v1.5 - 768 dimensions, higher quality
27    BGEBaseENV15,
28    /// BGE Large EN v1.5 - 1024 dimensions, best BGE quality
29    BGELargeENV15,
30
31    // === Nomic Family ===
32    /// Nomic Embed Text v1 - 768 dimensions
33    NomicEmbedTextV1,
34    /// Nomic Embed Text v1.5 - 768 dimensions, improved
35    NomicEmbedTextV15,
36    /// Quantized Nomic Embed Text v1.5 - 768 dimensions
37    NomicEmbedTextV15Q,
38
39    // === Specialized Models ===
40    /// Jina Embeddings v2 Base Code - 768 dimensions, optimized for code
41    JinaEmbeddingsV2BaseCode,
42    /// Multilingual E5 Small - 384 dimensions, multilingual support
43    MultilingualE5Small,
44    /// MxBai Embed Large v1 - 1024 dimensions, high quality
45    MxbaiEmbedLargeV1,
46    /// ModernBERT Embed Large - 1024 dimensions, latest architecture
47    ModernBertEmbedLarge,
48}
49
50impl ModelType {
51    pub fn to_fastembed_model(self) -> FastEmbedModel {
52        match self {
53            // MiniLM Family
54            Self::AllMiniLML6V2 => FastEmbedModel::AllMiniLML6V2,
55            Self::AllMiniLML6V2Q => FastEmbedModel::AllMiniLML6V2Q,
56            Self::AllMiniLML12V2 => FastEmbedModel::AllMiniLML12V2,
57            Self::AllMiniLML12V2Q => FastEmbedModel::AllMiniLML12V2Q,
58            Self::ParaphraseMLMiniLML12V2 => FastEmbedModel::ParaphraseMLMiniLML12V2,
59            // BGE Family
60            Self::BGESmallENV15 => FastEmbedModel::BGESmallENV15,
61            Self::BGESmallENV15Q => FastEmbedModel::BGESmallENV15Q,
62            Self::BGEBaseENV15 => FastEmbedModel::BGEBaseENV15,
63            Self::BGELargeENV15 => FastEmbedModel::BGELargeENV15,
64            // Nomic Family
65            Self::NomicEmbedTextV1 => FastEmbedModel::NomicEmbedTextV1,
66            Self::NomicEmbedTextV15 => FastEmbedModel::NomicEmbedTextV15,
67            Self::NomicEmbedTextV15Q => FastEmbedModel::NomicEmbedTextV15Q,
68            // Specialized
69            Self::JinaEmbeddingsV2BaseCode => FastEmbedModel::JinaEmbeddingsV2BaseCode,
70            Self::MultilingualE5Small => FastEmbedModel::MultilingualE5Small,
71            Self::MxbaiEmbedLargeV1 => FastEmbedModel::MxbaiEmbedLargeV1,
72            Self::ModernBertEmbedLarge => FastEmbedModel::ModernBertEmbedLarge,
73        }
74    }
75
76    pub fn dimensions(&self) -> usize {
77        match self {
78            // 384 dimensions
79            Self::AllMiniLML6V2
80            | Self::AllMiniLML6V2Q
81            | Self::AllMiniLML12V2
82            | Self::AllMiniLML12V2Q
83            | Self::ParaphraseMLMiniLML12V2
84            | Self::BGESmallENV15
85            | Self::BGESmallENV15Q
86            | Self::MultilingualE5Small => 384,
87            // 768 dimensions
88            Self::BGEBaseENV15
89            | Self::NomicEmbedTextV1
90            | Self::NomicEmbedTextV15
91            | Self::NomicEmbedTextV15Q
92            | Self::JinaEmbeddingsV2BaseCode => 768,
93            // 1024 dimensions
94            Self::BGELargeENV15 | Self::MxbaiEmbedLargeV1 | Self::ModernBertEmbedLarge => 1024,
95        }
96    }
97
98    pub fn name(&self) -> &'static str {
99        match self {
100            Self::AllMiniLML6V2 => "sentence-transformers/all-MiniLM-L6-v2",
101            Self::AllMiniLML6V2Q => "sentence-transformers/all-MiniLM-L6-v2 (quantized)",
102            Self::AllMiniLML12V2 => "sentence-transformers/all-MiniLM-L12-v2",
103            Self::AllMiniLML12V2Q => "sentence-transformers/all-MiniLM-L12-v2 (quantized)",
104            Self::ParaphraseMLMiniLML12V2 => "sentence-transformers/paraphrase-MiniLM-L6-v2",
105            Self::BGESmallENV15 => "BAAI/bge-small-en-v1.5",
106            Self::BGESmallENV15Q => "BAAI/bge-small-en-v1.5 (quantized)",
107            Self::BGEBaseENV15 => "BAAI/bge-base-en-v1.5",
108            Self::BGELargeENV15 => "BAAI/bge-large-en-v1.5",
109            Self::NomicEmbedTextV1 => "nomic-ai/nomic-embed-text-v1",
110            Self::NomicEmbedTextV15 => "nomic-ai/nomic-embed-text-v1.5",
111            Self::NomicEmbedTextV15Q => "nomic-ai/nomic-embed-text-v1.5 (quantized)",
112            Self::JinaEmbeddingsV2BaseCode => "jinaai/jina-embeddings-v2-base-code",
113            Self::MultilingualE5Small => "intfloat/multilingual-e5-small",
114            Self::MxbaiEmbedLargeV1 => "mixedbread-ai/mxbai-embed-large-v1",
115            Self::ModernBertEmbedLarge => "lightonai/modernbert-embed-large",
116        }
117    }
118
119    /// Check if model is quantized (faster but slightly less accurate)
120    #[allow(dead_code)] // Reserved for model selection UI
121    pub fn is_quantized(&self) -> bool {
122        matches!(
123            self,
124            Self::AllMiniLML6V2Q
125                | Self::AllMiniLML12V2Q
126                | Self::BGESmallENV15Q
127                | Self::NomicEmbedTextV15Q
128        )
129    }
130
131    /// Get a short identifier for the model (for filenames, etc.)
132    pub fn short_name(&self) -> &'static str {
133        match self {
134            Self::AllMiniLML6V2 => "minilm-l6",
135            Self::AllMiniLML6V2Q => "minilm-l6-q",
136            Self::AllMiniLML12V2 => "minilm-l12",
137            Self::AllMiniLML12V2Q => "minilm-l12-q",
138            Self::ParaphraseMLMiniLML12V2 => "paraphrase-minilm",
139            Self::BGESmallENV15 => "bge-small",
140            Self::BGESmallENV15Q => "bge-small-q",
141            Self::BGEBaseENV15 => "bge-base",
142            Self::BGELargeENV15 => "bge-large",
143            Self::NomicEmbedTextV1 => "nomic-v1",
144            Self::NomicEmbedTextV15 => "nomic-v1.5",
145            Self::NomicEmbedTextV15Q => "nomic-v1.5-q",
146            Self::JinaEmbeddingsV2BaseCode => "jina-code",
147            Self::MultilingualE5Small => "e5-multilingual",
148            Self::MxbaiEmbedLargeV1 => "mxbai-large",
149            Self::ModernBertEmbedLarge => "modernbert-large",
150        }
151    }
152
153    /// List all available models
154    #[allow(dead_code)] // Reserved for model listing command
155    pub fn all() -> &'static [ModelType] {
156        &[
157            Self::AllMiniLML6V2,
158            Self::AllMiniLML6V2Q,
159            Self::AllMiniLML12V2,
160            Self::AllMiniLML12V2Q,
161            Self::ParaphraseMLMiniLML12V2,
162            Self::BGESmallENV15,
163            Self::BGESmallENV15Q,
164            Self::BGEBaseENV15,
165            Self::BGELargeENV15,
166            Self::NomicEmbedTextV1,
167            Self::NomicEmbedTextV15,
168            Self::NomicEmbedTextV15Q,
169            Self::JinaEmbeddingsV2BaseCode,
170            Self::MultilingualE5Small,
171            Self::MxbaiEmbedLargeV1,
172            Self::ModernBertEmbedLarge,
173        ]
174    }
175
176    /// Parse model from string (for CLI)
177    pub fn parse(s: &str) -> Option<Self> {
178        match s.to_lowercase().as_str() {
179            "minilm-l6" | "allminiml6v2" => Some(Self::AllMiniLML6V2),
180            "minilm-l6-q" | "allminiml6v2q" => Some(Self::AllMiniLML6V2Q),
181            "minilm-l12" | "allminiml12v2" => Some(Self::AllMiniLML12V2),
182            "minilm-l12-q" | "allminiml12v2q" => Some(Self::AllMiniLML12V2Q),
183            "paraphrase-minilm" => Some(Self::ParaphraseMLMiniLML12V2),
184            "bge-small" | "bgesmallenv15" => Some(Self::BGESmallENV15),
185            "bge-small-q" | "bgesmallenv15q" => Some(Self::BGESmallENV15Q),
186            "bge-base" | "bgebaseenv15" => Some(Self::BGEBaseENV15),
187            "bge-large" | "bgelargeenv15" => Some(Self::BGELargeENV15),
188            "nomic-v1" | "nomicembedtextv1" => Some(Self::NomicEmbedTextV1),
189            "nomic-v1.5" | "nomicembedtextv15" => Some(Self::NomicEmbedTextV15),
190            "nomic-v1.5-q" | "nomicembedtextv15q" => Some(Self::NomicEmbedTextV15Q),
191            "jina-code" | "jinaembeddingsv2basecode" => Some(Self::JinaEmbeddingsV2BaseCode),
192            "e5-multilingual" | "multilinguale5small" => Some(Self::MultilingualE5Small),
193            "mxbai-large" | "mxbaiembedlargev1" => Some(Self::MxbaiEmbedLargeV1),
194            "modernbert-large" | "modernbertembedlarge" => Some(Self::ModernBertEmbedLarge),
195            _ => None,
196        }
197    }
198}
199
200/// Fast embedding model using fastembed library
201pub struct FastEmbedder {
202    model: TextEmbedding,
203    model_type: ModelType,
204}
205
206impl FastEmbedder {
207    /// Create a new embedder with default model
208    pub fn new() -> Result<Self> {
209        Self::with_model(ModelType::default())
210    }
211
212    /// Create a new embedder with specified model
213    pub fn with_model(model_type: ModelType) -> Result<Self> {
214        Self::with_cache_dir(model_type, None)
215    }
216
217    /// Create a new embedder with specified model and cache directory
218    pub fn with_cache_dir(
219        model_type: ModelType,
220        cache_dir: Option<&std::path::Path>,
221    ) -> Result<Self> {
222        // Set cache directory via environment variable if provided
223        // Note: fastembed library uses FASTEMBED_CACHE_DIR (not FASTEMBED_CACHE_PATH)
224        if let Some(cache_dir) = cache_dir {
225            std::env::set_var(
226                "FASTEMBED_CACHE_DIR",
227                cache_dir.to_string_lossy().to_string(),
228            );
229        }
230
231        // Use CPU execution provider WITH arena allocator for speed.
232        // Arena allocator provides fast memory reuse during inference.
233        let cpu_ep = CPUExecutionProvider::default()
234            .with_arena_allocator(true)
235            .build();
236
237        let model = TextEmbedding::try_new(
238            InitOptions::new(model_type.to_fastembed_model())
239                .with_show_download_progress(false)
240                .with_execution_providers(vec![cpu_ep]),
241        )
242        .map_err(|e| anyhow!("Failed to initialize embedding model: {}", e))?;
243
244        Ok(Self { model, model_type })
245    }
246    /// Embed a batch of texts (processes in mini-batches to avoid OOM)
247    /// Uses adaptive batch size based on model dimensions
248    /// Can be overridden with CODESEARCH_BATCH_SIZE environment variable
249    pub fn embed_batch(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
250        // Check for env var override (tune with CODESEARCH_BATCH_SIZE=N)
251        let batch_size = if let Ok(env_size) = std::env::var("CODESEARCH_BATCH_SIZE") {
252            env_size.parse().unwrap_or(256)
253        } else {
254            // Adaptive batch size: without arena allocator, ONNX frees buffers after each batch
255            // so larger batches are faster without accumulating memory.
256            match self.model_type.dimensions() {
257                d if d <= 384 => 256, // Small models (MiniLM etc.)
258                d if d <= 768 => 128, // Medium models (BGE-base, Jina etc.)
259                _ => 64,              // Large models (BGE-large, MxBai etc.)
260            }
261        };
262        self.embed_batch_chunked(texts, batch_size)
263    }
264
265    /// Embed a batch of texts with configurable mini-batch size
266    pub fn embed_batch_chunked(
267        &mut self,
268        texts: Vec<String>,
269        batch_size: usize,
270    ) -> Result<Vec<Vec<f32>>> {
271        if texts.is_empty() {
272            return Ok(Vec::new());
273        }
274
275        let mut all_embeddings = Vec::with_capacity(texts.len());
276
277        // Process in mini-batches to avoid OOM with large models
278        for chunk in texts.chunks(batch_size) {
279            // Check for CTRL-C between mini-batches so we don't block for minutes
280            if crate::constants::is_shutdown_requested() {
281                return Err(anyhow!("Embedding interrupted by shutdown request"));
282            }
283
284            let text_refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect();
285
286            let embeddings = self
287                .model
288                .embed(text_refs, None)
289                .map_err(|e| anyhow!("Failed to generate embeddings: {}", e))?;
290
291            all_embeddings.extend(embeddings);
292        }
293
294        Ok(all_embeddings)
295    }
296
297    /// Embed a single text
298    pub fn embed_one(&mut self, text: &str) -> Result<Vec<f32>> {
299        let embeddings = self.embed_batch(vec![text.to_string()])?;
300        embeddings
301            .into_iter()
302            .next()
303            .ok_or_else(|| anyhow!("No embedding generated"))
304    }
305
306    /// Get the dimensionality of embeddings
307    pub fn dimensions(&self) -> usize {
308        self.model_type.dimensions()
309    }
310
311    /// Get the model name
312    #[allow(dead_code)] // Reserved for diagnostics
313    pub fn model_name(&self) -> &str {
314        self.model_type.name()
315    }
316
317    /// Get the model type
318    #[allow(dead_code)] // Reserved for diagnostics
319    pub fn model_type(&self) -> ModelType {
320        self.model_type
321    }
322}
323
324impl Default for FastEmbedder {
325    fn default() -> Self {
326        Self::new().expect("Failed to create default embedder")
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_model_type_dimensions() {
336        // 384 dimension models
337        assert_eq!(ModelType::BGESmallENV15.dimensions(), 384);
338        assert_eq!(ModelType::BGESmallENV15Q.dimensions(), 384);
339        assert_eq!(ModelType::AllMiniLML6V2.dimensions(), 384);
340        assert_eq!(ModelType::AllMiniLML6V2Q.dimensions(), 384);
341        assert_eq!(ModelType::AllMiniLML12V2.dimensions(), 384);
342        assert_eq!(ModelType::MultilingualE5Small.dimensions(), 384);
343        // 768 dimension models
344        assert_eq!(ModelType::BGEBaseENV15.dimensions(), 768);
345        assert_eq!(ModelType::NomicEmbedTextV1.dimensions(), 768);
346        assert_eq!(ModelType::NomicEmbedTextV15.dimensions(), 768);
347        assert_eq!(ModelType::JinaEmbeddingsV2BaseCode.dimensions(), 768);
348        // 1024 dimension models
349        assert_eq!(ModelType::BGELargeENV15.dimensions(), 1024);
350        assert_eq!(ModelType::MxbaiEmbedLargeV1.dimensions(), 1024);
351        assert_eq!(ModelType::ModernBertEmbedLarge.dimensions(), 1024);
352    }
353
354    #[test]
355    fn test_model_type_names() {
356        assert_eq!(ModelType::BGESmallENV15.name(), "BAAI/bge-small-en-v1.5");
357        assert_eq!(
358            ModelType::AllMiniLML6V2.name(),
359            "sentence-transformers/all-MiniLM-L6-v2"
360        );
361        assert_eq!(
362            ModelType::JinaEmbeddingsV2BaseCode.name(),
363            "jinaai/jina-embeddings-v2-base-code"
364        );
365    }
366
367    #[test]
368    fn test_default_model() {
369        let model = ModelType::default();
370        assert_eq!(model, ModelType::AllMiniLML6V2Q);
371        assert_eq!(model.dimensions(), 384);
372    }
373
374    #[test]
375    fn test_all_models() {
376        let all = ModelType::all();
377        assert_eq!(all.len(), 16);
378    }
379
380    #[test]
381    fn test_parse() {
382        assert_eq!(
383            ModelType::parse("minilm-l6"),
384            Some(ModelType::AllMiniLML6V2)
385        );
386        assert_eq!(
387            ModelType::parse("minilm-l6-q"),
388            Some(ModelType::AllMiniLML6V2Q)
389        );
390        assert_eq!(
391            ModelType::parse("minilm-l12"),
392            Some(ModelType::AllMiniLML12V2)
393        );
394        assert_eq!(
395            ModelType::parse("minilm-l12-q"),
396            Some(ModelType::AllMiniLML12V2Q)
397        );
398        assert_eq!(
399            ModelType::parse("paraphrase-minilm"),
400            Some(ModelType::ParaphraseMLMiniLML12V2)
401        );
402        assert_eq!(
403            ModelType::parse("bge-small"),
404            Some(ModelType::BGESmallENV15)
405        );
406        assert_eq!(
407            ModelType::parse("bge-small-q"),
408            Some(ModelType::BGESmallENV15Q)
409        );
410        assert_eq!(ModelType::parse("bge-base"), Some(ModelType::BGEBaseENV15));
411        assert_eq!(
412            ModelType::parse("nomic-v1"),
413            Some(ModelType::NomicEmbedTextV1)
414        );
415        assert_eq!(
416            ModelType::parse("nomic-v1.5"),
417            Some(ModelType::NomicEmbedTextV15)
418        );
419        assert_eq!(
420            ModelType::parse("nomic-v1.5-q"),
421            Some(ModelType::NomicEmbedTextV15Q)
422        );
423        assert_eq!(
424            ModelType::parse("jina-code"),
425            Some(ModelType::JinaEmbeddingsV2BaseCode)
426        );
427        assert_eq!(ModelType::parse("invalid"), None);
428    }
429
430    #[test]
431    fn test_is_quantized() {
432        assert!(ModelType::AllMiniLML6V2Q.is_quantized());
433        assert!(ModelType::BGESmallENV15Q.is_quantized());
434        assert!(!ModelType::BGESmallENV15.is_quantized());
435        assert!(!ModelType::JinaEmbeddingsV2BaseCode.is_quantized());
436    }
437
438    #[test]
439    #[ignore] // Requires downloading model
440    fn test_embedder_creation() {
441        let embedder = FastEmbedder::new();
442        assert!(embedder.is_ok());
443
444        let embedder = embedder.unwrap();
445        assert_eq!(embedder.dimensions(), 384);
446    }
447
448    #[test]
449    #[ignore] // Requires model
450    fn test_embed_single_text() {
451        let mut embedder = FastEmbedder::new().unwrap();
452        let embedding = embedder.embed_one("Hello, world!").unwrap();
453
454        assert_eq!(embedding.len(), 384);
455        // Check embedding is normalized (roughly unit length)
456        let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
457        assert!((magnitude - 1.0).abs() < 0.1);
458    }
459
460    #[test]
461    #[ignore] // Requires model
462    fn test_embed_batch() {
463        let mut embedder = FastEmbedder::new().unwrap();
464        let texts = vec![
465            "Hello, world!".to_string(),
466            "Rust is awesome".to_string(),
467            "Code search with AI".to_string(),
468        ];
469
470        let embeddings = embedder.embed_batch(texts).unwrap();
471
472        assert_eq!(embeddings.len(), 3);
473        for embedding in embeddings {
474            assert_eq!(embedding.len(), 384);
475        }
476    }
477
478    #[test]
479    #[ignore] // Requires model
480    fn test_semantic_similarity() {
481        let mut embedder = FastEmbedder::new().unwrap();
482
483        let text1 = "The quick brown fox jumps over the lazy dog";
484        let text2 = "A fast auburn fox leaps over a sleepy canine";
485        let text3 = "Python is a programming language";
486
487        let emb1 = embedder.embed_one(text1).unwrap();
488        let emb2 = embedder.embed_one(text2).unwrap();
489        let emb3 = embedder.embed_one(text3).unwrap();
490
491        // Cosine similarity
492        let sim_1_2 = cosine_similarity(&emb1, &emb2);
493        let sim_1_3 = cosine_similarity(&emb1, &emb3);
494
495        // Similar texts should have higher similarity
496        assert!(sim_1_2 > sim_1_3);
497        assert!(sim_1_2 > 0.7); // Should be quite similar
498    }
499
500    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
501        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
502        let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
503        let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
504        dot / (mag_a * mag_b)
505    }
506}