Skip to main content

lean_ctx/core/embeddings/
model_registry.rs

1//! Embedding model registry — model configs, selection, and metadata.
2//!
3//! Supports multiple ONNX embedding models with different dimensions,
4//! tokenizers, and download sources. Models are selected via the
5//! `LEAN_CTX_EMBEDDING_MODEL` env var or the `[embedding].model` key in `config.toml`
6//! (env var wins) — see [`resolve_model`].
7
8use std::fmt;
9
10/// Supported embedding models.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
12#[serde(rename_all = "kebab-case")]
13pub enum EmbeddingModel {
14    /// all-MiniLM-L6-v2 — generic sentence embeddings (384d, ~91MB).
15    /// Default model for backward compatibility.
16    AllMiniLmL6V2,
17    /// jina-embeddings-v2-base-code — code-optimized, 30 languages (768d, ~642MB).
18    /// Best for mixed code + natural language search.
19    JinaCodeV2,
20    /// nomic-embed-text-v1.5 — top MTEB general-purpose (768d, ~547MB).
21    /// Matryoshka representation learning, supports dimension truncation.
22    NomicEmbedV1_5,
23}
24
25impl EmbeddingModel {
26    pub const DEFAULT: Self = Self::AllMiniLmL6V2;
27
28    pub fn config(self) -> ModelConfig {
29        match self {
30            Self::AllMiniLmL6V2 => ModelConfig {
31                model: self,
32                name: "all-MiniLM-L6-v2",
33                hf_repo: "sentence-transformers/all-MiniLM-L6-v2",
34                onnx_path: "onnx/model.onnx",
35                vocab_file: VocabSource::VocabTxt("vocab.txt"),
36                dimensions: 384,
37                max_seq_len: 256,
38                model_min_bytes: 1_000_000,
39                vocab_min_bytes: 100_000,
40                query_prefix: None,
41                document_prefix: None,
42                needs_token_type_ids: true,
43            },
44            Self::JinaCodeV2 => ModelConfig {
45                model: self,
46                name: "jina-embeddings-v2-base-code",
47                hf_repo: "jinaai/jina-embeddings-v2-base-code",
48                onnx_path: "onnx/model.onnx",
49                vocab_file: VocabSource::VocabTxt("vocab.txt"),
50                dimensions: 768,
51                max_seq_len: 512,
52                model_min_bytes: 100_000_000,
53                vocab_min_bytes: 100_000,
54                query_prefix: None,
55                document_prefix: None,
56                needs_token_type_ids: true,
57            },
58            Self::NomicEmbedV1_5 => ModelConfig {
59                model: self,
60                name: "nomic-embed-text-v1.5",
61                hf_repo: "nomic-ai/nomic-embed-text-v1.5",
62                onnx_path: "onnx/model.onnx",
63                vocab_file: VocabSource::VocabTxt("vocab.txt"),
64                dimensions: 768,
65                max_seq_len: 512,
66                model_min_bytes: 100_000_000,
67                vocab_min_bytes: 100_000,
68                query_prefix: Some("search_query: "),
69                document_prefix: Some("search_document: "),
70                needs_token_type_ids: false,
71            },
72        }
73    }
74
75    /// Parse model name from string (env var / config file).
76    pub fn from_str_name(s: &str) -> Option<Self> {
77        match s.to_lowercase().replace('_', "-").as_str() {
78            "all-minilm-l6-v2" | "minilm" | "default" => Some(Self::AllMiniLmL6V2),
79            "jina-code-v2" | "jina-embeddings-v2-base-code" | "jina-code" | "jina" => {
80                Some(Self::JinaCodeV2)
81            }
82            "nomic-embed-v1.5" | "nomic-embed-text-v1.5" | "nomic" | "nomic-embed" => {
83                Some(Self::NomicEmbedV1_5)
84            }
85            _ => None,
86        }
87    }
88
89    /// All available model variants.
90    pub const ALL: &'static [Self] = &[Self::AllMiniLmL6V2, Self::JinaCodeV2, Self::NomicEmbedV1_5];
91
92    /// Unique subdirectory name for model storage isolation.
93    pub fn storage_dir_name(self) -> &'static str {
94        match self {
95            Self::AllMiniLmL6V2 => "all-minilm-l6-v2",
96            Self::JinaCodeV2 => "jina-code-v2",
97            Self::NomicEmbedV1_5 => "nomic-embed-v1.5",
98        }
99    }
100}
101
102impl fmt::Display for EmbeddingModel {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        f.write_str(self.config().name)
105    }
106}
107
108/// Vocabulary/tokenizer source for a model.
109#[derive(Debug, Clone, Copy)]
110pub enum VocabSource {
111    /// Standard BERT vocab.txt (one token per line, WordPiece).
112    VocabTxt(&'static str),
113    /// HuggingFace tokenizer.json (BPE/Unigram via JSON config).
114    TokenizerJson(&'static str),
115}
116
117impl VocabSource {
118    pub fn filename(&self) -> &'static str {
119        match self {
120            Self::VocabTxt(f) | Self::TokenizerJson(f) => f,
121        }
122    }
123
124    pub fn is_wordpiece(&self) -> bool {
125        matches!(self, Self::VocabTxt(_))
126    }
127}
128
129/// Complete configuration for a single embedding model.
130#[derive(Debug, Clone)]
131pub struct ModelConfig {
132    pub model: EmbeddingModel,
133    pub name: &'static str,
134    pub hf_repo: &'static str,
135    pub onnx_path: &'static str,
136    pub vocab_file: VocabSource,
137    pub dimensions: usize,
138    pub max_seq_len: usize,
139    pub model_min_bytes: u64,
140    pub vocab_min_bytes: u64,
141    /// Optional prefix prepended to queries before embedding.
142    pub query_prefix: Option<&'static str>,
143    /// Optional prefix prepended to documents/code before embedding.
144    pub document_prefix: Option<&'static str>,
145    /// Whether the model expects token_type_ids input (BERT-style).
146    /// Some models (e.g. nomic-embed) only use input_ids + attention_mask.
147    pub needs_token_type_ids: bool,
148}
149
150impl ModelConfig {
151    /// Full HuggingFace download URL for the ONNX model file.
152    pub fn model_url(&self) -> String {
153        format!(
154            "https://huggingface.co/{}/resolve/main/{}",
155            self.hf_repo, self.onnx_path
156        )
157    }
158
159    /// Full HuggingFace download URL for the vocabulary/tokenizer file.
160    pub fn vocab_url(&self) -> String {
161        format!(
162            "https://huggingface.co/{}/resolve/main/{}",
163            self.hf_repo,
164            self.vocab_file.filename()
165        )
166    }
167}
168
169/// Resolve which embedding model to use.
170///
171/// Priority: `LEAN_CTX_EMBEDDING_MODEL` env var > `[embedding].model` in `config.toml` >
172/// the default model. An unrecognized name is skipped (with a warning) so a typo in one
173/// source never silently swaps the model — which would otherwise force a full re-index.
174pub fn resolve_model() -> EmbeddingModel {
175    let env_val = std::env::var("LEAN_CTX_EMBEDDING_MODEL").ok();
176    let config_val = crate::core::config::Config::load().embedding.model;
177    resolve_model_from(env_val.as_deref(), config_val.as_deref())
178}
179
180/// Pure model resolution used by [`resolve_model`]; kept separate so the env-var/config
181/// precedence is unit-testable without touching the process environment or the on-disk
182/// `config.toml`.
183fn resolve_model_from(env_val: Option<&str>, config_val: Option<&str>) -> EmbeddingModel {
184    for (source, raw) in [
185        ("LEAN_CTX_EMBEDDING_MODEL", env_val),
186        ("[embedding].model", config_val),
187    ] {
188        let Some(name) = raw.map(str::trim).filter(|s| !s.is_empty()) else {
189            continue;
190        };
191        if let Some(model) = EmbeddingModel::from_str_name(name) {
192            return model;
193        }
194        tracing::warn!(
195            "Unknown embedding model {name:?} from {source}; using {} instead",
196            EmbeddingModel::DEFAULT
197        );
198    }
199    EmbeddingModel::DEFAULT
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn default_model_is_minilm() {
208        assert_eq!(EmbeddingModel::DEFAULT, EmbeddingModel::AllMiniLmL6V2);
209    }
210
211    #[test]
212    fn from_str_name_variants() {
213        assert_eq!(
214            EmbeddingModel::from_str_name("minilm"),
215            Some(EmbeddingModel::AllMiniLmL6V2)
216        );
217        assert_eq!(
218            EmbeddingModel::from_str_name("jina-code-v2"),
219            Some(EmbeddingModel::JinaCodeV2)
220        );
221        assert_eq!(
222            EmbeddingModel::from_str_name("jina-code"),
223            Some(EmbeddingModel::JinaCodeV2)
224        );
225        assert_eq!(
226            EmbeddingModel::from_str_name("jina"),
227            Some(EmbeddingModel::JinaCodeV2)
228        );
229        assert_eq!(
230            EmbeddingModel::from_str_name("nomic-embed-v1.5"),
231            Some(EmbeddingModel::NomicEmbedV1_5)
232        );
233        assert_eq!(
234            EmbeddingModel::from_str_name("nomic"),
235            Some(EmbeddingModel::NomicEmbedV1_5)
236        );
237        assert_eq!(
238            EmbeddingModel::from_str_name("default"),
239            Some(EmbeddingModel::AllMiniLmL6V2)
240        );
241        assert_eq!(EmbeddingModel::from_str_name("unknown"), None);
242    }
243
244    #[test]
245    fn all_models_have_valid_configs() {
246        for model in EmbeddingModel::ALL {
247            let cfg = model.config();
248            assert!(!cfg.name.is_empty());
249            assert!(!cfg.hf_repo.is_empty());
250            assert!(cfg.dimensions > 0);
251            assert!(cfg.max_seq_len > 0);
252            assert!(cfg.model_min_bytes > 0);
253            assert!(cfg.vocab_min_bytes > 0);
254        }
255    }
256
257    #[test]
258    fn model_urls_are_valid() {
259        for model in EmbeddingModel::ALL {
260            let cfg = model.config();
261            let model_url = cfg.model_url();
262            let vocab_url = cfg.vocab_url();
263            assert!(model_url.starts_with("https://huggingface.co/"));
264            assert!(vocab_url.starts_with("https://huggingface.co/"));
265            assert!(model_url.contains("resolve/main"));
266        }
267    }
268
269    #[test]
270    fn storage_dir_names_are_unique() {
271        let names: Vec<_> = EmbeddingModel::ALL
272            .iter()
273            .map(|m| m.storage_dir_name())
274            .collect();
275        let unique: std::collections::HashSet<_> = names.iter().collect();
276        assert_eq!(names.len(), unique.len());
277    }
278
279    #[test]
280    fn display_uses_model_name() {
281        assert_eq!(
282            format!("{}", EmbeddingModel::AllMiniLmL6V2),
283            "all-MiniLM-L6-v2"
284        );
285        assert_eq!(
286            format!("{}", EmbeddingModel::JinaCodeV2),
287            "jina-embeddings-v2-base-code"
288        );
289    }
290
291    #[test]
292    fn resolve_defaults_when_nothing_set() {
293        assert_eq!(resolve_model_from(None, None), EmbeddingModel::DEFAULT);
294        assert_eq!(
295            resolve_model_from(Some(""), Some("   ")),
296            EmbeddingModel::DEFAULT
297        );
298    }
299
300    #[test]
301    fn config_selects_model_when_env_unset() {
302        assert_eq!(
303            resolve_model_from(None, Some("jina-code-v2")),
304            EmbeddingModel::JinaCodeV2
305        );
306        assert_eq!(
307            resolve_model_from(None, Some("nomic")),
308            EmbeddingModel::NomicEmbedV1_5
309        );
310    }
311
312    #[test]
313    fn env_var_overrides_config() {
314        assert_eq!(
315            resolve_model_from(Some("minilm"), Some("nomic")),
316            EmbeddingModel::AllMiniLmL6V2
317        );
318    }
319
320    #[test]
321    fn unknown_name_falls_through_then_defaults() {
322        // Bad env value → valid config value wins.
323        assert_eq!(
324            resolve_model_from(Some("bogus"), Some("nomic")),
325            EmbeddingModel::NomicEmbedV1_5
326        );
327        // Bad everywhere → default (never silently breaks the index).
328        assert_eq!(
329            resolve_model_from(Some("bogus"), Some("nope")),
330            EmbeddingModel::DEFAULT
331        );
332        // Empty/whitespace in the higher-priority source is skipped, not treated as a match.
333        assert_eq!(
334            resolve_model_from(Some("   "), Some("jina")),
335            EmbeddingModel::JinaCodeV2
336        );
337    }
338
339    #[test]
340    fn jina_code_v2_config_details() {
341        let cfg = EmbeddingModel::JinaCodeV2.config();
342        assert_eq!(cfg.dimensions, 768);
343        assert!(cfg.needs_token_type_ids);
344        assert!(cfg.query_prefix.is_none());
345    }
346
347    #[test]
348    fn nomic_has_prefixes() {
349        let cfg = EmbeddingModel::NomicEmbedV1_5.config();
350        assert!(cfg.query_prefix.is_some());
351        assert!(cfg.document_prefix.is_some());
352        assert!(!cfg.needs_token_type_ids);
353    }
354
355    #[test]
356    fn minilm_is_wordpiece() {
357        let cfg = EmbeddingModel::AllMiniLmL6V2.config();
358        assert!(cfg.vocab_file.is_wordpiece());
359    }
360
361    #[test]
362    fn all_current_models_use_wordpiece() {
363        for model in EmbeddingModel::ALL {
364            assert!(
365                model.config().vocab_file.is_wordpiece(),
366                "{model} should use WordPiece vocab.txt"
367            );
368        }
369    }
370}