1use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
13#[serde(rename_all = "kebab-case")]
14pub enum EmbeddingModel {
15 #[default]
20 BgeLarge,
21
22 MiniLM,
27
28 BgeSmall,
33
34 E5Small,
39}
40
41impl EmbeddingModel {
42 pub fn model_id(&self) -> &'static str {
44 match self {
45 EmbeddingModel::BgeLarge => "BAAI/bge-large-en-v1.5",
46 EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
47 EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
48 EmbeddingModel::E5Small => "intfloat/e5-small-v2",
49 }
50 }
51
52 pub fn dimension(&self) -> usize {
54 match self {
55 EmbeddingModel::BgeLarge => 1024,
56 EmbeddingModel::MiniLM => 384,
57 EmbeddingModel::BgeSmall => 384,
58 EmbeddingModel::E5Small => 384,
59 }
60 }
61
62 pub fn max_seq_length(&self) -> usize {
64 match self {
65 EmbeddingModel::BgeLarge => 512,
66 EmbeddingModel::MiniLM => 256,
67 EmbeddingModel::BgeSmall => 512,
68 EmbeddingModel::E5Small => 512,
69 }
70 }
71
72 pub fn query_prefix(&self) -> Option<&'static str> {
75 match self {
76 EmbeddingModel::BgeLarge => None,
77 EmbeddingModel::MiniLM => None,
78 EmbeddingModel::BgeSmall => None,
79 EmbeddingModel::E5Small => Some("query: "),
80 }
81 }
82
83 pub fn document_prefix(&self) -> Option<&'static str> {
85 match self {
86 EmbeddingModel::BgeLarge => None,
87 EmbeddingModel::MiniLM => None,
88 EmbeddingModel::BgeSmall => None,
89 EmbeddingModel::E5Small => Some("passage: "),
90 }
91 }
92
93 pub fn use_mean_pooling(&self) -> bool {
95 match self {
96 EmbeddingModel::BgeLarge => true,
97 EmbeddingModel::MiniLM => true,
98 EmbeddingModel::BgeSmall => true,
99 EmbeddingModel::E5Small => true,
100 }
101 }
102
103 pub fn normalize_embeddings(&self) -> bool {
105 true }
107
108 pub fn tokens_per_second_cpu(&self) -> usize {
110 match self {
111 EmbeddingModel::BgeLarge => 1000,
112 EmbeddingModel::MiniLM => 5000,
113 EmbeddingModel::BgeSmall => 3000,
114 EmbeddingModel::E5Small => 3000,
115 }
116 }
117
118 pub fn onnx_repo_id(&self) -> &'static str {
123 match self {
124 EmbeddingModel::BgeLarge => "Xenova/bge-large-en-v1.5",
125 EmbeddingModel::MiniLM => "Xenova/all-MiniLM-L6-v2",
126 EmbeddingModel::BgeSmall => "Xenova/bge-small-en-v1.5",
127 EmbeddingModel::E5Small => "Xenova/e5-small-v2",
128 }
129 }
130
131 pub fn onnx_filename(&self) -> &'static str {
133 "onnx/model_quantized.onnx"
134 }
135
136 pub fn onnx_filename_gpu(&self) -> &'static str {
143 "onnx/model.onnx"
144 }
145
146 pub fn all() -> &'static [EmbeddingModel] {
148 &[
149 EmbeddingModel::BgeLarge,
150 EmbeddingModel::MiniLM,
151 EmbeddingModel::BgeSmall,
152 EmbeddingModel::E5Small,
153 ]
154 }
155
156 pub fn parse(s: &str) -> Option<Self> {
158 match s.to_lowercase().as_str() {
159 "bge-large" | "bge-large-en" | "bge-large-en-v1.5" => Some(EmbeddingModel::BgeLarge),
160 "minilm" | "all-minilm-l6-v2" | "mini-lm" => Some(EmbeddingModel::MiniLM),
161 "bge-small" | "bge" | "bge-small-en" => Some(EmbeddingModel::BgeSmall),
162 "e5-small" | "e5" | "e5-small-v2" => Some(EmbeddingModel::E5Small),
163 _ => None,
164 }
165 }
166}
167
168impl std::fmt::Display for EmbeddingModel {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 match self {
171 EmbeddingModel::BgeLarge => write!(f, "bge-large-en-v1.5"),
172 EmbeddingModel::MiniLM => write!(f, "all-MiniLM-L6-v2"),
173 EmbeddingModel::BgeSmall => write!(f, "bge-small-en-v1.5"),
174 EmbeddingModel::E5Small => write!(f, "e5-small-v2"),
175 }
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct ModelConfig {
182 pub model: EmbeddingModel,
184
185 pub cache_dir: Option<String>,
188
189 pub max_batch_size: usize,
191
192 pub use_gpu: bool,
194
195 pub num_threads: Option<usize>,
197
198 pub session_pool_size: usize,
205}
206
207impl Default for ModelConfig {
208 fn default() -> Self {
209 let pool_size = std::env::var("DAKERA_ONNX_POOL_SIZE")
216 .ok()
217 .and_then(|v| v.parse::<usize>().ok())
218 .filter(|&n| n >= 1)
219 .unwrap_or(4);
220 let max_batch_size = std::env::var("DAKERA_ONNX_BATCH_SIZE")
225 .ok()
226 .and_then(|v| v.parse::<usize>().ok())
227 .filter(|&n| n >= 1)
228 .unwrap_or(32);
229 Self {
230 model: EmbeddingModel::default(),
231 cache_dir: None,
232 max_batch_size,
233 use_gpu: false,
234 num_threads: None,
235 session_pool_size: pool_size,
236 }
237 }
238}
239
240impl ModelConfig {
241 pub fn new(model: EmbeddingModel) -> Self {
243 Self {
244 model,
245 ..Default::default()
246 }
247 }
248
249 pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
251 self.cache_dir = Some(dir.into());
252 self
253 }
254
255 pub fn with_max_batch_size(mut self, size: usize) -> Self {
257 self.max_batch_size = size;
258 self
259 }
260
261 pub fn with_gpu(mut self, use_gpu: bool) -> Self {
263 self.use_gpu = use_gpu;
264 self
265 }
266
267 pub fn with_num_threads(mut self, threads: usize) -> Self {
269 self.num_threads = Some(threads);
270 self
271 }
272
273 pub fn with_session_pool_size(mut self, size: usize) -> Self {
275 self.session_pool_size = size.max(1);
276 self
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_model_ids() {
286 assert_eq!(
287 EmbeddingModel::BgeLarge.model_id(),
288 "BAAI/bge-large-en-v1.5"
289 );
290 assert_eq!(
291 EmbeddingModel::MiniLM.model_id(),
292 "sentence-transformers/all-MiniLM-L6-v2"
293 );
294 assert_eq!(
295 EmbeddingModel::BgeSmall.model_id(),
296 "BAAI/bge-small-en-v1.5"
297 );
298 assert_eq!(EmbeddingModel::E5Small.model_id(), "intfloat/e5-small-v2");
299 }
300
301 #[test]
302 fn test_dimensions() {
303 assert_eq!(EmbeddingModel::BgeLarge.dimension(), 1024);
304 assert_eq!(EmbeddingModel::MiniLM.dimension(), 384);
305 assert_eq!(EmbeddingModel::BgeSmall.dimension(), 384);
306 assert_eq!(EmbeddingModel::E5Small.dimension(), 384);
307 for model in EmbeddingModel::all() {
309 assert!(model.dimension() > 0);
310 }
311 }
312
313 #[test]
314 fn test_from_str() {
315 assert_eq!(
316 EmbeddingModel::parse("bge-large"),
317 Some(EmbeddingModel::BgeLarge)
318 );
319 assert_eq!(
320 EmbeddingModel::parse("minilm"),
321 Some(EmbeddingModel::MiniLM)
322 );
323 assert_eq!(
324 EmbeddingModel::parse("BGE-SMALL"),
325 Some(EmbeddingModel::BgeSmall)
326 );
327 assert_eq!(EmbeddingModel::parse("e5"), Some(EmbeddingModel::E5Small));
328 assert_eq!(EmbeddingModel::parse("unknown"), None);
329 }
330
331 #[test]
332 fn test_e5_prefixes() {
333 assert_eq!(EmbeddingModel::E5Small.query_prefix(), Some("query: "));
334 assert_eq!(EmbeddingModel::E5Small.document_prefix(), Some("passage: "));
335 assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
336 }
337
338 #[test]
339 fn test_onnx_filenames() {
340 for model in EmbeddingModel::all() {
342 assert_eq!(model.onnx_filename(), "onnx/model_quantized.onnx");
343 }
344 for model in EmbeddingModel::all() {
346 assert_eq!(model.onnx_filename_gpu(), "onnx/model.onnx");
347 }
348 assert_ne!(
350 EmbeddingModel::BgeLarge.onnx_filename(),
351 EmbeddingModel::BgeLarge.onnx_filename_gpu()
352 );
353 }
354}