use std::path::PathBuf;
use grafeo_common::utils::error::{Error, Result};
use hf_hub::api::sync::Api;
use super::config::{EmbeddingModelConfig, ResolveInfo};
pub(crate) struct ResolvedModel {
pub model_path: PathBuf,
pub tokenizer_path: PathBuf,
pub name: String,
}
pub(crate) fn resolve(config: &EmbeddingModelConfig) -> Result<ResolvedModel> {
let name = config.display_name();
match config.resolve_info() {
ResolveInfo::Local {
model_path,
tokenizer_path,
} => Ok(ResolvedModel {
model_path: model_path.clone(),
tokenizer_path: tokenizer_path.clone(),
name,
}),
ResolveInfo::Hub {
repo_id,
model_file,
tokenizer_file,
} => {
let api = Api::new().map_err(|e| {
Error::Internal(format!("Failed to initialize HuggingFace Hub client: {e}"))
})?;
let repo = api.model(repo_id.to_string());
let model_path = repo.get(model_file).map_err(|e| {
Error::Internal(format!(
"Failed to download model file '{model_file}' from '{repo_id}': {e}"
))
})?;
let tokenizer_path = repo.get(tokenizer_file).map_err(|e| {
Error::Internal(format!(
"Failed to download tokenizer file '{tokenizer_file}' from '{repo_id}': {e}"
))
})?;
Ok(ResolvedModel {
model_path,
tokenizer_path,
name,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn local_config_resolves_directly() {
let config = EmbeddingModelConfig::Local {
model_path: "/some/path/model.onnx".into(),
tokenizer_path: "/some/path/tokenizer.json".into(),
};
let resolved = resolve(&config).unwrap();
assert_eq!(resolved.model_path, PathBuf::from("/some/path/model.onnx"));
assert_eq!(
resolved.tokenizer_path,
PathBuf::from("/some/path/tokenizer.json")
);
assert_eq!(resolved.name, "model");
}
#[test]
#[ignore = "requires network access (~23MB download on first run)"]
fn preset_downloads_minilm_l6() {
let config = EmbeddingModelConfig::MiniLmL6v2;
let resolved = resolve(&config).unwrap();
assert!(resolved.model_path.exists());
assert!(resolved.tokenizer_path.exists());
assert_eq!(resolved.name, "all-MiniLM-L6-v2");
}
}