1use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
14#[serde(rename_all = "kebab-case")]
15pub enum EmbeddingModel {
16 #[default]
21 BgeLarge,
22
23 MiniLM,
28
29 BgeSmall,
34
35 E5Small,
40
41 ModernBertEmbedBase,
48}
49
50impl EmbeddingModel {
51 pub fn model_id(&self) -> &'static str {
53 match self {
54 EmbeddingModel::BgeLarge => "BAAI/bge-large-en-v1.5",
55 EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
56 EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
57 EmbeddingModel::E5Small => "intfloat/e5-small-v2",
58 EmbeddingModel::ModernBertEmbedBase => "nomic-ai/modernbert-embed-base",
59 }
60 }
61
62 pub fn dimension(&self) -> usize {
64 match self {
65 EmbeddingModel::BgeLarge => 1024,
66 EmbeddingModel::MiniLM => 384,
67 EmbeddingModel::BgeSmall => 384,
68 EmbeddingModel::E5Small => 384,
69 EmbeddingModel::ModernBertEmbedBase => 768,
70 }
71 }
72
73 pub fn max_seq_length(&self) -> usize {
75 match self {
76 EmbeddingModel::BgeLarge => 512,
77 EmbeddingModel::MiniLM => 256,
78 EmbeddingModel::BgeSmall => 512,
79 EmbeddingModel::E5Small => 512,
80 EmbeddingModel::ModernBertEmbedBase => 8192,
81 }
82 }
83
84 pub fn mrl_dimensions(&self) -> Option<&'static [usize]> {
88 match self {
89 EmbeddingModel::ModernBertEmbedBase => Some(&[64, 128, 256, 512, 768]),
90 _ => None,
91 }
92 }
93
94 pub fn safetensors_filename(&self) -> &'static str {
96 "model.safetensors"
97 }
98
99 pub fn config_filename(&self) -> &'static str {
101 "config.json"
102 }
103
104 pub fn model2vec_repo_id(&self) -> &'static str {
106 match self {
107 EmbeddingModel::BgeLarge => "dakera-ai/bge-large-model2vec-256d",
108 EmbeddingModel::ModernBertEmbedBase => "dakera-ai/modernbert-model2vec-256d",
109 _ => "dakera-ai/bge-small-model2vec-256d",
110 }
111 }
112
113 pub fn gguf_repo_id(&self) -> &'static str {
115 match self {
116 EmbeddingModel::BgeLarge => "dakera-ai/bge-large-gguf",
117 EmbeddingModel::ModernBertEmbedBase => "dakera-ai/modernbert-gguf",
118 _ => "dakera-ai/bge-small-gguf",
119 }
120 }
121
122 pub fn query_prefix(&self) -> Option<&'static str> {
125 match self {
126 EmbeddingModel::BgeLarge => None,
127 EmbeddingModel::MiniLM => None,
128 EmbeddingModel::BgeSmall => None,
129 EmbeddingModel::E5Small => Some("query: "),
130 EmbeddingModel::ModernBertEmbedBase => None,
131 }
132 }
133
134 pub fn document_prefix(&self) -> Option<&'static str> {
136 match self {
137 EmbeddingModel::BgeLarge => None,
138 EmbeddingModel::MiniLM => None,
139 EmbeddingModel::BgeSmall => None,
140 EmbeddingModel::E5Small => Some("passage: "),
141 EmbeddingModel::ModernBertEmbedBase => None,
142 }
143 }
144
145 pub fn use_mean_pooling(&self) -> bool {
147 match self {
148 EmbeddingModel::BgeLarge => true,
149 EmbeddingModel::MiniLM => true,
150 EmbeddingModel::BgeSmall => true,
151 EmbeddingModel::E5Small => true,
152 EmbeddingModel::ModernBertEmbedBase => true,
153 }
154 }
155
156 pub fn normalize_embeddings(&self) -> bool {
158 true }
160
161 pub fn tokens_per_second_cpu(&self) -> usize {
163 match self {
164 EmbeddingModel::BgeLarge => 1000,
165 EmbeddingModel::MiniLM => 5000,
166 EmbeddingModel::BgeSmall => 3000,
167 EmbeddingModel::E5Small => 3000,
168 EmbeddingModel::ModernBertEmbedBase => 1250, }
170 }
171
172 pub fn onnx_repo_id(&self) -> &'static str {
177 match self {
178 EmbeddingModel::BgeLarge => "Xenova/bge-large-en-v1.5",
179 EmbeddingModel::MiniLM => "Xenova/all-MiniLM-L6-v2",
180 EmbeddingModel::BgeSmall => "Xenova/bge-small-en-v1.5",
181 EmbeddingModel::E5Small => "Xenova/e5-small-v2",
182 EmbeddingModel::ModernBertEmbedBase => "Xenova/modernbert-embed-base",
183 }
184 }
185
186 pub fn onnx_filename(&self) -> &'static str {
188 "onnx/model_quantized.onnx"
189 }
190
191 pub fn onnx_filename_gpu(&self) -> &'static str {
198 "onnx/model.onnx"
199 }
200
201 pub fn all() -> &'static [EmbeddingModel] {
203 &[
204 EmbeddingModel::BgeLarge,
205 EmbeddingModel::MiniLM,
206 EmbeddingModel::BgeSmall,
207 EmbeddingModel::E5Small,
208 EmbeddingModel::ModernBertEmbedBase,
209 ]
210 }
211
212 pub fn parse(s: &str) -> Option<Self> {
214 match s.to_lowercase().as_str() {
215 "bge-large" | "bge-large-en" | "bge-large-en-v1.5" => Some(EmbeddingModel::BgeLarge),
216 "minilm" | "all-minilm-l6-v2" | "mini-lm" => Some(EmbeddingModel::MiniLM),
217 "bge-small" | "bge" | "bge-small-en" => Some(EmbeddingModel::BgeSmall),
218 "e5-small" | "e5" | "e5-small-v2" => Some(EmbeddingModel::E5Small),
219 "modernbert-embed-base" | "modernbert" | "modern-bert" => {
220 Some(EmbeddingModel::ModernBertEmbedBase)
221 }
222 _ => None,
223 }
224 }
225
226 pub fn from_env() -> Self {
228 std::env::var("DAKERA_MODEL")
229 .ok()
230 .as_deref()
231 .and_then(Self::parse)
232 .unwrap_or_default()
233 }
234}
235
236impl std::fmt::Display for EmbeddingModel {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 match self {
239 EmbeddingModel::BgeLarge => write!(f, "bge-large-en-v1.5"),
240 EmbeddingModel::MiniLM => write!(f, "all-MiniLM-L6-v2"),
241 EmbeddingModel::BgeSmall => write!(f, "bge-small-en-v1.5"),
242 EmbeddingModel::E5Small => write!(f, "e5-small-v2"),
243 EmbeddingModel::ModernBertEmbedBase => write!(f, "modernbert-embed-base"),
244 }
245 }
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct ModelConfig {
251 pub model: EmbeddingModel,
253
254 pub cache_dir: Option<String>,
257
258 pub max_batch_size: usize,
260
261 pub use_gpu: bool,
263
264 pub num_threads: Option<usize>,
266
267 pub session_pool_size: usize,
274
275 #[serde(skip)]
278 pub backend_override: Option<crate::backend::BackendKind>,
279}
280
281impl Default for ModelConfig {
282 fn default() -> Self {
283 let pool_size = std::env::var("DAKERA_ONNX_POOL_SIZE")
290 .ok()
291 .and_then(|v| v.parse::<usize>().ok())
292 .filter(|&n| n >= 1)
293 .unwrap_or(4);
294 let max_batch_size = std::env::var("DAKERA_ONNX_BATCH_SIZE")
299 .ok()
300 .and_then(|v| v.parse::<usize>().ok())
301 .filter(|&n| n >= 1)
302 .unwrap_or(32);
303 Self {
304 model: EmbeddingModel::default(),
305 cache_dir: None,
306 max_batch_size,
307 use_gpu: false,
308 num_threads: None,
309 session_pool_size: pool_size,
310 backend_override: None,
311 }
312 }
313}
314
315impl ModelConfig {
316 pub fn new(model: EmbeddingModel) -> Self {
318 Self {
319 model,
320 ..Default::default()
321 }
322 }
323
324 pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
326 self.cache_dir = Some(dir.into());
327 self
328 }
329
330 pub fn with_max_batch_size(mut self, size: usize) -> Self {
332 self.max_batch_size = size;
333 self
334 }
335
336 pub fn with_gpu(mut self, use_gpu: bool) -> Self {
338 self.use_gpu = use_gpu;
339 self
340 }
341
342 pub fn with_num_threads(mut self, threads: usize) -> Self {
344 self.num_threads = Some(threads);
345 self
346 }
347
348 pub fn with_session_pool_size(mut self, size: usize) -> Self {
350 self.session_pool_size = size.max(1);
351 self
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_model_ids() {
361 assert_eq!(
362 EmbeddingModel::BgeLarge.model_id(),
363 "BAAI/bge-large-en-v1.5"
364 );
365 assert_eq!(
366 EmbeddingModel::MiniLM.model_id(),
367 "sentence-transformers/all-MiniLM-L6-v2"
368 );
369 assert_eq!(
370 EmbeddingModel::BgeSmall.model_id(),
371 "BAAI/bge-small-en-v1.5"
372 );
373 assert_eq!(EmbeddingModel::E5Small.model_id(), "intfloat/e5-small-v2");
374 }
375
376 #[test]
377 fn test_dimensions() {
378 assert_eq!(EmbeddingModel::BgeLarge.dimension(), 1024);
379 assert_eq!(EmbeddingModel::MiniLM.dimension(), 384);
380 assert_eq!(EmbeddingModel::BgeSmall.dimension(), 384);
381 assert_eq!(EmbeddingModel::E5Small.dimension(), 384);
382 for model in EmbeddingModel::all() {
384 assert!(model.dimension() > 0);
385 }
386 }
387
388 #[test]
389 fn test_from_str() {
390 assert_eq!(
391 EmbeddingModel::parse("bge-large"),
392 Some(EmbeddingModel::BgeLarge)
393 );
394 assert_eq!(
395 EmbeddingModel::parse("minilm"),
396 Some(EmbeddingModel::MiniLM)
397 );
398 assert_eq!(
399 EmbeddingModel::parse("BGE-SMALL"),
400 Some(EmbeddingModel::BgeSmall)
401 );
402 assert_eq!(EmbeddingModel::parse("e5"), Some(EmbeddingModel::E5Small));
403 assert_eq!(EmbeddingModel::parse("unknown"), None);
404 }
405
406 #[test]
407 fn test_e5_prefixes() {
408 assert_eq!(EmbeddingModel::E5Small.query_prefix(), Some("query: "));
409 assert_eq!(EmbeddingModel::E5Small.document_prefix(), Some("passage: "));
410 assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
411 }
412
413 #[test]
414 fn test_onnx_filenames() {
415 for model in EmbeddingModel::all() {
417 assert_eq!(model.onnx_filename(), "onnx/model_quantized.onnx");
418 }
419 for model in EmbeddingModel::all() {
421 assert_eq!(model.onnx_filename_gpu(), "onnx/model.onnx");
422 }
423 assert_ne!(
425 EmbeddingModel::BgeLarge.onnx_filename(),
426 EmbeddingModel::BgeLarge.onnx_filename_gpu()
427 );
428 }
429
430 #[test]
433 fn test_modernbert_model_id() {
434 assert_eq!(
435 EmbeddingModel::ModernBertEmbedBase.model_id(),
436 "nomic-ai/modernbert-embed-base"
437 );
438 }
439
440 #[test]
441 fn test_modernbert_dimension_768() {
442 assert_eq!(EmbeddingModel::ModernBertEmbedBase.dimension(), 768);
443 }
444
445 #[test]
446 fn test_modernbert_max_tokens_8192() {
447 assert_eq!(EmbeddingModel::ModernBertEmbedBase.max_seq_length(), 8192);
448 }
449
450 #[test]
451 fn test_modernbert_mrl_dimensions() {
452 let dims = EmbeddingModel::ModernBertEmbedBase.mrl_dimensions();
453 assert!(dims.is_some());
454 let dims = dims.unwrap();
455 assert!(dims.contains(&256));
456 assert!(dims.contains(&768));
457 }
458
459 #[test]
460 fn test_modernbert_no_prefix() {
461 assert!(EmbeddingModel::ModernBertEmbedBase.query_prefix().is_none());
462 assert!(EmbeddingModel::ModernBertEmbedBase
463 .document_prefix()
464 .is_none());
465 }
466
467 #[test]
468 fn test_modernbert_parse() {
469 assert_eq!(
470 EmbeddingModel::parse("modernbert-embed-base"),
471 Some(EmbeddingModel::ModernBertEmbedBase)
472 );
473 assert_eq!(
474 EmbeddingModel::parse("modernbert"),
475 Some(EmbeddingModel::ModernBertEmbedBase)
476 );
477 assert_eq!(
478 EmbeddingModel::parse("MODERNBERT"),
479 Some(EmbeddingModel::ModernBertEmbedBase)
480 );
481 }
482
483 #[test]
484 fn test_modernbert_display() {
485 assert_eq!(
486 EmbeddingModel::ModernBertEmbedBase.to_string(),
487 "modernbert-embed-base"
488 );
489 }
490
491 #[test]
492 fn test_bge_large_no_mrl_dimensions() {
493 assert!(EmbeddingModel::BgeLarge.mrl_dimensions().is_none());
494 }
495
496 #[test]
497 fn test_safetensors_config_filenames() {
498 for model in EmbeddingModel::all() {
499 assert_eq!(model.safetensors_filename(), "model.safetensors");
500 assert_eq!(model.config_filename(), "config.json");
501 }
502 }
503
504 #[test]
505 fn test_model2vec_repo_id_bge_large() {
506 assert_eq!(
507 EmbeddingModel::BgeLarge.model2vec_repo_id(),
508 "dakera-ai/bge-large-model2vec-256d"
509 );
510 }
511
512 #[test]
513 fn test_gguf_repo_id_bge_large() {
514 assert_eq!(
515 EmbeddingModel::BgeLarge.gguf_repo_id(),
516 "dakera-ai/bge-large-gguf"
517 );
518 }
519}