Skip to main content

lean_ctx/core/embeddings/
model_registry.rs

1//! Embedding model registry — model configs, selection, and metadata.
2//!
3//! Supports multiple ONNX embedding models with different dimensions,
4//! tokenizers, and download sources. Models are selected via the
5//! `LEAN_CTX_EMBEDDING_MODEL` env var or config file.
6
7use std::fmt;
8
9/// Supported embedding models.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
11#[serde(rename_all = "kebab-case")]
12pub enum EmbeddingModel {
13    /// all-MiniLM-L6-v2 — generic sentence embeddings (384d, ~91MB).
14    /// Default model for backward compatibility.
15    AllMiniLmL6V2,
16    /// jina-embeddings-v2-base-code — code-optimized, 30 languages (768d, ~642MB).
17    /// Best for mixed code + natural language search.
18    JinaCodeV2,
19    /// nomic-embed-text-v1.5 — top MTEB general-purpose (768d, ~547MB).
20    /// Matryoshka representation learning, supports dimension truncation.
21    NomicEmbedV1_5,
22}
23
24impl EmbeddingModel {
25    pub const DEFAULT: Self = Self::AllMiniLmL6V2;
26
27    pub fn config(self) -> ModelConfig {
28        match self {
29            Self::AllMiniLmL6V2 => ModelConfig {
30                model: self,
31                name: "all-MiniLM-L6-v2",
32                hf_repo: "sentence-transformers/all-MiniLM-L6-v2",
33                onnx_path: "onnx/model.onnx",
34                vocab_file: VocabSource::VocabTxt("vocab.txt"),
35                dimensions: 384,
36                max_seq_len: 256,
37                model_min_bytes: 1_000_000,
38                vocab_min_bytes: 100_000,
39                query_prefix: None,
40                document_prefix: None,
41                needs_token_type_ids: true,
42            },
43            Self::JinaCodeV2 => ModelConfig {
44                model: self,
45                name: "jina-embeddings-v2-base-code",
46                hf_repo: "jinaai/jina-embeddings-v2-base-code",
47                onnx_path: "onnx/model.onnx",
48                vocab_file: VocabSource::VocabTxt("vocab.txt"),
49                dimensions: 768,
50                max_seq_len: 512,
51                model_min_bytes: 100_000_000,
52                vocab_min_bytes: 100_000,
53                query_prefix: None,
54                document_prefix: None,
55                needs_token_type_ids: true,
56            },
57            Self::NomicEmbedV1_5 => ModelConfig {
58                model: self,
59                name: "nomic-embed-text-v1.5",
60                hf_repo: "nomic-ai/nomic-embed-text-v1.5",
61                onnx_path: "onnx/model.onnx",
62                vocab_file: VocabSource::VocabTxt("vocab.txt"),
63                dimensions: 768,
64                max_seq_len: 512,
65                model_min_bytes: 100_000_000,
66                vocab_min_bytes: 100_000,
67                query_prefix: Some("search_query: "),
68                document_prefix: Some("search_document: "),
69                needs_token_type_ids: false,
70            },
71        }
72    }
73
74    /// Parse model name from string (env var / config file).
75    pub fn from_str_name(s: &str) -> Option<Self> {
76        match s.to_lowercase().replace('_', "-").as_str() {
77            "all-minilm-l6-v2" | "minilm" | "default" => Some(Self::AllMiniLmL6V2),
78            "jina-code-v2" | "jina-embeddings-v2-base-code" | "jina-code" | "jina" => {
79                Some(Self::JinaCodeV2)
80            }
81            "nomic-embed-v1.5" | "nomic-embed-text-v1.5" | "nomic" | "nomic-embed" => {
82                Some(Self::NomicEmbedV1_5)
83            }
84            _ => None,
85        }
86    }
87
88    /// All available model variants.
89    pub const ALL: &'static [Self] = &[Self::AllMiniLmL6V2, Self::JinaCodeV2, Self::NomicEmbedV1_5];
90
91    /// Unique subdirectory name for model storage isolation.
92    pub fn storage_dir_name(self) -> &'static str {
93        match self {
94            Self::AllMiniLmL6V2 => "all-minilm-l6-v2",
95            Self::JinaCodeV2 => "jina-code-v2",
96            Self::NomicEmbedV1_5 => "nomic-embed-v1.5",
97        }
98    }
99}
100
101impl fmt::Display for EmbeddingModel {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        f.write_str(self.config().name)
104    }
105}
106
107/// Vocabulary/tokenizer source for a model.
108#[derive(Debug, Clone, Copy)]
109pub enum VocabSource {
110    /// Standard BERT vocab.txt (one token per line, WordPiece).
111    VocabTxt(&'static str),
112    /// HuggingFace tokenizer.json (BPE/Unigram via JSON config).
113    TokenizerJson(&'static str),
114}
115
116impl VocabSource {
117    pub fn filename(&self) -> &'static str {
118        match self {
119            Self::VocabTxt(f) | Self::TokenizerJson(f) => f,
120        }
121    }
122
123    pub fn is_wordpiece(&self) -> bool {
124        matches!(self, Self::VocabTxt(_))
125    }
126}
127
128/// Complete configuration for a single embedding model.
129#[derive(Debug, Clone)]
130pub struct ModelConfig {
131    pub model: EmbeddingModel,
132    pub name: &'static str,
133    pub hf_repo: &'static str,
134    pub onnx_path: &'static str,
135    pub vocab_file: VocabSource,
136    pub dimensions: usize,
137    pub max_seq_len: usize,
138    pub model_min_bytes: u64,
139    pub vocab_min_bytes: u64,
140    /// Optional prefix prepended to queries before embedding.
141    pub query_prefix: Option<&'static str>,
142    /// Optional prefix prepended to documents/code before embedding.
143    pub document_prefix: Option<&'static str>,
144    /// Whether the model expects token_type_ids input (BERT-style).
145    /// Some models (e.g. nomic-embed) only use input_ids + attention_mask.
146    pub needs_token_type_ids: bool,
147}
148
149impl ModelConfig {
150    /// Full HuggingFace download URL for the ONNX model file.
151    pub fn model_url(&self) -> String {
152        format!(
153            "https://huggingface.co/{}/resolve/main/{}",
154            self.hf_repo, self.onnx_path
155        )
156    }
157
158    /// Full HuggingFace download URL for the vocabulary/tokenizer file.
159    pub fn vocab_url(&self) -> String {
160        format!(
161            "https://huggingface.co/{}/resolve/main/{}",
162            self.hf_repo,
163            self.vocab_file.filename()
164        )
165    }
166}
167
168/// Resolve which embedding model to use.
169/// Priority: env var > config > default.
170pub fn resolve_model() -> EmbeddingModel {
171    if let Ok(val) = std::env::var("LEAN_CTX_EMBEDDING_MODEL") {
172        if let Some(model) = EmbeddingModel::from_str_name(&val) {
173            return model;
174        }
175        tracing::warn!(
176            "Unknown LEAN_CTX_EMBEDDING_MODEL={val:?}, falling back to default ({})",
177            EmbeddingModel::DEFAULT
178        );
179    }
180    EmbeddingModel::DEFAULT
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn default_model_is_minilm() {
189        assert_eq!(EmbeddingModel::DEFAULT, EmbeddingModel::AllMiniLmL6V2);
190    }
191
192    #[test]
193    fn from_str_name_variants() {
194        assert_eq!(
195            EmbeddingModel::from_str_name("minilm"),
196            Some(EmbeddingModel::AllMiniLmL6V2)
197        );
198        assert_eq!(
199            EmbeddingModel::from_str_name("jina-code-v2"),
200            Some(EmbeddingModel::JinaCodeV2)
201        );
202        assert_eq!(
203            EmbeddingModel::from_str_name("jina-code"),
204            Some(EmbeddingModel::JinaCodeV2)
205        );
206        assert_eq!(
207            EmbeddingModel::from_str_name("jina"),
208            Some(EmbeddingModel::JinaCodeV2)
209        );
210        assert_eq!(
211            EmbeddingModel::from_str_name("nomic-embed-v1.5"),
212            Some(EmbeddingModel::NomicEmbedV1_5)
213        );
214        assert_eq!(
215            EmbeddingModel::from_str_name("nomic"),
216            Some(EmbeddingModel::NomicEmbedV1_5)
217        );
218        assert_eq!(
219            EmbeddingModel::from_str_name("default"),
220            Some(EmbeddingModel::AllMiniLmL6V2)
221        );
222        assert_eq!(EmbeddingModel::from_str_name("unknown"), None);
223    }
224
225    #[test]
226    fn all_models_have_valid_configs() {
227        for model in EmbeddingModel::ALL {
228            let cfg = model.config();
229            assert!(!cfg.name.is_empty());
230            assert!(!cfg.hf_repo.is_empty());
231            assert!(cfg.dimensions > 0);
232            assert!(cfg.max_seq_len > 0);
233            assert!(cfg.model_min_bytes > 0);
234            assert!(cfg.vocab_min_bytes > 0);
235        }
236    }
237
238    #[test]
239    fn model_urls_are_valid() {
240        for model in EmbeddingModel::ALL {
241            let cfg = model.config();
242            let model_url = cfg.model_url();
243            let vocab_url = cfg.vocab_url();
244            assert!(model_url.starts_with("https://huggingface.co/"));
245            assert!(vocab_url.starts_with("https://huggingface.co/"));
246            assert!(model_url.contains("resolve/main"));
247        }
248    }
249
250    #[test]
251    fn storage_dir_names_are_unique() {
252        let names: Vec<_> = EmbeddingModel::ALL
253            .iter()
254            .map(|m| m.storage_dir_name())
255            .collect();
256        let unique: std::collections::HashSet<_> = names.iter().collect();
257        assert_eq!(names.len(), unique.len());
258    }
259
260    #[test]
261    fn display_uses_model_name() {
262        assert_eq!(
263            format!("{}", EmbeddingModel::AllMiniLmL6V2),
264            "all-MiniLM-L6-v2"
265        );
266        assert_eq!(
267            format!("{}", EmbeddingModel::JinaCodeV2),
268            "jina-embeddings-v2-base-code"
269        );
270    }
271
272    #[test]
273    fn resolve_model_default() {
274        std::env::remove_var("LEAN_CTX_EMBEDDING_MODEL");
275        assert_eq!(resolve_model(), EmbeddingModel::DEFAULT);
276    }
277
278    #[test]
279    fn jina_code_v2_config_details() {
280        let cfg = EmbeddingModel::JinaCodeV2.config();
281        assert_eq!(cfg.dimensions, 768);
282        assert!(cfg.needs_token_type_ids);
283        assert!(cfg.query_prefix.is_none());
284    }
285
286    #[test]
287    fn nomic_has_prefixes() {
288        let cfg = EmbeddingModel::NomicEmbedV1_5.config();
289        assert!(cfg.query_prefix.is_some());
290        assert!(cfg.document_prefix.is_some());
291        assert!(!cfg.needs_token_type_ids);
292    }
293
294    #[test]
295    fn minilm_is_wordpiece() {
296        let cfg = EmbeddingModel::AllMiniLmL6V2.config();
297        assert!(cfg.vocab_file.is_wordpiece());
298    }
299
300    #[test]
301    fn all_current_models_use_wordpiece() {
302        for model in EmbeddingModel::ALL {
303            assert!(
304                model.config().vocab_file.is_wordpiece(),
305                "{model} should use WordPiece vocab.txt"
306            );
307        }
308    }
309}