1use std::path::{Path, PathBuf};
40use std::sync::Arc;
41
42use super::embedder::{Embedder, EmbedderError, EmbedderInfo, EmbedderResult};
43use super::fastembed_embedder::FastEmbedder;
44use super::hash_embedder::HashEmbedder;
45
46pub const DEFAULT_EMBEDDER: &str = "minilm";
48
49pub const HASH_EMBEDDER: &str = "hash";
51
52#[derive(Debug, Clone)]
56pub struct RegisteredEmbedder {
57 pub name: &'static str,
59 pub id: &'static str,
61 pub dimension: usize,
63 pub is_semantic: bool,
65 pub description: &'static str,
67 pub requires_model_files: bool,
69 pub release_date: &'static str,
71 pub huggingface_id: &'static str,
73 pub size_bytes: u64,
75 pub is_baseline: bool,
77}
78
79pub const REQUIRED_ONNX_FILES: &[&str] = &[
81 "model.onnx",
82 "tokenizer.json",
83 "config.json",
84 "special_tokens_map.json",
85 "tokenizer_config.json",
86];
87
88pub const BAKEOFF_ELIGIBILITY_CUTOFF: &str = "2025-11-01";
90
91impl RegisteredEmbedder {
92 pub fn is_available(&self, data_dir: &Path) -> bool {
94 if !self.requires_model_files {
95 return true;
96 }
97
98 if let Some(model_dir) = self.model_dir(data_dir) {
99 self.required_files()
100 .iter()
101 .all(|f| model_dir.join(f).is_file())
102 } else {
103 false
104 }
105 }
106
107 pub fn model_dir(&self, data_dir: &Path) -> Option<PathBuf> {
109 if !self.requires_model_files {
110 return None;
111 }
112
113 FastEmbedder::model_dir_for(data_dir, self.name)
114 }
115
116 pub fn required_files(&self) -> &'static [&'static str] {
118 if !self.requires_model_files {
119 return &[];
120 }
121 REQUIRED_ONNX_FILES
123 }
124
125 pub fn missing_files(&self, data_dir: &Path) -> Vec<String> {
127 if !self.requires_model_files {
128 return Vec::new();
129 }
130
131 if let Some(model_dir) = self.model_dir(data_dir) {
132 self.required_files()
133 .iter()
134 .filter(|f| !model_dir.join(*f).is_file())
135 .map(|f| (*f).to_string())
136 .collect()
137 } else {
138 Vec::new()
139 }
140 }
141
142 pub fn is_bakeoff_eligible(&self) -> bool {
144 if self.is_baseline {
145 return false;
146 }
147 self.release_date >= BAKEOFF_ELIGIBILITY_CUTOFF
148 }
149
150 pub fn to_model_metadata(&self) -> crate::bakeoff::ModelMetadata {
152 crate::bakeoff::ModelMetadata {
153 id: self.id.to_string(),
154 name: self.name.to_string(),
155 source: self.huggingface_id.to_string(),
156 release_date: self.release_date.to_string(),
157 dimension: Some(self.dimension),
158 size_bytes: if self.size_bytes > 0 {
159 Some(self.size_bytes)
160 } else {
161 None
162 },
163 is_baseline: self.is_baseline,
164 }
165 }
166}
167
168pub static EMBEDDERS: &[RegisteredEmbedder] = &[
173 RegisteredEmbedder {
175 name: "minilm",
176 id: "minilm-384",
177 dimension: 384,
178 is_semantic: true,
179 description: "MiniLM L6 v2 - fast, high-quality semantic embeddings (baseline)",
180 requires_model_files: true,
181 release_date: "2022-08-01",
182 huggingface_id: "sentence-transformers/all-MiniLM-L6-v2",
183 size_bytes: 90_000_000,
184 is_baseline: true,
185 },
186 RegisteredEmbedder {
188 name: "snowflake-arctic-s",
189 id: "snowflake-arctic-s-384",
190 dimension: 384,
191 is_semantic: true,
192 description: "Snowflake Arctic Embed S - small, fast, MiniLM-compatible dimension",
193 requires_model_files: true,
194 release_date: "2025-11-10",
195 huggingface_id: "Snowflake/snowflake-arctic-embed-s",
196 size_bytes: 130_000_000,
197 is_baseline: false,
198 },
199 RegisteredEmbedder {
200 name: "nomic-embed",
201 id: "nomic-embed-768",
202 dimension: 768,
203 is_semantic: true,
204 description: "Nomic Embed Text v1.5 - long context, Matryoshka support",
205 requires_model_files: true,
206 release_date: "2025-11-05",
207 huggingface_id: "nomic-ai/nomic-embed-text-v1.5",
208 size_bytes: 280_000_000,
209 is_baseline: false,
210 },
211 RegisteredEmbedder {
213 name: "hash",
214 id: "fnv1a-384",
215 dimension: 384,
216 is_semantic: false,
217 description: "FNV-1a feature hashing - lexical fallback, always available",
218 requires_model_files: false,
219 release_date: "2020-01-01",
220 huggingface_id: "",
221 size_bytes: 0,
222 is_baseline: true,
223 },
224];
225
226pub struct EmbedderRegistry {
228 data_dir: PathBuf,
229}
230
231impl EmbedderRegistry {
232 pub fn new(data_dir: &Path) -> Self {
234 Self {
235 data_dir: data_dir.to_path_buf(),
236 }
237 }
238
239 pub fn all(&self) -> &'static [RegisteredEmbedder] {
241 EMBEDDERS
242 }
243
244 pub fn available(&self) -> Vec<&'static RegisteredEmbedder> {
246 EMBEDDERS
247 .iter()
248 .filter(|e| e.is_available(&self.data_dir))
249 .collect()
250 }
251
252 pub fn get(&self, name: &str) -> Option<&'static RegisteredEmbedder> {
254 let name_lower = FastEmbedder::canonical_name(name)
255 .unwrap_or_else(|| name.trim())
256 .to_ascii_lowercase();
257 EMBEDDERS.iter().find(|e| {
258 e.name == name_lower
259 || e.id == name_lower
260 || e.id.starts_with(&format!("{}-", name_lower))
261 })
262 }
263
264 pub fn is_available(&self, name: &str) -> bool {
266 self.get(name)
267 .map(|e| e.is_available(&self.data_dir))
268 .unwrap_or(false)
269 }
270
271 pub fn default_embedder(&self) -> &'static RegisteredEmbedder {
273 self.get(DEFAULT_EMBEDDER)
274 .expect("default embedder must exist")
275 }
276
277 pub fn best_available(&self) -> &'static RegisteredEmbedder {
279 for e in EMBEDDERS.iter().filter(|e| e.is_semantic) {
281 if e.is_available(&self.data_dir) {
282 return e;
283 }
284 }
285 self.get(HASH_EMBEDDER).expect("hash embedder must exist")
287 }
288
289 pub fn bakeoff_eligible(&self) -> Vec<&'static RegisteredEmbedder> {
291 EMBEDDERS
292 .iter()
293 .filter(|e| e.is_bakeoff_eligible())
294 .collect()
295 }
296
297 pub fn available_bakeoff_candidates(&self) -> Vec<&'static RegisteredEmbedder> {
299 EMBEDDERS
300 .iter()
301 .filter(|e| e.is_bakeoff_eligible() && e.is_available(&self.data_dir))
302 .collect()
303 }
304
305 pub fn baseline_embedder(&self) -> Option<&'static RegisteredEmbedder> {
307 EMBEDDERS.iter().find(|e| e.is_baseline)
308 }
309
310 pub fn validate(&self, name: &str) -> EmbedderResult<&'static RegisteredEmbedder> {
314 let embedder = self.get(name).ok_or_else(|| {
315 embedder_unavailable(
316 name,
317 format!(
318 "unknown embedder. Available: {}",
319 EMBEDDERS
320 .iter()
321 .map(|e| e.name)
322 .collect::<Vec<_>>()
323 .join(", ")
324 ),
325 )
326 })?;
327
328 if !embedder.is_available(&self.data_dir) {
329 let model_dir = FastEmbedder::runtime_model_dir_for(&self.data_dir, embedder.name);
330 let missing = model_dir
331 .as_ref()
332 .map(|dir| {
333 embedder
334 .required_files()
335 .iter()
336 .filter(|file| !dir.join(*file).is_file())
337 .map(|file| (*file).to_string())
338 .collect::<Vec<_>>()
339 })
340 .unwrap_or_else(|| embedder.missing_files(&self.data_dir));
341 if missing.is_empty() {
342 return Ok(embedder);
343 }
344 let model_dir = model_dir
345 .or_else(|| embedder.model_dir(&self.data_dir))
346 .map(|p| p.display().to_string())
347 .unwrap_or_else(|| "unknown".to_string());
348
349 return Err(embedder_unavailable(
350 name,
351 format!(
352 "missing files in {}: {}. Run 'cass models install' to download.",
353 model_dir,
354 missing.join(", ")
355 ),
356 ));
357 }
358
359 Ok(embedder)
360 }
361}
362
363pub fn get_embedder(data_dir: &Path, name: Option<&str>) -> EmbedderResult<Arc<dyn Embedder>> {
374 let registry = EmbedderRegistry::new(data_dir);
375
376 let embedder_info = match name {
377 Some(n) => registry.validate(n)?,
378 None => registry.best_available(),
379 };
380
381 load_embedder_by_name(data_dir, embedder_info.name)
382}
383
384fn load_embedder_by_name(data_dir: &Path, name: &str) -> EmbedderResult<Arc<dyn Embedder>> {
386 match name {
387 "hash" => {
388 let embedder = HashEmbedder::default();
389 Ok(Arc::new(embedder))
390 }
391 "minilm" | "snowflake-arctic-s" | "nomic-embed" => {
393 let embedder = FastEmbedder::load_by_name(data_dir, name)?;
394 Ok(Arc::new(embedder))
395 }
396 _ => Err(embedder_unavailable(name, "embedder not implemented")),
397 }
398}
399
400fn embedder_unavailable(model: &str, reason: impl Into<String>) -> EmbedderError {
401 EmbedderError::EmbedderUnavailable {
402 model: model.to_string(),
403 reason: reason.into(),
404 }
405}
406
407pub fn get_embedder_info(data_dir: &Path, name: Option<&str>) -> Option<EmbedderInfo> {
409 let registry = EmbedderRegistry::new(data_dir);
410
411 let embedder_info = match name {
412 Some(n) => registry.get(n)?,
413 None => registry.best_available(),
414 };
415
416 Some(EmbedderInfo {
417 id: embedder_info.id.to_string(),
418 dimension: embedder_info.dimension,
419 is_semantic: embedder_info.is_semantic,
420 })
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use tempfile::{TempDir, tempdir};
427
428 fn registry_fixture() -> (TempDir, EmbedderRegistry) {
429 let tmp = tempdir().unwrap();
430 let registry = EmbedderRegistry::new(tmp.path());
431 (tmp, registry)
432 }
433
434 #[test]
435 fn test_registry_all() {
436 let (_tmp, registry) = registry_fixture();
437 assert!(registry.all().len() >= 2);
438 }
439
440 #[test]
441 fn test_registry_get_by_name() {
442 let (_tmp, registry) = registry_fixture();
443
444 let minilm = registry.get("minilm");
445 assert!(minilm.is_some());
446 assert_eq!(minilm.unwrap().dimension, 384);
447
448 let hash = registry.get("hash");
449 assert!(hash.is_some());
450 assert_eq!(hash.unwrap().dimension, 384);
451
452 let unknown = registry.get("unknown");
453 assert!(unknown.is_none());
454 }
455
456 #[test]
457 fn test_registry_get_by_id() {
458 let (_tmp, registry) = registry_fixture();
459
460 let minilm = registry.get("minilm-384");
461 assert!(minilm.is_some());
462 assert_eq!(minilm.unwrap().name, "minilm");
463
464 let hash = registry.get("fnv1a-384");
465 assert!(hash.is_some());
466 assert_eq!(hash.unwrap().name, "hash");
467 }
468
469 #[test]
470 fn test_hash_always_available() {
471 let (_tmp, registry) = registry_fixture();
472
473 assert!(registry.is_available("hash"));
474 let available = registry.available();
475 assert!(available.iter().any(|e| e.name == "hash"));
476 }
477
478 #[test]
479 fn test_minilm_unavailable_without_files() {
480 let (_tmp, registry) = registry_fixture();
481
482 assert!(!registry.is_available("minilm"));
484
485 let result = registry.validate("minilm");
486 assert!(result.is_err());
487 let err = result.unwrap_err();
488 assert!(matches!(err, EmbedderError::EmbedderUnavailable { .. }));
489 }
490
491 #[test]
492 fn test_embedder_unavailable_helper_shape() {
493 let err = embedder_unavailable("demo", "missing model");
494 match err {
495 EmbedderError::EmbedderUnavailable { model, reason } => {
496 assert_eq!(model, "demo");
497 assert_eq!(reason, "missing model");
498 }
499 other => panic!("unexpected error shape: {other:?}"),
500 }
501 }
502
503 #[test]
504 fn test_best_available_fallback() {
505 let (_tmp, registry) = registry_fixture();
506
507 let best = registry.best_available();
509 assert_eq!(best.name, "hash");
510 }
511
512 #[test]
513 fn test_get_embedder_hash() {
514 let tmp = tempdir().unwrap();
515 let embedder = get_embedder(tmp.path(), Some("hash")).unwrap();
516 assert_eq!(embedder.id(), "fnv1a-384");
517 assert!(!embedder.is_semantic());
518 }
519
520 #[test]
521 fn test_get_embedder_default_no_models() {
522 let tmp = tempdir().unwrap();
523 let embedder = get_embedder(tmp.path(), None).unwrap();
525 assert_eq!(embedder.id(), "fnv1a-384");
526 }
527
528 #[test]
529 fn test_validate_unknown_embedder() {
530 let (_tmp, registry) = registry_fixture();
531
532 let result = registry.validate("nonexistent");
533 assert!(result.is_err());
534 let err = result.unwrap_err();
535 assert!(err.to_string().contains("unknown embedder"));
536 assert!(err.to_string().contains("Available:"));
537 }
538
539 #[test]
540 fn test_registered_embedder_missing_files() {
541 let (tmp, registry) = registry_fixture();
542
543 let minilm = registry.get("minilm").unwrap();
544 let missing = minilm.missing_files(tmp.path());
545 assert!(!missing.is_empty());
546 assert!(missing.contains(&"model.onnx".to_string()));
547 }
548
549 #[test]
550 fn test_get_embedder_info() {
551 let tmp = tempdir().unwrap();
552
553 let hash_info = get_embedder_info(tmp.path(), Some("hash")).unwrap();
554 assert_eq!(hash_info.id, "fnv1a-384");
555 assert!(!hash_info.is_semantic);
556
557 let minilm_info = get_embedder_info(tmp.path(), Some("minilm")).unwrap();
558 assert_eq!(minilm_info.id, "minilm-384");
559 assert!(minilm_info.is_semantic);
560 }
561
562 #[test]
565 fn test_bakeoff_eligible_count() {
566 let (_tmp, registry) = registry_fixture();
567
568 let eligible = registry.bakeoff_eligible();
569 assert_eq!(
571 eligible.len(),
572 2,
573 "Expected 2 eligible models, got {}",
574 eligible.len()
575 );
576
577 assert!(
579 !eligible.iter().any(|e| e.name == "minilm"),
580 "minilm should not be in eligible list"
581 );
582
583 assert!(
585 !eligible.iter().any(|e| e.name == "hash"),
586 "hash should not be in eligible list"
587 );
588
589 assert!(
591 eligible.iter().any(|e| e.name == "snowflake-arctic-s"),
592 "snowflake should be in eligible list"
593 );
594 assert!(
595 eligible.iter().any(|e| e.name == "nomic-embed"),
596 "nomic should be in eligible list"
597 );
598 }
599
600 #[test]
601 fn test_baseline_embedder() {
602 let (_tmp, registry) = registry_fixture();
603
604 let baseline = registry.baseline_embedder();
605 assert!(baseline.is_some());
606 let baseline = baseline.unwrap();
607 assert_eq!(baseline.name, "minilm");
608 assert!(baseline.is_baseline);
609 assert!(!baseline.is_bakeoff_eligible());
610 }
611
612 #[test]
613 fn test_bakeoff_eligibility_by_date() {
614 let (_tmp, registry) = registry_fixture();
615
616 let minilm = registry.get("minilm").unwrap();
618 assert!(
619 minilm.release_date < BAKEOFF_ELIGIBILITY_CUTOFF,
620 "minilm should be released before cutoff"
621 );
622
623 for e in registry.bakeoff_eligible() {
625 assert!(
626 e.release_date >= BAKEOFF_ELIGIBILITY_CUTOFF,
627 "{} should be released after cutoff (date: {})",
628 e.name,
629 e.release_date
630 );
631 }
632 }
633
634 #[test]
635 fn test_bakeoff_model_metadata_conversion() {
636 let (_tmp, registry) = registry_fixture();
637
638 let minilm = registry.get("minilm").unwrap();
639 let metadata = minilm.to_model_metadata();
640
641 assert_eq!(metadata.id, "minilm-384");
642 assert_eq!(metadata.name, "minilm");
643 assert!(metadata.source.contains("MiniLM"));
644 assert_eq!(metadata.release_date, "2022-08-01");
645 assert_eq!(metadata.dimension, Some(384));
646 assert!(metadata.is_baseline);
647 assert!(!metadata.is_eligible());
648 }
649
650 #[test]
651 fn test_eligible_embedder_metadata() {
652 let (_tmp, registry) = registry_fixture();
653
654 let snowflake = registry.get("snowflake-arctic-s").unwrap();
656 assert!(snowflake.is_bakeoff_eligible());
657 let metadata = snowflake.to_model_metadata();
658 assert!(!metadata.is_baseline);
659 assert!(metadata.is_eligible());
660 assert_eq!(metadata.dimension, Some(384));
661
662 let nomic = registry.get("nomic-embed").unwrap();
664 assert!(nomic.is_bakeoff_eligible());
665 let metadata = nomic.to_model_metadata();
666 assert!(!metadata.is_baseline);
667 assert!(metadata.is_eligible());
668 assert_eq!(metadata.dimension, Some(768));
669 }
670
671 #[test]
672 fn test_all_embedders_have_required_fields() {
673 for e in EMBEDDERS.iter() {
674 assert!(
676 !e.release_date.is_empty(),
677 "{} should have a release date",
678 e.name
679 );
680
681 if e.is_semantic && e.requires_model_files {
683 assert!(
684 !e.huggingface_id.is_empty(),
685 "{} should have a huggingface_id",
686 e.name
687 );
688 }
689
690 assert!(e.dimension >= 256 && e.dimension <= 2048);
692 }
693 }
694
695 #[test]
696 fn test_model_dir_for_all_embedders() {
697 let tmp = tempdir().unwrap();
698
699 for e in EMBEDDERS.iter() {
700 if e.requires_model_files {
701 let dir = e.model_dir(tmp.path());
702 assert!(dir.is_some(), "{} should have a model directory", e.name);
703 let dir = dir.unwrap();
704 assert!(
705 dir.starts_with(tmp.path().join("models")),
706 "{} model dir should be under models/",
707 e.name
708 );
709 }
710 }
711 }
712}