Skip to main content

lattice_embed/
model.rs

1//! Embedding model definitions.
2//!
3//! Provides `EmbeddingModel` enum for local model selection.
4
5use serde::{Deserialize, Serialize};
6use std::time::SystemTime;
7
8/// **Stable**: external consumers may depend on this; breaking changes require a SemVer bump.
9///
10/// Model provenance information for security audits.
11///
12/// Tracks metadata about when and how a model was loaded, including a hash
13/// for verification that the model hasn't been tampered with.
14///
15/// # Example
16///
17/// ```rust
18/// use lattice_embed::{EmbeddingModel, ModelProvenance};
19///
20/// // Created when a model is loaded
21/// let provenance = ModelProvenance::new(
22///     EmbeddingModel::BgeSmallEnV15,
23///     "BAAI/bge-small-en-v1.5".to_string(),
24/// );
25///
26/// assert!(provenance.model_id.contains("BAAI"));
27/// assert!(!provenance.hash.is_empty());
28/// ```
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModelProvenance {
31    /// **Stable**: model variant that was loaded.
32    pub model: EmbeddingModel,
33    /// **Stable**: source identifier (HuggingFace ID, URL, or file path).
34    pub model_id: String,
35    /// **Stable**: Blake3 hash of the model identifier + timestamp for uniqueness.
36    ///
37    /// Note: This is a lightweight hash based on metadata, not a full hash
38    /// of model weights (which would be expensive). For full model verification,
39    /// use the lattice-inference library's built-in checksum verification.
40    pub hash: String,
41    /// **Stable**: when the model was loaded.
42    pub loaded_at: SystemTime,
43    /// **Stable**: formatted timestamp string for convenience.
44    pub loaded_at_iso: String,
45}
46
47impl ModelProvenance {
48    /// **Stable**: create new provenance information for a loaded model.
49    pub fn new(model: EmbeddingModel, model_id: String) -> Self {
50        let loaded_at = SystemTime::now();
51        let loaded_at_iso = {
52            let dt: chrono::DateTime<chrono::Utc> = loaded_at.into();
53            dt.to_rfc3339()
54        };
55
56        // Create a lightweight hash from model metadata
57        let hash_input = format!("{model_id}:{loaded_at_iso}:{model:?}");
58        let hash = blake3::hash(hash_input.as_bytes()).to_hex().to_string();
59
60        Self {
61            model,
62            model_id,
63            hash,
64            loaded_at,
65            loaded_at_iso,
66        }
67    }
68
69    /// **Stable**: get the model dimensions.
70    pub fn dimensions(&self) -> usize {
71        self.model.dimensions()
72    }
73
74    /// **Stable**: check if this provenance matches expected model.
75    pub fn matches_model(&self, expected: EmbeddingModel) -> bool {
76        self.model == expected
77    }
78}
79
80/// **Stable**: external consumers may depend on this; breaking changes require a SemVer bump.
81///
82/// Supported embedding models.
83///
84/// This enum represents the embedding models available for text vectorization.
85/// Models are categorized as either local (run on-device via lattice-inference) or
86/// remote (require API calls).
87///
88/// # Example
89///
90/// ```rust
91/// use lattice_embed::EmbeddingModel;
92///
93/// let model = EmbeddingModel::default();
94/// assert_eq!(model.dimensions(), 384);
95/// assert!(model.is_local());
96/// ```
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
98#[serde(rename_all = "snake_case")]
99#[non_exhaustive]
100pub enum EmbeddingModel {
101    /// BGE small English v1.5 (384 dimensions) - fast and efficient.
102    #[default]
103    #[serde(alias = "BgeSmallEnV15")]
104    BgeSmallEnV15,
105
106    /// BGE base English v1.5 (768 dimensions) - balanced quality/speed.
107    #[serde(alias = "BgeBaseEnV15")]
108    BgeBaseEnV15,
109
110    /// BGE large English v1.5 (1024 dimensions) - highest quality local.
111    #[serde(alias = "BgeLargeEnV15")]
112    BgeLargeEnV15,
113
114    /// Multilingual E5 small (384 dimensions) - multilingual, same arch as BGE.
115    #[serde(alias = "MultilingualE5Small")]
116    MultilingualE5Small,
117
118    /// Multilingual E5 base (768 dimensions) - best multilingual quality/speed.
119    #[serde(alias = "MultilingualE5Base")]
120    MultilingualE5Base,
121
122    /// Qwen3-Embedding-0.6B (1024 dimensions) - multilingual, decoder-only, GPU-accelerated.
123    #[serde(alias = "Qwen3Embedding0_6B")]
124    Qwen3Embedding0_6B,
125
126    /// Qwen3-Embedding-4B (2560 dimensions, MRL-capable) - multilingual, decoder-only, GPU-accelerated.
127    #[serde(alias = "Qwen3Embedding4B")]
128    Qwen3Embedding4B,
129
130    /// all-MiniLM-L6-v2 (384 dimensions) - BERT-class, WordPiece tokenizer, sentence-transformers.
131    #[serde(alias = "AllMiniLmL6V2")]
132    AllMiniLmL6V2,
133
134    /// paraphrase-multilingual-MiniLM-L12-v2 (384 dimensions) - multilingual, XLM-R base, sentence-transformers.
135    #[serde(alias = "ParaphraseMultilingualMiniLmL12V2")]
136    ParaphraseMultilingualMiniLmL12V2,
137
138    /// OpenAI text-embedding-3-small (1536 dimensions) - remote API.
139    #[serde(alias = "TextEmbedding3Small")]
140    TextEmbedding3Small,
141}
142
143impl EmbeddingModel {
144    /// **Stable**: get the native (full-resolution) output dimension of this model's embeddings.
145    ///
146    /// Returns the model's intrinsic dimension regardless of any MRL truncation.
147    /// For MRL-capable models with a configured truncation, use `ModelConfig::dimensions()`.
148    #[inline]
149    pub const fn native_dimensions(&self) -> usize {
150        match self {
151            EmbeddingModel::BgeSmallEnV15
152            | EmbeddingModel::MultilingualE5Small
153            | EmbeddingModel::AllMiniLmL6V2
154            | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => 384,
155            EmbeddingModel::BgeBaseEnV15 | EmbeddingModel::MultilingualE5Base => 768,
156            EmbeddingModel::BgeLargeEnV15 | EmbeddingModel::Qwen3Embedding0_6B => 1024,
157            EmbeddingModel::Qwen3Embedding4B => 2560,
158            EmbeddingModel::TextEmbedding3Small => 1536,
159        }
160    }
161
162    /// **Stable**: get the output dimension of this model's embeddings.
163    ///
164    /// # Example
165    ///
166    /// ```rust
167    /// use lattice_embed::EmbeddingModel;
168    ///
169    /// assert_eq!(EmbeddingModel::BgeSmallEnV15.dimensions(), 384);
170    /// assert_eq!(EmbeddingModel::BgeBaseEnV15.dimensions(), 768);
171    /// assert_eq!(EmbeddingModel::BgeLargeEnV15.dimensions(), 1024);
172    /// ```
173    #[inline]
174    pub const fn dimensions(&self) -> usize {
175        self.native_dimensions()
176    }
177
178    /// **Stable**: check if this model can run locally (via lattice-inference).
179    #[inline]
180    pub const fn is_local(&self) -> bool {
181        matches!(
182            self,
183            EmbeddingModel::BgeSmallEnV15
184                | EmbeddingModel::BgeBaseEnV15
185                | EmbeddingModel::BgeLargeEnV15
186                | EmbeddingModel::MultilingualE5Small
187                | EmbeddingModel::MultilingualE5Base
188                | EmbeddingModel::AllMiniLmL6V2
189                | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2
190                | EmbeddingModel::Qwen3Embedding0_6B
191                | EmbeddingModel::Qwen3Embedding4B
192        )
193    }
194
195    /// **Stable**: check if this model requires a remote API.
196    #[inline]
197    pub const fn is_remote(&self) -> bool {
198        matches!(self, EmbeddingModel::TextEmbedding3Small)
199    }
200
201    /// **Stable**: maximum input tokens supported by this model.
202    ///
203    /// Use this for chunking/truncation decisions. Values are conservative
204    /// to leave room for special tokens.
205    ///
206    /// Reference limits:
207    /// - BGE models: 512 tokens
208    /// - OpenAI text-embedding-3: 8191 tokens
209    /// - Gemini embedding-001: 20000 tokens
210    #[inline]
211    pub const fn max_input_tokens(&self) -> usize {
212        match self {
213            // BGE models have 512 token limit
214            EmbeddingModel::BgeSmallEnV15 => 512,
215            EmbeddingModel::BgeBaseEnV15 => 512,
216            EmbeddingModel::BgeLargeEnV15 => 512,
217            // E5 models have 512 token limit
218            EmbeddingModel::MultilingualE5Small => 512,
219            EmbeddingModel::MultilingualE5Base => 512,
220            // MiniLM has a shorter context window
221            EmbeddingModel::AllMiniLmL6V2 => 256,
222            // paraphrase-multilingual-MiniLM max sequence length 128
223            EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => 128,
224            // Qwen3-Embedding supports 32K but we cap at 8192 for practical use
225            EmbeddingModel::Qwen3Embedding0_6B => 8192,
226            EmbeddingModel::Qwen3Embedding4B => 8192,
227            // OpenAI text-embedding-3-small has 8191 token limit
228            EmbeddingModel::TextEmbedding3Small => 8191,
229        }
230    }
231
232    /// **Stable**: query instruction prefix for asymmetric retrieval.
233    ///
234    /// Some models require different text for queries vs documents (asymmetric retrieval).
235    ///
236    /// - **E5 models** (`MultilingualE5Small`, `MultilingualE5Base`): trained with
237    ///   "query: " / "passage: " asymmetric prefixes. Omitting the prefix degrades
238    ///   retrieval quality significantly — the model expects them during fine-tuning.
239    ///
240    /// - **Qwen3-Embedding** models: require an instruction prompt to align the
241    ///   decoder embedding space for retrieval tasks.
242    ///
243    /// - **BGE / MiniLM** models: trained with contrastive objectives on raw text;
244    ///   no prefix needed.
245    ///
246    /// Returns `Some(prefix)` if the query text should be wrapped as
247    /// `"{prefix}{query}"` before embedding. Returns `None` for models that
248    /// don't need instruction prompting.
249    #[inline]
250    pub const fn query_instruction(&self) -> Option<&'static str> {
251        match self {
252            EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => {
253                // E5 asymmetric retrieval: "query: " prefix for queries,
254                // "passage: " prefix for documents (see document_instruction()).
255                Some("query: ")
256            }
257            EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => Some(
258                "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ",
259            ),
260            _ => None,
261        }
262    }
263
264    /// **Stable**: document instruction prefix for asymmetric retrieval.
265    ///
266    /// Some models use different prompts for documents vs queries.
267    /// Returns `Some(prefix)` if the document text should be wrapped as
268    /// `"{prefix}{text}"` before embedding at storage time.
269    ///
270    /// - **E5 models**: trained with `"passage: "` prefix on document/passage inputs.
271    ///   Omitting the prefix on the document side degrades retrieval quality because
272    ///   the model's embedding space was conditioned on this asymmetry during fine-tuning.
273    /// - **BGE / MiniLM**: no document prefix required (contrastive training on raw text).
274    /// - **Qwen3-Embedding**: raw passage text is used without an instruction prefix;
275    ///   only the query side carries the task instruction.
276    #[inline]
277    pub const fn document_instruction(&self) -> Option<&'static str> {
278        match self {
279            EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => {
280                // E5 asymmetric retrieval: "passage: " prefix for documents/passages.
281                Some("passage: ")
282            }
283            _ => None,
284        }
285    }
286
287    /// **Stable**: get the model identifier (HuggingFace ID or provider/model).
288    #[inline]
289    pub const fn model_id(&self) -> &'static str {
290        match self {
291            EmbeddingModel::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
292            EmbeddingModel::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
293            EmbeddingModel::BgeLargeEnV15 => "BAAI/bge-large-en-v1.5",
294            EmbeddingModel::MultilingualE5Small => "intfloat/multilingual-e5-small",
295            EmbeddingModel::MultilingualE5Base => "intfloat/multilingual-e5-base",
296            EmbeddingModel::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
297            EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
298                "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
299            }
300            EmbeddingModel::Qwen3Embedding0_6B => "Qwen/Qwen3-Embedding-0.6B",
301            EmbeddingModel::Qwen3Embedding4B => "Qwen/Qwen3-Embedding-4B",
302            EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
303        }
304    }
305
306    /// **Stable**: whether this model supports configurable output dimensions (MRL/Matryoshka).
307    #[inline]
308    pub const fn supports_output_dim(&self) -> bool {
309        matches!(
310            self,
311            EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B
312        )
313    }
314
315    /// **Stable**: the pooling strategy this model expects from BERT-family inference.
316    ///
317    /// BGE v1.5 models use CLS-token pooling (first token) as documented on their
318    /// HuggingFace model cards (`model_output[0][:, 0]`).  All other BERT-family
319    /// models (E5, MiniLM) use masked mean pooling.
320    ///
321    /// Returns `None` for non-BERT models (Qwen3, OpenAI remote) which have their
322    /// own pooling paths.
323    ///
324    /// Only available when the `native` feature is enabled (requires `lattice-inference`).
325    #[cfg(feature = "native")]
326    #[inline]
327    pub const fn bert_pooling(&self) -> Option<lattice_inference::BertPooling> {
328        match self {
329            // BGE v1.5 — CLS pooling per model card
330            EmbeddingModel::BgeSmallEnV15
331            | EmbeddingModel::BgeBaseEnV15
332            | EmbeddingModel::BgeLargeEnV15 => Some(lattice_inference::BertPooling::CLS),
333            // E5 multilingual — masked mean pooling per model card
334            EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => {
335                Some(lattice_inference::BertPooling::Mean)
336            }
337            // MiniLM family — masked mean pooling per sentence-transformers convention
338            EmbeddingModel::AllMiniLmL6V2 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
339                Some(lattice_inference::BertPooling::Mean)
340            }
341            // Qwen and remote models — not BERT-family, pooling handled separately
342            EmbeddingModel::Qwen3Embedding0_6B
343            | EmbeddingModel::Qwen3Embedding4B
344            | EmbeddingModel::TextEmbedding3Small => None,
345        }
346    }
347
348    /// **Stable**: embedding key revision string for this model family.
349    #[inline]
350    pub const fn key_version(&self) -> &'static str {
351        match self {
352            EmbeddingModel::TextEmbedding3Small
353            | EmbeddingModel::Qwen3Embedding0_6B
354            | EmbeddingModel::Qwen3Embedding4B => "v3",
355            EmbeddingModel::AllMiniLmL6V2 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
356                "v2"
357            }
358            _ => "v1.5",
359        }
360    }
361}
362
363impl std::fmt::Display for EmbeddingModel {
364    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365        match self {
366            EmbeddingModel::BgeSmallEnV15 => write!(f, "bge-small-en-v1.5"),
367            EmbeddingModel::BgeBaseEnV15 => write!(f, "bge-base-en-v1.5"),
368            EmbeddingModel::BgeLargeEnV15 => write!(f, "bge-large-en-v1.5"),
369            EmbeddingModel::MultilingualE5Small => write!(f, "multilingual-e5-small"),
370            EmbeddingModel::MultilingualE5Base => write!(f, "multilingual-e5-base"),
371            EmbeddingModel::Qwen3Embedding0_6B => write!(f, "qwen3-embedding-0.6b"),
372            EmbeddingModel::Qwen3Embedding4B => write!(f, "qwen3-embedding-4b"),
373            EmbeddingModel::AllMiniLmL6V2 => write!(f, "all-minilm-l6-v2"),
374            EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
375                write!(f, "paraphrase-multilingual-minilm-l12-v2")
376            }
377            EmbeddingModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
378        }
379    }
380}
381
382impl std::str::FromStr for EmbeddingModel {
383    type Err = String;
384
385    /// **Stable**: parse model from string (case-insensitive, flexible matching).
386    ///
387    /// Accepts:
388    /// - Display names: "bge-small-en-v1.5"
389    /// - Short names: "bge-small", "small"
390    /// - HuggingFace IDs: "BAAI/bge-small-en-v1.5"
391    fn from_str(s: &str) -> Result<Self, Self::Err> {
392        let lower = s.to_lowercase();
393        let normalized = lower.trim().replace("_", "-").replace("baai/", "");
394
395        match normalized.as_str() {
396            "bge-small-en-v1.5" | "bge-small-en" | "bge-small" | "small" => {
397                Ok(EmbeddingModel::BgeSmallEnV15)
398            }
399            "bge-base-en-v1.5" | "bge-base-en" | "bge-base" | "base" => {
400                Ok(EmbeddingModel::BgeBaseEnV15)
401            }
402            "bge-large-en-v1.5" | "bge-large-en" | "bge-large" | "large" => {
403                Ok(EmbeddingModel::BgeLargeEnV15)
404            }
405            "multilingual-e5-small" | "e5-small" | "intfloat/multilingual-e5-small" => {
406                Ok(EmbeddingModel::MultilingualE5Small)
407            }
408            "multilingual-e5-base" | "e5-base" | "intfloat/multilingual-e5-base" => {
409                Ok(EmbeddingModel::MultilingualE5Base)
410            }
411            "qwen3-embedding-0.6b" | "qwen3-embedding" | "qwen3" | "qwen/qwen3-embedding-0.6b" => {
412                Ok(EmbeddingModel::Qwen3Embedding0_6B)
413            }
414            "qwen3-embedding-4b" | "qwen3-4b" | "qwen/qwen3-embedding-4b" => {
415                Ok(EmbeddingModel::Qwen3Embedding4B)
416            }
417            "all-minilm-l6-v2"
418            | "minilm"
419            | "all-minilm"
420            | "sentence-transformers/all-minilm-l6-v2" => Ok(EmbeddingModel::AllMiniLmL6V2),
421            "paraphrase-multilingual-minilm-l12-v2"
422            | "paraphrase-multilingual"
423            | "multilingual-minilm"
424            | "sentence-transformers/paraphrase-multilingual-minilm-l12-v2" => {
425                Ok(EmbeddingModel::ParaphraseMultilingualMiniLmL12V2)
426            }
427            "text-embedding-3-small" | "openai-small" | "openai" => {
428                Ok(EmbeddingModel::TextEmbedding3Small)
429            }
430            _ => Err(format!(
431                "unknown embedding model: '{s}'. Valid: bge-small-en-v1.5, bge-base-en-v1.5, bge-large-en-v1.5, multilingual-e5-small, multilingual-e5-base, text-embedding-3-small"
432            )),
433        }
434    }
435}
436
437// ============================================================================
438// ModelConfig — runtime MRL dimension configuration
439// ============================================================================
440
441/// Minimum allowed MRL output dimension.
442pub const MIN_MRL_OUTPUT_DIM: usize = 32;
443
444/// Runtime configuration pairing a model with an optional MRL truncation dimension.
445///
446/// Two `ModelConfig` values with different `output_dim` produce different embedding spaces
447/// and must be stored in separate vector index namespaces.
448#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
449pub struct ModelConfig {
450    /// The underlying embedding model.
451    pub model: EmbeddingModel,
452    /// MRL truncation dimension. `None` uses the model's native dimension.
453    #[serde(default)]
454    pub output_dim: Option<usize>,
455}
456
457impl Default for ModelConfig {
458    fn default() -> Self {
459        Self::new(EmbeddingModel::default())
460    }
461}
462
463impl ModelConfig {
464    /// Create a config with no MRL truncation (native model dimension).
465    pub const fn new(model: EmbeddingModel) -> Self {
466        Self {
467            model,
468            output_dim: None,
469        }
470    }
471
472    /// Create and validate a config with an optional MRL truncation dimension.
473    pub fn try_new(
474        model: EmbeddingModel,
475        output_dim: Option<usize>,
476    ) -> std::result::Result<Self, crate::error::EmbedError> {
477        let config = Self { model, output_dim };
478        config.validate()?;
479        Ok(config)
480    }
481
482    /// Validate that the output dimension is consistent with the model.
483    pub fn validate(&self) -> std::result::Result<(), crate::error::EmbedError> {
484        let Some(dim) = self.output_dim else {
485            return Ok(());
486        };
487        if !self.model.supports_output_dim() {
488            return Err(crate::error::EmbedError::InvalidInput(format!(
489                "{} does not support configurable embedding dimensions",
490                self.model
491            )));
492        }
493        if dim < MIN_MRL_OUTPUT_DIM {
494            return Err(crate::error::EmbedError::InvalidInput(format!(
495                "embedding output dimension {dim} is below minimum {MIN_MRL_OUTPUT_DIM}"
496            )));
497        }
498        let native = self.model.native_dimensions();
499        if dim > native {
500            return Err(crate::error::EmbedError::InvalidInput(format!(
501                "embedding output dimension {dim} exceeds native dimension {native} for {}",
502                self.model
503            )));
504        }
505        Ok(())
506    }
507
508    /// Active output dimension: configured truncation if set, otherwise the model's native dimension.
509    pub fn dimensions(&self) -> usize {
510        self.output_dim
511            .unwrap_or_else(|| self.model.native_dimensions())
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    #[test]
520    fn test_default_model() {
521        let model = EmbeddingModel::default();
522        assert_eq!(model, EmbeddingModel::BgeSmallEnV15);
523    }
524
525    #[test]
526    fn test_model_provenance_new() {
527        let provenance = ModelProvenance::new(
528            EmbeddingModel::BgeSmallEnV15,
529            "BAAI/bge-small-en-v1.5".into(),
530        );
531
532        assert_eq!(provenance.model, EmbeddingModel::BgeSmallEnV15);
533        assert_eq!(provenance.model_id, "BAAI/bge-small-en-v1.5");
534        assert!(!provenance.hash.is_empty());
535        assert_eq!(provenance.hash.len(), 64); // blake3 hex is 64 chars
536        assert!(!provenance.loaded_at_iso.is_empty());
537    }
538
539    #[test]
540    fn test_model_provenance_unique_hash() {
541        let p1 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "model1".into());
542        std::thread::sleep(std::time::Duration::from_millis(10)); // Ensure different timestamp
543        let p2 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "model1".into());
544
545        // Different timestamps should produce different hashes
546        assert_ne!(p1.hash, p2.hash);
547    }
548
549    #[test]
550    fn test_model_provenance_dimensions() {
551        let p1 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "small".into());
552        assert_eq!(p1.dimensions(), 384);
553
554        let p2 = ModelProvenance::new(EmbeddingModel::BgeBaseEnV15, "base".into());
555        assert_eq!(p2.dimensions(), 768);
556
557        let p3 = ModelProvenance::new(EmbeddingModel::BgeLargeEnV15, "large".into());
558        assert_eq!(p3.dimensions(), 1024);
559    }
560
561    #[test]
562    fn test_model_provenance_matches_model() {
563        let provenance = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "test".into());
564
565        assert!(provenance.matches_model(EmbeddingModel::BgeSmallEnV15));
566        assert!(!provenance.matches_model(EmbeddingModel::BgeBaseEnV15));
567        assert!(!provenance.matches_model(EmbeddingModel::BgeLargeEnV15));
568    }
569
570    #[test]
571    fn test_model_provenance_serialization() {
572        let provenance = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "test-model".into());
573
574        let json = serde_json::to_string(&provenance).unwrap();
575        // FP-037: EmbeddingModel has #[serde(rename_all = "snake_case")] so
576        // BgeSmallEnV15 serializes as "bge_small_en_v15", not "BgeSmallEnV15".
577        assert!(json.contains("bge_small_en_v15"), "json={json}");
578        assert!(json.contains("test-model"));
579        assert!(json.contains(&provenance.hash));
580
581        let parsed: ModelProvenance = serde_json::from_str(&json).unwrap();
582        assert_eq!(parsed.model, provenance.model);
583        assert_eq!(parsed.model_id, provenance.model_id);
584        assert_eq!(parsed.hash, provenance.hash);
585    }
586
587    #[test]
588    fn test_dimensions() {
589        assert_eq!(EmbeddingModel::BgeSmallEnV15.dimensions(), 384);
590        assert_eq!(EmbeddingModel::BgeBaseEnV15.dimensions(), 768);
591        assert_eq!(EmbeddingModel::BgeLargeEnV15.dimensions(), 1024);
592        assert_eq!(EmbeddingModel::Qwen3Embedding4B.dimensions(), 2560);
593    }
594
595    #[test]
596    fn test_model_config_native_dims() {
597        assert_eq!(
598            ModelConfig::new(EmbeddingModel::Qwen3Embedding4B).dimensions(),
599            2560
600        );
601        assert_eq!(
602            ModelConfig::new(EmbeddingModel::Qwen3Embedding0_6B).dimensions(),
603            1024
604        );
605        assert_eq!(
606            ModelConfig::new(EmbeddingModel::BgeSmallEnV15).dimensions(),
607            384
608        );
609    }
610
611    #[test]
612    fn test_model_config_configured_dim() {
613        let cfg = ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(1024)).unwrap();
614        assert_eq!(cfg.dimensions(), 1024);
615
616        let cfg = ModelConfig::try_new(EmbeddingModel::Qwen3Embedding0_6B, Some(512)).unwrap();
617        assert_eq!(cfg.dimensions(), 512);
618    }
619
620    #[test]
621    fn test_model_config_validation_below_min() {
622        assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(31)).is_err());
623        assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(0)).is_err());
624    }
625
626    #[test]
627    fn test_model_config_validation_above_native() {
628        assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(2561)).is_err());
629        assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding0_6B, Some(1025)).is_err());
630    }
631
632    #[test]
633    fn test_model_config_validation_non_mrl_model() {
634        assert!(ModelConfig::try_new(EmbeddingModel::BgeSmallEnV15, Some(128)).is_err());
635        assert!(ModelConfig::try_new(EmbeddingModel::BgeBaseEnV15, Some(512)).is_err());
636    }
637
638    #[test]
639    fn test_model_config_none_output_dim_ok_for_any_model() {
640        assert!(ModelConfig::try_new(EmbeddingModel::BgeSmallEnV15, None).is_ok());
641        assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, None).is_ok());
642    }
643
644    #[test]
645    fn test_is_local() {
646        assert!(EmbeddingModel::BgeSmallEnV15.is_local());
647        assert!(EmbeddingModel::BgeBaseEnV15.is_local());
648        assert!(EmbeddingModel::BgeLargeEnV15.is_local());
649    }
650
651    #[test]
652    fn test_display() {
653        assert_eq!(
654            EmbeddingModel::BgeSmallEnV15.to_string(),
655            "bge-small-en-v1.5"
656        );
657        assert_eq!(EmbeddingModel::BgeBaseEnV15.to_string(), "bge-base-en-v1.5");
658        assert_eq!(
659            EmbeddingModel::BgeLargeEnV15.to_string(),
660            "bge-large-en-v1.5"
661        );
662    }
663
664    #[test]
665    fn test_serialization_roundtrip() {
666        let model = EmbeddingModel::BgeSmallEnV15;
667        let json = serde_json::to_string(&model).unwrap();
668        let parsed: EmbeddingModel = serde_json::from_str(&json).unwrap();
669        assert_eq!(model, parsed);
670    }
671
672    #[test]
673    fn test_max_input_tokens() {
674        assert_eq!(EmbeddingModel::BgeSmallEnV15.max_input_tokens(), 512);
675        assert_eq!(EmbeddingModel::BgeBaseEnV15.max_input_tokens(), 512);
676        assert_eq!(EmbeddingModel::BgeLargeEnV15.max_input_tokens(), 512);
677    }
678
679    #[test]
680    fn test_from_str_display_names() {
681        assert_eq!(
682            "bge-small-en-v1.5".parse::<EmbeddingModel>().unwrap(),
683            EmbeddingModel::BgeSmallEnV15
684        );
685        assert_eq!(
686            "bge-base-en-v1.5".parse::<EmbeddingModel>().unwrap(),
687            EmbeddingModel::BgeBaseEnV15
688        );
689        assert_eq!(
690            "bge-large-en-v1.5".parse::<EmbeddingModel>().unwrap(),
691            EmbeddingModel::BgeLargeEnV15
692        );
693    }
694
695    #[test]
696    fn test_from_str_short_names() {
697        assert_eq!(
698            "small".parse::<EmbeddingModel>().unwrap(),
699            EmbeddingModel::BgeSmallEnV15
700        );
701        assert_eq!(
702            "bge-base".parse::<EmbeddingModel>().unwrap(),
703            EmbeddingModel::BgeBaseEnV15
704        );
705        assert_eq!(
706            "LARGE".parse::<EmbeddingModel>().unwrap(), // case insensitive
707            EmbeddingModel::BgeLargeEnV15
708        );
709    }
710
711    #[test]
712    fn test_from_str_huggingface_ids() {
713        assert_eq!(
714            "BAAI/bge-small-en-v1.5".parse::<EmbeddingModel>().unwrap(),
715            EmbeddingModel::BgeSmallEnV15
716        );
717    }
718
719    #[test]
720    fn test_from_str_invalid() {
721        let result = "unknown-model".parse::<EmbeddingModel>();
722        assert!(result.is_err());
723        assert!(result.unwrap_err().contains("unknown embedding model"));
724    }
725
726    // -------------------------------------------------------------------------
727    // bert_pooling() routing tests (P1-E3) — require `native` feature
728    // -------------------------------------------------------------------------
729
730    /// BGE small/base/large must use CLS pooling per their HF model cards.
731    #[cfg(feature = "native")]
732    #[test]
733    fn test_bge_models_use_cls_pooling() {
734        use lattice_inference::BertPooling;
735
736        assert_eq!(
737            EmbeddingModel::BgeSmallEnV15.bert_pooling(),
738            Some(BertPooling::CLS),
739            "BgeSmallEnV15 must use CLS pooling"
740        );
741        assert_eq!(
742            EmbeddingModel::BgeBaseEnV15.bert_pooling(),
743            Some(BertPooling::CLS),
744            "BgeBaseEnV15 must use CLS pooling"
745        );
746        assert_eq!(
747            EmbeddingModel::BgeLargeEnV15.bert_pooling(),
748            Some(BertPooling::CLS),
749            "BgeLargeEnV15 must use CLS pooling"
750        );
751    }
752
753    /// E5 models must use mean pooling per their HF model cards.
754    #[cfg(feature = "native")]
755    #[test]
756    fn test_e5_models_use_mean_pooling() {
757        use lattice_inference::BertPooling;
758
759        assert_eq!(
760            EmbeddingModel::MultilingualE5Small.bert_pooling(),
761            Some(BertPooling::Mean),
762            "MultilingualE5Small must use mean pooling"
763        );
764        assert_eq!(
765            EmbeddingModel::MultilingualE5Base.bert_pooling(),
766            Some(BertPooling::Mean),
767            "MultilingualE5Base must use mean pooling"
768        );
769    }
770
771    /// MiniLM models must use mean pooling per sentence-transformers convention.
772    #[cfg(feature = "native")]
773    #[test]
774    fn test_minilm_models_use_mean_pooling() {
775        use lattice_inference::BertPooling;
776
777        assert_eq!(
778            EmbeddingModel::AllMiniLmL6V2.bert_pooling(),
779            Some(BertPooling::Mean),
780            "AllMiniLmL6V2 must use mean pooling"
781        );
782        assert_eq!(
783            EmbeddingModel::ParaphraseMultilingualMiniLmL12V2.bert_pooling(),
784            Some(BertPooling::Mean),
785            "ParaphraseMultilingualMiniLmL12V2 must use mean pooling"
786        );
787    }
788
789    /// Qwen and remote models return None — they have separate pooling paths.
790    #[cfg(feature = "native")]
791    #[test]
792    fn test_non_bert_models_return_none_pooling() {
793        assert_eq!(
794            EmbeddingModel::Qwen3Embedding0_6B.bert_pooling(),
795            None,
796            "Qwen model must return None for bert_pooling()"
797        );
798        assert_eq!(
799            EmbeddingModel::Qwen3Embedding4B.bert_pooling(),
800            None,
801            "Qwen model must return None for bert_pooling()"
802        );
803        assert_eq!(
804            EmbeddingModel::TextEmbedding3Small.bert_pooling(),
805            None,
806            "Remote model must return None for bert_pooling()"
807        );
808    }
809
810    /// BGE and E5 use DIFFERENT pooling strategies — this is the key correctness distinction.
811    #[cfg(feature = "native")]
812    #[test]
813    fn test_bge_and_e5_use_different_pooling() {
814        assert_ne!(
815            EmbeddingModel::BgeSmallEnV15.bert_pooling(),
816            EmbeddingModel::MultilingualE5Small.bert_pooling(),
817            "BGE and E5 must use different pooling strategies"
818        );
819    }
820}