Skip to main content

inference/
models.rs

1//! Model configurations for supported embedding models.
2//!
3//! Supported models:
4//! - **BGE-large** (BAAI/bge-large-en-v1.5): Highest quality, 1024 dimensions (default)
5//! - **MiniLM** (all-MiniLM-L6-v2): Fast, 384 dimensions, good for general use
6//! - **BGE-small** (BAAI/bge-small-en-v1.5): Balanced, 384 dimensions, high quality
7//! - **E5-small** (intfloat/e5-small-v2): Quality-focused, 384 dimensions
8//! - **ModernBERT-embed-base** (nomic-ai/modernbert-embed-base): 768d, 8192 tokens, Flash Attn
9
10use serde::{Deserialize, Serialize};
11
12/// Supported embedding models.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
14#[serde(rename_all = "kebab-case")]
15pub enum EmbeddingModel {
16    /// BAAI/bge-large-en-v1.5 - Highest quality, 1024 dimensions (default)
17    /// - Dimensions: 1024
18    /// - Max tokens: 512
19    /// - Speed: Slower than small models, but highest quality
20    #[default]
21    BgeLarge,
22
23    /// all-MiniLM-L6-v2 - Fast and efficient, good for general use
24    /// - Dimensions: 384
25    /// - Max tokens: 256
26    /// - Speed: Fastest
27    MiniLM,
28
29    /// BAAI/bge-small-en-v1.5 - Balanced quality and speed
30    /// - Dimensions: 384
31    /// - Max tokens: 512
32    /// - Speed: Medium
33    BgeSmall,
34
35    /// intfloat/e5-small-v2 - Higher quality embeddings
36    /// - Dimensions: 384
37    /// - Max tokens: 512
38    /// - Speed: Medium
39    E5Small,
40
41    /// nomic-ai/modernbert-embed-base — modern transformer with 8192-token context
42    /// - Dimensions: 768 (native Matryoshka: 768/512/256/128/64)
43    /// - Max tokens: 8192
44    /// - Speed: 25% faster than BGE-Large on same hardware
45    /// - Flash Attention support for long sequences
46    /// - Env var: DAKERA_MODEL=modernbert-embed-base
47    ModernBertEmbedBase,
48}
49
50impl EmbeddingModel {
51    /// Get the HuggingFace model ID.
52    pub fn model_id(&self) -> &'static str {
53        match self {
54            EmbeddingModel::BgeLarge => "BAAI/bge-large-en-v1.5",
55            EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
56            EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
57            EmbeddingModel::E5Small => "intfloat/e5-small-v2",
58            EmbeddingModel::ModernBertEmbedBase => "nomic-ai/modernbert-embed-base",
59        }
60    }
61
62    /// Get the embedding dimension for this model.
63    pub fn dimension(&self) -> usize {
64        match self {
65            EmbeddingModel::BgeLarge => 1024,
66            EmbeddingModel::MiniLM => 384,
67            EmbeddingModel::BgeSmall => 384,
68            EmbeddingModel::E5Small => 384,
69            EmbeddingModel::ModernBertEmbedBase => 768,
70        }
71    }
72
73    /// Get the maximum sequence length (in tokens).
74    pub fn max_seq_length(&self) -> usize {
75        match self {
76            EmbeddingModel::BgeLarge => 512,
77            EmbeddingModel::MiniLM => 256,
78            EmbeddingModel::BgeSmall => 512,
79            EmbeddingModel::E5Small => 512,
80            EmbeddingModel::ModernBertEmbedBase => 8192,
81        }
82    }
83
84    /// Get the Matryoshka-supported dimensions for this model (smallest to largest).
85    ///
86    /// Returns `None` for models that do not support MRL truncation.
87    pub fn mrl_dimensions(&self) -> Option<&'static [usize]> {
88        match self {
89            EmbeddingModel::ModernBertEmbedBase => Some(&[64, 128, 256, 512, 768]),
90            _ => None,
91        }
92    }
93
94    /// Get the safetensors model filename (for Candle backend).
95    pub fn safetensors_filename(&self) -> &'static str {
96        "model.safetensors"
97    }
98
99    /// Get the config filename (for Candle/GGUF backends).
100    pub fn config_filename(&self) -> &'static str {
101        "config.json"
102    }
103
104    /// Get the HuggingFace repo hosting the Model2Vec distilled vocabulary matrix.
105    pub fn model2vec_repo_id(&self) -> &'static str {
106        match self {
107            EmbeddingModel::BgeLarge => "dakera-ai/bge-large-model2vec-256d",
108            EmbeddingModel::ModernBertEmbedBase => "dakera-ai/modernbert-model2vec-256d",
109            _ => "dakera-ai/bge-small-model2vec-256d",
110        }
111    }
112
113    /// Get the HuggingFace repo hosting the GGUF quantised models for this embedding.
114    pub fn gguf_repo_id(&self) -> &'static str {
115        match self {
116            EmbeddingModel::BgeLarge => "dakera-ai/bge-large-gguf",
117            EmbeddingModel::ModernBertEmbedBase => "dakera-ai/modernbert-gguf",
118            _ => "dakera-ai/bge-small-gguf",
119        }
120    }
121
122    /// Get the query prefix for models that require it.
123    /// Some models like E5 require a prefix for queries vs documents.
124    pub fn query_prefix(&self) -> Option<&'static str> {
125        match self {
126            EmbeddingModel::BgeLarge => None,
127            EmbeddingModel::MiniLM => None,
128            EmbeddingModel::BgeSmall => None,
129            EmbeddingModel::E5Small => Some("query: "),
130            EmbeddingModel::ModernBertEmbedBase => None,
131        }
132    }
133
134    /// Get the document/passage prefix for models that require it.
135    pub fn document_prefix(&self) -> Option<&'static str> {
136        match self {
137            EmbeddingModel::BgeLarge => None,
138            EmbeddingModel::MiniLM => None,
139            EmbeddingModel::BgeSmall => None,
140            EmbeddingModel::E5Small => Some("passage: "),
141            EmbeddingModel::ModernBertEmbedBase => None,
142        }
143    }
144
145    /// Whether this model uses mean pooling (vs CLS token).
146    pub fn use_mean_pooling(&self) -> bool {
147        match self {
148            EmbeddingModel::BgeLarge => true,
149            EmbeddingModel::MiniLM => true,
150            EmbeddingModel::BgeSmall => true,
151            EmbeddingModel::E5Small => true,
152            EmbeddingModel::ModernBertEmbedBase => true,
153        }
154    }
155
156    /// Whether embeddings should be normalized.
157    pub fn normalize_embeddings(&self) -> bool {
158        true // All supported models use normalized embeddings
159    }
160
161    /// Get approximate tokens per second on CPU (for estimation).
162    pub fn tokens_per_second_cpu(&self) -> usize {
163        match self {
164            EmbeddingModel::BgeLarge => 1000,
165            EmbeddingModel::MiniLM => 5000,
166            EmbeddingModel::BgeSmall => 3000,
167            EmbeddingModel::E5Small => 3000,
168            EmbeddingModel::ModernBertEmbedBase => 1250, // ~25% faster than BGE-Large
169        }
170    }
171
172    /// Get the HuggingFace repository ID hosting the ONNX INT8 model for this embedding model.
173    ///
174    /// These are Xenova-hosted Optimum ONNX exports — quantized INT8, pre-built, no conversion
175    /// needed. BgeLarge: ~130 MB, MiniLM: 23 MB, BGE-small: 35 MB, E5-small: 35 MB.
176    pub fn onnx_repo_id(&self) -> &'static str {
177        match self {
178            EmbeddingModel::BgeLarge => "Xenova/bge-large-en-v1.5",
179            EmbeddingModel::MiniLM => "Xenova/all-MiniLM-L6-v2",
180            EmbeddingModel::BgeSmall => "Xenova/bge-small-en-v1.5",
181            EmbeddingModel::E5Small => "Xenova/e5-small-v2",
182            EmbeddingModel::ModernBertEmbedBase => "Xenova/modernbert-embed-base",
183        }
184    }
185
186    /// Get the ONNX model filename for CPU inference (INT8 quantized).
187    pub fn onnx_filename(&self) -> &'static str {
188        "onnx/model_quantized.onnx"
189    }
190
191    /// Get the ONNX model filename for GPU (CUDA EP) inference.
192    ///
193    /// Returns the FP32 model (`onnx/model.onnx`) instead of INT8. The INT8 quantized
194    /// model has 336 Memcpy CPU↔GPU round-trips caused by ORT falling back to CPU EP
195    /// for every unsupported INT8 op — making CUDA 24× slower than pure CPU inference.
196    /// The FP32 model contains no unsupported ops and runs entirely on-device.
197    pub fn onnx_filename_gpu(&self) -> &'static str {
198        "onnx/model.onnx"
199    }
200
201    /// List all available models.
202    pub fn all() -> &'static [EmbeddingModel] {
203        &[
204            EmbeddingModel::BgeLarge,
205            EmbeddingModel::MiniLM,
206            EmbeddingModel::BgeSmall,
207            EmbeddingModel::E5Small,
208            EmbeddingModel::ModernBertEmbedBase,
209        ]
210    }
211
212    /// Parse model from string (case-insensitive).
213    pub fn parse(s: &str) -> Option<Self> {
214        match s.to_lowercase().as_str() {
215            "bge-large" | "bge-large-en" | "bge-large-en-v1.5" => Some(EmbeddingModel::BgeLarge),
216            "minilm" | "all-minilm-l6-v2" | "mini-lm" => Some(EmbeddingModel::MiniLM),
217            "bge-small" | "bge" | "bge-small-en" => Some(EmbeddingModel::BgeSmall),
218            "e5-small" | "e5" | "e5-small-v2" => Some(EmbeddingModel::E5Small),
219            "modernbert-embed-base" | "modernbert" | "modern-bert" => {
220                Some(EmbeddingModel::ModernBertEmbedBase)
221            }
222            _ => None,
223        }
224    }
225
226    /// Get the active model from `DAKERA_MODEL` env var, defaulting to `BgeLarge`.
227    pub fn from_env() -> Self {
228        std::env::var("DAKERA_MODEL")
229            .ok()
230            .as_deref()
231            .and_then(Self::parse)
232            .unwrap_or_default()
233    }
234}
235
236impl std::fmt::Display for EmbeddingModel {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        match self {
239            EmbeddingModel::BgeLarge => write!(f, "bge-large-en-v1.5"),
240            EmbeddingModel::MiniLM => write!(f, "all-MiniLM-L6-v2"),
241            EmbeddingModel::BgeSmall => write!(f, "bge-small-en-v1.5"),
242            EmbeddingModel::E5Small => write!(f, "e5-small-v2"),
243            EmbeddingModel::ModernBertEmbedBase => write!(f, "modernbert-embed-base"),
244        }
245    }
246}
247
248/// Configuration for model loading and inference.
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct ModelConfig {
251    /// The embedding model to use.
252    pub model: EmbeddingModel,
253
254    /// Custom cache directory for model files.
255    /// If None, uses HuggingFace default cache.
256    pub cache_dir: Option<String>,
257
258    /// Maximum batch size for inference.
259    pub max_batch_size: usize,
260
261    /// Whether to use GPU acceleration if available.
262    pub use_gpu: bool,
263
264    /// Number of threads for CPU inference.
265    pub num_threads: Option<usize>,
266
267    /// Number of parallel ONNX sessions in the session pool.
268    ///
269    /// Each session holds its own ORT context. Pool members serve batches
270    /// concurrently via `spawn_blocking`, eliminating Mutex head-of-line
271    /// blocking when multiple callers embed text simultaneously.
272    /// Defaults to 4; override with `DAKERA_ONNX_POOL_SIZE` env var at startup.
273    pub session_pool_size: usize,
274
275    /// Force a specific backend kind, bypassing `DAKERA_BACKEND` env var.
276    /// Used by `TieredEngine` to build the fast (static) backend from the same config.
277    #[serde(skip)]
278    pub backend_override: Option<crate::backend::BackendKind>,
279}
280
281impl Default for ModelConfig {
282    fn default() -> Self {
283        // DAK-5746: pool=4 restored. PR#488 regressed LME ingest: pool=1 serializes all
284        // ONNX calls onto session[0]. With 4 concurrent HTTP requests × 7 sub-batches each,
285        // pool=1 produces ~28 serial ONNX calls vs pool=4's 7 parallel chains — ~4× throughput
286        // regression measured at 2761ms/50-text batch on prod. OOM root causes (unbounded HNSW,
287        // RocksDB cache) fixed by PR#488 other changes; pool=4 × BGE-Large INT8 ≈ 1.6GB fits
288        // comfortably on the 8GB server. pool_size 4→1 downgrade was the wrong OOM fix.
289        let pool_size = std::env::var("DAKERA_ONNX_POOL_SIZE")
290            .ok()
291            .and_then(|v| v.parse::<usize>().ok())
292            .filter(|&n| n >= 1)
293            .unwrap_or(4);
294        // DAK-5716: no length-sorting (PR#476 proved sorted batching regresses INT8
295        // quantization quality). DAK-5953: default raised 8→32 — amortises per-call ONNX
296        // overhead 4× with no quality impact (size-only change, not order). Bench sets
297        // DAKERA_ONNX_BATCH_SIZE=128; 32 is a safe default for CPU-only deployments.
298        let max_batch_size = std::env::var("DAKERA_ONNX_BATCH_SIZE")
299            .ok()
300            .and_then(|v| v.parse::<usize>().ok())
301            .filter(|&n| n >= 1)
302            .unwrap_or(32);
303        Self {
304            model: EmbeddingModel::default(),
305            cache_dir: None,
306            max_batch_size,
307            use_gpu: false,
308            num_threads: None,
309            session_pool_size: pool_size,
310            backend_override: None,
311        }
312    }
313}
314
315impl ModelConfig {
316    /// Create a new config with the specified model.
317    pub fn new(model: EmbeddingModel) -> Self {
318        Self {
319            model,
320            ..Default::default()
321        }
322    }
323
324    /// Set the cache directory.
325    pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
326        self.cache_dir = Some(dir.into());
327        self
328    }
329
330    /// Set the maximum batch size.
331    pub fn with_max_batch_size(mut self, size: usize) -> Self {
332        self.max_batch_size = size;
333        self
334    }
335
336    /// Enable GPU acceleration.
337    pub fn with_gpu(mut self, use_gpu: bool) -> Self {
338        self.use_gpu = use_gpu;
339        self
340    }
341
342    /// Set the number of CPU threads.
343    pub fn with_num_threads(mut self, threads: usize) -> Self {
344        self.num_threads = Some(threads);
345        self
346    }
347
348    /// Set the number of parallel ONNX sessions in the pool.
349    pub fn with_session_pool_size(mut self, size: usize) -> Self {
350        self.session_pool_size = size.max(1);
351        self
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_model_ids() {
361        assert_eq!(
362            EmbeddingModel::BgeLarge.model_id(),
363            "BAAI/bge-large-en-v1.5"
364        );
365        assert_eq!(
366            EmbeddingModel::MiniLM.model_id(),
367            "sentence-transformers/all-MiniLM-L6-v2"
368        );
369        assert_eq!(
370            EmbeddingModel::BgeSmall.model_id(),
371            "BAAI/bge-small-en-v1.5"
372        );
373        assert_eq!(EmbeddingModel::E5Small.model_id(), "intfloat/e5-small-v2");
374    }
375
376    #[test]
377    fn test_dimensions() {
378        assert_eq!(EmbeddingModel::BgeLarge.dimension(), 1024);
379        assert_eq!(EmbeddingModel::MiniLM.dimension(), 384);
380        assert_eq!(EmbeddingModel::BgeSmall.dimension(), 384);
381        assert_eq!(EmbeddingModel::E5Small.dimension(), 384);
382        // Verify each model reports its own dimension
383        for model in EmbeddingModel::all() {
384            assert!(model.dimension() > 0);
385        }
386    }
387
388    #[test]
389    fn test_from_str() {
390        assert_eq!(
391            EmbeddingModel::parse("bge-large"),
392            Some(EmbeddingModel::BgeLarge)
393        );
394        assert_eq!(
395            EmbeddingModel::parse("minilm"),
396            Some(EmbeddingModel::MiniLM)
397        );
398        assert_eq!(
399            EmbeddingModel::parse("BGE-SMALL"),
400            Some(EmbeddingModel::BgeSmall)
401        );
402        assert_eq!(EmbeddingModel::parse("e5"), Some(EmbeddingModel::E5Small));
403        assert_eq!(EmbeddingModel::parse("unknown"), None);
404    }
405
406    #[test]
407    fn test_e5_prefixes() {
408        assert_eq!(EmbeddingModel::E5Small.query_prefix(), Some("query: "));
409        assert_eq!(EmbeddingModel::E5Small.document_prefix(), Some("passage: "));
410        assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
411    }
412
413    #[test]
414    fn test_onnx_filenames() {
415        // INT8 model for CPU — all models use the same quantized file
416        for model in EmbeddingModel::all() {
417            assert_eq!(model.onnx_filename(), "onnx/model_quantized.onnx");
418        }
419        // FP32 model for GPU — no Memcpy fallback ops
420        for model in EmbeddingModel::all() {
421            assert_eq!(model.onnx_filename_gpu(), "onnx/model.onnx");
422        }
423        // Sanity: GPU and CPU filenames are distinct
424        assert_ne!(
425            EmbeddingModel::BgeLarge.onnx_filename(),
426            EmbeddingModel::BgeLarge.onnx_filename_gpu()
427        );
428    }
429
430    // ── ModernBERT-specific tests ────────────────────────────────────────────
431
432    #[test]
433    fn test_modernbert_model_id() {
434        assert_eq!(
435            EmbeddingModel::ModernBertEmbedBase.model_id(),
436            "nomic-ai/modernbert-embed-base"
437        );
438    }
439
440    #[test]
441    fn test_modernbert_dimension_768() {
442        assert_eq!(EmbeddingModel::ModernBertEmbedBase.dimension(), 768);
443    }
444
445    #[test]
446    fn test_modernbert_max_tokens_8192() {
447        assert_eq!(EmbeddingModel::ModernBertEmbedBase.max_seq_length(), 8192);
448    }
449
450    #[test]
451    fn test_modernbert_mrl_dimensions() {
452        let dims = EmbeddingModel::ModernBertEmbedBase.mrl_dimensions();
453        assert!(dims.is_some());
454        let dims = dims.unwrap();
455        assert!(dims.contains(&256));
456        assert!(dims.contains(&768));
457    }
458
459    #[test]
460    fn test_modernbert_no_prefix() {
461        assert!(EmbeddingModel::ModernBertEmbedBase.query_prefix().is_none());
462        assert!(EmbeddingModel::ModernBertEmbedBase
463            .document_prefix()
464            .is_none());
465    }
466
467    #[test]
468    fn test_modernbert_parse() {
469        assert_eq!(
470            EmbeddingModel::parse("modernbert-embed-base"),
471            Some(EmbeddingModel::ModernBertEmbedBase)
472        );
473        assert_eq!(
474            EmbeddingModel::parse("modernbert"),
475            Some(EmbeddingModel::ModernBertEmbedBase)
476        );
477        assert_eq!(
478            EmbeddingModel::parse("MODERNBERT"),
479            Some(EmbeddingModel::ModernBertEmbedBase)
480        );
481    }
482
483    #[test]
484    fn test_modernbert_display() {
485        assert_eq!(
486            EmbeddingModel::ModernBertEmbedBase.to_string(),
487            "modernbert-embed-base"
488        );
489    }
490
491    #[test]
492    fn test_bge_large_no_mrl_dimensions() {
493        assert!(EmbeddingModel::BgeLarge.mrl_dimensions().is_none());
494    }
495
496    #[test]
497    fn test_safetensors_config_filenames() {
498        for model in EmbeddingModel::all() {
499            assert_eq!(model.safetensors_filename(), "model.safetensors");
500            assert_eq!(model.config_filename(), "config.json");
501        }
502    }
503
504    #[test]
505    fn test_model2vec_repo_id_bge_large() {
506        assert_eq!(
507            EmbeddingModel::BgeLarge.model2vec_repo_id(),
508            "dakera-ai/bge-large-model2vec-256d"
509        );
510    }
511
512    #[test]
513    fn test_gguf_repo_id_bge_large() {
514        assert_eq!(
515            EmbeddingModel::BgeLarge.gguf_repo_id(),
516            "dakera-ai/bge-large-gguf"
517        );
518    }
519}