lean_ctx/core/embeddings/
model_registry.rs1use std::fmt;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
11#[serde(rename_all = "kebab-case")]
12pub enum EmbeddingModel {
13 AllMiniLmL6V2,
16 JinaCodeV2,
19 NomicEmbedV1_5,
22}
23
24impl EmbeddingModel {
25 pub const DEFAULT: Self = Self::AllMiniLmL6V2;
26
27 pub fn config(self) -> ModelConfig {
28 match self {
29 Self::AllMiniLmL6V2 => ModelConfig {
30 model: self,
31 name: "all-MiniLM-L6-v2",
32 hf_repo: "sentence-transformers/all-MiniLM-L6-v2",
33 onnx_path: "onnx/model.onnx",
34 vocab_file: VocabSource::VocabTxt("vocab.txt"),
35 dimensions: 384,
36 max_seq_len: 256,
37 model_min_bytes: 1_000_000,
38 vocab_min_bytes: 100_000,
39 query_prefix: None,
40 document_prefix: None,
41 needs_token_type_ids: true,
42 },
43 Self::JinaCodeV2 => ModelConfig {
44 model: self,
45 name: "jina-embeddings-v2-base-code",
46 hf_repo: "jinaai/jina-embeddings-v2-base-code",
47 onnx_path: "onnx/model.onnx",
48 vocab_file: VocabSource::VocabTxt("vocab.txt"),
49 dimensions: 768,
50 max_seq_len: 512,
51 model_min_bytes: 100_000_000,
52 vocab_min_bytes: 100_000,
53 query_prefix: None,
54 document_prefix: None,
55 needs_token_type_ids: true,
56 },
57 Self::NomicEmbedV1_5 => ModelConfig {
58 model: self,
59 name: "nomic-embed-text-v1.5",
60 hf_repo: "nomic-ai/nomic-embed-text-v1.5",
61 onnx_path: "onnx/model.onnx",
62 vocab_file: VocabSource::VocabTxt("vocab.txt"),
63 dimensions: 768,
64 max_seq_len: 512,
65 model_min_bytes: 100_000_000,
66 vocab_min_bytes: 100_000,
67 query_prefix: Some("search_query: "),
68 document_prefix: Some("search_document: "),
69 needs_token_type_ids: false,
70 },
71 }
72 }
73
74 pub fn from_str_name(s: &str) -> Option<Self> {
76 match s.to_lowercase().replace('_', "-").as_str() {
77 "all-minilm-l6-v2" | "minilm" | "default" => Some(Self::AllMiniLmL6V2),
78 "jina-code-v2" | "jina-embeddings-v2-base-code" | "jina-code" | "jina" => {
79 Some(Self::JinaCodeV2)
80 }
81 "nomic-embed-v1.5" | "nomic-embed-text-v1.5" | "nomic" | "nomic-embed" => {
82 Some(Self::NomicEmbedV1_5)
83 }
84 _ => None,
85 }
86 }
87
88 pub const ALL: &'static [Self] = &[Self::AllMiniLmL6V2, Self::JinaCodeV2, Self::NomicEmbedV1_5];
90
91 pub fn storage_dir_name(self) -> &'static str {
93 match self {
94 Self::AllMiniLmL6V2 => "all-minilm-l6-v2",
95 Self::JinaCodeV2 => "jina-code-v2",
96 Self::NomicEmbedV1_5 => "nomic-embed-v1.5",
97 }
98 }
99}
100
101impl fmt::Display for EmbeddingModel {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 f.write_str(self.config().name)
104 }
105}
106
107#[derive(Debug, Clone, Copy)]
109pub enum VocabSource {
110 VocabTxt(&'static str),
112 TokenizerJson(&'static str),
114}
115
116impl VocabSource {
117 pub fn filename(&self) -> &'static str {
118 match self {
119 Self::VocabTxt(f) | Self::TokenizerJson(f) => f,
120 }
121 }
122
123 pub fn is_wordpiece(&self) -> bool {
124 matches!(self, Self::VocabTxt(_))
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct ModelConfig {
131 pub model: EmbeddingModel,
132 pub name: &'static str,
133 pub hf_repo: &'static str,
134 pub onnx_path: &'static str,
135 pub vocab_file: VocabSource,
136 pub dimensions: usize,
137 pub max_seq_len: usize,
138 pub model_min_bytes: u64,
139 pub vocab_min_bytes: u64,
140 pub query_prefix: Option<&'static str>,
142 pub document_prefix: Option<&'static str>,
144 pub needs_token_type_ids: bool,
147}
148
149impl ModelConfig {
150 pub fn model_url(&self) -> String {
152 format!(
153 "https://huggingface.co/{}/resolve/main/{}",
154 self.hf_repo, self.onnx_path
155 )
156 }
157
158 pub fn vocab_url(&self) -> String {
160 format!(
161 "https://huggingface.co/{}/resolve/main/{}",
162 self.hf_repo,
163 self.vocab_file.filename()
164 )
165 }
166}
167
168pub fn resolve_model() -> EmbeddingModel {
171 if let Ok(val) = std::env::var("LEAN_CTX_EMBEDDING_MODEL") {
172 if let Some(model) = EmbeddingModel::from_str_name(&val) {
173 return model;
174 }
175 tracing::warn!(
176 "Unknown LEAN_CTX_EMBEDDING_MODEL={val:?}, falling back to default ({})",
177 EmbeddingModel::DEFAULT
178 );
179 }
180 EmbeddingModel::DEFAULT
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn default_model_is_minilm() {
189 assert_eq!(EmbeddingModel::DEFAULT, EmbeddingModel::AllMiniLmL6V2);
190 }
191
192 #[test]
193 fn from_str_name_variants() {
194 assert_eq!(
195 EmbeddingModel::from_str_name("minilm"),
196 Some(EmbeddingModel::AllMiniLmL6V2)
197 );
198 assert_eq!(
199 EmbeddingModel::from_str_name("jina-code-v2"),
200 Some(EmbeddingModel::JinaCodeV2)
201 );
202 assert_eq!(
203 EmbeddingModel::from_str_name("jina-code"),
204 Some(EmbeddingModel::JinaCodeV2)
205 );
206 assert_eq!(
207 EmbeddingModel::from_str_name("jina"),
208 Some(EmbeddingModel::JinaCodeV2)
209 );
210 assert_eq!(
211 EmbeddingModel::from_str_name("nomic-embed-v1.5"),
212 Some(EmbeddingModel::NomicEmbedV1_5)
213 );
214 assert_eq!(
215 EmbeddingModel::from_str_name("nomic"),
216 Some(EmbeddingModel::NomicEmbedV1_5)
217 );
218 assert_eq!(
219 EmbeddingModel::from_str_name("default"),
220 Some(EmbeddingModel::AllMiniLmL6V2)
221 );
222 assert_eq!(EmbeddingModel::from_str_name("unknown"), None);
223 }
224
225 #[test]
226 fn all_models_have_valid_configs() {
227 for model in EmbeddingModel::ALL {
228 let cfg = model.config();
229 assert!(!cfg.name.is_empty());
230 assert!(!cfg.hf_repo.is_empty());
231 assert!(cfg.dimensions > 0);
232 assert!(cfg.max_seq_len > 0);
233 assert!(cfg.model_min_bytes > 0);
234 assert!(cfg.vocab_min_bytes > 0);
235 }
236 }
237
238 #[test]
239 fn model_urls_are_valid() {
240 for model in EmbeddingModel::ALL {
241 let cfg = model.config();
242 let model_url = cfg.model_url();
243 let vocab_url = cfg.vocab_url();
244 assert!(model_url.starts_with("https://huggingface.co/"));
245 assert!(vocab_url.starts_with("https://huggingface.co/"));
246 assert!(model_url.contains("resolve/main"));
247 }
248 }
249
250 #[test]
251 fn storage_dir_names_are_unique() {
252 let names: Vec<_> = EmbeddingModel::ALL
253 .iter()
254 .map(|m| m.storage_dir_name())
255 .collect();
256 let unique: std::collections::HashSet<_> = names.iter().collect();
257 assert_eq!(names.len(), unique.len());
258 }
259
260 #[test]
261 fn display_uses_model_name() {
262 assert_eq!(
263 format!("{}", EmbeddingModel::AllMiniLmL6V2),
264 "all-MiniLM-L6-v2"
265 );
266 assert_eq!(
267 format!("{}", EmbeddingModel::JinaCodeV2),
268 "jina-embeddings-v2-base-code"
269 );
270 }
271
272 #[test]
273 fn resolve_model_default() {
274 std::env::remove_var("LEAN_CTX_EMBEDDING_MODEL");
275 assert_eq!(resolve_model(), EmbeddingModel::DEFAULT);
276 }
277
278 #[test]
279 fn jina_code_v2_config_details() {
280 let cfg = EmbeddingModel::JinaCodeV2.config();
281 assert_eq!(cfg.dimensions, 768);
282 assert!(cfg.needs_token_type_ids);
283 assert!(cfg.query_prefix.is_none());
284 }
285
286 #[test]
287 fn nomic_has_prefixes() {
288 let cfg = EmbeddingModel::NomicEmbedV1_5.config();
289 assert!(cfg.query_prefix.is_some());
290 assert!(cfg.document_prefix.is_some());
291 assert!(!cfg.needs_token_type_ids);
292 }
293
294 #[test]
295 fn minilm_is_wordpiece() {
296 let cfg = EmbeddingModel::AllMiniLmL6V2.config();
297 assert!(cfg.vocab_file.is_wordpiece());
298 }
299
300 #[test]
301 fn all_current_models_use_wordpiece() {
302 for model in EmbeddingModel::ALL {
303 assert!(
304 model.config().vocab_file.is_wordpiece(),
305 "{model} should use WordPiece vocab.txt"
306 );
307 }
308 }
309}