Skip to main content

lattice_embed/
model.rs

1//! Embedding model definitions.
2//!
3//! Provides `EmbeddingModel` enum for local model selection.
4
5use serde::{Deserialize, Serialize};
6use std::time::SystemTime;
7
8/// **Stable**: external consumers may depend on this; breaking changes require a SemVer bump.
9///
10/// Model provenance information for security audits.
11///
12/// Tracks metadata about when and how a model was loaded, including a hash
13/// for verification that the model hasn't been tampered with.
14///
15/// # Example
16///
17/// ```rust
18/// use lattice_embed::{EmbeddingModel, ModelProvenance};
19///
20/// // Created when a model is loaded
21/// let provenance = ModelProvenance::new(
22///     EmbeddingModel::BgeSmallEnV15,
23///     "BAAI/bge-small-en-v1.5".to_string(),
24/// );
25///
26/// assert!(provenance.model_id.contains("BAAI"));
27/// assert!(!provenance.hash.is_empty());
28/// ```
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModelProvenance {
31    /// **Stable**: model variant that was loaded.
32    pub model: EmbeddingModel,
33    /// **Stable**: source identifier (HuggingFace ID, URL, or file path).
34    pub model_id: String,
35    /// **Stable**: Blake3 hash of the model identifier + timestamp for uniqueness.
36    ///
37    /// Note: This is a lightweight hash based on metadata, not a full hash
38    /// of model weights (which would be expensive). For full model verification,
39    /// use the lattice-inference library's built-in checksum verification.
40    pub hash: String,
41    /// **Stable**: when the model was loaded.
42    pub loaded_at: SystemTime,
43    /// **Stable**: formatted timestamp string for convenience.
44    pub loaded_at_iso: String,
45}
46
47impl ModelProvenance {
48    /// **Stable**: create new provenance information for a loaded model.
49    pub fn new(model: EmbeddingModel, model_id: String) -> Self {
50        let loaded_at = SystemTime::now();
51        let loaded_at_iso = {
52            let dt: chrono::DateTime<chrono::Utc> = loaded_at.into();
53            dt.to_rfc3339()
54        };
55
56        // Create a lightweight hash from model metadata
57        let hash_input = format!("{model_id}:{loaded_at_iso}:{model:?}");
58        let hash = blake3::hash(hash_input.as_bytes()).to_hex().to_string();
59
60        Self {
61            model,
62            model_id,
63            hash,
64            loaded_at,
65            loaded_at_iso,
66        }
67    }
68
69    /// **Stable**: get the model dimensions.
70    pub fn dimensions(&self) -> usize {
71        self.model.dimensions()
72    }
73
74    /// **Stable**: check if this provenance matches expected model.
75    pub fn matches_model(&self, expected: EmbeddingModel) -> bool {
76        self.model == expected
77    }
78}
79
80/// **Stable**: external consumers may depend on this; breaking changes require a SemVer bump.
81///
82/// Supported embedding models.
83///
84/// This enum represents the embedding models available for text vectorization.
85/// Models are categorized as either local (run on-device via lattice-inference) or
86/// remote (require API calls).
87///
88/// # Example
89///
90/// ```rust
91/// use lattice_embed::EmbeddingModel;
92///
93/// let model = EmbeddingModel::default();
94/// assert_eq!(model.dimensions(), 384);
95/// assert!(model.is_local());
96/// ```
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
98#[serde(rename_all = "snake_case")]
99#[non_exhaustive]
100pub enum EmbeddingModel {
101    /// BGE small English v1.5 (384 dimensions) - fast and efficient.
102    #[default]
103    #[serde(alias = "BgeSmallEnV15")]
104    BgeSmallEnV15,
105
106    /// BGE base English v1.5 (768 dimensions) - balanced quality/speed.
107    #[serde(alias = "BgeBaseEnV15")]
108    BgeBaseEnV15,
109
110    /// BGE large English v1.5 (1024 dimensions) - highest quality local.
111    #[serde(alias = "BgeLargeEnV15")]
112    BgeLargeEnV15,
113
114    /// Multilingual E5 small (384 dimensions) - multilingual, same arch as BGE.
115    #[serde(alias = "MultilingualE5Small")]
116    MultilingualE5Small,
117
118    /// Multilingual E5 base (768 dimensions) - best multilingual quality/speed.
119    #[serde(alias = "MultilingualE5Base")]
120    MultilingualE5Base,
121
122    /// Qwen3-Embedding-0.6B (1024 dimensions) - multilingual, decoder-only, GPU-accelerated.
123    #[serde(alias = "Qwen3Embedding0_6B")]
124    Qwen3Embedding0_6B,
125
126    /// Qwen3-Embedding-4B (2560 dimensions, MRL-capable) - multilingual, decoder-only, GPU-accelerated.
127    #[serde(alias = "Qwen3Embedding4B")]
128    Qwen3Embedding4B,
129
130    /// all-MiniLM-L6-v2 (384 dimensions) - BERT-class, WordPiece tokenizer, sentence-transformers.
131    #[serde(alias = "AllMiniLmL6V2")]
132    AllMiniLmL6V2,
133
134    /// paraphrase-multilingual-MiniLM-L12-v2 (384 dimensions) - multilingual, XLM-R base, sentence-transformers.
135    #[serde(alias = "ParaphraseMultilingualMiniLmL12V2")]
136    ParaphraseMultilingualMiniLmL12V2,
137
138    /// OpenAI text-embedding-3-small (1536 dimensions) - remote API.
139    #[serde(alias = "TextEmbedding3Small")]
140    TextEmbedding3Small,
141}
142
143impl EmbeddingModel {
144    /// **Stable**: get the native (full-resolution) output dimension of this model's embeddings.
145    ///
146    /// Returns the model's intrinsic dimension regardless of any MRL truncation.
147    /// For MRL-capable models with a configured truncation, use `ModelConfig::dimensions()`.
148    #[inline]
149    pub const fn native_dimensions(&self) -> usize {
150        match self {
151            EmbeddingModel::BgeSmallEnV15
152            | EmbeddingModel::MultilingualE5Small
153            | EmbeddingModel::AllMiniLmL6V2
154            | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => 384,
155            EmbeddingModel::BgeBaseEnV15 | EmbeddingModel::MultilingualE5Base => 768,
156            EmbeddingModel::BgeLargeEnV15 | EmbeddingModel::Qwen3Embedding0_6B => 1024,
157            EmbeddingModel::Qwen3Embedding4B => 2560,
158            EmbeddingModel::TextEmbedding3Small => 1536,
159        }
160    }
161
162    /// **Stable**: get the output dimension of this model's embeddings.
163    ///
164    /// # Example
165    ///
166    /// ```rust
167    /// use lattice_embed::EmbeddingModel;
168    ///
169    /// assert_eq!(EmbeddingModel::BgeSmallEnV15.dimensions(), 384);
170    /// assert_eq!(EmbeddingModel::BgeBaseEnV15.dimensions(), 768);
171    /// assert_eq!(EmbeddingModel::BgeLargeEnV15.dimensions(), 1024);
172    /// ```
173    #[inline]
174    pub const fn dimensions(&self) -> usize {
175        self.native_dimensions()
176    }
177
178    /// **Stable**: check if this model can run locally (via lattice-inference).
179    #[inline]
180    pub const fn is_local(&self) -> bool {
181        matches!(
182            self,
183            EmbeddingModel::BgeSmallEnV15
184                | EmbeddingModel::BgeBaseEnV15
185                | EmbeddingModel::BgeLargeEnV15
186                | EmbeddingModel::MultilingualE5Small
187                | EmbeddingModel::MultilingualE5Base
188                | EmbeddingModel::AllMiniLmL6V2
189                | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2
190                | EmbeddingModel::Qwen3Embedding0_6B
191                | EmbeddingModel::Qwen3Embedding4B
192        )
193    }
194
195    /// **Stable**: check if this model requires a remote API.
196    #[inline]
197    pub const fn is_remote(&self) -> bool {
198        matches!(self, EmbeddingModel::TextEmbedding3Small)
199    }
200
201    /// **Stable**: maximum input tokens supported by this model.
202    ///
203    /// Use this for chunking/truncation decisions. Values are conservative
204    /// to leave room for special tokens.
205    ///
206    /// Reference limits:
207    /// - BGE models: 512 tokens
208    /// - OpenAI text-embedding-3: 8191 tokens
209    /// - Gemini embedding-001: 20000 tokens
210    #[inline]
211    pub const fn max_input_tokens(&self) -> usize {
212        match self {
213            // BGE models have 512 token limit
214            EmbeddingModel::BgeSmallEnV15 => 512,
215            EmbeddingModel::BgeBaseEnV15 => 512,
216            EmbeddingModel::BgeLargeEnV15 => 512,
217            // E5 models have 512 token limit
218            EmbeddingModel::MultilingualE5Small => 512,
219            EmbeddingModel::MultilingualE5Base => 512,
220            // MiniLM has a shorter context window
221            EmbeddingModel::AllMiniLmL6V2 => 256,
222            // paraphrase-multilingual-MiniLM max sequence length 128
223            EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => 128,
224            // Qwen3-Embedding supports 32K but we cap at 8192 for practical use
225            EmbeddingModel::Qwen3Embedding0_6B => 8192,
226            EmbeddingModel::Qwen3Embedding4B => 8192,
227            // OpenAI text-embedding-3-small has 8191 token limit
228            EmbeddingModel::TextEmbedding3Small => 8191,
229        }
230    }
231
232    /// **Stable**: query instruction prefix for asymmetric retrieval.
233    ///
234    /// Some models require different text for queries vs documents (asymmetric retrieval).
235    ///
236    /// - **E5 models** (`MultilingualE5Small`, `MultilingualE5Base`): trained with
237    ///   "query: " / "passage: " asymmetric prefixes. Omitting the prefix degrades
238    ///   retrieval quality significantly — the model expects them during fine-tuning.
239    ///
240    /// - **Qwen3-Embedding** models: require an instruction prompt to align the
241    ///   decoder embedding space for retrieval tasks.
242    ///
243    /// - **BGE / MiniLM** models: trained with contrastive objectives on raw text;
244    ///   no prefix needed.
245    ///
246    /// Returns `Some(prefix)` if the query text should be wrapped as
247    /// `"{prefix}{query}"` before embedding. Returns `None` for models that
248    /// don't need instruction prompting.
249    #[inline]
250    pub const fn query_instruction(&self) -> Option<&'static str> {
251        match self {
252            EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => {
253                // E5 asymmetric retrieval: "query: " prefix for queries,
254                // "passage: " prefix for documents (see document_instruction()).
255                Some("query: ")
256            }
257            EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => Some(
258                "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ",
259            ),
260            _ => None,
261        }
262    }
263
264    /// **Stable**: document instruction prefix for asymmetric retrieval.
265    ///
266    /// Some models use different prompts for documents vs queries.
267    /// Returns `Some(prefix)` if the document text should be wrapped as
268    /// `"{prefix}{text}"` before embedding at storage time.
269    #[inline]
270    pub const fn document_instruction(&self) -> Option<&'static str> {
271        None
272    }
273
274    /// **Stable**: get the model identifier (HuggingFace ID or provider/model).
275    #[inline]
276    pub const fn model_id(&self) -> &'static str {
277        match self {
278            EmbeddingModel::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
279            EmbeddingModel::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
280            EmbeddingModel::BgeLargeEnV15 => "BAAI/bge-large-en-v1.5",
281            EmbeddingModel::MultilingualE5Small => "intfloat/multilingual-e5-small",
282            EmbeddingModel::MultilingualE5Base => "intfloat/multilingual-e5-base",
283            EmbeddingModel::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
284            EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
285                "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
286            }
287            EmbeddingModel::Qwen3Embedding0_6B => "Qwen/Qwen3-Embedding-0.6B",
288            EmbeddingModel::Qwen3Embedding4B => "Qwen/Qwen3-Embedding-4B",
289            EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
290        }
291    }
292
293    /// **Stable**: whether this model supports configurable output dimensions (MRL/Matryoshka).
294    #[inline]
295    pub const fn supports_output_dim(&self) -> bool {
296        matches!(
297            self,
298            EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B
299        )
300    }
301
302    /// **Stable**: embedding key revision string for this model family.
303    #[inline]
304    pub const fn key_version(&self) -> &'static str {
305        match self {
306            EmbeddingModel::TextEmbedding3Small
307            | EmbeddingModel::Qwen3Embedding0_6B
308            | EmbeddingModel::Qwen3Embedding4B => "v3",
309            EmbeddingModel::AllMiniLmL6V2 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
310                "v2"
311            }
312            _ => "v1.5",
313        }
314    }
315}
316
317impl std::fmt::Display for EmbeddingModel {
318    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319        match self {
320            EmbeddingModel::BgeSmallEnV15 => write!(f, "bge-small-en-v1.5"),
321            EmbeddingModel::BgeBaseEnV15 => write!(f, "bge-base-en-v1.5"),
322            EmbeddingModel::BgeLargeEnV15 => write!(f, "bge-large-en-v1.5"),
323            EmbeddingModel::MultilingualE5Small => write!(f, "multilingual-e5-small"),
324            EmbeddingModel::MultilingualE5Base => write!(f, "multilingual-e5-base"),
325            EmbeddingModel::Qwen3Embedding0_6B => write!(f, "qwen3-embedding-0.6b"),
326            EmbeddingModel::Qwen3Embedding4B => write!(f, "qwen3-embedding-4b"),
327            EmbeddingModel::AllMiniLmL6V2 => write!(f, "all-minilm-l6-v2"),
328            EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
329                write!(f, "paraphrase-multilingual-minilm-l12-v2")
330            }
331            EmbeddingModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
332        }
333    }
334}
335
336impl std::str::FromStr for EmbeddingModel {
337    type Err = String;
338
339    /// **Stable**: parse model from string (case-insensitive, flexible matching).
340    ///
341    /// Accepts:
342    /// - Display names: "bge-small-en-v1.5"
343    /// - Short names: "bge-small", "small"
344    /// - HuggingFace IDs: "BAAI/bge-small-en-v1.5"
345    fn from_str(s: &str) -> Result<Self, Self::Err> {
346        let lower = s.to_lowercase();
347        let normalized = lower.trim().replace("_", "-").replace("baai/", "");
348
349        match normalized.as_str() {
350            "bge-small-en-v1.5" | "bge-small-en" | "bge-small" | "small" => {
351                Ok(EmbeddingModel::BgeSmallEnV15)
352            }
353            "bge-base-en-v1.5" | "bge-base-en" | "bge-base" | "base" => {
354                Ok(EmbeddingModel::BgeBaseEnV15)
355            }
356            "bge-large-en-v1.5" | "bge-large-en" | "bge-large" | "large" => {
357                Ok(EmbeddingModel::BgeLargeEnV15)
358            }
359            "multilingual-e5-small" | "e5-small" | "intfloat/multilingual-e5-small" => {
360                Ok(EmbeddingModel::MultilingualE5Small)
361            }
362            "multilingual-e5-base" | "e5-base" | "intfloat/multilingual-e5-base" => {
363                Ok(EmbeddingModel::MultilingualE5Base)
364            }
365            "qwen3-embedding-0.6b" | "qwen3-embedding" | "qwen3" | "qwen/qwen3-embedding-0.6b" => {
366                Ok(EmbeddingModel::Qwen3Embedding0_6B)
367            }
368            "qwen3-embedding-4b" | "qwen3-4b" | "qwen/qwen3-embedding-4b" => {
369                Ok(EmbeddingModel::Qwen3Embedding4B)
370            }
371            "all-minilm-l6-v2"
372            | "minilm"
373            | "all-minilm"
374            | "sentence-transformers/all-minilm-l6-v2" => Ok(EmbeddingModel::AllMiniLmL6V2),
375            "paraphrase-multilingual-minilm-l12-v2"
376            | "paraphrase-multilingual"
377            | "multilingual-minilm"
378            | "sentence-transformers/paraphrase-multilingual-minilm-l12-v2" => {
379                Ok(EmbeddingModel::ParaphraseMultilingualMiniLmL12V2)
380            }
381            "text-embedding-3-small" | "openai-small" | "openai" => {
382                Ok(EmbeddingModel::TextEmbedding3Small)
383            }
384            _ => Err(format!(
385                "unknown embedding model: '{s}'. Valid: bge-small-en-v1.5, bge-base-en-v1.5, bge-large-en-v1.5, multilingual-e5-small, multilingual-e5-base, text-embedding-3-small"
386            )),
387        }
388    }
389}
390
391// ============================================================================
392// ModelConfig — runtime MRL dimension configuration
393// ============================================================================
394
395/// Minimum allowed MRL output dimension.
396pub const MIN_MRL_OUTPUT_DIM: usize = 32;
397
398/// Runtime configuration pairing a model with an optional MRL truncation dimension.
399///
400/// Two `ModelConfig` values with different `output_dim` produce different embedding spaces
401/// and must be stored in separate vector index namespaces.
402#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
403pub struct ModelConfig {
404    /// The underlying embedding model.
405    pub model: EmbeddingModel,
406    /// MRL truncation dimension. `None` uses the model's native dimension.
407    #[serde(default)]
408    pub output_dim: Option<usize>,
409}
410
411impl Default for ModelConfig {
412    fn default() -> Self {
413        Self::new(EmbeddingModel::default())
414    }
415}
416
417impl ModelConfig {
418    /// Create a config with no MRL truncation (native model dimension).
419    pub const fn new(model: EmbeddingModel) -> Self {
420        Self {
421            model,
422            output_dim: None,
423        }
424    }
425
426    /// Create and validate a config with an optional MRL truncation dimension.
427    pub fn try_new(
428        model: EmbeddingModel,
429        output_dim: Option<usize>,
430    ) -> std::result::Result<Self, crate::error::EmbedError> {
431        let config = Self { model, output_dim };
432        config.validate()?;
433        Ok(config)
434    }
435
436    /// Validate that the output dimension is consistent with the model.
437    pub fn validate(&self) -> std::result::Result<(), crate::error::EmbedError> {
438        let Some(dim) = self.output_dim else {
439            return Ok(());
440        };
441        if !self.model.supports_output_dim() {
442            return Err(crate::error::EmbedError::InvalidInput(format!(
443                "{} does not support configurable embedding dimensions",
444                self.model
445            )));
446        }
447        if dim < MIN_MRL_OUTPUT_DIM {
448            return Err(crate::error::EmbedError::InvalidInput(format!(
449                "embedding output dimension {dim} is below minimum {MIN_MRL_OUTPUT_DIM}"
450            )));
451        }
452        let native = self.model.native_dimensions();
453        if dim > native {
454            return Err(crate::error::EmbedError::InvalidInput(format!(
455                "embedding output dimension {dim} exceeds native dimension {native} for {}",
456                self.model
457            )));
458        }
459        Ok(())
460    }
461
462    /// Active output dimension: configured truncation if set, otherwise the model's native dimension.
463    pub fn dimensions(&self) -> usize {
464        self.output_dim
465            .unwrap_or_else(|| self.model.native_dimensions())
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    #[test]
474    fn test_default_model() {
475        let model = EmbeddingModel::default();
476        assert_eq!(model, EmbeddingModel::BgeSmallEnV15);
477    }
478
479    #[test]
480    fn test_model_provenance_new() {
481        let provenance = ModelProvenance::new(
482            EmbeddingModel::BgeSmallEnV15,
483            "BAAI/bge-small-en-v1.5".into(),
484        );
485
486        assert_eq!(provenance.model, EmbeddingModel::BgeSmallEnV15);
487        assert_eq!(provenance.model_id, "BAAI/bge-small-en-v1.5");
488        assert!(!provenance.hash.is_empty());
489        assert_eq!(provenance.hash.len(), 64); // blake3 hex is 64 chars
490        assert!(!provenance.loaded_at_iso.is_empty());
491    }
492
493    #[test]
494    fn test_model_provenance_unique_hash() {
495        let p1 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "model1".into());
496        std::thread::sleep(std::time::Duration::from_millis(10)); // Ensure different timestamp
497        let p2 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "model1".into());
498
499        // Different timestamps should produce different hashes
500        assert_ne!(p1.hash, p2.hash);
501    }
502
503    #[test]
504    fn test_model_provenance_dimensions() {
505        let p1 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "small".into());
506        assert_eq!(p1.dimensions(), 384);
507
508        let p2 = ModelProvenance::new(EmbeddingModel::BgeBaseEnV15, "base".into());
509        assert_eq!(p2.dimensions(), 768);
510
511        let p3 = ModelProvenance::new(EmbeddingModel::BgeLargeEnV15, "large".into());
512        assert_eq!(p3.dimensions(), 1024);
513    }
514
515    #[test]
516    fn test_model_provenance_matches_model() {
517        let provenance = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "test".into());
518
519        assert!(provenance.matches_model(EmbeddingModel::BgeSmallEnV15));
520        assert!(!provenance.matches_model(EmbeddingModel::BgeBaseEnV15));
521        assert!(!provenance.matches_model(EmbeddingModel::BgeLargeEnV15));
522    }
523
524    #[test]
525    fn test_model_provenance_serialization() {
526        let provenance = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "test-model".into());
527
528        let json = serde_json::to_string(&provenance).unwrap();
529        // FP-037: EmbeddingModel has #[serde(rename_all = "snake_case")] so
530        // BgeSmallEnV15 serializes as "bge_small_en_v15", not "BgeSmallEnV15".
531        assert!(json.contains("bge_small_en_v15"), "json={json}");
532        assert!(json.contains("test-model"));
533        assert!(json.contains(&provenance.hash));
534
535        let parsed: ModelProvenance = serde_json::from_str(&json).unwrap();
536        assert_eq!(parsed.model, provenance.model);
537        assert_eq!(parsed.model_id, provenance.model_id);
538        assert_eq!(parsed.hash, provenance.hash);
539    }
540
541    #[test]
542    fn test_dimensions() {
543        assert_eq!(EmbeddingModel::BgeSmallEnV15.dimensions(), 384);
544        assert_eq!(EmbeddingModel::BgeBaseEnV15.dimensions(), 768);
545        assert_eq!(EmbeddingModel::BgeLargeEnV15.dimensions(), 1024);
546        assert_eq!(EmbeddingModel::Qwen3Embedding4B.dimensions(), 2560);
547    }
548
549    #[test]
550    fn test_model_config_native_dims() {
551        assert_eq!(
552            ModelConfig::new(EmbeddingModel::Qwen3Embedding4B).dimensions(),
553            2560
554        );
555        assert_eq!(
556            ModelConfig::new(EmbeddingModel::Qwen3Embedding0_6B).dimensions(),
557            1024
558        );
559        assert_eq!(
560            ModelConfig::new(EmbeddingModel::BgeSmallEnV15).dimensions(),
561            384
562        );
563    }
564
565    #[test]
566    fn test_model_config_configured_dim() {
567        let cfg = ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(1024)).unwrap();
568        assert_eq!(cfg.dimensions(), 1024);
569
570        let cfg = ModelConfig::try_new(EmbeddingModel::Qwen3Embedding0_6B, Some(512)).unwrap();
571        assert_eq!(cfg.dimensions(), 512);
572    }
573
574    #[test]
575    fn test_model_config_validation_below_min() {
576        assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(31)).is_err());
577        assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(0)).is_err());
578    }
579
580    #[test]
581    fn test_model_config_validation_above_native() {
582        assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(2561)).is_err());
583        assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding0_6B, Some(1025)).is_err());
584    }
585
586    #[test]
587    fn test_model_config_validation_non_mrl_model() {
588        assert!(ModelConfig::try_new(EmbeddingModel::BgeSmallEnV15, Some(128)).is_err());
589        assert!(ModelConfig::try_new(EmbeddingModel::BgeBaseEnV15, Some(512)).is_err());
590    }
591
592    #[test]
593    fn test_model_config_none_output_dim_ok_for_any_model() {
594        assert!(ModelConfig::try_new(EmbeddingModel::BgeSmallEnV15, None).is_ok());
595        assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, None).is_ok());
596    }
597
598    #[test]
599    fn test_is_local() {
600        assert!(EmbeddingModel::BgeSmallEnV15.is_local());
601        assert!(EmbeddingModel::BgeBaseEnV15.is_local());
602        assert!(EmbeddingModel::BgeLargeEnV15.is_local());
603    }
604
605    #[test]
606    fn test_display() {
607        assert_eq!(
608            EmbeddingModel::BgeSmallEnV15.to_string(),
609            "bge-small-en-v1.5"
610        );
611        assert_eq!(EmbeddingModel::BgeBaseEnV15.to_string(), "bge-base-en-v1.5");
612        assert_eq!(
613            EmbeddingModel::BgeLargeEnV15.to_string(),
614            "bge-large-en-v1.5"
615        );
616    }
617
618    #[test]
619    fn test_serialization_roundtrip() {
620        let model = EmbeddingModel::BgeSmallEnV15;
621        let json = serde_json::to_string(&model).unwrap();
622        let parsed: EmbeddingModel = serde_json::from_str(&json).unwrap();
623        assert_eq!(model, parsed);
624    }
625
626    #[test]
627    fn test_max_input_tokens() {
628        assert_eq!(EmbeddingModel::BgeSmallEnV15.max_input_tokens(), 512);
629        assert_eq!(EmbeddingModel::BgeBaseEnV15.max_input_tokens(), 512);
630        assert_eq!(EmbeddingModel::BgeLargeEnV15.max_input_tokens(), 512);
631    }
632
633    #[test]
634    fn test_from_str_display_names() {
635        assert_eq!(
636            "bge-small-en-v1.5".parse::<EmbeddingModel>().unwrap(),
637            EmbeddingModel::BgeSmallEnV15
638        );
639        assert_eq!(
640            "bge-base-en-v1.5".parse::<EmbeddingModel>().unwrap(),
641            EmbeddingModel::BgeBaseEnV15
642        );
643        assert_eq!(
644            "bge-large-en-v1.5".parse::<EmbeddingModel>().unwrap(),
645            EmbeddingModel::BgeLargeEnV15
646        );
647    }
648
649    #[test]
650    fn test_from_str_short_names() {
651        assert_eq!(
652            "small".parse::<EmbeddingModel>().unwrap(),
653            EmbeddingModel::BgeSmallEnV15
654        );
655        assert_eq!(
656            "bge-base".parse::<EmbeddingModel>().unwrap(),
657            EmbeddingModel::BgeBaseEnV15
658        );
659        assert_eq!(
660            "LARGE".parse::<EmbeddingModel>().unwrap(), // case insensitive
661            EmbeddingModel::BgeLargeEnV15
662        );
663    }
664
665    #[test]
666    fn test_from_str_huggingface_ids() {
667        assert_eq!(
668            "BAAI/bge-small-en-v1.5".parse::<EmbeddingModel>().unwrap(),
669            EmbeddingModel::BgeSmallEnV15
670        );
671    }
672
673    #[test]
674    fn test_from_str_invalid() {
675        let result = "unknown-model".parse::<EmbeddingModel>();
676        assert!(result.is_err());
677        assert!(result.unwrap_err().contains("unknown embedding model"));
678    }
679}