use crate::{Error, Result};
pub fn no_downloads() -> bool {
match std::env::var("ANNO_NO_DOWNLOADS") {
Ok(v) => matches!(
v.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "y" | "on"
),
Err(_) => false,
}
}
pub fn hf_api() -> Result<hf_hub::api::sync::Api> {
use hf_hub::api::sync::{Api, ApiBuilder};
crate::env::load_dotenv();
if let Some(token) = crate::env::hf_token() {
ApiBuilder::new()
.with_token(Some(token))
.build()
.map_err(|e| Error::Retrieval(format!("HuggingFace API init with token: {}", e)))
} else {
Api::new().map_err(|e| Error::Retrieval(format!("HuggingFace API init: {}", e)))
}
}
pub fn download_model_file(
repo: &hf_hub::api::sync::ApiRepo,
candidates: &[&str],
) -> Result<std::path::PathBuf> {
if candidates.is_empty() {
return Err(Error::Retrieval(
"download_model_file: candidates must not be empty".to_string(),
));
}
if no_downloads() {
for candidate in candidates {
if let Some(path) = hf_hub::Cache::default()
.repo(hf_hub::Repo::model(repo_id_of(repo)))
.get(candidate)
{
return Ok(path);
}
}
return Err(Error::Retrieval(format!(
"ANNO_NO_DOWNLOADS is set and none of [{}] are present in the \
HuggingFace cache. Pre-fetch the model (unset ANNO_NO_DOWNLOADS \
and re-run once), or skip this backend.",
candidates.join(", "),
)));
}
let mut last_err = None;
for candidate in candidates {
match repo.get(candidate) {
Ok(path) => return Ok(path),
Err(e) => last_err = Some(e),
}
}
Err(Error::Retrieval(format!(
"Failed to download any of [{}]: {}",
candidates.join(", "),
last_err
.map(|e| e.to_string())
.unwrap_or_else(|| "unknown".to_string())
)))
}
fn repo_id_of(repo: &hf_hub::api::sync::ApiRepo) -> String {
let dbg = format!("{:?}", repo);
if let Some(start) = dbg.find("repo_id: \"") {
let rest = &dbg[start + "repo_id: \"".len()..];
if let Some(end) = rest.find('"') {
return rest[..end].to_string();
}
}
String::new()
}
pub fn download_onnx_model(
repo: &hf_hub::api::sync::ApiRepo,
prefer_quantized: bool,
) -> Result<(std::path::PathBuf, bool)> {
if prefer_quantized {
let quantized_candidates = [
"onnx/model_quantized.onnx",
"model_quantized.onnx",
"onnx/model_int8.onnx",
"model_int8.onnx",
];
for candidate in &quantized_candidates {
if let Ok(path) = repo.get(candidate) {
log::info!("[hf_loader] Using quantized model: {}", candidate);
return Ok((path, true));
}
}
}
let path = download_model_file(repo, &["onnx/model.onnx", "model.onnx"])?;
if prefer_quantized {
log::info!("[hf_loader] Using FP32 model (quantized not available)");
}
Ok((path, false))
}
#[cfg(feature = "onnx")]
#[derive(Debug, Clone)]
pub struct OnnxSessionConfig {
pub optimization_level: u8,
pub num_threads: usize,
pub use_cpu_provider: bool,
#[cfg_attr(not(feature = "onnx-coreml"), allow(dead_code))]
pub prefer_coreml: bool,
}
#[cfg(feature = "onnx")]
impl Default for OnnxSessionConfig {
fn default() -> Self {
Self {
optimization_level: 3,
num_threads: 0,
use_cpu_provider: true,
prefer_coreml: false,
}
}
}
#[cfg(feature = "onnx")]
pub fn create_onnx_session(
model_path: &std::path::Path,
config: OnnxSessionConfig,
) -> Result<ort::session::Session> {
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
let opt_level = match config.optimization_level {
1 => GraphOptimizationLevel::Level1,
2 => GraphOptimizationLevel::Level2,
_ => GraphOptimizationLevel::Level3,
};
let mut builder = Session::builder()
.map_err(|e| Error::Retrieval(format!("ONNX session builder: {}", e)))?
.with_optimization_level(opt_level)
.map_err(|e| Error::Retrieval(format!("ONNX optimization level: {}", e)))?;
let mut providers: Vec<ort::execution_providers::ExecutionProviderDispatch> = Vec::new();
#[cfg(feature = "onnx-coreml")]
if config.prefer_coreml {
use ort::execution_providers::CoreMLExecutionProvider;
providers.push(CoreMLExecutionProvider::default().build());
}
if config.use_cpu_provider {
use ort::execution_providers::CPUExecutionProvider;
providers.push(CPUExecutionProvider::default().build());
}
if !providers.is_empty() {
builder = builder
.with_execution_providers(providers)
.map_err(|e| Error::Retrieval(format!("ONNX execution providers: {}", e)))?;
}
if config.num_threads > 0 {
builder = builder
.with_intra_threads(config.num_threads)
.map_err(|e| Error::Retrieval(format!("ONNX thread config: {}", e)))?;
}
builder
.commit_from_file(model_path)
.map_err(|e| Error::Retrieval(format!("ONNX model load: {}", e)))
}
pub fn load_tokenizer(path: &std::path::Path) -> Result<tokenizers::Tokenizer> {
tokenizers::Tokenizer::from_file(path)
.map_err(|e| Error::Retrieval(format!("Tokenizer load: {}", e)))
}