1use serde::{Deserialize, Serialize};
6use std::time::SystemTime;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModelProvenance {
31 pub model: EmbeddingModel,
33 pub model_id: String,
35 pub hash: String,
41 pub loaded_at: SystemTime,
43 pub loaded_at_iso: String,
45}
46
47impl ModelProvenance {
48 pub fn new(model: EmbeddingModel, model_id: String) -> Self {
50 let loaded_at = SystemTime::now();
51 let loaded_at_iso = {
52 let dt: chrono::DateTime<chrono::Utc> = loaded_at.into();
53 dt.to_rfc3339()
54 };
55
56 let hash_input = format!("{model_id}:{loaded_at_iso}:{model:?}");
58 let hash = blake3::hash(hash_input.as_bytes()).to_hex().to_string();
59
60 Self {
61 model,
62 model_id,
63 hash,
64 loaded_at,
65 loaded_at_iso,
66 }
67 }
68
69 pub fn dimensions(&self) -> usize {
71 self.model.dimensions()
72 }
73
74 pub fn matches_model(&self, expected: EmbeddingModel) -> bool {
76 self.model == expected
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
98#[serde(rename_all = "snake_case")]
99#[non_exhaustive]
100pub enum EmbeddingModel {
101 #[default]
103 #[serde(alias = "BgeSmallEnV15")]
104 BgeSmallEnV15,
105
106 #[serde(alias = "BgeBaseEnV15")]
108 BgeBaseEnV15,
109
110 #[serde(alias = "BgeLargeEnV15")]
112 BgeLargeEnV15,
113
114 #[serde(alias = "MultilingualE5Small")]
116 MultilingualE5Small,
117
118 #[serde(alias = "MultilingualE5Base")]
120 MultilingualE5Base,
121
122 #[serde(alias = "Qwen3Embedding0_6B")]
124 Qwen3Embedding0_6B,
125
126 #[serde(alias = "Qwen3Embedding4B")]
128 Qwen3Embedding4B,
129
130 #[serde(alias = "AllMiniLmL6V2")]
132 AllMiniLmL6V2,
133
134 #[serde(alias = "ParaphraseMultilingualMiniLmL12V2")]
136 ParaphraseMultilingualMiniLmL12V2,
137
138 #[serde(alias = "TextEmbedding3Small")]
140 TextEmbedding3Small,
141}
142
143impl EmbeddingModel {
144 #[inline]
149 pub const fn native_dimensions(&self) -> usize {
150 match self {
151 EmbeddingModel::BgeSmallEnV15
152 | EmbeddingModel::MultilingualE5Small
153 | EmbeddingModel::AllMiniLmL6V2
154 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => 384,
155 EmbeddingModel::BgeBaseEnV15 | EmbeddingModel::MultilingualE5Base => 768,
156 EmbeddingModel::BgeLargeEnV15 | EmbeddingModel::Qwen3Embedding0_6B => 1024,
157 EmbeddingModel::Qwen3Embedding4B => 2560,
158 EmbeddingModel::TextEmbedding3Small => 1536,
159 }
160 }
161
162 #[inline]
174 pub const fn dimensions(&self) -> usize {
175 self.native_dimensions()
176 }
177
178 #[inline]
180 pub const fn is_local(&self) -> bool {
181 matches!(
182 self,
183 EmbeddingModel::BgeSmallEnV15
184 | EmbeddingModel::BgeBaseEnV15
185 | EmbeddingModel::BgeLargeEnV15
186 | EmbeddingModel::MultilingualE5Small
187 | EmbeddingModel::MultilingualE5Base
188 | EmbeddingModel::AllMiniLmL6V2
189 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2
190 | EmbeddingModel::Qwen3Embedding0_6B
191 | EmbeddingModel::Qwen3Embedding4B
192 )
193 }
194
195 #[inline]
197 pub const fn is_remote(&self) -> bool {
198 matches!(self, EmbeddingModel::TextEmbedding3Small)
199 }
200
201 #[inline]
211 pub const fn max_input_tokens(&self) -> usize {
212 match self {
213 EmbeddingModel::BgeSmallEnV15 => 512,
215 EmbeddingModel::BgeBaseEnV15 => 512,
216 EmbeddingModel::BgeLargeEnV15 => 512,
217 EmbeddingModel::MultilingualE5Small => 512,
219 EmbeddingModel::MultilingualE5Base => 512,
220 EmbeddingModel::AllMiniLmL6V2 => 256,
222 EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => 128,
224 EmbeddingModel::Qwen3Embedding0_6B => 8192,
226 EmbeddingModel::Qwen3Embedding4B => 8192,
227 EmbeddingModel::TextEmbedding3Small => 8191,
229 }
230 }
231
232 #[inline]
250 pub const fn query_instruction(&self) -> Option<&'static str> {
251 match self {
252 EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => {
253 Some("query: ")
256 }
257 EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => Some(
258 "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ",
259 ),
260 _ => None,
261 }
262 }
263
264 #[inline]
270 pub const fn document_instruction(&self) -> Option<&'static str> {
271 None
272 }
273
274 #[inline]
276 pub const fn model_id(&self) -> &'static str {
277 match self {
278 EmbeddingModel::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
279 EmbeddingModel::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
280 EmbeddingModel::BgeLargeEnV15 => "BAAI/bge-large-en-v1.5",
281 EmbeddingModel::MultilingualE5Small => "intfloat/multilingual-e5-small",
282 EmbeddingModel::MultilingualE5Base => "intfloat/multilingual-e5-base",
283 EmbeddingModel::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
284 EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
285 "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
286 }
287 EmbeddingModel::Qwen3Embedding0_6B => "Qwen/Qwen3-Embedding-0.6B",
288 EmbeddingModel::Qwen3Embedding4B => "Qwen/Qwen3-Embedding-4B",
289 EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
290 }
291 }
292
293 #[inline]
295 pub const fn supports_output_dim(&self) -> bool {
296 matches!(
297 self,
298 EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B
299 )
300 }
301
302 #[inline]
304 pub const fn key_version(&self) -> &'static str {
305 match self {
306 EmbeddingModel::TextEmbedding3Small
307 | EmbeddingModel::Qwen3Embedding0_6B
308 | EmbeddingModel::Qwen3Embedding4B => "v3",
309 EmbeddingModel::AllMiniLmL6V2 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
310 "v2"
311 }
312 _ => "v1.5",
313 }
314 }
315}
316
317impl std::fmt::Display for EmbeddingModel {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 match self {
320 EmbeddingModel::BgeSmallEnV15 => write!(f, "bge-small-en-v1.5"),
321 EmbeddingModel::BgeBaseEnV15 => write!(f, "bge-base-en-v1.5"),
322 EmbeddingModel::BgeLargeEnV15 => write!(f, "bge-large-en-v1.5"),
323 EmbeddingModel::MultilingualE5Small => write!(f, "multilingual-e5-small"),
324 EmbeddingModel::MultilingualE5Base => write!(f, "multilingual-e5-base"),
325 EmbeddingModel::Qwen3Embedding0_6B => write!(f, "qwen3-embedding-0.6b"),
326 EmbeddingModel::Qwen3Embedding4B => write!(f, "qwen3-embedding-4b"),
327 EmbeddingModel::AllMiniLmL6V2 => write!(f, "all-minilm-l6-v2"),
328 EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
329 write!(f, "paraphrase-multilingual-minilm-l12-v2")
330 }
331 EmbeddingModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
332 }
333 }
334}
335
336impl std::str::FromStr for EmbeddingModel {
337 type Err = String;
338
339 fn from_str(s: &str) -> Result<Self, Self::Err> {
346 let lower = s.to_lowercase();
347 let normalized = lower.trim().replace("_", "-").replace("baai/", "");
348
349 match normalized.as_str() {
350 "bge-small-en-v1.5" | "bge-small-en" | "bge-small" | "small" => {
351 Ok(EmbeddingModel::BgeSmallEnV15)
352 }
353 "bge-base-en-v1.5" | "bge-base-en" | "bge-base" | "base" => {
354 Ok(EmbeddingModel::BgeBaseEnV15)
355 }
356 "bge-large-en-v1.5" | "bge-large-en" | "bge-large" | "large" => {
357 Ok(EmbeddingModel::BgeLargeEnV15)
358 }
359 "multilingual-e5-small" | "e5-small" | "intfloat/multilingual-e5-small" => {
360 Ok(EmbeddingModel::MultilingualE5Small)
361 }
362 "multilingual-e5-base" | "e5-base" | "intfloat/multilingual-e5-base" => {
363 Ok(EmbeddingModel::MultilingualE5Base)
364 }
365 "qwen3-embedding-0.6b" | "qwen3-embedding" | "qwen3" | "qwen/qwen3-embedding-0.6b" => {
366 Ok(EmbeddingModel::Qwen3Embedding0_6B)
367 }
368 "qwen3-embedding-4b" | "qwen3-4b" | "qwen/qwen3-embedding-4b" => {
369 Ok(EmbeddingModel::Qwen3Embedding4B)
370 }
371 "all-minilm-l6-v2"
372 | "minilm"
373 | "all-minilm"
374 | "sentence-transformers/all-minilm-l6-v2" => Ok(EmbeddingModel::AllMiniLmL6V2),
375 "paraphrase-multilingual-minilm-l12-v2"
376 | "paraphrase-multilingual"
377 | "multilingual-minilm"
378 | "sentence-transformers/paraphrase-multilingual-minilm-l12-v2" => {
379 Ok(EmbeddingModel::ParaphraseMultilingualMiniLmL12V2)
380 }
381 "text-embedding-3-small" | "openai-small" | "openai" => {
382 Ok(EmbeddingModel::TextEmbedding3Small)
383 }
384 _ => Err(format!(
385 "unknown embedding model: '{s}'. Valid: bge-small-en-v1.5, bge-base-en-v1.5, bge-large-en-v1.5, multilingual-e5-small, multilingual-e5-base, text-embedding-3-small"
386 )),
387 }
388 }
389}
390
391pub const MIN_MRL_OUTPUT_DIM: usize = 32;
397
398#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
403pub struct ModelConfig {
404 pub model: EmbeddingModel,
406 #[serde(default)]
408 pub output_dim: Option<usize>,
409}
410
411impl Default for ModelConfig {
412 fn default() -> Self {
413 Self::new(EmbeddingModel::default())
414 }
415}
416
417impl ModelConfig {
418 pub const fn new(model: EmbeddingModel) -> Self {
420 Self {
421 model,
422 output_dim: None,
423 }
424 }
425
426 pub fn try_new(
428 model: EmbeddingModel,
429 output_dim: Option<usize>,
430 ) -> std::result::Result<Self, crate::error::EmbedError> {
431 let config = Self { model, output_dim };
432 config.validate()?;
433 Ok(config)
434 }
435
436 pub fn validate(&self) -> std::result::Result<(), crate::error::EmbedError> {
438 let Some(dim) = self.output_dim else {
439 return Ok(());
440 };
441 if !self.model.supports_output_dim() {
442 return Err(crate::error::EmbedError::InvalidInput(format!(
443 "{} does not support configurable embedding dimensions",
444 self.model
445 )));
446 }
447 if dim < MIN_MRL_OUTPUT_DIM {
448 return Err(crate::error::EmbedError::InvalidInput(format!(
449 "embedding output dimension {dim} is below minimum {MIN_MRL_OUTPUT_DIM}"
450 )));
451 }
452 let native = self.model.native_dimensions();
453 if dim > native {
454 return Err(crate::error::EmbedError::InvalidInput(format!(
455 "embedding output dimension {dim} exceeds native dimension {native} for {}",
456 self.model
457 )));
458 }
459 Ok(())
460 }
461
462 pub fn dimensions(&self) -> usize {
464 self.output_dim
465 .unwrap_or_else(|| self.model.native_dimensions())
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
474 fn test_default_model() {
475 let model = EmbeddingModel::default();
476 assert_eq!(model, EmbeddingModel::BgeSmallEnV15);
477 }
478
479 #[test]
480 fn test_model_provenance_new() {
481 let provenance = ModelProvenance::new(
482 EmbeddingModel::BgeSmallEnV15,
483 "BAAI/bge-small-en-v1.5".into(),
484 );
485
486 assert_eq!(provenance.model, EmbeddingModel::BgeSmallEnV15);
487 assert_eq!(provenance.model_id, "BAAI/bge-small-en-v1.5");
488 assert!(!provenance.hash.is_empty());
489 assert_eq!(provenance.hash.len(), 64); assert!(!provenance.loaded_at_iso.is_empty());
491 }
492
493 #[test]
494 fn test_model_provenance_unique_hash() {
495 let p1 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "model1".into());
496 std::thread::sleep(std::time::Duration::from_millis(10)); let p2 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "model1".into());
498
499 assert_ne!(p1.hash, p2.hash);
501 }
502
503 #[test]
504 fn test_model_provenance_dimensions() {
505 let p1 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "small".into());
506 assert_eq!(p1.dimensions(), 384);
507
508 let p2 = ModelProvenance::new(EmbeddingModel::BgeBaseEnV15, "base".into());
509 assert_eq!(p2.dimensions(), 768);
510
511 let p3 = ModelProvenance::new(EmbeddingModel::BgeLargeEnV15, "large".into());
512 assert_eq!(p3.dimensions(), 1024);
513 }
514
515 #[test]
516 fn test_model_provenance_matches_model() {
517 let provenance = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "test".into());
518
519 assert!(provenance.matches_model(EmbeddingModel::BgeSmallEnV15));
520 assert!(!provenance.matches_model(EmbeddingModel::BgeBaseEnV15));
521 assert!(!provenance.matches_model(EmbeddingModel::BgeLargeEnV15));
522 }
523
524 #[test]
525 fn test_model_provenance_serialization() {
526 let provenance = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "test-model".into());
527
528 let json = serde_json::to_string(&provenance).unwrap();
529 assert!(json.contains("bge_small_en_v15"), "json={json}");
532 assert!(json.contains("test-model"));
533 assert!(json.contains(&provenance.hash));
534
535 let parsed: ModelProvenance = serde_json::from_str(&json).unwrap();
536 assert_eq!(parsed.model, provenance.model);
537 assert_eq!(parsed.model_id, provenance.model_id);
538 assert_eq!(parsed.hash, provenance.hash);
539 }
540
541 #[test]
542 fn test_dimensions() {
543 assert_eq!(EmbeddingModel::BgeSmallEnV15.dimensions(), 384);
544 assert_eq!(EmbeddingModel::BgeBaseEnV15.dimensions(), 768);
545 assert_eq!(EmbeddingModel::BgeLargeEnV15.dimensions(), 1024);
546 assert_eq!(EmbeddingModel::Qwen3Embedding4B.dimensions(), 2560);
547 }
548
549 #[test]
550 fn test_model_config_native_dims() {
551 assert_eq!(
552 ModelConfig::new(EmbeddingModel::Qwen3Embedding4B).dimensions(),
553 2560
554 );
555 assert_eq!(
556 ModelConfig::new(EmbeddingModel::Qwen3Embedding0_6B).dimensions(),
557 1024
558 );
559 assert_eq!(
560 ModelConfig::new(EmbeddingModel::BgeSmallEnV15).dimensions(),
561 384
562 );
563 }
564
565 #[test]
566 fn test_model_config_configured_dim() {
567 let cfg = ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(1024)).unwrap();
568 assert_eq!(cfg.dimensions(), 1024);
569
570 let cfg = ModelConfig::try_new(EmbeddingModel::Qwen3Embedding0_6B, Some(512)).unwrap();
571 assert_eq!(cfg.dimensions(), 512);
572 }
573
574 #[test]
575 fn test_model_config_validation_below_min() {
576 assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(31)).is_err());
577 assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(0)).is_err());
578 }
579
580 #[test]
581 fn test_model_config_validation_above_native() {
582 assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(2561)).is_err());
583 assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding0_6B, Some(1025)).is_err());
584 }
585
586 #[test]
587 fn test_model_config_validation_non_mrl_model() {
588 assert!(ModelConfig::try_new(EmbeddingModel::BgeSmallEnV15, Some(128)).is_err());
589 assert!(ModelConfig::try_new(EmbeddingModel::BgeBaseEnV15, Some(512)).is_err());
590 }
591
592 #[test]
593 fn test_model_config_none_output_dim_ok_for_any_model() {
594 assert!(ModelConfig::try_new(EmbeddingModel::BgeSmallEnV15, None).is_ok());
595 assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, None).is_ok());
596 }
597
598 #[test]
599 fn test_is_local() {
600 assert!(EmbeddingModel::BgeSmallEnV15.is_local());
601 assert!(EmbeddingModel::BgeBaseEnV15.is_local());
602 assert!(EmbeddingModel::BgeLargeEnV15.is_local());
603 }
604
605 #[test]
606 fn test_display() {
607 assert_eq!(
608 EmbeddingModel::BgeSmallEnV15.to_string(),
609 "bge-small-en-v1.5"
610 );
611 assert_eq!(EmbeddingModel::BgeBaseEnV15.to_string(), "bge-base-en-v1.5");
612 assert_eq!(
613 EmbeddingModel::BgeLargeEnV15.to_string(),
614 "bge-large-en-v1.5"
615 );
616 }
617
618 #[test]
619 fn test_serialization_roundtrip() {
620 let model = EmbeddingModel::BgeSmallEnV15;
621 let json = serde_json::to_string(&model).unwrap();
622 let parsed: EmbeddingModel = serde_json::from_str(&json).unwrap();
623 assert_eq!(model, parsed);
624 }
625
626 #[test]
627 fn test_max_input_tokens() {
628 assert_eq!(EmbeddingModel::BgeSmallEnV15.max_input_tokens(), 512);
629 assert_eq!(EmbeddingModel::BgeBaseEnV15.max_input_tokens(), 512);
630 assert_eq!(EmbeddingModel::BgeLargeEnV15.max_input_tokens(), 512);
631 }
632
633 #[test]
634 fn test_from_str_display_names() {
635 assert_eq!(
636 "bge-small-en-v1.5".parse::<EmbeddingModel>().unwrap(),
637 EmbeddingModel::BgeSmallEnV15
638 );
639 assert_eq!(
640 "bge-base-en-v1.5".parse::<EmbeddingModel>().unwrap(),
641 EmbeddingModel::BgeBaseEnV15
642 );
643 assert_eq!(
644 "bge-large-en-v1.5".parse::<EmbeddingModel>().unwrap(),
645 EmbeddingModel::BgeLargeEnV15
646 );
647 }
648
649 #[test]
650 fn test_from_str_short_names() {
651 assert_eq!(
652 "small".parse::<EmbeddingModel>().unwrap(),
653 EmbeddingModel::BgeSmallEnV15
654 );
655 assert_eq!(
656 "bge-base".parse::<EmbeddingModel>().unwrap(),
657 EmbeddingModel::BgeBaseEnV15
658 );
659 assert_eq!(
660 "LARGE".parse::<EmbeddingModel>().unwrap(), EmbeddingModel::BgeLargeEnV15
662 );
663 }
664
665 #[test]
666 fn test_from_str_huggingface_ids() {
667 assert_eq!(
668 "BAAI/bge-small-en-v1.5".parse::<EmbeddingModel>().unwrap(),
669 EmbeddingModel::BgeSmallEnV15
670 );
671 }
672
673 #[test]
674 fn test_from_str_invalid() {
675 let result = "unknown-model".parse::<EmbeddingModel>();
676 assert!(result.is_err());
677 assert!(result.unwrap_err().contains("unknown embedding model"));
678 }
679}