Skip to main content

coding_agent_search/search/
embedder_registry.rs

1//! Embedder registry for model selection (bd-2mbe).
2//!
3//! This module provides a registry of available embedding backends that allows:
4//! - Listing available embedders with metadata
5//! - Selecting embedder by name from CLI/config
6//! - Validating model availability before use
7//! - Providing a sensible default model
8//!
9//! **Note**: The core types ([`RegisteredEmbedder`], [`EmbedderRegistry`]) are
10//! structurally identical to those in `frankensearch_embed::model_registry`.
11//! They are kept locally for now due to build-system sync constraints (rch does
12//! not sync sibling path dependencies).  See frankensearch-embed for the
13//! canonical definitions, which additionally include reranker support, two
14//! additional Potion embedders, and richer directory-resolution helpers.
15//!
16//! # Supported Embedders
17//!
18//! | Name | ID | Dimension | Type | Notes |
19//! |------|-----|-----------|------|-------|
20//! | minilm | minilm-384 | 384 | ML | Default semantic embedder |
21//! | hash | fnv1a-384 | 384 | Hash | Always available fallback |
22//!
23//! # Example
24//!
25//! ```ignore
26//! use crate::search::embedder_registry::{EmbedderRegistry, get_embedder};
27//!
28//! let registry = EmbedderRegistry::new(&data_dir);
29//!
30//! // List available embedders
31//! for info in registry.available() {
32//!     println!("{}: {} ({})", info.name, info.id, info.dimension);
33//! }
34//!
35//! // Get embedder by name
36//! let embedder = get_embedder(&data_dir, Some("minilm"))?;
37//! ```
38
39use std::path::{Path, PathBuf};
40use std::sync::Arc;
41
42use super::embedder::{Embedder, EmbedderError, EmbedderInfo, EmbedderResult};
43use super::fastembed_embedder::FastEmbedder;
44use super::hash_embedder::HashEmbedder;
45
46/// Default embedder name when none specified.
47pub const DEFAULT_EMBEDDER: &str = "minilm";
48
49/// Hash embedder name (always available).
50pub const HASH_EMBEDDER: &str = "hash";
51
52/// Information about a registered embedder.
53///
54/// Structurally identical to `frankensearch_embed::model_registry::RegisteredEmbedder`.
55#[derive(Debug, Clone)]
56pub struct RegisteredEmbedder {
57    /// Short name for CLI/config (e.g., "minilm", "hash").
58    pub name: &'static str,
59    /// Unique embedder ID (e.g., "minilm-384", "fnv1a-384").
60    pub id: &'static str,
61    /// Output dimension.
62    pub dimension: usize,
63    /// Whether this is a semantic (ML) embedder.
64    pub is_semantic: bool,
65    /// Human-readable description.
66    pub description: &'static str,
67    /// Whether the model files are required (false = always available).
68    pub requires_model_files: bool,
69    /// Release/update date (YYYY-MM-DD format) for bake-off eligibility.
70    pub release_date: &'static str,
71    /// HuggingFace model ID for download/reference.
72    pub huggingface_id: &'static str,
73    /// Approximate model size in bytes.
74    pub size_bytes: u64,
75    /// Whether this is a baseline model (not eligible for bake-off).
76    pub is_baseline: bool,
77}
78
79/// Files required for any ONNX-based embedder.
80pub const REQUIRED_ONNX_FILES: &[&str] = &[
81    "model.onnx",
82    "tokenizer.json",
83    "config.json",
84    "special_tokens_map.json",
85    "tokenizer_config.json",
86];
87
88/// Eligibility cutoff for bake-off (models must be released on/after this date).
89pub const BAKEOFF_ELIGIBILITY_CUTOFF: &str = "2025-11-01";
90
91impl RegisteredEmbedder {
92    /// Check if this embedder is available in the given data directory.
93    pub fn is_available(&self, data_dir: &Path) -> bool {
94        if !self.requires_model_files {
95            return true;
96        }
97
98        if let Some(model_dir) = self.model_dir(data_dir) {
99            self.required_files()
100                .iter()
101                .all(|f| model_dir.join(f).is_file())
102        } else {
103            false
104        }
105    }
106
107    /// Get the model directory path for this embedder (if applicable).
108    pub fn model_dir(&self, data_dir: &Path) -> Option<PathBuf> {
109        if !self.requires_model_files {
110            return None;
111        }
112
113        // Map embedder names to their model directory names
114        let dir_name = match self.name {
115            "minilm" => "all-MiniLM-L6-v2",
116            "snowflake-arctic-s" => "snowflake-arctic-embed-s",
117            "nomic-embed" => "nomic-embed-text-v1.5",
118            _ => return None,
119        };
120        Some(data_dir.join("models").join(dir_name))
121    }
122
123    /// Get required model files for this embedder.
124    pub fn required_files(&self) -> &'static [&'static str] {
125        if !self.requires_model_files {
126            return &[];
127        }
128        // All ONNX-based embedders use the same file structure
129        REQUIRED_ONNX_FILES
130    }
131
132    /// Get missing model files for this embedder.
133    pub fn missing_files(&self, data_dir: &Path) -> Vec<String> {
134        if !self.requires_model_files {
135            return Vec::new();
136        }
137
138        if let Some(model_dir) = self.model_dir(data_dir) {
139            self.required_files()
140                .iter()
141                .filter(|f| !model_dir.join(*f).is_file())
142                .map(|f| (*f).to_string())
143                .collect()
144        } else {
145            Vec::new()
146        }
147    }
148
149    /// Check if this embedder is eligible for the bake-off.
150    pub fn is_bakeoff_eligible(&self) -> bool {
151        if self.is_baseline {
152            return false;
153        }
154        self.release_date >= BAKEOFF_ELIGIBILITY_CUTOFF
155    }
156
157    /// Convert to bakeoff ModelMetadata.
158    pub fn to_model_metadata(&self) -> crate::bakeoff::ModelMetadata {
159        crate::bakeoff::ModelMetadata {
160            id: self.id.to_string(),
161            name: self.name.to_string(),
162            source: self.huggingface_id.to_string(),
163            release_date: self.release_date.to_string(),
164            dimension: Some(self.dimension),
165            size_bytes: if self.size_bytes > 0 {
166                Some(self.size_bytes)
167            } else {
168                None
169            },
170            is_baseline: self.is_baseline,
171        }
172    }
173}
174
175/// Static registry of all supported embedders.
176///
177/// Models marked with `bakeoff_eligible: true` are candidates for the embedding bake-off
178/// (released after 2025-11-01). The baseline (minilm) is not eligible but used for comparison.
179pub static EMBEDDERS: &[RegisteredEmbedder] = &[
180    // === Baseline (not eligible for bake-off) ===
181    RegisteredEmbedder {
182        name: "minilm",
183        id: "minilm-384",
184        dimension: 384,
185        is_semantic: true,
186        description: "MiniLM L6 v2 - fast, high-quality semantic embeddings (baseline)",
187        requires_model_files: true,
188        release_date: "2022-08-01",
189        huggingface_id: "sentence-transformers/all-MiniLM-L6-v2",
190        size_bytes: 90_000_000,
191        is_baseline: true,
192    },
193    // === Bake-off Eligible Models (released >= 2025-11-01, verified checksums) ===
194    RegisteredEmbedder {
195        name: "snowflake-arctic-s",
196        id: "snowflake-arctic-s-384",
197        dimension: 384,
198        is_semantic: true,
199        description: "Snowflake Arctic Embed S - small, fast, MiniLM-compatible dimension",
200        requires_model_files: true,
201        release_date: "2025-11-10",
202        huggingface_id: "Snowflake/snowflake-arctic-embed-s",
203        size_bytes: 130_000_000,
204        is_baseline: false,
205    },
206    RegisteredEmbedder {
207        name: "nomic-embed",
208        id: "nomic-embed-768",
209        dimension: 768,
210        is_semantic: true,
211        description: "Nomic Embed Text v1.5 - long context, Matryoshka support",
212        requires_model_files: true,
213        release_date: "2025-11-05",
214        huggingface_id: "nomic-ai/nomic-embed-text-v1.5",
215        size_bytes: 280_000_000,
216        is_baseline: false,
217    },
218    // === Fallback (always available) ===
219    RegisteredEmbedder {
220        name: "hash",
221        id: "fnv1a-384",
222        dimension: 384,
223        is_semantic: false,
224        description: "FNV-1a feature hashing - lexical fallback, always available",
225        requires_model_files: false,
226        release_date: "2020-01-01",
227        huggingface_id: "",
228        size_bytes: 0,
229        is_baseline: true,
230    },
231];
232
233/// Embedder registry with data directory context.
234pub struct EmbedderRegistry {
235    data_dir: PathBuf,
236}
237
238impl EmbedderRegistry {
239    /// Create a new registry bound to the given data directory.
240    pub fn new(data_dir: &Path) -> Self {
241        Self {
242            data_dir: data_dir.to_path_buf(),
243        }
244    }
245
246    /// Get all registered embedders.
247    pub fn all(&self) -> &'static [RegisteredEmbedder] {
248        EMBEDDERS
249    }
250
251    /// Get only available embedders (model files present).
252    pub fn available(&self) -> Vec<&'static RegisteredEmbedder> {
253        EMBEDDERS
254            .iter()
255            .filter(|e| e.is_available(&self.data_dir))
256            .collect()
257    }
258
259    /// Get embedder info by name.
260    pub fn get(&self, name: &str) -> Option<&'static RegisteredEmbedder> {
261        let name_lower = name.to_ascii_lowercase();
262        EMBEDDERS.iter().find(|e| {
263            e.name == name_lower
264                || e.id == name_lower
265                || e.id.starts_with(&format!("{}-", name_lower))
266        })
267    }
268
269    /// Check if an embedder is available by name.
270    pub fn is_available(&self, name: &str) -> bool {
271        self.get(name)
272            .map(|e| e.is_available(&self.data_dir))
273            .unwrap_or(false)
274    }
275
276    /// Get the default embedder info.
277    pub fn default_embedder(&self) -> &'static RegisteredEmbedder {
278        self.get(DEFAULT_EMBEDDER)
279            .expect("default embedder must exist")
280    }
281
282    /// Get the best available embedder (ML if available, hash fallback).
283    pub fn best_available(&self) -> &'static RegisteredEmbedder {
284        // Try ML embedders first
285        for e in EMBEDDERS.iter().filter(|e| e.is_semantic) {
286            if e.is_available(&self.data_dir) {
287                return e;
288            }
289        }
290        // Fall back to hash
291        self.get(HASH_EMBEDDER).expect("hash embedder must exist")
292    }
293
294    /// Get all bake-off eligible embedders.
295    pub fn bakeoff_eligible(&self) -> Vec<&'static RegisteredEmbedder> {
296        EMBEDDERS
297            .iter()
298            .filter(|e| e.is_bakeoff_eligible())
299            .collect()
300    }
301
302    /// Get available bake-off eligible embedders (model files present).
303    pub fn available_bakeoff_candidates(&self) -> Vec<&'static RegisteredEmbedder> {
304        EMBEDDERS
305            .iter()
306            .filter(|e| e.is_bakeoff_eligible() && e.is_available(&self.data_dir))
307            .collect()
308    }
309
310    /// Get the baseline embedder for bake-off comparison.
311    pub fn baseline_embedder(&self) -> Option<&'static RegisteredEmbedder> {
312        EMBEDDERS.iter().find(|e| e.is_baseline)
313    }
314
315    /// Validate that an embedder is ready to use.
316    ///
317    /// Returns `Ok(())` if available, or an error with details about what's missing.
318    pub fn validate(&self, name: &str) -> EmbedderResult<&'static RegisteredEmbedder> {
319        let embedder = self.get(name).ok_or_else(|| {
320            embedder_unavailable(
321                name,
322                format!(
323                    "unknown embedder. Available: {}",
324                    EMBEDDERS
325                        .iter()
326                        .map(|e| e.name)
327                        .collect::<Vec<_>>()
328                        .join(", ")
329                ),
330            )
331        })?;
332
333        if !embedder.is_available(&self.data_dir) {
334            let missing = embedder.missing_files(&self.data_dir);
335            let model_dir = embedder
336                .model_dir(&self.data_dir)
337                .map(|p| p.display().to_string())
338                .unwrap_or_else(|| "unknown".to_string());
339
340            return Err(embedder_unavailable(
341                name,
342                format!(
343                    "missing files in {}: {}. Run 'cass models install' to download.",
344                    model_dir,
345                    missing.join(", ")
346                ),
347            ));
348        }
349
350        Ok(embedder)
351    }
352}
353
354/// Load an embedder by name (or default if None).
355///
356/// # Arguments
357///
358/// * `data_dir` - The cass data directory containing model files.
359/// * `name` - Optional embedder name. If None, uses the best available.
360///
361/// # Returns
362///
363/// An `Arc<dyn Embedder>` ready for use, or an error if unavailable.
364pub fn get_embedder(data_dir: &Path, name: Option<&str>) -> EmbedderResult<Arc<dyn Embedder>> {
365    let registry = EmbedderRegistry::new(data_dir);
366
367    let embedder_info = match name {
368        Some(n) => registry.validate(n)?,
369        None => registry.best_available(),
370    };
371
372    load_embedder_by_name(data_dir, embedder_info.name)
373}
374
375/// Load an embedder by registered name.
376fn load_embedder_by_name(data_dir: &Path, name: &str) -> EmbedderResult<Arc<dyn Embedder>> {
377    match name {
378        "hash" => {
379            let embedder = HashEmbedder::default();
380            Ok(Arc::new(embedder))
381        }
382        // All ONNX-based embedders (baseline and bake-off candidates)
383        "minilm" | "snowflake-arctic-s" | "nomic-embed" => {
384            let embedder = FastEmbedder::load_by_name(data_dir, name)?;
385            Ok(Arc::new(embedder))
386        }
387        _ => Err(embedder_unavailable(name, "embedder not implemented")),
388    }
389}
390
391fn embedder_unavailable(model: &str, reason: impl Into<String>) -> EmbedderError {
392    EmbedderError::EmbedderUnavailable {
393        model: model.to_string(),
394        reason: reason.into(),
395    }
396}
397
398/// Get embedder info for display/logging.
399pub fn get_embedder_info(data_dir: &Path, name: Option<&str>) -> Option<EmbedderInfo> {
400    let registry = EmbedderRegistry::new(data_dir);
401
402    let embedder_info = match name {
403        Some(n) => registry.get(n)?,
404        None => registry.best_available(),
405    };
406
407    Some(EmbedderInfo {
408        id: embedder_info.id.to_string(),
409        dimension: embedder_info.dimension,
410        is_semantic: embedder_info.is_semantic,
411    })
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use tempfile::{TempDir, tempdir};
418
419    fn registry_fixture() -> (TempDir, EmbedderRegistry) {
420        let tmp = tempdir().unwrap();
421        let registry = EmbedderRegistry::new(tmp.path());
422        (tmp, registry)
423    }
424
425    #[test]
426    fn test_registry_all() {
427        let (_tmp, registry) = registry_fixture();
428        assert!(registry.all().len() >= 2);
429    }
430
431    #[test]
432    fn test_registry_get_by_name() {
433        let (_tmp, registry) = registry_fixture();
434
435        let minilm = registry.get("minilm");
436        assert!(minilm.is_some());
437        assert_eq!(minilm.unwrap().dimension, 384);
438
439        let hash = registry.get("hash");
440        assert!(hash.is_some());
441        assert_eq!(hash.unwrap().dimension, 384);
442
443        let unknown = registry.get("unknown");
444        assert!(unknown.is_none());
445    }
446
447    #[test]
448    fn test_registry_get_by_id() {
449        let (_tmp, registry) = registry_fixture();
450
451        let minilm = registry.get("minilm-384");
452        assert!(minilm.is_some());
453        assert_eq!(minilm.unwrap().name, "minilm");
454
455        let hash = registry.get("fnv1a-384");
456        assert!(hash.is_some());
457        assert_eq!(hash.unwrap().name, "hash");
458    }
459
460    #[test]
461    fn test_hash_always_available() {
462        let (_tmp, registry) = registry_fixture();
463
464        assert!(registry.is_available("hash"));
465        let available = registry.available();
466        assert!(available.iter().any(|e| e.name == "hash"));
467    }
468
469    #[test]
470    fn test_minilm_unavailable_without_files() {
471        let (_tmp, registry) = registry_fixture();
472
473        // MiniLM should not be available without model files
474        assert!(!registry.is_available("minilm"));
475
476        let result = registry.validate("minilm");
477        assert!(result.is_err());
478        let err = result.unwrap_err();
479        assert!(matches!(err, EmbedderError::EmbedderUnavailable { .. }));
480    }
481
482    #[test]
483    fn test_embedder_unavailable_helper_shape() {
484        let err = embedder_unavailable("demo", "missing model");
485        match err {
486            EmbedderError::EmbedderUnavailable { model, reason } => {
487                assert_eq!(model, "demo");
488                assert_eq!(reason, "missing model");
489            }
490            other => panic!("unexpected error shape: {other:?}"),
491        }
492    }
493
494    #[test]
495    fn test_best_available_fallback() {
496        let (_tmp, registry) = registry_fixture();
497
498        // Without model files, best_available should return hash
499        let best = registry.best_available();
500        assert_eq!(best.name, "hash");
501    }
502
503    #[test]
504    fn test_get_embedder_hash() {
505        let tmp = tempdir().unwrap();
506        let embedder = get_embedder(tmp.path(), Some("hash")).unwrap();
507        assert_eq!(embedder.id(), "fnv1a-384");
508        assert!(!embedder.is_semantic());
509    }
510
511    #[test]
512    fn test_get_embedder_default_no_models() {
513        let tmp = tempdir().unwrap();
514        // Without model files, should fall back to hash
515        let embedder = get_embedder(tmp.path(), None).unwrap();
516        assert_eq!(embedder.id(), "fnv1a-384");
517    }
518
519    #[test]
520    fn test_validate_unknown_embedder() {
521        let (_tmp, registry) = registry_fixture();
522
523        let result = registry.validate("nonexistent");
524        assert!(result.is_err());
525        let err = result.unwrap_err();
526        assert!(err.to_string().contains("unknown embedder"));
527        assert!(err.to_string().contains("Available:"));
528    }
529
530    #[test]
531    fn test_registered_embedder_missing_files() {
532        let (tmp, registry) = registry_fixture();
533
534        let minilm = registry.get("minilm").unwrap();
535        let missing = minilm.missing_files(tmp.path());
536        assert!(!missing.is_empty());
537        assert!(missing.contains(&"model.onnx".to_string()));
538    }
539
540    #[test]
541    fn test_get_embedder_info() {
542        let tmp = tempdir().unwrap();
543
544        let hash_info = get_embedder_info(tmp.path(), Some("hash")).unwrap();
545        assert_eq!(hash_info.id, "fnv1a-384");
546        assert!(!hash_info.is_semantic);
547
548        let minilm_info = get_embedder_info(tmp.path(), Some("minilm")).unwrap();
549        assert_eq!(minilm_info.id, "minilm-384");
550        assert!(minilm_info.is_semantic);
551    }
552
553    // ==================== Bake-off Tests ====================
554
555    #[test]
556    fn test_bakeoff_eligible_count() {
557        let (_tmp, registry) = registry_fixture();
558
559        let eligible = registry.bakeoff_eligible();
560        // Should have exactly 2 eligible models: snowflake, nomic
561        assert_eq!(
562            eligible.len(),
563            2,
564            "Expected 2 eligible models, got {}",
565            eligible.len()
566        );
567
568        // MiniLM should NOT be in the eligible list (it's the baseline)
569        assert!(
570            !eligible.iter().any(|e| e.name == "minilm"),
571            "minilm should not be in eligible list"
572        );
573
574        // Hash should NOT be in the eligible list (not semantic)
575        assert!(
576            !eligible.iter().any(|e| e.name == "hash"),
577            "hash should not be in eligible list"
578        );
579
580        // Verify the correct models are in the eligible list
581        assert!(
582            eligible.iter().any(|e| e.name == "snowflake-arctic-s"),
583            "snowflake should be in eligible list"
584        );
585        assert!(
586            eligible.iter().any(|e| e.name == "nomic-embed"),
587            "nomic should be in eligible list"
588        );
589    }
590
591    #[test]
592    fn test_baseline_embedder() {
593        let (_tmp, registry) = registry_fixture();
594
595        let baseline = registry.baseline_embedder();
596        assert!(baseline.is_some());
597        let baseline = baseline.unwrap();
598        assert_eq!(baseline.name, "minilm");
599        assert!(baseline.is_baseline);
600        assert!(!baseline.is_bakeoff_eligible());
601    }
602
603    #[test]
604    fn test_bakeoff_eligibility_by_date() {
605        let (_tmp, registry) = registry_fixture();
606
607        // MiniLM was released before cutoff (2022-08-01)
608        let minilm = registry.get("minilm").unwrap();
609        assert!(
610            minilm.release_date < BAKEOFF_ELIGIBILITY_CUTOFF,
611            "minilm should be released before cutoff"
612        );
613
614        // All eligible models should be released after cutoff
615        for e in registry.bakeoff_eligible() {
616            assert!(
617                e.release_date >= BAKEOFF_ELIGIBILITY_CUTOFF,
618                "{} should be released after cutoff (date: {})",
619                e.name,
620                e.release_date
621            );
622        }
623    }
624
625    #[test]
626    fn test_bakeoff_model_metadata_conversion() {
627        let (_tmp, registry) = registry_fixture();
628
629        let minilm = registry.get("minilm").unwrap();
630        let metadata = minilm.to_model_metadata();
631
632        assert_eq!(metadata.id, "minilm-384");
633        assert_eq!(metadata.name, "minilm");
634        assert!(metadata.source.contains("MiniLM"));
635        assert_eq!(metadata.release_date, "2022-08-01");
636        assert_eq!(metadata.dimension, Some(384));
637        assert!(metadata.is_baseline);
638        assert!(!metadata.is_eligible());
639    }
640
641    #[test]
642    fn test_eligible_embedder_metadata() {
643        let (_tmp, registry) = registry_fixture();
644
645        // Check snowflake (eligible candidate, same dimension as minilm)
646        let snowflake = registry.get("snowflake-arctic-s").unwrap();
647        assert!(snowflake.is_bakeoff_eligible());
648        let metadata = snowflake.to_model_metadata();
649        assert!(!metadata.is_baseline);
650        assert!(metadata.is_eligible());
651        assert_eq!(metadata.dimension, Some(384));
652
653        // Check nomic (eligible candidate)
654        let nomic = registry.get("nomic-embed").unwrap();
655        assert!(nomic.is_bakeoff_eligible());
656        let metadata = nomic.to_model_metadata();
657        assert!(!metadata.is_baseline);
658        assert!(metadata.is_eligible());
659        assert_eq!(metadata.dimension, Some(768));
660    }
661
662    #[test]
663    fn test_all_embedders_have_required_fields() {
664        for e in EMBEDDERS.iter() {
665            // All should have valid release dates
666            assert!(
667                !e.release_date.is_empty(),
668                "{} should have a release date",
669                e.name
670            );
671
672            // All semantic embedders should have HuggingFace IDs
673            if e.is_semantic && e.requires_model_files {
674                assert!(
675                    !e.huggingface_id.is_empty(),
676                    "{} should have a huggingface_id",
677                    e.name
678                );
679            }
680
681            // Dimensions should be reasonable
682            assert!(e.dimension >= 256 && e.dimension <= 2048);
683        }
684    }
685
686    #[test]
687    fn test_model_dir_for_all_embedders() {
688        let tmp = tempdir().unwrap();
689
690        for e in EMBEDDERS.iter() {
691            if e.requires_model_files {
692                let dir = e.model_dir(tmp.path());
693                assert!(dir.is_some(), "{} should have a model directory", e.name);
694                let dir = dir.unwrap();
695                assert!(
696                    dir.starts_with(tmp.path().join("models")),
697                    "{} model dir should be under models/",
698                    e.name
699                );
700            }
701        }
702    }
703}