use anyhow::Result;
use hf_hub::api::sync::ApiBuilder;
use std::path::PathBuf;
pub const DEFAULT_MODEL: &str = "lightonai/LateOn-Code-edge";
const REQUIRED_FILES: &[&str] = &[
"model_int8.onnx",
"tokenizer.json",
"config_sentence_transformers.json",
"config.json",
"onnx_config.json",
];
const OPTIONAL_FILES: &[&str] = &["model.onnx"];
pub fn ensure_model(model_id: Option<&str>, _quiet: bool) -> Result<PathBuf> {
let model_id = model_id.unwrap_or(DEFAULT_MODEL);
let local_path = PathBuf::from(model_id);
if local_path.exists() && local_path.is_dir() {
return Ok(local_path);
}
let mut builder = ApiBuilder::from_env();
let token_from_env = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok()
.map(|t| t.trim_matches('"').trim_matches('\'').to_string());
if token_from_env.is_some() {
builder = builder.with_token(token_from_env);
}
let api = builder.build()?;
let repo = api.model(model_id.to_string());
let mut model_dir = None;
for file in REQUIRED_FILES {
match repo.get(file) {
Ok(path) => {
if model_dir.is_none() {
model_dir = path.parent().map(|p| p.to_path_buf());
}
}
Err(e) => {
if *file != "config.json" {
return Err(e.into());
}
}
}
}
for file in OPTIONAL_FILES {
let _ = repo.get(file);
}
model_dir.ok_or_else(|| anyhow::anyhow!("Failed to determine model directory"))
}