use std::path::{Path, PathBuf};
use std::sync::Arc;
use super::embedder::{Embedder, EmbedderError, EmbedderInfo, EmbedderResult};
use super::fastembed_embedder::FastEmbedder;
use super::hash_embedder::HashEmbedder;
pub const DEFAULT_EMBEDDER: &str = "minilm";
pub const HASH_EMBEDDER: &str = "hash";
#[derive(Debug, Clone)]
pub struct RegisteredEmbedder {
pub name: &'static str,
pub id: &'static str,
pub dimension: usize,
pub is_semantic: bool,
pub description: &'static str,
pub requires_model_files: bool,
pub release_date: &'static str,
pub huggingface_id: &'static str,
pub size_bytes: u64,
pub is_baseline: bool,
}
pub const REQUIRED_ONNX_FILES: &[&str] = &[
"model.onnx",
"tokenizer.json",
"config.json",
"special_tokens_map.json",
"tokenizer_config.json",
];
pub const BAKEOFF_ELIGIBILITY_CUTOFF: &str = "2025-11-01";
impl RegisteredEmbedder {
pub fn is_available(&self, data_dir: &Path) -> bool {
if !self.requires_model_files {
return true;
}
if let Some(model_dir) = self.model_dir(data_dir) {
self.required_files()
.iter()
.all(|f| model_dir.join(f).is_file())
} else {
false
}
}
pub fn model_dir(&self, data_dir: &Path) -> Option<PathBuf> {
if !self.requires_model_files {
return None;
}
let dir_name = match self.name {
"minilm" => "all-MiniLM-L6-v2",
"snowflake-arctic-s" => "snowflake-arctic-embed-s",
"nomic-embed" => "nomic-embed-text-v1.5",
_ => return None,
};
Some(data_dir.join("models").join(dir_name))
}
pub fn required_files(&self) -> &'static [&'static str] {
if !self.requires_model_files {
return &[];
}
REQUIRED_ONNX_FILES
}
pub fn missing_files(&self, data_dir: &Path) -> Vec<String> {
if !self.requires_model_files {
return Vec::new();
}
if let Some(model_dir) = self.model_dir(data_dir) {
self.required_files()
.iter()
.filter(|f| !model_dir.join(*f).is_file())
.map(|f| (*f).to_string())
.collect()
} else {
Vec::new()
}
}
pub fn is_bakeoff_eligible(&self) -> bool {
if self.is_baseline {
return false;
}
self.release_date >= BAKEOFF_ELIGIBILITY_CUTOFF
}
pub fn to_model_metadata(&self) -> crate::bakeoff::ModelMetadata {
crate::bakeoff::ModelMetadata {
id: self.id.to_string(),
name: self.name.to_string(),
source: self.huggingface_id.to_string(),
release_date: self.release_date.to_string(),
dimension: Some(self.dimension),
size_bytes: if self.size_bytes > 0 {
Some(self.size_bytes)
} else {
None
},
is_baseline: self.is_baseline,
}
}
}
pub static EMBEDDERS: &[RegisteredEmbedder] = &[
RegisteredEmbedder {
name: "minilm",
id: "minilm-384",
dimension: 384,
is_semantic: true,
description: "MiniLM L6 v2 - fast, high-quality semantic embeddings (baseline)",
requires_model_files: true,
release_date: "2022-08-01",
huggingface_id: "sentence-transformers/all-MiniLM-L6-v2",
size_bytes: 90_000_000,
is_baseline: true,
},
RegisteredEmbedder {
name: "snowflake-arctic-s",
id: "snowflake-arctic-s-384",
dimension: 384,
is_semantic: true,
description: "Snowflake Arctic Embed S - small, fast, MiniLM-compatible dimension",
requires_model_files: true,
release_date: "2025-11-10",
huggingface_id: "Snowflake/snowflake-arctic-embed-s",
size_bytes: 130_000_000,
is_baseline: false,
},
RegisteredEmbedder {
name: "nomic-embed",
id: "nomic-embed-768",
dimension: 768,
is_semantic: true,
description: "Nomic Embed Text v1.5 - long context, Matryoshka support",
requires_model_files: true,
release_date: "2025-11-05",
huggingface_id: "nomic-ai/nomic-embed-text-v1.5",
size_bytes: 280_000_000,
is_baseline: false,
},
RegisteredEmbedder {
name: "hash",
id: "fnv1a-384",
dimension: 384,
is_semantic: false,
description: "FNV-1a feature hashing - lexical fallback, always available",
requires_model_files: false,
release_date: "2020-01-01",
huggingface_id: "",
size_bytes: 0,
is_baseline: true,
},
];
pub struct EmbedderRegistry {
data_dir: PathBuf,
}
impl EmbedderRegistry {
pub fn new(data_dir: &Path) -> Self {
Self {
data_dir: data_dir.to_path_buf(),
}
}
pub fn all(&self) -> &'static [RegisteredEmbedder] {
EMBEDDERS
}
pub fn available(&self) -> Vec<&'static RegisteredEmbedder> {
EMBEDDERS
.iter()
.filter(|e| e.is_available(&self.data_dir))
.collect()
}
pub fn get(&self, name: &str) -> Option<&'static RegisteredEmbedder> {
let name_lower = name.to_ascii_lowercase();
EMBEDDERS.iter().find(|e| {
e.name == name_lower
|| e.id == name_lower
|| e.id.starts_with(&format!("{}-", name_lower))
})
}
pub fn is_available(&self, name: &str) -> bool {
self.get(name)
.map(|e| e.is_available(&self.data_dir))
.unwrap_or(false)
}
pub fn default_embedder(&self) -> &'static RegisteredEmbedder {
self.get(DEFAULT_EMBEDDER)
.expect("default embedder must exist")
}
pub fn best_available(&self) -> &'static RegisteredEmbedder {
for e in EMBEDDERS.iter().filter(|e| e.is_semantic) {
if e.is_available(&self.data_dir) {
return e;
}
}
self.get(HASH_EMBEDDER).expect("hash embedder must exist")
}
pub fn bakeoff_eligible(&self) -> Vec<&'static RegisteredEmbedder> {
EMBEDDERS
.iter()
.filter(|e| e.is_bakeoff_eligible())
.collect()
}
pub fn available_bakeoff_candidates(&self) -> Vec<&'static RegisteredEmbedder> {
EMBEDDERS
.iter()
.filter(|e| e.is_bakeoff_eligible() && e.is_available(&self.data_dir))
.collect()
}
pub fn baseline_embedder(&self) -> Option<&'static RegisteredEmbedder> {
EMBEDDERS.iter().find(|e| e.is_baseline)
}
pub fn validate(&self, name: &str) -> EmbedderResult<&'static RegisteredEmbedder> {
let embedder = self.get(name).ok_or_else(|| {
embedder_unavailable(
name,
format!(
"unknown embedder. Available: {}",
EMBEDDERS
.iter()
.map(|e| e.name)
.collect::<Vec<_>>()
.join(", ")
),
)
})?;
if !embedder.is_available(&self.data_dir) {
let missing = embedder.missing_files(&self.data_dir);
let model_dir = embedder
.model_dir(&self.data_dir)
.map(|p| p.display().to_string())
.unwrap_or_else(|| "unknown".to_string());
return Err(embedder_unavailable(
name,
format!(
"missing files in {}: {}. Run 'cass models install' to download.",
model_dir,
missing.join(", ")
),
));
}
Ok(embedder)
}
}
pub fn get_embedder(data_dir: &Path, name: Option<&str>) -> EmbedderResult<Arc<dyn Embedder>> {
let registry = EmbedderRegistry::new(data_dir);
let embedder_info = match name {
Some(n) => registry.validate(n)?,
None => registry.best_available(),
};
load_embedder_by_name(data_dir, embedder_info.name)
}
fn load_embedder_by_name(data_dir: &Path, name: &str) -> EmbedderResult<Arc<dyn Embedder>> {
match name {
"hash" => {
let embedder = HashEmbedder::default();
Ok(Arc::new(embedder))
}
"minilm" | "snowflake-arctic-s" | "nomic-embed" => {
let embedder = FastEmbedder::load_by_name(data_dir, name)?;
Ok(Arc::new(embedder))
}
_ => Err(embedder_unavailable(name, "embedder not implemented")),
}
}
fn embedder_unavailable(model: &str, reason: impl Into<String>) -> EmbedderError {
EmbedderError::EmbedderUnavailable {
model: model.to_string(),
reason: reason.into(),
}
}
pub fn get_embedder_info(data_dir: &Path, name: Option<&str>) -> Option<EmbedderInfo> {
let registry = EmbedderRegistry::new(data_dir);
let embedder_info = match name {
Some(n) => registry.get(n)?,
None => registry.best_available(),
};
Some(EmbedderInfo {
id: embedder_info.id.to_string(),
dimension: embedder_info.dimension,
is_semantic: embedder_info.is_semantic,
})
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::{TempDir, tempdir};
fn registry_fixture() -> (TempDir, EmbedderRegistry) {
let tmp = tempdir().unwrap();
let registry = EmbedderRegistry::new(tmp.path());
(tmp, registry)
}
#[test]
fn test_registry_all() {
let (_tmp, registry) = registry_fixture();
assert!(registry.all().len() >= 2);
}
#[test]
fn test_registry_get_by_name() {
let (_tmp, registry) = registry_fixture();
let minilm = registry.get("minilm");
assert!(minilm.is_some());
assert_eq!(minilm.unwrap().dimension, 384);
let hash = registry.get("hash");
assert!(hash.is_some());
assert_eq!(hash.unwrap().dimension, 384);
let unknown = registry.get("unknown");
assert!(unknown.is_none());
}
#[test]
fn test_registry_get_by_id() {
let (_tmp, registry) = registry_fixture();
let minilm = registry.get("minilm-384");
assert!(minilm.is_some());
assert_eq!(minilm.unwrap().name, "minilm");
let hash = registry.get("fnv1a-384");
assert!(hash.is_some());
assert_eq!(hash.unwrap().name, "hash");
}
#[test]
fn test_hash_always_available() {
let (_tmp, registry) = registry_fixture();
assert!(registry.is_available("hash"));
let available = registry.available();
assert!(available.iter().any(|e| e.name == "hash"));
}
#[test]
fn test_minilm_unavailable_without_files() {
let (_tmp, registry) = registry_fixture();
assert!(!registry.is_available("minilm"));
let result = registry.validate("minilm");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, EmbedderError::EmbedderUnavailable { .. }));
}
#[test]
fn test_embedder_unavailable_helper_shape() {
let err = embedder_unavailable("demo", "missing model");
match err {
EmbedderError::EmbedderUnavailable { model, reason } => {
assert_eq!(model, "demo");
assert_eq!(reason, "missing model");
}
other => panic!("unexpected error shape: {other:?}"),
}
}
#[test]
fn test_best_available_fallback() {
let (_tmp, registry) = registry_fixture();
let best = registry.best_available();
assert_eq!(best.name, "hash");
}
#[test]
fn test_get_embedder_hash() {
let tmp = tempdir().unwrap();
let embedder = get_embedder(tmp.path(), Some("hash")).unwrap();
assert_eq!(embedder.id(), "fnv1a-384");
assert!(!embedder.is_semantic());
}
#[test]
fn test_get_embedder_default_no_models() {
let tmp = tempdir().unwrap();
let embedder = get_embedder(tmp.path(), None).unwrap();
assert_eq!(embedder.id(), "fnv1a-384");
}
#[test]
fn test_validate_unknown_embedder() {
let (_tmp, registry) = registry_fixture();
let result = registry.validate("nonexistent");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("unknown embedder"));
assert!(err.to_string().contains("Available:"));
}
#[test]
fn test_registered_embedder_missing_files() {
let (tmp, registry) = registry_fixture();
let minilm = registry.get("minilm").unwrap();
let missing = minilm.missing_files(tmp.path());
assert!(!missing.is_empty());
assert!(missing.contains(&"model.onnx".to_string()));
}
#[test]
fn test_get_embedder_info() {
let tmp = tempdir().unwrap();
let hash_info = get_embedder_info(tmp.path(), Some("hash")).unwrap();
assert_eq!(hash_info.id, "fnv1a-384");
assert!(!hash_info.is_semantic);
let minilm_info = get_embedder_info(tmp.path(), Some("minilm")).unwrap();
assert_eq!(minilm_info.id, "minilm-384");
assert!(minilm_info.is_semantic);
}
#[test]
fn test_bakeoff_eligible_count() {
let (_tmp, registry) = registry_fixture();
let eligible = registry.bakeoff_eligible();
assert_eq!(
eligible.len(),
2,
"Expected 2 eligible models, got {}",
eligible.len()
);
assert!(
!eligible.iter().any(|e| e.name == "minilm"),
"minilm should not be in eligible list"
);
assert!(
!eligible.iter().any(|e| e.name == "hash"),
"hash should not be in eligible list"
);
assert!(
eligible.iter().any(|e| e.name == "snowflake-arctic-s"),
"snowflake should be in eligible list"
);
assert!(
eligible.iter().any(|e| e.name == "nomic-embed"),
"nomic should be in eligible list"
);
}
#[test]
fn test_baseline_embedder() {
let (_tmp, registry) = registry_fixture();
let baseline = registry.baseline_embedder();
assert!(baseline.is_some());
let baseline = baseline.unwrap();
assert_eq!(baseline.name, "minilm");
assert!(baseline.is_baseline);
assert!(!baseline.is_bakeoff_eligible());
}
#[test]
fn test_bakeoff_eligibility_by_date() {
let (_tmp, registry) = registry_fixture();
let minilm = registry.get("minilm").unwrap();
assert!(
minilm.release_date < BAKEOFF_ELIGIBILITY_CUTOFF,
"minilm should be released before cutoff"
);
for e in registry.bakeoff_eligible() {
assert!(
e.release_date >= BAKEOFF_ELIGIBILITY_CUTOFF,
"{} should be released after cutoff (date: {})",
e.name,
e.release_date
);
}
}
#[test]
fn test_bakeoff_model_metadata_conversion() {
let (_tmp, registry) = registry_fixture();
let minilm = registry.get("minilm").unwrap();
let metadata = minilm.to_model_metadata();
assert_eq!(metadata.id, "minilm-384");
assert_eq!(metadata.name, "minilm");
assert!(metadata.source.contains("MiniLM"));
assert_eq!(metadata.release_date, "2022-08-01");
assert_eq!(metadata.dimension, Some(384));
assert!(metadata.is_baseline);
assert!(!metadata.is_eligible());
}
#[test]
fn test_eligible_embedder_metadata() {
let (_tmp, registry) = registry_fixture();
let snowflake = registry.get("snowflake-arctic-s").unwrap();
assert!(snowflake.is_bakeoff_eligible());
let metadata = snowflake.to_model_metadata();
assert!(!metadata.is_baseline);
assert!(metadata.is_eligible());
assert_eq!(metadata.dimension, Some(384));
let nomic = registry.get("nomic-embed").unwrap();
assert!(nomic.is_bakeoff_eligible());
let metadata = nomic.to_model_metadata();
assert!(!metadata.is_baseline);
assert!(metadata.is_eligible());
assert_eq!(metadata.dimension, Some(768));
}
#[test]
fn test_all_embedders_have_required_fields() {
for e in EMBEDDERS.iter() {
assert!(
!e.release_date.is_empty(),
"{} should have a release date",
e.name
);
if e.is_semantic && e.requires_model_files {
assert!(
!e.huggingface_id.is_empty(),
"{} should have a huggingface_id",
e.name
);
}
assert!(e.dimension >= 256 && e.dimension <= 2048);
}
}
#[test]
fn test_model_dir_for_all_embedders() {
let tmp = tempdir().unwrap();
for e in EMBEDDERS.iter() {
if e.requires_model_files {
let dir = e.model_dir(tmp.path());
assert!(dir.is_some(), "{} should have a model directory", e.name);
let dir = dir.unwrap();
assert!(
dir.starts_with(tmp.path().join("models")),
"{} model dir should be under models/",
e.name
);
}
}
}
}