Skip to main content

inference/
models.rs

1//! Model configurations for supported embedding models.
2//!
3//! Supported models:
4//! - **BGE-large** (BAAI/bge-large-en-v1.5): Highest quality, 1024 dimensions (default)
5//! - **MiniLM** (all-MiniLM-L6-v2): Fast, 384 dimensions, good for general use
6//! - **BGE-small** (BAAI/bge-small-en-v1.5): Balanced, 384 dimensions, high quality
7//! - **E5-small** (intfloat/e5-small-v2): Quality-focused, 384 dimensions
8
9use serde::{Deserialize, Serialize};
10
11/// Supported embedding models.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
13#[serde(rename_all = "kebab-case")]
14pub enum EmbeddingModel {
15    /// BAAI/bge-large-en-v1.5 - Highest quality, 1024 dimensions (default)
16    /// - Dimensions: 1024
17    /// - Max tokens: 512
18    /// - Speed: Slower than small models, but highest quality
19    #[default]
20    BgeLarge,
21
22    /// all-MiniLM-L6-v2 - Fast and efficient, good for general use
23    /// - Dimensions: 384
24    /// - Max tokens: 256
25    /// - Speed: Fastest
26    MiniLM,
27
28    /// BAAI/bge-small-en-v1.5 - Balanced quality and speed
29    /// - Dimensions: 384
30    /// - Max tokens: 512
31    /// - Speed: Medium
32    BgeSmall,
33
34    /// intfloat/e5-small-v2 - Higher quality embeddings
35    /// - Dimensions: 384
36    /// - Max tokens: 512
37    /// - Speed: Medium
38    E5Small,
39}
40
41impl EmbeddingModel {
42    /// Get the HuggingFace model ID.
43    pub fn model_id(&self) -> &'static str {
44        match self {
45            EmbeddingModel::BgeLarge => "BAAI/bge-large-en-v1.5",
46            EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
47            EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
48            EmbeddingModel::E5Small => "intfloat/e5-small-v2",
49        }
50    }
51
52    /// Get the embedding dimension for this model.
53    pub fn dimension(&self) -> usize {
54        match self {
55            EmbeddingModel::BgeLarge => 1024,
56            EmbeddingModel::MiniLM => 384,
57            EmbeddingModel::BgeSmall => 384,
58            EmbeddingModel::E5Small => 384,
59        }
60    }
61
62    /// Get the maximum sequence length (in tokens).
63    pub fn max_seq_length(&self) -> usize {
64        match self {
65            EmbeddingModel::BgeLarge => 512,
66            EmbeddingModel::MiniLM => 256,
67            EmbeddingModel::BgeSmall => 512,
68            EmbeddingModel::E5Small => 512,
69        }
70    }
71
72    /// Get the query prefix for models that require it.
73    /// Some models like E5 require a prefix for queries vs documents.
74    pub fn query_prefix(&self) -> Option<&'static str> {
75        match self {
76            EmbeddingModel::BgeLarge => None,
77            EmbeddingModel::MiniLM => None,
78            EmbeddingModel::BgeSmall => None,
79            EmbeddingModel::E5Small => Some("query: "),
80        }
81    }
82
83    /// Get the document/passage prefix for models that require it.
84    pub fn document_prefix(&self) -> Option<&'static str> {
85        match self {
86            EmbeddingModel::BgeLarge => None,
87            EmbeddingModel::MiniLM => None,
88            EmbeddingModel::BgeSmall => None,
89            EmbeddingModel::E5Small => Some("passage: "),
90        }
91    }
92
93    /// Whether this model uses mean pooling (vs CLS token).
94    pub fn use_mean_pooling(&self) -> bool {
95        match self {
96            EmbeddingModel::BgeLarge => true,
97            EmbeddingModel::MiniLM => true,
98            EmbeddingModel::BgeSmall => true,
99            EmbeddingModel::E5Small => true,
100        }
101    }
102
103    /// Whether embeddings should be normalized.
104    pub fn normalize_embeddings(&self) -> bool {
105        true // All supported models use normalized embeddings
106    }
107
108    /// Get approximate tokens per second on CPU (for estimation).
109    pub fn tokens_per_second_cpu(&self) -> usize {
110        match self {
111            EmbeddingModel::BgeLarge => 1000,
112            EmbeddingModel::MiniLM => 5000,
113            EmbeddingModel::BgeSmall => 3000,
114            EmbeddingModel::E5Small => 3000,
115        }
116    }
117
118    /// Get the HuggingFace repository ID hosting the ONNX INT8 model for this embedding model.
119    ///
120    /// These are Xenova-hosted Optimum ONNX exports — quantized INT8, pre-built, no conversion
121    /// needed. BgeLarge: ~130 MB, MiniLM: 23 MB, BGE-small: 35 MB, E5-small: 35 MB.
122    pub fn onnx_repo_id(&self) -> &'static str {
123        match self {
124            EmbeddingModel::BgeLarge => "Xenova/bge-large-en-v1.5",
125            EmbeddingModel::MiniLM => "Xenova/all-MiniLM-L6-v2",
126            EmbeddingModel::BgeSmall => "Xenova/bge-small-en-v1.5",
127            EmbeddingModel::E5Small => "Xenova/e5-small-v2",
128        }
129    }
130
131    /// Get the ONNX model filename (path within the repository).
132    pub fn onnx_filename(&self) -> &'static str {
133        "onnx/model_quantized.onnx"
134    }
135
136    /// List all available models.
137    pub fn all() -> &'static [EmbeddingModel] {
138        &[
139            EmbeddingModel::BgeLarge,
140            EmbeddingModel::MiniLM,
141            EmbeddingModel::BgeSmall,
142            EmbeddingModel::E5Small,
143        ]
144    }
145
146    /// Parse model from string (case-insensitive).
147    pub fn parse(s: &str) -> Option<Self> {
148        match s.to_lowercase().as_str() {
149            "bge-large" | "bge-large-en" | "bge-large-en-v1.5" => Some(EmbeddingModel::BgeLarge),
150            "minilm" | "all-minilm-l6-v2" | "mini-lm" => Some(EmbeddingModel::MiniLM),
151            "bge-small" | "bge" | "bge-small-en" => Some(EmbeddingModel::BgeSmall),
152            "e5-small" | "e5" | "e5-small-v2" => Some(EmbeddingModel::E5Small),
153            _ => None,
154        }
155    }
156}
157
158impl std::fmt::Display for EmbeddingModel {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        match self {
161            EmbeddingModel::BgeLarge => write!(f, "bge-large-en-v1.5"),
162            EmbeddingModel::MiniLM => write!(f, "all-MiniLM-L6-v2"),
163            EmbeddingModel::BgeSmall => write!(f, "bge-small-en-v1.5"),
164            EmbeddingModel::E5Small => write!(f, "e5-small-v2"),
165        }
166    }
167}
168
169/// Configuration for model loading and inference.
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct ModelConfig {
172    /// The embedding model to use.
173    pub model: EmbeddingModel,
174
175    /// Custom cache directory for model files.
176    /// If None, uses HuggingFace default cache.
177    pub cache_dir: Option<String>,
178
179    /// Maximum batch size for inference.
180    pub max_batch_size: usize,
181
182    /// Whether to use GPU acceleration if available.
183    pub use_gpu: bool,
184
185    /// Number of threads for CPU inference.
186    pub num_threads: Option<usize>,
187
188    /// Number of parallel ONNX sessions in the session pool.
189    ///
190    /// Each session holds its own ORT context. Pool members serve batches
191    /// concurrently via `spawn_blocking`, eliminating Mutex head-of-line
192    /// blocking when multiple callers embed text simultaneously.
193    /// Defaults to 4; override with `DAKERA_ONNX_POOL_SIZE` env var at startup.
194    pub session_pool_size: usize,
195}
196
197impl Default for ModelConfig {
198    fn default() -> Self {
199        let pool_size = std::env::var("DAKERA_ONNX_POOL_SIZE")
200            .ok()
201            .and_then(|v| v.parse::<usize>().ok())
202            .filter(|&n| n >= 1)
203            .unwrap_or(4);
204        Self {
205            model: EmbeddingModel::default(),
206            cache_dir: None,
207            max_batch_size: 32,
208            use_gpu: false,
209            num_threads: None,
210            session_pool_size: pool_size,
211        }
212    }
213}
214
215impl ModelConfig {
216    /// Create a new config with the specified model.
217    pub fn new(model: EmbeddingModel) -> Self {
218        Self {
219            model,
220            ..Default::default()
221        }
222    }
223
224    /// Set the cache directory.
225    pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
226        self.cache_dir = Some(dir.into());
227        self
228    }
229
230    /// Set the maximum batch size.
231    pub fn with_max_batch_size(mut self, size: usize) -> Self {
232        self.max_batch_size = size;
233        self
234    }
235
236    /// Enable GPU acceleration.
237    pub fn with_gpu(mut self, use_gpu: bool) -> Self {
238        self.use_gpu = use_gpu;
239        self
240    }
241
242    /// Set the number of CPU threads.
243    pub fn with_num_threads(mut self, threads: usize) -> Self {
244        self.num_threads = Some(threads);
245        self
246    }
247
248    /// Set the number of parallel ONNX sessions in the pool.
249    pub fn with_session_pool_size(mut self, size: usize) -> Self {
250        self.session_pool_size = size.max(1);
251        self
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_model_ids() {
261        assert_eq!(
262            EmbeddingModel::BgeLarge.model_id(),
263            "BAAI/bge-large-en-v1.5"
264        );
265        assert_eq!(
266            EmbeddingModel::MiniLM.model_id(),
267            "sentence-transformers/all-MiniLM-L6-v2"
268        );
269        assert_eq!(
270            EmbeddingModel::BgeSmall.model_id(),
271            "BAAI/bge-small-en-v1.5"
272        );
273        assert_eq!(EmbeddingModel::E5Small.model_id(), "intfloat/e5-small-v2");
274    }
275
276    #[test]
277    fn test_dimensions() {
278        assert_eq!(EmbeddingModel::BgeLarge.dimension(), 1024);
279        assert_eq!(EmbeddingModel::MiniLM.dimension(), 384);
280        assert_eq!(EmbeddingModel::BgeSmall.dimension(), 384);
281        assert_eq!(EmbeddingModel::E5Small.dimension(), 384);
282        // Verify each model reports its own dimension
283        for model in EmbeddingModel::all() {
284            assert!(model.dimension() > 0);
285        }
286    }
287
288    #[test]
289    fn test_from_str() {
290        assert_eq!(
291            EmbeddingModel::parse("bge-large"),
292            Some(EmbeddingModel::BgeLarge)
293        );
294        assert_eq!(
295            EmbeddingModel::parse("minilm"),
296            Some(EmbeddingModel::MiniLM)
297        );
298        assert_eq!(
299            EmbeddingModel::parse("BGE-SMALL"),
300            Some(EmbeddingModel::BgeSmall)
301        );
302        assert_eq!(EmbeddingModel::parse("e5"), Some(EmbeddingModel::E5Small));
303        assert_eq!(EmbeddingModel::parse("unknown"), None);
304    }
305
306    #[test]
307    fn test_e5_prefixes() {
308        assert_eq!(EmbeddingModel::E5Small.query_prefix(), Some("query: "));
309        assert_eq!(EmbeddingModel::E5Small.document_prefix(), Some("passage: "));
310        assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
311    }
312}