1use std::path::{Path, PathBuf};
40use std::sync::Arc;
41
42use super::embedder::{Embedder, EmbedderError, EmbedderInfo, EmbedderResult};
43use super::fastembed_embedder::FastEmbedder;
44use super::hash_embedder::HashEmbedder;
45
46pub const DEFAULT_EMBEDDER: &str = "minilm";
48
49pub const HASH_EMBEDDER: &str = "hash";
51
52#[derive(Debug, Clone)]
56pub struct RegisteredEmbedder {
57 pub name: &'static str,
59 pub id: &'static str,
61 pub dimension: usize,
63 pub is_semantic: bool,
65 pub description: &'static str,
67 pub requires_model_files: bool,
69 pub release_date: &'static str,
71 pub huggingface_id: &'static str,
73 pub size_bytes: u64,
75 pub is_baseline: bool,
77}
78
79pub const REQUIRED_ONNX_FILES: &[&str] = &[
81 "model.onnx",
82 "tokenizer.json",
83 "config.json",
84 "special_tokens_map.json",
85 "tokenizer_config.json",
86];
87
88pub const BAKEOFF_ELIGIBILITY_CUTOFF: &str = "2025-11-01";
90
91impl RegisteredEmbedder {
92 pub fn is_available(&self, data_dir: &Path) -> bool {
94 if !self.requires_model_files {
95 return true;
96 }
97
98 if let Some(model_dir) = self.model_dir(data_dir) {
99 self.required_files()
100 .iter()
101 .all(|f| model_dir.join(f).is_file())
102 } else {
103 false
104 }
105 }
106
107 pub fn model_dir(&self, data_dir: &Path) -> Option<PathBuf> {
109 if !self.requires_model_files {
110 return None;
111 }
112
113 let dir_name = match self.name {
115 "minilm" => "all-MiniLM-L6-v2",
116 "snowflake-arctic-s" => "snowflake-arctic-embed-s",
117 "nomic-embed" => "nomic-embed-text-v1.5",
118 _ => return None,
119 };
120 Some(data_dir.join("models").join(dir_name))
121 }
122
123 pub fn required_files(&self) -> &'static [&'static str] {
125 if !self.requires_model_files {
126 return &[];
127 }
128 REQUIRED_ONNX_FILES
130 }
131
132 pub fn missing_files(&self, data_dir: &Path) -> Vec<String> {
134 if !self.requires_model_files {
135 return Vec::new();
136 }
137
138 if let Some(model_dir) = self.model_dir(data_dir) {
139 self.required_files()
140 .iter()
141 .filter(|f| !model_dir.join(*f).is_file())
142 .map(|f| (*f).to_string())
143 .collect()
144 } else {
145 Vec::new()
146 }
147 }
148
149 pub fn is_bakeoff_eligible(&self) -> bool {
151 if self.is_baseline {
152 return false;
153 }
154 self.release_date >= BAKEOFF_ELIGIBILITY_CUTOFF
155 }
156
157 pub fn to_model_metadata(&self) -> crate::bakeoff::ModelMetadata {
159 crate::bakeoff::ModelMetadata {
160 id: self.id.to_string(),
161 name: self.name.to_string(),
162 source: self.huggingface_id.to_string(),
163 release_date: self.release_date.to_string(),
164 dimension: Some(self.dimension),
165 size_bytes: if self.size_bytes > 0 {
166 Some(self.size_bytes)
167 } else {
168 None
169 },
170 is_baseline: self.is_baseline,
171 }
172 }
173}
174
175pub static EMBEDDERS: &[RegisteredEmbedder] = &[
180 RegisteredEmbedder {
182 name: "minilm",
183 id: "minilm-384",
184 dimension: 384,
185 is_semantic: true,
186 description: "MiniLM L6 v2 - fast, high-quality semantic embeddings (baseline)",
187 requires_model_files: true,
188 release_date: "2022-08-01",
189 huggingface_id: "sentence-transformers/all-MiniLM-L6-v2",
190 size_bytes: 90_000_000,
191 is_baseline: true,
192 },
193 RegisteredEmbedder {
195 name: "snowflake-arctic-s",
196 id: "snowflake-arctic-s-384",
197 dimension: 384,
198 is_semantic: true,
199 description: "Snowflake Arctic Embed S - small, fast, MiniLM-compatible dimension",
200 requires_model_files: true,
201 release_date: "2025-11-10",
202 huggingface_id: "Snowflake/snowflake-arctic-embed-s",
203 size_bytes: 130_000_000,
204 is_baseline: false,
205 },
206 RegisteredEmbedder {
207 name: "nomic-embed",
208 id: "nomic-embed-768",
209 dimension: 768,
210 is_semantic: true,
211 description: "Nomic Embed Text v1.5 - long context, Matryoshka support",
212 requires_model_files: true,
213 release_date: "2025-11-05",
214 huggingface_id: "nomic-ai/nomic-embed-text-v1.5",
215 size_bytes: 280_000_000,
216 is_baseline: false,
217 },
218 RegisteredEmbedder {
220 name: "hash",
221 id: "fnv1a-384",
222 dimension: 384,
223 is_semantic: false,
224 description: "FNV-1a feature hashing - lexical fallback, always available",
225 requires_model_files: false,
226 release_date: "2020-01-01",
227 huggingface_id: "",
228 size_bytes: 0,
229 is_baseline: true,
230 },
231];
232
233pub struct EmbedderRegistry {
235 data_dir: PathBuf,
236}
237
238impl EmbedderRegistry {
239 pub fn new(data_dir: &Path) -> Self {
241 Self {
242 data_dir: data_dir.to_path_buf(),
243 }
244 }
245
246 pub fn all(&self) -> &'static [RegisteredEmbedder] {
248 EMBEDDERS
249 }
250
251 pub fn available(&self) -> Vec<&'static RegisteredEmbedder> {
253 EMBEDDERS
254 .iter()
255 .filter(|e| e.is_available(&self.data_dir))
256 .collect()
257 }
258
259 pub fn get(&self, name: &str) -> Option<&'static RegisteredEmbedder> {
261 let name_lower = name.to_ascii_lowercase();
262 EMBEDDERS.iter().find(|e| {
263 e.name == name_lower
264 || e.id == name_lower
265 || e.id.starts_with(&format!("{}-", name_lower))
266 })
267 }
268
269 pub fn is_available(&self, name: &str) -> bool {
271 self.get(name)
272 .map(|e| e.is_available(&self.data_dir))
273 .unwrap_or(false)
274 }
275
276 pub fn default_embedder(&self) -> &'static RegisteredEmbedder {
278 self.get(DEFAULT_EMBEDDER)
279 .expect("default embedder must exist")
280 }
281
282 pub fn best_available(&self) -> &'static RegisteredEmbedder {
284 for e in EMBEDDERS.iter().filter(|e| e.is_semantic) {
286 if e.is_available(&self.data_dir) {
287 return e;
288 }
289 }
290 self.get(HASH_EMBEDDER).expect("hash embedder must exist")
292 }
293
294 pub fn bakeoff_eligible(&self) -> Vec<&'static RegisteredEmbedder> {
296 EMBEDDERS
297 .iter()
298 .filter(|e| e.is_bakeoff_eligible())
299 .collect()
300 }
301
302 pub fn available_bakeoff_candidates(&self) -> Vec<&'static RegisteredEmbedder> {
304 EMBEDDERS
305 .iter()
306 .filter(|e| e.is_bakeoff_eligible() && e.is_available(&self.data_dir))
307 .collect()
308 }
309
310 pub fn baseline_embedder(&self) -> Option<&'static RegisteredEmbedder> {
312 EMBEDDERS.iter().find(|e| e.is_baseline)
313 }
314
315 pub fn validate(&self, name: &str) -> EmbedderResult<&'static RegisteredEmbedder> {
319 let embedder = self.get(name).ok_or_else(|| {
320 embedder_unavailable(
321 name,
322 format!(
323 "unknown embedder. Available: {}",
324 EMBEDDERS
325 .iter()
326 .map(|e| e.name)
327 .collect::<Vec<_>>()
328 .join(", ")
329 ),
330 )
331 })?;
332
333 if !embedder.is_available(&self.data_dir) {
334 let missing = embedder.missing_files(&self.data_dir);
335 let model_dir = embedder
336 .model_dir(&self.data_dir)
337 .map(|p| p.display().to_string())
338 .unwrap_or_else(|| "unknown".to_string());
339
340 return Err(embedder_unavailable(
341 name,
342 format!(
343 "missing files in {}: {}. Run 'cass models install' to download.",
344 model_dir,
345 missing.join(", ")
346 ),
347 ));
348 }
349
350 Ok(embedder)
351 }
352}
353
354pub fn get_embedder(data_dir: &Path, name: Option<&str>) -> EmbedderResult<Arc<dyn Embedder>> {
365 let registry = EmbedderRegistry::new(data_dir);
366
367 let embedder_info = match name {
368 Some(n) => registry.validate(n)?,
369 None => registry.best_available(),
370 };
371
372 load_embedder_by_name(data_dir, embedder_info.name)
373}
374
375fn load_embedder_by_name(data_dir: &Path, name: &str) -> EmbedderResult<Arc<dyn Embedder>> {
377 match name {
378 "hash" => {
379 let embedder = HashEmbedder::default();
380 Ok(Arc::new(embedder))
381 }
382 "minilm" | "snowflake-arctic-s" | "nomic-embed" => {
384 let embedder = FastEmbedder::load_by_name(data_dir, name)?;
385 Ok(Arc::new(embedder))
386 }
387 _ => Err(embedder_unavailable(name, "embedder not implemented")),
388 }
389}
390
391fn embedder_unavailable(model: &str, reason: impl Into<String>) -> EmbedderError {
392 EmbedderError::EmbedderUnavailable {
393 model: model.to_string(),
394 reason: reason.into(),
395 }
396}
397
398pub fn get_embedder_info(data_dir: &Path, name: Option<&str>) -> Option<EmbedderInfo> {
400 let registry = EmbedderRegistry::new(data_dir);
401
402 let embedder_info = match name {
403 Some(n) => registry.get(n)?,
404 None => registry.best_available(),
405 };
406
407 Some(EmbedderInfo {
408 id: embedder_info.id.to_string(),
409 dimension: embedder_info.dimension,
410 is_semantic: embedder_info.is_semantic,
411 })
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use tempfile::{TempDir, tempdir};
418
419 fn registry_fixture() -> (TempDir, EmbedderRegistry) {
420 let tmp = tempdir().unwrap();
421 let registry = EmbedderRegistry::new(tmp.path());
422 (tmp, registry)
423 }
424
425 #[test]
426 fn test_registry_all() {
427 let (_tmp, registry) = registry_fixture();
428 assert!(registry.all().len() >= 2);
429 }
430
431 #[test]
432 fn test_registry_get_by_name() {
433 let (_tmp, registry) = registry_fixture();
434
435 let minilm = registry.get("minilm");
436 assert!(minilm.is_some());
437 assert_eq!(minilm.unwrap().dimension, 384);
438
439 let hash = registry.get("hash");
440 assert!(hash.is_some());
441 assert_eq!(hash.unwrap().dimension, 384);
442
443 let unknown = registry.get("unknown");
444 assert!(unknown.is_none());
445 }
446
447 #[test]
448 fn test_registry_get_by_id() {
449 let (_tmp, registry) = registry_fixture();
450
451 let minilm = registry.get("minilm-384");
452 assert!(minilm.is_some());
453 assert_eq!(minilm.unwrap().name, "minilm");
454
455 let hash = registry.get("fnv1a-384");
456 assert!(hash.is_some());
457 assert_eq!(hash.unwrap().name, "hash");
458 }
459
460 #[test]
461 fn test_hash_always_available() {
462 let (_tmp, registry) = registry_fixture();
463
464 assert!(registry.is_available("hash"));
465 let available = registry.available();
466 assert!(available.iter().any(|e| e.name == "hash"));
467 }
468
469 #[test]
470 fn test_minilm_unavailable_without_files() {
471 let (_tmp, registry) = registry_fixture();
472
473 assert!(!registry.is_available("minilm"));
475
476 let result = registry.validate("minilm");
477 assert!(result.is_err());
478 let err = result.unwrap_err();
479 assert!(matches!(err, EmbedderError::EmbedderUnavailable { .. }));
480 }
481
482 #[test]
483 fn test_embedder_unavailable_helper_shape() {
484 let err = embedder_unavailable("demo", "missing model");
485 match err {
486 EmbedderError::EmbedderUnavailable { model, reason } => {
487 assert_eq!(model, "demo");
488 assert_eq!(reason, "missing model");
489 }
490 other => panic!("unexpected error shape: {other:?}"),
491 }
492 }
493
494 #[test]
495 fn test_best_available_fallback() {
496 let (_tmp, registry) = registry_fixture();
497
498 let best = registry.best_available();
500 assert_eq!(best.name, "hash");
501 }
502
503 #[test]
504 fn test_get_embedder_hash() {
505 let tmp = tempdir().unwrap();
506 let embedder = get_embedder(tmp.path(), Some("hash")).unwrap();
507 assert_eq!(embedder.id(), "fnv1a-384");
508 assert!(!embedder.is_semantic());
509 }
510
511 #[test]
512 fn test_get_embedder_default_no_models() {
513 let tmp = tempdir().unwrap();
514 let embedder = get_embedder(tmp.path(), None).unwrap();
516 assert_eq!(embedder.id(), "fnv1a-384");
517 }
518
519 #[test]
520 fn test_validate_unknown_embedder() {
521 let (_tmp, registry) = registry_fixture();
522
523 let result = registry.validate("nonexistent");
524 assert!(result.is_err());
525 let err = result.unwrap_err();
526 assert!(err.to_string().contains("unknown embedder"));
527 assert!(err.to_string().contains("Available:"));
528 }
529
530 #[test]
531 fn test_registered_embedder_missing_files() {
532 let (tmp, registry) = registry_fixture();
533
534 let minilm = registry.get("minilm").unwrap();
535 let missing = minilm.missing_files(tmp.path());
536 assert!(!missing.is_empty());
537 assert!(missing.contains(&"model.onnx".to_string()));
538 }
539
540 #[test]
541 fn test_get_embedder_info() {
542 let tmp = tempdir().unwrap();
543
544 let hash_info = get_embedder_info(tmp.path(), Some("hash")).unwrap();
545 assert_eq!(hash_info.id, "fnv1a-384");
546 assert!(!hash_info.is_semantic);
547
548 let minilm_info = get_embedder_info(tmp.path(), Some("minilm")).unwrap();
549 assert_eq!(minilm_info.id, "minilm-384");
550 assert!(minilm_info.is_semantic);
551 }
552
553 #[test]
556 fn test_bakeoff_eligible_count() {
557 let (_tmp, registry) = registry_fixture();
558
559 let eligible = registry.bakeoff_eligible();
560 assert_eq!(
562 eligible.len(),
563 2,
564 "Expected 2 eligible models, got {}",
565 eligible.len()
566 );
567
568 assert!(
570 !eligible.iter().any(|e| e.name == "minilm"),
571 "minilm should not be in eligible list"
572 );
573
574 assert!(
576 !eligible.iter().any(|e| e.name == "hash"),
577 "hash should not be in eligible list"
578 );
579
580 assert!(
582 eligible.iter().any(|e| e.name == "snowflake-arctic-s"),
583 "snowflake should be in eligible list"
584 );
585 assert!(
586 eligible.iter().any(|e| e.name == "nomic-embed"),
587 "nomic should be in eligible list"
588 );
589 }
590
591 #[test]
592 fn test_baseline_embedder() {
593 let (_tmp, registry) = registry_fixture();
594
595 let baseline = registry.baseline_embedder();
596 assert!(baseline.is_some());
597 let baseline = baseline.unwrap();
598 assert_eq!(baseline.name, "minilm");
599 assert!(baseline.is_baseline);
600 assert!(!baseline.is_bakeoff_eligible());
601 }
602
603 #[test]
604 fn test_bakeoff_eligibility_by_date() {
605 let (_tmp, registry) = registry_fixture();
606
607 let minilm = registry.get("minilm").unwrap();
609 assert!(
610 minilm.release_date < BAKEOFF_ELIGIBILITY_CUTOFF,
611 "minilm should be released before cutoff"
612 );
613
614 for e in registry.bakeoff_eligible() {
616 assert!(
617 e.release_date >= BAKEOFF_ELIGIBILITY_CUTOFF,
618 "{} should be released after cutoff (date: {})",
619 e.name,
620 e.release_date
621 );
622 }
623 }
624
625 #[test]
626 fn test_bakeoff_model_metadata_conversion() {
627 let (_tmp, registry) = registry_fixture();
628
629 let minilm = registry.get("minilm").unwrap();
630 let metadata = minilm.to_model_metadata();
631
632 assert_eq!(metadata.id, "minilm-384");
633 assert_eq!(metadata.name, "minilm");
634 assert!(metadata.source.contains("MiniLM"));
635 assert_eq!(metadata.release_date, "2022-08-01");
636 assert_eq!(metadata.dimension, Some(384));
637 assert!(metadata.is_baseline);
638 assert!(!metadata.is_eligible());
639 }
640
641 #[test]
642 fn test_eligible_embedder_metadata() {
643 let (_tmp, registry) = registry_fixture();
644
645 let snowflake = registry.get("snowflake-arctic-s").unwrap();
647 assert!(snowflake.is_bakeoff_eligible());
648 let metadata = snowflake.to_model_metadata();
649 assert!(!metadata.is_baseline);
650 assert!(metadata.is_eligible());
651 assert_eq!(metadata.dimension, Some(384));
652
653 let nomic = registry.get("nomic-embed").unwrap();
655 assert!(nomic.is_bakeoff_eligible());
656 let metadata = nomic.to_model_metadata();
657 assert!(!metadata.is_baseline);
658 assert!(metadata.is_eligible());
659 assert_eq!(metadata.dimension, Some(768));
660 }
661
662 #[test]
663 fn test_all_embedders_have_required_fields() {
664 for e in EMBEDDERS.iter() {
665 assert!(
667 !e.release_date.is_empty(),
668 "{} should have a release date",
669 e.name
670 );
671
672 if e.is_semantic && e.requires_model_files {
674 assert!(
675 !e.huggingface_id.is_empty(),
676 "{} should have a huggingface_id",
677 e.name
678 );
679 }
680
681 assert!(e.dimension >= 256 && e.dimension <= 2048);
683 }
684 }
685
686 #[test]
687 fn test_model_dir_for_all_embedders() {
688 let tmp = tempdir().unwrap();
689
690 for e in EMBEDDERS.iter() {
691 if e.requires_model_files {
692 let dir = e.model_dir(tmp.path());
693 assert!(dir.is_some(), "{} should have a model directory", e.name);
694 let dir = dir.unwrap();
695 assert!(
696 dir.starts_with(tmp.path().join("models")),
697 "{} model dir should be under models/",
698 e.name
699 );
700 }
701 }
702 }
703}