use std::path::PathBuf;
use anyhow::{Context, Result};
use hf_hub::api::sync::Api;
pub struct HfModelFiles {
pub onnx: Vec<u8>,
pub tokenizer: Vec<u8>,
pub tokenizer_config: Vec<u8>,
#[allow(dead_code)]
pub special_tokens_map: Vec<u8>,
pub config: Vec<u8>,
}
pub fn fetch_user_defined_files(repo: &str, onnx_path: &str) -> Result<HfModelFiles> {
let api = Api::new().context("init hf-hub api")?;
let r = api.model(repo.to_string());
let onnx = read_bytes(
r.get(onnx_path)
.with_context(|| format!("fetch {repo}:{onnx_path}"))?,
)?;
let tokenizer = read_bytes(
r.get("tokenizer.json")
.with_context(|| format!("fetch {repo}:tokenizer.json"))?,
)?;
let tokenizer_config = read_bytes(
r.get("tokenizer_config.json")
.with_context(|| format!("fetch {repo}:tokenizer_config.json"))?,
)?;
let special_tokens_map = read_bytes(
r.get("special_tokens_map.json")
.with_context(|| format!("fetch {repo}:special_tokens_map.json"))?,
)?;
let config = read_bytes(
r.get("config.json")
.with_context(|| format!("fetch {repo}:config.json"))?,
)?;
Ok(HfModelFiles {
onnx,
tokenizer,
tokenizer_config,
special_tokens_map,
config,
})
}
fn read_bytes(p: PathBuf) -> Result<Vec<u8>> {
std::fs::read(&p).with_context(|| format!("read cached file {}", p.display()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn function_compiles() {
let _ = fetch_user_defined_files;
}
}