lean_ctx/core/embeddings/
model_registry.rs1use std::fmt;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
12#[serde(rename_all = "kebab-case")]
13pub enum EmbeddingModel {
14 AllMiniLmL6V2,
17 JinaCodeV2,
20 NomicEmbedV1_5,
23}
24
25impl EmbeddingModel {
26 pub const DEFAULT: Self = Self::AllMiniLmL6V2;
27
28 pub fn config(self) -> ModelConfig {
29 match self {
30 Self::AllMiniLmL6V2 => ModelConfig {
31 model: self,
32 name: "all-MiniLM-L6-v2",
33 hf_repo: "sentence-transformers/all-MiniLM-L6-v2",
34 onnx_path: "onnx/model.onnx",
35 vocab_file: VocabSource::VocabTxt("vocab.txt"),
36 dimensions: 384,
37 max_seq_len: 256,
38 model_min_bytes: 1_000_000,
39 vocab_min_bytes: 100_000,
40 query_prefix: None,
41 document_prefix: None,
42 needs_token_type_ids: true,
43 },
44 Self::JinaCodeV2 => ModelConfig {
45 model: self,
46 name: "jina-embeddings-v2-base-code",
47 hf_repo: "jinaai/jina-embeddings-v2-base-code",
48 onnx_path: "onnx/model.onnx",
49 vocab_file: VocabSource::VocabTxt("vocab.txt"),
50 dimensions: 768,
51 max_seq_len: 512,
52 model_min_bytes: 100_000_000,
53 vocab_min_bytes: 100_000,
54 query_prefix: None,
55 document_prefix: None,
56 needs_token_type_ids: true,
57 },
58 Self::NomicEmbedV1_5 => ModelConfig {
59 model: self,
60 name: "nomic-embed-text-v1.5",
61 hf_repo: "nomic-ai/nomic-embed-text-v1.5",
62 onnx_path: "onnx/model.onnx",
63 vocab_file: VocabSource::VocabTxt("vocab.txt"),
64 dimensions: 768,
65 max_seq_len: 512,
66 model_min_bytes: 100_000_000,
67 vocab_min_bytes: 100_000,
68 query_prefix: Some("search_query: "),
69 document_prefix: Some("search_document: "),
70 needs_token_type_ids: false,
71 },
72 }
73 }
74
75 pub fn from_str_name(s: &str) -> Option<Self> {
77 match s.to_lowercase().replace('_', "-").as_str() {
78 "all-minilm-l6-v2" | "minilm" | "default" => Some(Self::AllMiniLmL6V2),
79 "jina-code-v2" | "jina-embeddings-v2-base-code" | "jina-code" | "jina" => {
80 Some(Self::JinaCodeV2)
81 }
82 "nomic-embed-v1.5" | "nomic-embed-text-v1.5" | "nomic" | "nomic-embed" => {
83 Some(Self::NomicEmbedV1_5)
84 }
85 _ => None,
86 }
87 }
88
89 pub const ALL: &'static [Self] = &[Self::AllMiniLmL6V2, Self::JinaCodeV2, Self::NomicEmbedV1_5];
91
92 pub fn storage_dir_name(self) -> &'static str {
94 match self {
95 Self::AllMiniLmL6V2 => "all-minilm-l6-v2",
96 Self::JinaCodeV2 => "jina-code-v2",
97 Self::NomicEmbedV1_5 => "nomic-embed-v1.5",
98 }
99 }
100}
101
102impl fmt::Display for EmbeddingModel {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 f.write_str(self.config().name)
105 }
106}
107
108#[derive(Debug, Clone, Copy)]
110pub enum VocabSource {
111 VocabTxt(&'static str),
113 TokenizerJson(&'static str),
115}
116
117impl VocabSource {
118 pub fn filename(&self) -> &'static str {
119 match self {
120 Self::VocabTxt(f) | Self::TokenizerJson(f) => f,
121 }
122 }
123
124 pub fn is_wordpiece(&self) -> bool {
125 matches!(self, Self::VocabTxt(_))
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct ModelConfig {
132 pub model: EmbeddingModel,
133 pub name: &'static str,
134 pub hf_repo: &'static str,
135 pub onnx_path: &'static str,
136 pub vocab_file: VocabSource,
137 pub dimensions: usize,
138 pub max_seq_len: usize,
139 pub model_min_bytes: u64,
140 pub vocab_min_bytes: u64,
141 pub query_prefix: Option<&'static str>,
143 pub document_prefix: Option<&'static str>,
145 pub needs_token_type_ids: bool,
148}
149
150impl ModelConfig {
151 pub fn model_url(&self) -> String {
153 format!(
154 "https://huggingface.co/{}/resolve/main/{}",
155 self.hf_repo, self.onnx_path
156 )
157 }
158
159 pub fn vocab_url(&self) -> String {
161 format!(
162 "https://huggingface.co/{}/resolve/main/{}",
163 self.hf_repo,
164 self.vocab_file.filename()
165 )
166 }
167}
168
169pub fn resolve_model() -> EmbeddingModel {
175 let env_val = std::env::var("LEAN_CTX_EMBEDDING_MODEL").ok();
176 let config_val = crate::core::config::Config::load().embedding.model;
177 resolve_model_from(env_val.as_deref(), config_val.as_deref())
178}
179
180fn resolve_model_from(env_val: Option<&str>, config_val: Option<&str>) -> EmbeddingModel {
184 for (source, raw) in [
185 ("LEAN_CTX_EMBEDDING_MODEL", env_val),
186 ("[embedding].model", config_val),
187 ] {
188 let Some(name) = raw.map(str::trim).filter(|s| !s.is_empty()) else {
189 continue;
190 };
191 if let Some(model) = EmbeddingModel::from_str_name(name) {
192 return model;
193 }
194 tracing::warn!(
195 "Unknown embedding model {name:?} from {source}; using {} instead",
196 EmbeddingModel::DEFAULT
197 );
198 }
199 EmbeddingModel::DEFAULT
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn default_model_is_minilm() {
208 assert_eq!(EmbeddingModel::DEFAULT, EmbeddingModel::AllMiniLmL6V2);
209 }
210
211 #[test]
212 fn from_str_name_variants() {
213 assert_eq!(
214 EmbeddingModel::from_str_name("minilm"),
215 Some(EmbeddingModel::AllMiniLmL6V2)
216 );
217 assert_eq!(
218 EmbeddingModel::from_str_name("jina-code-v2"),
219 Some(EmbeddingModel::JinaCodeV2)
220 );
221 assert_eq!(
222 EmbeddingModel::from_str_name("jina-code"),
223 Some(EmbeddingModel::JinaCodeV2)
224 );
225 assert_eq!(
226 EmbeddingModel::from_str_name("jina"),
227 Some(EmbeddingModel::JinaCodeV2)
228 );
229 assert_eq!(
230 EmbeddingModel::from_str_name("nomic-embed-v1.5"),
231 Some(EmbeddingModel::NomicEmbedV1_5)
232 );
233 assert_eq!(
234 EmbeddingModel::from_str_name("nomic"),
235 Some(EmbeddingModel::NomicEmbedV1_5)
236 );
237 assert_eq!(
238 EmbeddingModel::from_str_name("default"),
239 Some(EmbeddingModel::AllMiniLmL6V2)
240 );
241 assert_eq!(EmbeddingModel::from_str_name("unknown"), None);
242 }
243
244 #[test]
245 fn all_models_have_valid_configs() {
246 for model in EmbeddingModel::ALL {
247 let cfg = model.config();
248 assert!(!cfg.name.is_empty());
249 assert!(!cfg.hf_repo.is_empty());
250 assert!(cfg.dimensions > 0);
251 assert!(cfg.max_seq_len > 0);
252 assert!(cfg.model_min_bytes > 0);
253 assert!(cfg.vocab_min_bytes > 0);
254 }
255 }
256
257 #[test]
258 fn model_urls_are_valid() {
259 for model in EmbeddingModel::ALL {
260 let cfg = model.config();
261 let model_url = cfg.model_url();
262 let vocab_url = cfg.vocab_url();
263 assert!(model_url.starts_with("https://huggingface.co/"));
264 assert!(vocab_url.starts_with("https://huggingface.co/"));
265 assert!(model_url.contains("resolve/main"));
266 }
267 }
268
269 #[test]
270 fn storage_dir_names_are_unique() {
271 let names: Vec<_> = EmbeddingModel::ALL
272 .iter()
273 .map(|m| m.storage_dir_name())
274 .collect();
275 let unique: std::collections::HashSet<_> = names.iter().collect();
276 assert_eq!(names.len(), unique.len());
277 }
278
279 #[test]
280 fn display_uses_model_name() {
281 assert_eq!(
282 format!("{}", EmbeddingModel::AllMiniLmL6V2),
283 "all-MiniLM-L6-v2"
284 );
285 assert_eq!(
286 format!("{}", EmbeddingModel::JinaCodeV2),
287 "jina-embeddings-v2-base-code"
288 );
289 }
290
291 #[test]
292 fn resolve_defaults_when_nothing_set() {
293 assert_eq!(resolve_model_from(None, None), EmbeddingModel::DEFAULT);
294 assert_eq!(
295 resolve_model_from(Some(""), Some(" ")),
296 EmbeddingModel::DEFAULT
297 );
298 }
299
300 #[test]
301 fn config_selects_model_when_env_unset() {
302 assert_eq!(
303 resolve_model_from(None, Some("jina-code-v2")),
304 EmbeddingModel::JinaCodeV2
305 );
306 assert_eq!(
307 resolve_model_from(None, Some("nomic")),
308 EmbeddingModel::NomicEmbedV1_5
309 );
310 }
311
312 #[test]
313 fn env_var_overrides_config() {
314 assert_eq!(
315 resolve_model_from(Some("minilm"), Some("nomic")),
316 EmbeddingModel::AllMiniLmL6V2
317 );
318 }
319
320 #[test]
321 fn unknown_name_falls_through_then_defaults() {
322 assert_eq!(
324 resolve_model_from(Some("bogus"), Some("nomic")),
325 EmbeddingModel::NomicEmbedV1_5
326 );
327 assert_eq!(
329 resolve_model_from(Some("bogus"), Some("nope")),
330 EmbeddingModel::DEFAULT
331 );
332 assert_eq!(
334 resolve_model_from(Some(" "), Some("jina")),
335 EmbeddingModel::JinaCodeV2
336 );
337 }
338
339 #[test]
340 fn jina_code_v2_config_details() {
341 let cfg = EmbeddingModel::JinaCodeV2.config();
342 assert_eq!(cfg.dimensions, 768);
343 assert!(cfg.needs_token_type_ids);
344 assert!(cfg.query_prefix.is_none());
345 }
346
347 #[test]
348 fn nomic_has_prefixes() {
349 let cfg = EmbeddingModel::NomicEmbedV1_5.config();
350 assert!(cfg.query_prefix.is_some());
351 assert!(cfg.document_prefix.is_some());
352 assert!(!cfg.needs_token_type_ids);
353 }
354
355 #[test]
356 fn minilm_is_wordpiece() {
357 let cfg = EmbeddingModel::AllMiniLmL6V2.config();
358 assert!(cfg.vocab_file.is_wordpiece());
359 }
360
361 #[test]
362 fn all_current_models_use_wordpiece() {
363 for model in EmbeddingModel::ALL {
364 assert!(
365 model.config().vocab_file.is_wordpiece(),
366 "{model} should use WordPiece vocab.txt"
367 );
368 }
369 }
370}