Skip to main content

coding_agent_search/search/
embedder_registry.rs

1//! Embedder registry for model selection (bd-2mbe).
2//!
3//! This module provides a registry of available embedding backends that allows:
4//! - Listing available embedders with metadata
5//! - Selecting embedder by name from CLI/config
6//! - Validating model availability before use
7//! - Providing a sensible default model
8//!
9//! **Note**: The core types ([`RegisteredEmbedder`], [`EmbedderRegistry`]) are
10//! structurally identical to those in `frankensearch_embed::model_registry`.
11//! They are kept locally for now due to build-system sync constraints (rch does
12//! not sync sibling path dependencies).  See frankensearch-embed for the
13//! canonical definitions, which additionally include reranker support, two
14//! additional Potion embedders, and richer directory-resolution helpers.
15//!
16//! # Supported Embedders
17//!
18//! | Name | ID | Dimension | Type | Notes |
19//! |------|-----|-----------|------|-------|
20//! | minilm | minilm-384 | 384 | ML | Default semantic embedder |
21//! | hash | fnv1a-384 | 384 | Hash | Always available fallback |
22//!
23//! # Example
24//!
25//! ```ignore
26//! use crate::search::embedder_registry::{EmbedderRegistry, get_embedder};
27//!
28//! let registry = EmbedderRegistry::new(&data_dir);
29//!
30//! // List available embedders
31//! for info in registry.available() {
32//!     println!("{}: {} ({})", info.name, info.id, info.dimension);
33//! }
34//!
35//! // Get embedder by name
36//! let embedder = get_embedder(&data_dir, Some("minilm"))?;
37//! ```
38
39use std::path::{Path, PathBuf};
40use std::sync::Arc;
41
42use super::embedder::{Embedder, EmbedderError, EmbedderInfo, EmbedderResult};
43use super::fastembed_embedder::FastEmbedder;
44use super::hash_embedder::HashEmbedder;
45
46/// Default embedder name when none specified.
47pub const DEFAULT_EMBEDDER: &str = "minilm";
48
49/// Hash embedder name (always available).
50pub const HASH_EMBEDDER: &str = "hash";
51
52/// Information about a registered embedder.
53///
54/// Structurally identical to `frankensearch_embed::model_registry::RegisteredEmbedder`.
55#[derive(Debug, Clone)]
56pub struct RegisteredEmbedder {
57    /// Short name for CLI/config (e.g., "minilm", "hash").
58    pub name: &'static str,
59    /// Unique embedder ID (e.g., "minilm-384", "fnv1a-384").
60    pub id: &'static str,
61    /// Output dimension.
62    pub dimension: usize,
63    /// Whether this is a semantic (ML) embedder.
64    pub is_semantic: bool,
65    /// Human-readable description.
66    pub description: &'static str,
67    /// Whether the model files are required (false = always available).
68    pub requires_model_files: bool,
69    /// Release/update date (YYYY-MM-DD format) for bake-off eligibility.
70    pub release_date: &'static str,
71    /// HuggingFace model ID for download/reference.
72    pub huggingface_id: &'static str,
73    /// Approximate model size in bytes.
74    pub size_bytes: u64,
75    /// Whether this is a baseline model (not eligible for bake-off).
76    pub is_baseline: bool,
77}
78
79/// Files required for any ONNX-based embedder.
80pub const REQUIRED_ONNX_FILES: &[&str] = &[
81    "model.onnx",
82    "tokenizer.json",
83    "config.json",
84    "special_tokens_map.json",
85    "tokenizer_config.json",
86];
87
88/// Eligibility cutoff for bake-off (models must be released on/after this date).
89pub const BAKEOFF_ELIGIBILITY_CUTOFF: &str = "2025-11-01";
90
91impl RegisteredEmbedder {
92    /// Check if this embedder is available in the given data directory.
93    pub fn is_available(&self, data_dir: &Path) -> bool {
94        if !self.requires_model_files {
95            return true;
96        }
97
98        if let Some(model_dir) = self.model_dir(data_dir) {
99            self.required_files()
100                .iter()
101                .all(|f| model_dir.join(f).is_file())
102        } else {
103            false
104        }
105    }
106
107    /// Get the model directory path for this embedder (if applicable).
108    pub fn model_dir(&self, data_dir: &Path) -> Option<PathBuf> {
109        if !self.requires_model_files {
110            return None;
111        }
112
113        FastEmbedder::model_dir_for(data_dir, self.name)
114    }
115
116    /// Get required model files for this embedder.
117    pub fn required_files(&self) -> &'static [&'static str] {
118        if !self.requires_model_files {
119            return &[];
120        }
121        // All ONNX-based embedders use the same file structure
122        REQUIRED_ONNX_FILES
123    }
124
125    /// Get missing model files for this embedder.
126    pub fn missing_files(&self, data_dir: &Path) -> Vec<String> {
127        if !self.requires_model_files {
128            return Vec::new();
129        }
130
131        if let Some(model_dir) = self.model_dir(data_dir) {
132            self.required_files()
133                .iter()
134                .filter(|f| !model_dir.join(*f).is_file())
135                .map(|f| (*f).to_string())
136                .collect()
137        } else {
138            Vec::new()
139        }
140    }
141
142    /// Check if this embedder is eligible for the bake-off.
143    pub fn is_bakeoff_eligible(&self) -> bool {
144        if self.is_baseline {
145            return false;
146        }
147        self.release_date >= BAKEOFF_ELIGIBILITY_CUTOFF
148    }
149
150    /// Convert to bakeoff ModelMetadata.
151    pub fn to_model_metadata(&self) -> crate::bakeoff::ModelMetadata {
152        crate::bakeoff::ModelMetadata {
153            id: self.id.to_string(),
154            name: self.name.to_string(),
155            source: self.huggingface_id.to_string(),
156            release_date: self.release_date.to_string(),
157            dimension: Some(self.dimension),
158            size_bytes: if self.size_bytes > 0 {
159                Some(self.size_bytes)
160            } else {
161                None
162            },
163            is_baseline: self.is_baseline,
164        }
165    }
166}
167
168/// Static registry of all supported embedders.
169///
170/// Models marked with `bakeoff_eligible: true` are candidates for the embedding bake-off
171/// (released after 2025-11-01). The baseline (minilm) is not eligible but used for comparison.
172pub static EMBEDDERS: &[RegisteredEmbedder] = &[
173    // === Baseline (not eligible for bake-off) ===
174    RegisteredEmbedder {
175        name: "minilm",
176        id: "minilm-384",
177        dimension: 384,
178        is_semantic: true,
179        description: "MiniLM L6 v2 - fast, high-quality semantic embeddings (baseline)",
180        requires_model_files: true,
181        release_date: "2022-08-01",
182        huggingface_id: "sentence-transformers/all-MiniLM-L6-v2",
183        size_bytes: 90_000_000,
184        is_baseline: true,
185    },
186    // === Bake-off Eligible Models (released >= 2025-11-01, verified checksums) ===
187    RegisteredEmbedder {
188        name: "snowflake-arctic-s",
189        id: "snowflake-arctic-s-384",
190        dimension: 384,
191        is_semantic: true,
192        description: "Snowflake Arctic Embed S - small, fast, MiniLM-compatible dimension",
193        requires_model_files: true,
194        release_date: "2025-11-10",
195        huggingface_id: "Snowflake/snowflake-arctic-embed-s",
196        size_bytes: 130_000_000,
197        is_baseline: false,
198    },
199    RegisteredEmbedder {
200        name: "nomic-embed",
201        id: "nomic-embed-768",
202        dimension: 768,
203        is_semantic: true,
204        description: "Nomic Embed Text v1.5 - long context, Matryoshka support",
205        requires_model_files: true,
206        release_date: "2025-11-05",
207        huggingface_id: "nomic-ai/nomic-embed-text-v1.5",
208        size_bytes: 280_000_000,
209        is_baseline: false,
210    },
211    // === Fallback (always available) ===
212    RegisteredEmbedder {
213        name: "hash",
214        id: "fnv1a-384",
215        dimension: 384,
216        is_semantic: false,
217        description: "FNV-1a feature hashing - lexical fallback, always available",
218        requires_model_files: false,
219        release_date: "2020-01-01",
220        huggingface_id: "",
221        size_bytes: 0,
222        is_baseline: true,
223    },
224];
225
226/// Embedder registry with data directory context.
227pub struct EmbedderRegistry {
228    data_dir: PathBuf,
229}
230
231impl EmbedderRegistry {
232    /// Create a new registry bound to the given data directory.
233    pub fn new(data_dir: &Path) -> Self {
234        Self {
235            data_dir: data_dir.to_path_buf(),
236        }
237    }
238
239    /// Get all registered embedders.
240    pub fn all(&self) -> &'static [RegisteredEmbedder] {
241        EMBEDDERS
242    }
243
244    /// Get only available embedders (model files present).
245    pub fn available(&self) -> Vec<&'static RegisteredEmbedder> {
246        EMBEDDERS
247            .iter()
248            .filter(|e| e.is_available(&self.data_dir))
249            .collect()
250    }
251
252    /// Get embedder info by name.
253    pub fn get(&self, name: &str) -> Option<&'static RegisteredEmbedder> {
254        let name_lower = FastEmbedder::canonical_name(name)
255            .unwrap_or_else(|| name.trim())
256            .to_ascii_lowercase();
257        EMBEDDERS.iter().find(|e| {
258            e.name == name_lower
259                || e.id == name_lower
260                || e.id.starts_with(&format!("{}-", name_lower))
261        })
262    }
263
264    /// Check if an embedder is available by name.
265    pub fn is_available(&self, name: &str) -> bool {
266        self.get(name)
267            .map(|e| e.is_available(&self.data_dir))
268            .unwrap_or(false)
269    }
270
271    /// Get the default embedder info.
272    pub fn default_embedder(&self) -> &'static RegisteredEmbedder {
273        self.get(DEFAULT_EMBEDDER)
274            .expect("default embedder must exist")
275    }
276
277    /// Get the best available embedder (ML if available, hash fallback).
278    pub fn best_available(&self) -> &'static RegisteredEmbedder {
279        // Try ML embedders first
280        for e in EMBEDDERS.iter().filter(|e| e.is_semantic) {
281            if e.is_available(&self.data_dir) {
282                return e;
283            }
284        }
285        // Fall back to hash
286        self.get(HASH_EMBEDDER).expect("hash embedder must exist")
287    }
288
289    /// Get all bake-off eligible embedders.
290    pub fn bakeoff_eligible(&self) -> Vec<&'static RegisteredEmbedder> {
291        EMBEDDERS
292            .iter()
293            .filter(|e| e.is_bakeoff_eligible())
294            .collect()
295    }
296
297    /// Get available bake-off eligible embedders (model files present).
298    pub fn available_bakeoff_candidates(&self) -> Vec<&'static RegisteredEmbedder> {
299        EMBEDDERS
300            .iter()
301            .filter(|e| e.is_bakeoff_eligible() && e.is_available(&self.data_dir))
302            .collect()
303    }
304
305    /// Get the baseline embedder for bake-off comparison.
306    pub fn baseline_embedder(&self) -> Option<&'static RegisteredEmbedder> {
307        EMBEDDERS.iter().find(|e| e.is_baseline)
308    }
309
310    /// Validate that an embedder is ready to use.
311    ///
312    /// Returns `Ok(())` if available, or an error with details about what's missing.
313    pub fn validate(&self, name: &str) -> EmbedderResult<&'static RegisteredEmbedder> {
314        let embedder = self.get(name).ok_or_else(|| {
315            embedder_unavailable(
316                name,
317                format!(
318                    "unknown embedder. Available: {}",
319                    EMBEDDERS
320                        .iter()
321                        .map(|e| e.name)
322                        .collect::<Vec<_>>()
323                        .join(", ")
324                ),
325            )
326        })?;
327
328        if !embedder.is_available(&self.data_dir) {
329            let model_dir = FastEmbedder::runtime_model_dir_for(&self.data_dir, embedder.name);
330            let missing = model_dir
331                .as_ref()
332                .map(|dir| {
333                    embedder
334                        .required_files()
335                        .iter()
336                        .filter(|file| !dir.join(*file).is_file())
337                        .map(|file| (*file).to_string())
338                        .collect::<Vec<_>>()
339                })
340                .unwrap_or_else(|| embedder.missing_files(&self.data_dir));
341            if missing.is_empty() {
342                return Ok(embedder);
343            }
344            let model_dir = model_dir
345                .or_else(|| embedder.model_dir(&self.data_dir))
346                .map(|p| p.display().to_string())
347                .unwrap_or_else(|| "unknown".to_string());
348
349            return Err(embedder_unavailable(
350                name,
351                format!(
352                    "missing files in {}: {}. Run 'cass models install' to download.",
353                    model_dir,
354                    missing.join(", ")
355                ),
356            ));
357        }
358
359        Ok(embedder)
360    }
361}
362
363/// Load an embedder by name (or default if None).
364///
365/// # Arguments
366///
367/// * `data_dir` - The cass data directory containing model files.
368/// * `name` - Optional embedder name. If None, uses the best available.
369///
370/// # Returns
371///
372/// An `Arc<dyn Embedder>` ready for use, or an error if unavailable.
373pub fn get_embedder(data_dir: &Path, name: Option<&str>) -> EmbedderResult<Arc<dyn Embedder>> {
374    let registry = EmbedderRegistry::new(data_dir);
375
376    let embedder_info = match name {
377        Some(n) => registry.validate(n)?,
378        None => registry.best_available(),
379    };
380
381    load_embedder_by_name(data_dir, embedder_info.name)
382}
383
384/// Load an embedder by registered name.
385fn load_embedder_by_name(data_dir: &Path, name: &str) -> EmbedderResult<Arc<dyn Embedder>> {
386    match name {
387        "hash" => {
388            let embedder = HashEmbedder::default();
389            Ok(Arc::new(embedder))
390        }
391        // All ONNX-based embedders (baseline and bake-off candidates)
392        "minilm" | "snowflake-arctic-s" | "nomic-embed" => {
393            let embedder = FastEmbedder::load_by_name(data_dir, name)?;
394            Ok(Arc::new(embedder))
395        }
396        _ => Err(embedder_unavailable(name, "embedder not implemented")),
397    }
398}
399
400fn embedder_unavailable(model: &str, reason: impl Into<String>) -> EmbedderError {
401    EmbedderError::EmbedderUnavailable {
402        model: model.to_string(),
403        reason: reason.into(),
404    }
405}
406
407/// Get embedder info for display/logging.
408pub fn get_embedder_info(data_dir: &Path, name: Option<&str>) -> Option<EmbedderInfo> {
409    let registry = EmbedderRegistry::new(data_dir);
410
411    let embedder_info = match name {
412        Some(n) => registry.get(n)?,
413        None => registry.best_available(),
414    };
415
416    Some(EmbedderInfo {
417        id: embedder_info.id.to_string(),
418        dimension: embedder_info.dimension,
419        is_semantic: embedder_info.is_semantic,
420    })
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use tempfile::{TempDir, tempdir};
427
428    fn registry_fixture() -> (TempDir, EmbedderRegistry) {
429        let tmp = tempdir().unwrap();
430        let registry = EmbedderRegistry::new(tmp.path());
431        (tmp, registry)
432    }
433
434    #[test]
435    fn test_registry_all() {
436        let (_tmp, registry) = registry_fixture();
437        assert!(registry.all().len() >= 2);
438    }
439
440    #[test]
441    fn test_registry_get_by_name() {
442        let (_tmp, registry) = registry_fixture();
443
444        let minilm = registry.get("minilm");
445        assert!(minilm.is_some());
446        assert_eq!(minilm.unwrap().dimension, 384);
447
448        let hash = registry.get("hash");
449        assert!(hash.is_some());
450        assert_eq!(hash.unwrap().dimension, 384);
451
452        let unknown = registry.get("unknown");
453        assert!(unknown.is_none());
454    }
455
456    #[test]
457    fn test_registry_get_by_id() {
458        let (_tmp, registry) = registry_fixture();
459
460        let minilm = registry.get("minilm-384");
461        assert!(minilm.is_some());
462        assert_eq!(minilm.unwrap().name, "minilm");
463
464        let hash = registry.get("fnv1a-384");
465        assert!(hash.is_some());
466        assert_eq!(hash.unwrap().name, "hash");
467    }
468
469    #[test]
470    fn test_hash_always_available() {
471        let (_tmp, registry) = registry_fixture();
472
473        assert!(registry.is_available("hash"));
474        let available = registry.available();
475        assert!(available.iter().any(|e| e.name == "hash"));
476    }
477
478    #[test]
479    fn test_minilm_unavailable_without_files() {
480        let (_tmp, registry) = registry_fixture();
481
482        // MiniLM should not be available without model files
483        assert!(!registry.is_available("minilm"));
484
485        let result = registry.validate("minilm");
486        assert!(result.is_err());
487        let err = result.unwrap_err();
488        assert!(matches!(err, EmbedderError::EmbedderUnavailable { .. }));
489    }
490
491    #[test]
492    fn test_embedder_unavailable_helper_shape() {
493        let err = embedder_unavailable("demo", "missing model");
494        match err {
495            EmbedderError::EmbedderUnavailable { model, reason } => {
496                assert_eq!(model, "demo");
497                assert_eq!(reason, "missing model");
498            }
499            other => panic!("unexpected error shape: {other:?}"),
500        }
501    }
502
503    #[test]
504    fn test_best_available_fallback() {
505        let (_tmp, registry) = registry_fixture();
506
507        // Without model files, best_available should return hash
508        let best = registry.best_available();
509        assert_eq!(best.name, "hash");
510    }
511
512    #[test]
513    fn test_get_embedder_hash() {
514        let tmp = tempdir().unwrap();
515        let embedder = get_embedder(tmp.path(), Some("hash")).unwrap();
516        assert_eq!(embedder.id(), "fnv1a-384");
517        assert!(!embedder.is_semantic());
518    }
519
520    #[test]
521    fn test_get_embedder_default_no_models() {
522        let tmp = tempdir().unwrap();
523        // Without model files, should fall back to hash
524        let embedder = get_embedder(tmp.path(), None).unwrap();
525        assert_eq!(embedder.id(), "fnv1a-384");
526    }
527
528    #[test]
529    fn test_validate_unknown_embedder() {
530        let (_tmp, registry) = registry_fixture();
531
532        let result = registry.validate("nonexistent");
533        assert!(result.is_err());
534        let err = result.unwrap_err();
535        assert!(err.to_string().contains("unknown embedder"));
536        assert!(err.to_string().contains("Available:"));
537    }
538
539    #[test]
540    fn test_registered_embedder_missing_files() {
541        let (tmp, registry) = registry_fixture();
542
543        let minilm = registry.get("minilm").unwrap();
544        let missing = minilm.missing_files(tmp.path());
545        assert!(!missing.is_empty());
546        assert!(missing.contains(&"model.onnx".to_string()));
547    }
548
549    #[test]
550    fn test_get_embedder_info() {
551        let tmp = tempdir().unwrap();
552
553        let hash_info = get_embedder_info(tmp.path(), Some("hash")).unwrap();
554        assert_eq!(hash_info.id, "fnv1a-384");
555        assert!(!hash_info.is_semantic);
556
557        let minilm_info = get_embedder_info(tmp.path(), Some("minilm")).unwrap();
558        assert_eq!(minilm_info.id, "minilm-384");
559        assert!(minilm_info.is_semantic);
560    }
561
562    // ==================== Bake-off Tests ====================
563
564    #[test]
565    fn test_bakeoff_eligible_count() {
566        let (_tmp, registry) = registry_fixture();
567
568        let eligible = registry.bakeoff_eligible();
569        // Should have exactly 2 eligible models: snowflake, nomic
570        assert_eq!(
571            eligible.len(),
572            2,
573            "Expected 2 eligible models, got {}",
574            eligible.len()
575        );
576
577        // MiniLM should NOT be in the eligible list (it's the baseline)
578        assert!(
579            !eligible.iter().any(|e| e.name == "minilm"),
580            "minilm should not be in eligible list"
581        );
582
583        // Hash should NOT be in the eligible list (not semantic)
584        assert!(
585            !eligible.iter().any(|e| e.name == "hash"),
586            "hash should not be in eligible list"
587        );
588
589        // Verify the correct models are in the eligible list
590        assert!(
591            eligible.iter().any(|e| e.name == "snowflake-arctic-s"),
592            "snowflake should be in eligible list"
593        );
594        assert!(
595            eligible.iter().any(|e| e.name == "nomic-embed"),
596            "nomic should be in eligible list"
597        );
598    }
599
600    #[test]
601    fn test_baseline_embedder() {
602        let (_tmp, registry) = registry_fixture();
603
604        let baseline = registry.baseline_embedder();
605        assert!(baseline.is_some());
606        let baseline = baseline.unwrap();
607        assert_eq!(baseline.name, "minilm");
608        assert!(baseline.is_baseline);
609        assert!(!baseline.is_bakeoff_eligible());
610    }
611
612    #[test]
613    fn test_bakeoff_eligibility_by_date() {
614        let (_tmp, registry) = registry_fixture();
615
616        // MiniLM was released before cutoff (2022-08-01)
617        let minilm = registry.get("minilm").unwrap();
618        assert!(
619            minilm.release_date < BAKEOFF_ELIGIBILITY_CUTOFF,
620            "minilm should be released before cutoff"
621        );
622
623        // All eligible models should be released after cutoff
624        for e in registry.bakeoff_eligible() {
625            assert!(
626                e.release_date >= BAKEOFF_ELIGIBILITY_CUTOFF,
627                "{} should be released after cutoff (date: {})",
628                e.name,
629                e.release_date
630            );
631        }
632    }
633
634    #[test]
635    fn test_bakeoff_model_metadata_conversion() {
636        let (_tmp, registry) = registry_fixture();
637
638        let minilm = registry.get("minilm").unwrap();
639        let metadata = minilm.to_model_metadata();
640
641        assert_eq!(metadata.id, "minilm-384");
642        assert_eq!(metadata.name, "minilm");
643        assert!(metadata.source.contains("MiniLM"));
644        assert_eq!(metadata.release_date, "2022-08-01");
645        assert_eq!(metadata.dimension, Some(384));
646        assert!(metadata.is_baseline);
647        assert!(!metadata.is_eligible());
648    }
649
650    #[test]
651    fn test_eligible_embedder_metadata() {
652        let (_tmp, registry) = registry_fixture();
653
654        // Check snowflake (eligible candidate, same dimension as minilm)
655        let snowflake = registry.get("snowflake-arctic-s").unwrap();
656        assert!(snowflake.is_bakeoff_eligible());
657        let metadata = snowflake.to_model_metadata();
658        assert!(!metadata.is_baseline);
659        assert!(metadata.is_eligible());
660        assert_eq!(metadata.dimension, Some(384));
661
662        // Check nomic (eligible candidate)
663        let nomic = registry.get("nomic-embed").unwrap();
664        assert!(nomic.is_bakeoff_eligible());
665        let metadata = nomic.to_model_metadata();
666        assert!(!metadata.is_baseline);
667        assert!(metadata.is_eligible());
668        assert_eq!(metadata.dimension, Some(768));
669    }
670
671    #[test]
672    fn test_all_embedders_have_required_fields() {
673        for e in EMBEDDERS.iter() {
674            // All should have valid release dates
675            assert!(
676                !e.release_date.is_empty(),
677                "{} should have a release date",
678                e.name
679            );
680
681            // All semantic embedders should have HuggingFace IDs
682            if e.is_semantic && e.requires_model_files {
683                assert!(
684                    !e.huggingface_id.is_empty(),
685                    "{} should have a huggingface_id",
686                    e.name
687                );
688            }
689
690            // Dimensions should be reasonable
691            assert!(e.dimension >= 256 && e.dimension <= 2048);
692        }
693    }
694
695    #[test]
696    fn test_model_dir_for_all_embedders() {
697        let tmp = tempdir().unwrap();
698
699        for e in EMBEDDERS.iter() {
700            if e.requires_model_files {
701                let dir = e.model_dir(tmp.path());
702                assert!(dir.is_some(), "{} should have a model directory", e.name);
703                let dir = dir.unwrap();
704                assert!(
705                    dir.starts_with(tmp.path().join("models")),
706                    "{} model dir should be under models/",
707                    e.name
708                );
709            }
710        }
711    }
712}