1use serde::{Deserialize, Serialize};
6use std::time::SystemTime;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModelProvenance {
31 pub model: EmbeddingModel,
33 pub model_id: String,
35 pub hash: String,
41 pub loaded_at: SystemTime,
43 pub loaded_at_iso: String,
45}
46
47impl ModelProvenance {
48 pub fn new(model: EmbeddingModel, model_id: String) -> Self {
50 let loaded_at = SystemTime::now();
51 let loaded_at_iso = {
52 let dt: chrono::DateTime<chrono::Utc> = loaded_at.into();
53 dt.to_rfc3339()
54 };
55
56 let hash_input = format!("{model_id}:{loaded_at_iso}:{model:?}");
58 let hash = blake3::hash(hash_input.as_bytes()).to_hex().to_string();
59
60 Self {
61 model,
62 model_id,
63 hash,
64 loaded_at,
65 loaded_at_iso,
66 }
67 }
68
69 pub fn dimensions(&self) -> usize {
71 self.model.dimensions()
72 }
73
74 pub fn matches_model(&self, expected: EmbeddingModel) -> bool {
76 self.model == expected
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
98#[serde(rename_all = "snake_case")]
99#[non_exhaustive]
100pub enum EmbeddingModel {
101 #[default]
103 #[serde(alias = "BgeSmallEnV15")]
104 BgeSmallEnV15,
105
106 #[serde(alias = "BgeBaseEnV15")]
108 BgeBaseEnV15,
109
110 #[serde(alias = "BgeLargeEnV15")]
112 BgeLargeEnV15,
113
114 #[serde(alias = "MultilingualE5Small")]
116 MultilingualE5Small,
117
118 #[serde(alias = "MultilingualE5Base")]
120 MultilingualE5Base,
121
122 #[serde(alias = "Qwen3Embedding0_6B")]
124 Qwen3Embedding0_6B,
125
126 #[serde(alias = "Qwen3Embedding4B")]
128 Qwen3Embedding4B,
129
130 #[serde(alias = "AllMiniLmL6V2")]
132 AllMiniLmL6V2,
133
134 #[serde(alias = "ParaphraseMultilingualMiniLmL12V2")]
136 ParaphraseMultilingualMiniLmL12V2,
137
138 #[serde(alias = "TextEmbedding3Small")]
140 TextEmbedding3Small,
141}
142
143impl EmbeddingModel {
144 #[inline]
149 pub const fn native_dimensions(&self) -> usize {
150 match self {
151 EmbeddingModel::BgeSmallEnV15
152 | EmbeddingModel::MultilingualE5Small
153 | EmbeddingModel::AllMiniLmL6V2
154 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => 384,
155 EmbeddingModel::BgeBaseEnV15 | EmbeddingModel::MultilingualE5Base => 768,
156 EmbeddingModel::BgeLargeEnV15 | EmbeddingModel::Qwen3Embedding0_6B => 1024,
157 EmbeddingModel::Qwen3Embedding4B => 2560,
158 EmbeddingModel::TextEmbedding3Small => 1536,
159 }
160 }
161
162 #[inline]
174 pub const fn dimensions(&self) -> usize {
175 self.native_dimensions()
176 }
177
178 #[inline]
180 pub const fn is_local(&self) -> bool {
181 matches!(
182 self,
183 EmbeddingModel::BgeSmallEnV15
184 | EmbeddingModel::BgeBaseEnV15
185 | EmbeddingModel::BgeLargeEnV15
186 | EmbeddingModel::MultilingualE5Small
187 | EmbeddingModel::MultilingualE5Base
188 | EmbeddingModel::AllMiniLmL6V2
189 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2
190 | EmbeddingModel::Qwen3Embedding0_6B
191 | EmbeddingModel::Qwen3Embedding4B
192 )
193 }
194
195 #[inline]
197 pub const fn is_remote(&self) -> bool {
198 matches!(self, EmbeddingModel::TextEmbedding3Small)
199 }
200
201 #[inline]
211 pub const fn max_input_tokens(&self) -> usize {
212 match self {
213 EmbeddingModel::BgeSmallEnV15 => 512,
215 EmbeddingModel::BgeBaseEnV15 => 512,
216 EmbeddingModel::BgeLargeEnV15 => 512,
217 EmbeddingModel::MultilingualE5Small => 512,
219 EmbeddingModel::MultilingualE5Base => 512,
220 EmbeddingModel::AllMiniLmL6V2 => 256,
222 EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => 128,
224 EmbeddingModel::Qwen3Embedding0_6B => 8192,
226 EmbeddingModel::Qwen3Embedding4B => 8192,
227 EmbeddingModel::TextEmbedding3Small => 8191,
229 }
230 }
231
232 #[inline]
250 pub const fn query_instruction(&self) -> Option<&'static str> {
251 match self {
252 EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => {
253 Some("query: ")
256 }
257 EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => Some(
258 "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ",
259 ),
260 _ => None,
261 }
262 }
263
264 #[inline]
277 pub const fn document_instruction(&self) -> Option<&'static str> {
278 match self {
279 EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => {
280 Some("passage: ")
282 }
283 _ => None,
284 }
285 }
286
287 #[inline]
289 pub const fn model_id(&self) -> &'static str {
290 match self {
291 EmbeddingModel::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
292 EmbeddingModel::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
293 EmbeddingModel::BgeLargeEnV15 => "BAAI/bge-large-en-v1.5",
294 EmbeddingModel::MultilingualE5Small => "intfloat/multilingual-e5-small",
295 EmbeddingModel::MultilingualE5Base => "intfloat/multilingual-e5-base",
296 EmbeddingModel::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
297 EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
298 "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
299 }
300 EmbeddingModel::Qwen3Embedding0_6B => "Qwen/Qwen3-Embedding-0.6B",
301 EmbeddingModel::Qwen3Embedding4B => "Qwen/Qwen3-Embedding-4B",
302 EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
303 }
304 }
305
306 #[inline]
308 pub const fn supports_output_dim(&self) -> bool {
309 matches!(
310 self,
311 EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B
312 )
313 }
314
315 #[cfg(feature = "native")]
326 #[inline]
327 pub const fn bert_pooling(&self) -> Option<lattice_inference::BertPooling> {
328 match self {
329 EmbeddingModel::BgeSmallEnV15
331 | EmbeddingModel::BgeBaseEnV15
332 | EmbeddingModel::BgeLargeEnV15 => Some(lattice_inference::BertPooling::CLS),
333 EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => {
335 Some(lattice_inference::BertPooling::Mean)
336 }
337 EmbeddingModel::AllMiniLmL6V2 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
339 Some(lattice_inference::BertPooling::Mean)
340 }
341 EmbeddingModel::Qwen3Embedding0_6B
343 | EmbeddingModel::Qwen3Embedding4B
344 | EmbeddingModel::TextEmbedding3Small => None,
345 }
346 }
347
348 #[inline]
350 pub const fn key_version(&self) -> &'static str {
351 match self {
352 EmbeddingModel::TextEmbedding3Small
353 | EmbeddingModel::Qwen3Embedding0_6B
354 | EmbeddingModel::Qwen3Embedding4B => "v3",
355 EmbeddingModel::AllMiniLmL6V2 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
356 "v2"
357 }
358 _ => "v1.5",
359 }
360 }
361}
362
363impl std::fmt::Display for EmbeddingModel {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 match self {
366 EmbeddingModel::BgeSmallEnV15 => write!(f, "bge-small-en-v1.5"),
367 EmbeddingModel::BgeBaseEnV15 => write!(f, "bge-base-en-v1.5"),
368 EmbeddingModel::BgeLargeEnV15 => write!(f, "bge-large-en-v1.5"),
369 EmbeddingModel::MultilingualE5Small => write!(f, "multilingual-e5-small"),
370 EmbeddingModel::MultilingualE5Base => write!(f, "multilingual-e5-base"),
371 EmbeddingModel::Qwen3Embedding0_6B => write!(f, "qwen3-embedding-0.6b"),
372 EmbeddingModel::Qwen3Embedding4B => write!(f, "qwen3-embedding-4b"),
373 EmbeddingModel::AllMiniLmL6V2 => write!(f, "all-minilm-l6-v2"),
374 EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
375 write!(f, "paraphrase-multilingual-minilm-l12-v2")
376 }
377 EmbeddingModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
378 }
379 }
380}
381
382impl std::str::FromStr for EmbeddingModel {
383 type Err = String;
384
385 fn from_str(s: &str) -> Result<Self, Self::Err> {
392 let lower = s.to_lowercase();
393 let normalized = lower.trim().replace("_", "-").replace("baai/", "");
394
395 match normalized.as_str() {
396 "bge-small-en-v1.5" | "bge-small-en" | "bge-small" | "small" => {
397 Ok(EmbeddingModel::BgeSmallEnV15)
398 }
399 "bge-base-en-v1.5" | "bge-base-en" | "bge-base" | "base" => {
400 Ok(EmbeddingModel::BgeBaseEnV15)
401 }
402 "bge-large-en-v1.5" | "bge-large-en" | "bge-large" | "large" => {
403 Ok(EmbeddingModel::BgeLargeEnV15)
404 }
405 "multilingual-e5-small" | "e5-small" | "intfloat/multilingual-e5-small" => {
406 Ok(EmbeddingModel::MultilingualE5Small)
407 }
408 "multilingual-e5-base" | "e5-base" | "intfloat/multilingual-e5-base" => {
409 Ok(EmbeddingModel::MultilingualE5Base)
410 }
411 "qwen3-embedding-0.6b" | "qwen3-embedding" | "qwen3" | "qwen/qwen3-embedding-0.6b" => {
412 Ok(EmbeddingModel::Qwen3Embedding0_6B)
413 }
414 "qwen3-embedding-4b" | "qwen3-4b" | "qwen/qwen3-embedding-4b" => {
415 Ok(EmbeddingModel::Qwen3Embedding4B)
416 }
417 "all-minilm-l6-v2"
418 | "minilm"
419 | "all-minilm"
420 | "sentence-transformers/all-minilm-l6-v2" => Ok(EmbeddingModel::AllMiniLmL6V2),
421 "paraphrase-multilingual-minilm-l12-v2"
422 | "paraphrase-multilingual"
423 | "multilingual-minilm"
424 | "sentence-transformers/paraphrase-multilingual-minilm-l12-v2" => {
425 Ok(EmbeddingModel::ParaphraseMultilingualMiniLmL12V2)
426 }
427 "text-embedding-3-small" | "openai-small" | "openai" => {
428 Ok(EmbeddingModel::TextEmbedding3Small)
429 }
430 _ => Err(format!(
431 "unknown embedding model: '{s}'. Valid: bge-small-en-v1.5, bge-base-en-v1.5, bge-large-en-v1.5, multilingual-e5-small, multilingual-e5-base, text-embedding-3-small"
432 )),
433 }
434 }
435}
436
437pub const MIN_MRL_OUTPUT_DIM: usize = 32;
443
444#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
449pub struct ModelConfig {
450 pub model: EmbeddingModel,
452 #[serde(default)]
454 pub output_dim: Option<usize>,
455}
456
457impl Default for ModelConfig {
458 fn default() -> Self {
459 Self::new(EmbeddingModel::default())
460 }
461}
462
463impl ModelConfig {
464 pub const fn new(model: EmbeddingModel) -> Self {
466 Self {
467 model,
468 output_dim: None,
469 }
470 }
471
472 pub fn try_new(
474 model: EmbeddingModel,
475 output_dim: Option<usize>,
476 ) -> std::result::Result<Self, crate::error::EmbedError> {
477 let config = Self { model, output_dim };
478 config.validate()?;
479 Ok(config)
480 }
481
482 pub fn validate(&self) -> std::result::Result<(), crate::error::EmbedError> {
484 let Some(dim) = self.output_dim else {
485 return Ok(());
486 };
487 if !self.model.supports_output_dim() {
488 return Err(crate::error::EmbedError::InvalidInput(format!(
489 "{} does not support configurable embedding dimensions",
490 self.model
491 )));
492 }
493 if dim < MIN_MRL_OUTPUT_DIM {
494 return Err(crate::error::EmbedError::InvalidInput(format!(
495 "embedding output dimension {dim} is below minimum {MIN_MRL_OUTPUT_DIM}"
496 )));
497 }
498 let native = self.model.native_dimensions();
499 if dim > native {
500 return Err(crate::error::EmbedError::InvalidInput(format!(
501 "embedding output dimension {dim} exceeds native dimension {native} for {}",
502 self.model
503 )));
504 }
505 Ok(())
506 }
507
508 pub fn dimensions(&self) -> usize {
510 self.output_dim
511 .unwrap_or_else(|| self.model.native_dimensions())
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn test_default_model() {
521 let model = EmbeddingModel::default();
522 assert_eq!(model, EmbeddingModel::BgeSmallEnV15);
523 }
524
525 #[test]
526 fn test_model_provenance_new() {
527 let provenance = ModelProvenance::new(
528 EmbeddingModel::BgeSmallEnV15,
529 "BAAI/bge-small-en-v1.5".into(),
530 );
531
532 assert_eq!(provenance.model, EmbeddingModel::BgeSmallEnV15);
533 assert_eq!(provenance.model_id, "BAAI/bge-small-en-v1.5");
534 assert!(!provenance.hash.is_empty());
535 assert_eq!(provenance.hash.len(), 64); assert!(!provenance.loaded_at_iso.is_empty());
537 }
538
539 #[test]
540 fn test_model_provenance_unique_hash() {
541 let p1 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "model1".into());
542 std::thread::sleep(std::time::Duration::from_millis(10)); let p2 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "model1".into());
544
545 assert_ne!(p1.hash, p2.hash);
547 }
548
549 #[test]
550 fn test_model_provenance_dimensions() {
551 let p1 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "small".into());
552 assert_eq!(p1.dimensions(), 384);
553
554 let p2 = ModelProvenance::new(EmbeddingModel::BgeBaseEnV15, "base".into());
555 assert_eq!(p2.dimensions(), 768);
556
557 let p3 = ModelProvenance::new(EmbeddingModel::BgeLargeEnV15, "large".into());
558 assert_eq!(p3.dimensions(), 1024);
559 }
560
561 #[test]
562 fn test_model_provenance_matches_model() {
563 let provenance = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "test".into());
564
565 assert!(provenance.matches_model(EmbeddingModel::BgeSmallEnV15));
566 assert!(!provenance.matches_model(EmbeddingModel::BgeBaseEnV15));
567 assert!(!provenance.matches_model(EmbeddingModel::BgeLargeEnV15));
568 }
569
570 #[test]
571 fn test_model_provenance_serialization() {
572 let provenance = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "test-model".into());
573
574 let json = serde_json::to_string(&provenance).unwrap();
575 assert!(json.contains("bge_small_en_v15"), "json={json}");
578 assert!(json.contains("test-model"));
579 assert!(json.contains(&provenance.hash));
580
581 let parsed: ModelProvenance = serde_json::from_str(&json).unwrap();
582 assert_eq!(parsed.model, provenance.model);
583 assert_eq!(parsed.model_id, provenance.model_id);
584 assert_eq!(parsed.hash, provenance.hash);
585 }
586
587 #[test]
588 fn test_dimensions() {
589 assert_eq!(EmbeddingModel::BgeSmallEnV15.dimensions(), 384);
590 assert_eq!(EmbeddingModel::BgeBaseEnV15.dimensions(), 768);
591 assert_eq!(EmbeddingModel::BgeLargeEnV15.dimensions(), 1024);
592 assert_eq!(EmbeddingModel::Qwen3Embedding4B.dimensions(), 2560);
593 }
594
595 #[test]
596 fn test_model_config_native_dims() {
597 assert_eq!(
598 ModelConfig::new(EmbeddingModel::Qwen3Embedding4B).dimensions(),
599 2560
600 );
601 assert_eq!(
602 ModelConfig::new(EmbeddingModel::Qwen3Embedding0_6B).dimensions(),
603 1024
604 );
605 assert_eq!(
606 ModelConfig::new(EmbeddingModel::BgeSmallEnV15).dimensions(),
607 384
608 );
609 }
610
611 #[test]
612 fn test_model_config_configured_dim() {
613 let cfg = ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(1024)).unwrap();
614 assert_eq!(cfg.dimensions(), 1024);
615
616 let cfg = ModelConfig::try_new(EmbeddingModel::Qwen3Embedding0_6B, Some(512)).unwrap();
617 assert_eq!(cfg.dimensions(), 512);
618 }
619
620 #[test]
621 fn test_model_config_validation_below_min() {
622 assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(31)).is_err());
623 assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(0)).is_err());
624 }
625
626 #[test]
627 fn test_model_config_validation_above_native() {
628 assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(2561)).is_err());
629 assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding0_6B, Some(1025)).is_err());
630 }
631
632 #[test]
633 fn test_model_config_validation_non_mrl_model() {
634 assert!(ModelConfig::try_new(EmbeddingModel::BgeSmallEnV15, Some(128)).is_err());
635 assert!(ModelConfig::try_new(EmbeddingModel::BgeBaseEnV15, Some(512)).is_err());
636 }
637
638 #[test]
639 fn test_model_config_none_output_dim_ok_for_any_model() {
640 assert!(ModelConfig::try_new(EmbeddingModel::BgeSmallEnV15, None).is_ok());
641 assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, None).is_ok());
642 }
643
644 #[test]
645 fn test_is_local() {
646 assert!(EmbeddingModel::BgeSmallEnV15.is_local());
647 assert!(EmbeddingModel::BgeBaseEnV15.is_local());
648 assert!(EmbeddingModel::BgeLargeEnV15.is_local());
649 }
650
651 #[test]
652 fn test_display() {
653 assert_eq!(
654 EmbeddingModel::BgeSmallEnV15.to_string(),
655 "bge-small-en-v1.5"
656 );
657 assert_eq!(EmbeddingModel::BgeBaseEnV15.to_string(), "bge-base-en-v1.5");
658 assert_eq!(
659 EmbeddingModel::BgeLargeEnV15.to_string(),
660 "bge-large-en-v1.5"
661 );
662 }
663
664 #[test]
665 fn test_serialization_roundtrip() {
666 let model = EmbeddingModel::BgeSmallEnV15;
667 let json = serde_json::to_string(&model).unwrap();
668 let parsed: EmbeddingModel = serde_json::from_str(&json).unwrap();
669 assert_eq!(model, parsed);
670 }
671
672 #[test]
673 fn test_max_input_tokens() {
674 assert_eq!(EmbeddingModel::BgeSmallEnV15.max_input_tokens(), 512);
675 assert_eq!(EmbeddingModel::BgeBaseEnV15.max_input_tokens(), 512);
676 assert_eq!(EmbeddingModel::BgeLargeEnV15.max_input_tokens(), 512);
677 }
678
679 #[test]
680 fn test_from_str_display_names() {
681 assert_eq!(
682 "bge-small-en-v1.5".parse::<EmbeddingModel>().unwrap(),
683 EmbeddingModel::BgeSmallEnV15
684 );
685 assert_eq!(
686 "bge-base-en-v1.5".parse::<EmbeddingModel>().unwrap(),
687 EmbeddingModel::BgeBaseEnV15
688 );
689 assert_eq!(
690 "bge-large-en-v1.5".parse::<EmbeddingModel>().unwrap(),
691 EmbeddingModel::BgeLargeEnV15
692 );
693 }
694
695 #[test]
696 fn test_from_str_short_names() {
697 assert_eq!(
698 "small".parse::<EmbeddingModel>().unwrap(),
699 EmbeddingModel::BgeSmallEnV15
700 );
701 assert_eq!(
702 "bge-base".parse::<EmbeddingModel>().unwrap(),
703 EmbeddingModel::BgeBaseEnV15
704 );
705 assert_eq!(
706 "LARGE".parse::<EmbeddingModel>().unwrap(), EmbeddingModel::BgeLargeEnV15
708 );
709 }
710
711 #[test]
712 fn test_from_str_huggingface_ids() {
713 assert_eq!(
714 "BAAI/bge-small-en-v1.5".parse::<EmbeddingModel>().unwrap(),
715 EmbeddingModel::BgeSmallEnV15
716 );
717 }
718
719 #[test]
720 fn test_from_str_invalid() {
721 let result = "unknown-model".parse::<EmbeddingModel>();
722 assert!(result.is_err());
723 assert!(result.unwrap_err().contains("unknown embedding model"));
724 }
725
726 #[cfg(feature = "native")]
732 #[test]
733 fn test_bge_models_use_cls_pooling() {
734 use lattice_inference::BertPooling;
735
736 assert_eq!(
737 EmbeddingModel::BgeSmallEnV15.bert_pooling(),
738 Some(BertPooling::CLS),
739 "BgeSmallEnV15 must use CLS pooling"
740 );
741 assert_eq!(
742 EmbeddingModel::BgeBaseEnV15.bert_pooling(),
743 Some(BertPooling::CLS),
744 "BgeBaseEnV15 must use CLS pooling"
745 );
746 assert_eq!(
747 EmbeddingModel::BgeLargeEnV15.bert_pooling(),
748 Some(BertPooling::CLS),
749 "BgeLargeEnV15 must use CLS pooling"
750 );
751 }
752
753 #[cfg(feature = "native")]
755 #[test]
756 fn test_e5_models_use_mean_pooling() {
757 use lattice_inference::BertPooling;
758
759 assert_eq!(
760 EmbeddingModel::MultilingualE5Small.bert_pooling(),
761 Some(BertPooling::Mean),
762 "MultilingualE5Small must use mean pooling"
763 );
764 assert_eq!(
765 EmbeddingModel::MultilingualE5Base.bert_pooling(),
766 Some(BertPooling::Mean),
767 "MultilingualE5Base must use mean pooling"
768 );
769 }
770
771 #[cfg(feature = "native")]
773 #[test]
774 fn test_minilm_models_use_mean_pooling() {
775 use lattice_inference::BertPooling;
776
777 assert_eq!(
778 EmbeddingModel::AllMiniLmL6V2.bert_pooling(),
779 Some(BertPooling::Mean),
780 "AllMiniLmL6V2 must use mean pooling"
781 );
782 assert_eq!(
783 EmbeddingModel::ParaphraseMultilingualMiniLmL12V2.bert_pooling(),
784 Some(BertPooling::Mean),
785 "ParaphraseMultilingualMiniLmL12V2 must use mean pooling"
786 );
787 }
788
789 #[cfg(feature = "native")]
791 #[test]
792 fn test_non_bert_models_return_none_pooling() {
793 assert_eq!(
794 EmbeddingModel::Qwen3Embedding0_6B.bert_pooling(),
795 None,
796 "Qwen model must return None for bert_pooling()"
797 );
798 assert_eq!(
799 EmbeddingModel::Qwen3Embedding4B.bert_pooling(),
800 None,
801 "Qwen model must return None for bert_pooling()"
802 );
803 assert_eq!(
804 EmbeddingModel::TextEmbedding3Small.bert_pooling(),
805 None,
806 "Remote model must return None for bert_pooling()"
807 );
808 }
809
810 #[cfg(feature = "native")]
812 #[test]
813 fn test_bge_and_e5_use_different_pooling() {
814 assert_ne!(
815 EmbeddingModel::BgeSmallEnV15.bert_pooling(),
816 EmbeddingModel::MultilingualE5Small.bert_pooling(),
817 "BGE and E5 must use different pooling strategies"
818 );
819 }
820}