Skip to main content

inference/
models.rs

1//! Model configurations for supported embedding models.
2//!
3//! Supported models:
4//! - **MiniLM** (all-MiniLM-L6-v2): Fast, 384 dimensions, good for general use
5//! - **BGE-small** (BAAI/bge-small-en-v1.5): Balanced, 384 dimensions, high quality
6//! - **E5-small** (intfloat/e5-small-v2): Quality-focused, 384 dimensions
7
8use serde::{Deserialize, Serialize};
9
10/// Supported embedding models.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
12#[serde(rename_all = "kebab-case")]
13pub enum EmbeddingModel {
14    /// all-MiniLM-L6-v2 - Fast and efficient, good for general use
15    /// - Dimensions: 384
16    /// - Max tokens: 256
17    /// - Speed: Fastest
18    #[default]
19    MiniLM,
20
21    /// BAAI/bge-small-en-v1.5 - Balanced quality and speed
22    /// - Dimensions: 384
23    /// - Max tokens: 512
24    /// - Speed: Medium
25    BgeSmall,
26
27    /// intfloat/e5-small-v2 - Higher quality embeddings
28    /// - Dimensions: 384
29    /// - Max tokens: 512
30    /// - Speed: Medium
31    E5Small,
32}
33
34impl EmbeddingModel {
35    /// Get the HuggingFace model ID.
36    pub fn model_id(&self) -> &'static str {
37        match self {
38            EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
39            EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
40            EmbeddingModel::E5Small => "intfloat/e5-small-v2",
41        }
42    }
43
44    /// Get the embedding dimension for this model.
45    pub fn dimension(&self) -> usize {
46        match self {
47            EmbeddingModel::MiniLM => 384,
48            EmbeddingModel::BgeSmall => 384,
49            EmbeddingModel::E5Small => 384,
50        }
51    }
52
53    /// Get the maximum sequence length (in tokens).
54    pub fn max_seq_length(&self) -> usize {
55        match self {
56            EmbeddingModel::MiniLM => 256,
57            EmbeddingModel::BgeSmall => 512,
58            EmbeddingModel::E5Small => 512,
59        }
60    }
61
62    /// Get the query prefix for models that require it.
63    /// Some models like E5 require a prefix for queries vs documents.
64    pub fn query_prefix(&self) -> Option<&'static str> {
65        match self {
66            EmbeddingModel::MiniLM => None,
67            EmbeddingModel::BgeSmall => None,
68            EmbeddingModel::E5Small => Some("query: "),
69        }
70    }
71
72    /// Get the document/passage prefix for models that require it.
73    pub fn document_prefix(&self) -> Option<&'static str> {
74        match self {
75            EmbeddingModel::MiniLM => None,
76            EmbeddingModel::BgeSmall => None,
77            EmbeddingModel::E5Small => Some("passage: "),
78        }
79    }
80
81    /// Whether this model uses mean pooling (vs CLS token).
82    pub fn use_mean_pooling(&self) -> bool {
83        match self {
84            EmbeddingModel::MiniLM => true,
85            EmbeddingModel::BgeSmall => true,
86            EmbeddingModel::E5Small => true,
87        }
88    }
89
90    /// Whether embeddings should be normalized.
91    pub fn normalize_embeddings(&self) -> bool {
92        true // All supported models use normalized embeddings
93    }
94
95    /// Get approximate tokens per second on CPU (for estimation).
96    pub fn tokens_per_second_cpu(&self) -> usize {
97        match self {
98            EmbeddingModel::MiniLM => 5000,
99            EmbeddingModel::BgeSmall => 3000,
100            EmbeddingModel::E5Small => 3000,
101        }
102    }
103
104    /// Get the HuggingFace repository ID hosting the ONNX INT8 model for this embedding model.
105    ///
106    /// These are Xenova-hosted Optimum ONNX exports — quantized INT8, pre-built, no conversion
107    /// needed. MiniLM: 23 MB, BGE-small: 35 MB, E5-small: 35 MB.
108    pub fn onnx_repo_id(&self) -> &'static str {
109        match self {
110            EmbeddingModel::MiniLM => "Xenova/all-MiniLM-L6-v2",
111            EmbeddingModel::BgeSmall => "Xenova/bge-small-en-v1.5",
112            EmbeddingModel::E5Small => "Xenova/e5-small-v2",
113        }
114    }
115
116    /// Get the ONNX model filename (path within the repository).
117    pub fn onnx_filename(&self) -> &'static str {
118        "onnx/model_quantized.onnx"
119    }
120
121    /// List all available models.
122    pub fn all() -> &'static [EmbeddingModel] {
123        &[
124            EmbeddingModel::MiniLM,
125            EmbeddingModel::BgeSmall,
126            EmbeddingModel::E5Small,
127        ]
128    }
129
130    /// Parse model from string (case-insensitive).
131    pub fn parse(s: &str) -> Option<Self> {
132        match s.to_lowercase().as_str() {
133            "minilm" | "all-minilm-l6-v2" | "mini-lm" => Some(EmbeddingModel::MiniLM),
134            "bge-small" | "bge" | "bge-small-en" => Some(EmbeddingModel::BgeSmall),
135            "e5-small" | "e5" | "e5-small-v2" => Some(EmbeddingModel::E5Small),
136            _ => None,
137        }
138    }
139}
140
141impl std::fmt::Display for EmbeddingModel {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        match self {
144            EmbeddingModel::MiniLM => write!(f, "all-MiniLM-L6-v2"),
145            EmbeddingModel::BgeSmall => write!(f, "bge-small-en-v1.5"),
146            EmbeddingModel::E5Small => write!(f, "e5-small-v2"),
147        }
148    }
149}
150
151/// Configuration for model loading and inference.
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct ModelConfig {
154    /// The embedding model to use.
155    pub model: EmbeddingModel,
156
157    /// Custom cache directory for model files.
158    /// If None, uses HuggingFace default cache.
159    pub cache_dir: Option<String>,
160
161    /// Maximum batch size for inference.
162    pub max_batch_size: usize,
163
164    /// Whether to use GPU acceleration if available.
165    pub use_gpu: bool,
166
167    /// Number of threads for CPU inference.
168    pub num_threads: Option<usize>,
169}
170
171impl Default for ModelConfig {
172    fn default() -> Self {
173        Self {
174            model: EmbeddingModel::default(),
175            cache_dir: None,
176            max_batch_size: 32,
177            use_gpu: false,
178            num_threads: None,
179        }
180    }
181}
182
183impl ModelConfig {
184    /// Create a new config with the specified model.
185    pub fn new(model: EmbeddingModel) -> Self {
186        Self {
187            model,
188            ..Default::default()
189        }
190    }
191
192    /// Set the cache directory.
193    pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
194        self.cache_dir = Some(dir.into());
195        self
196    }
197
198    /// Set the maximum batch size.
199    pub fn with_max_batch_size(mut self, size: usize) -> Self {
200        self.max_batch_size = size;
201        self
202    }
203
204    /// Enable GPU acceleration.
205    pub fn with_gpu(mut self, use_gpu: bool) -> Self {
206        self.use_gpu = use_gpu;
207        self
208    }
209
210    /// Set the number of CPU threads.
211    pub fn with_num_threads(mut self, threads: usize) -> Self {
212        self.num_threads = Some(threads);
213        self
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_model_ids() {
223        assert_eq!(
224            EmbeddingModel::MiniLM.model_id(),
225            "sentence-transformers/all-MiniLM-L6-v2"
226        );
227        assert_eq!(
228            EmbeddingModel::BgeSmall.model_id(),
229            "BAAI/bge-small-en-v1.5"
230        );
231        assert_eq!(EmbeddingModel::E5Small.model_id(), "intfloat/e5-small-v2");
232    }
233
234    #[test]
235    fn test_dimensions() {
236        for model in EmbeddingModel::all() {
237            assert_eq!(model.dimension(), 384);
238        }
239    }
240
241    #[test]
242    fn test_from_str() {
243        assert_eq!(
244            EmbeddingModel::parse("minilm"),
245            Some(EmbeddingModel::MiniLM)
246        );
247        assert_eq!(
248            EmbeddingModel::parse("BGE-SMALL"),
249            Some(EmbeddingModel::BgeSmall)
250        );
251        assert_eq!(EmbeddingModel::parse("e5"), Some(EmbeddingModel::E5Small));
252        assert_eq!(EmbeddingModel::parse("unknown"), None);
253    }
254
255    #[test]
256    fn test_e5_prefixes() {
257        assert_eq!(EmbeddingModel::E5Small.query_prefix(), Some("query: "));
258        assert_eq!(EmbeddingModel::E5Small.document_prefix(), Some("passage: "));
259        assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
260    }
261}