use std::path::{Path, PathBuf};
use anyhow::{Context, Result, bail};
use hf_hub::api::sync::{Api, ApiBuilder, ApiRepo, ApiError};
#[derive(Debug, Clone)]
pub struct DownloadConfig {
pub repo: String,
pub output_dir: PathBuf,
pub token: Option<String>,
pub overwrite: bool,
}
impl Default for DownloadConfig {
fn default() -> Self {
Self {
repo: "eugenehp/tribev2".into(),
output_dir: PathBuf::from("./weights"),
token: None,
overwrite: false,
}
}
}
#[derive(Debug, Clone)]
pub struct ModelFiles {
pub config: PathBuf,
pub weights: PathBuf,
pub weights_is_safetensors: bool,
pub build_args: Option<PathBuf>,
}
impl ModelFiles {
pub fn print_summary(&self) {
println!("\n── Downloaded files ─────────────────────────────");
println!(" config : {}", self.config.display());
println!(" weights : {}", self.weights.display());
if !self.weights_is_safetensors {
println!(" ⚠ weights are a PyTorch .ckpt — convert with:");
println!(" python3 -c \"");
println!(" import torch, safetensors.torch");
println!(" ckpt = torch.load('{}', map_location='cpu', weights_only=True)",
self.weights.display());
println!(" sd = {{k.removeprefix('model.'): v for k, v in ckpt['state_dict'].items()}}");
println!(" safetensors.torch.save_file(sd, '{}')",
self.weights.parent().unwrap_or(Path::new(".")).join("model.safetensors").display());
println!(" \"");
}
if let Some(ref ba) = self.build_args {
println!(" build_args : {}", ba.display());
}
println!("─────────────────────────────────────────────────");
}
}
pub fn download_model(cfg: &DownloadConfig) -> Result<ModelFiles> {
std::fs::create_dir_all(&cfg.output_dir)
.with_context(|| format!("creating output dir {:?}", cfg.output_dir))?;
let api = build_api(cfg)?;
let repo = api.model(cfg.repo.clone());
let config = fetch_file(&repo, "config.yaml", &cfg.output_dir, cfg.overwrite)
.context("downloading config.yaml")?;
let (weights, weights_is_safetensors) =
fetch_weights(&repo, &cfg.output_dir, cfg.overwrite)
.context("downloading model weights")?;
let build_args = fetch_optional_file(&repo, "build_args.json", &cfg.output_dir, cfg.overwrite)
.context("checking for build_args.json")?;
Ok(ModelFiles { config, weights, weights_is_safetensors, build_args })
}
fn build_api(cfg: &DownloadConfig) -> Result<Api> {
let mut builder = ApiBuilder::new().with_progress(true);
if let Some(ref token) = cfg.token {
builder = builder.with_token(Some(token.clone()));
}
builder.build().context("building HuggingFace API client")
}
fn fetch_file(
repo: &ApiRepo,
filename: &str,
out_dir: &Path,
overwrite: bool,
) -> Result<PathBuf> {
let dest = out_dir.join(filename);
if dest.exists() && !overwrite {
println!(" ✓ {filename} (already present, skipping)");
return Ok(dest);
}
println!(" ↓ {filename}");
let cached = repo.get(filename)
.with_context(|| format!("fetching {filename} from HF Hub"))?;
std::fs::copy(&cached, &dest)
.with_context(|| format!("copying {filename} to {}", dest.display()))?;
println!(" → {}", dest.display());
Ok(dest)
}
fn fetch_optional_file(
repo: &ApiRepo,
filename: &str,
out_dir: &Path,
overwrite: bool,
) -> Result<Option<PathBuf>> {
let dest = out_dir.join(filename);
if dest.exists() && !overwrite {
println!(" ✓ {filename} (already present, skipping)");
return Ok(Some(dest));
}
match repo.get(filename) {
Ok(cached) => {
println!(" ↓ {filename}");
std::fs::copy(&cached, &dest)
.with_context(|| format!("copying {filename} to {}", dest.display()))?;
println!(" → {}", dest.display());
Ok(Some(dest))
}
Err(e) if is_not_found(&e) => {
println!(" – {filename} not found in repo (optional, skipping)");
Ok(None)
}
Err(e) => Err(e).with_context(|| format!("fetching optional file {filename}")),
}
}
fn fetch_weights(
repo: &ApiRepo,
out_dir: &Path,
overwrite: bool,
) -> Result<(PathBuf, bool)> {
let st_dest = out_dir.join("model.safetensors");
if st_dest.exists() && !overwrite {
println!(" ✓ model.safetensors (already present, skipping)");
return Ok((st_dest, true));
}
match repo.get("model.safetensors") {
Ok(cached) => {
println!(" ↓ model.safetensors");
std::fs::copy(&cached, &st_dest)
.context("copying model.safetensors")?;
println!(" → {}", st_dest.display());
return Ok((st_dest, true));
}
Err(e) if is_not_found(&e) => {
println!(" – model.safetensors not found, trying best.ckpt …");
}
Err(e) => {
return Err(e).context("fetching model.safetensors");
}
}
let ckpt_dest = out_dir.join("best.ckpt");
if ckpt_dest.exists() && !overwrite {
println!(" ✓ best.ckpt (already present, skipping)");
return Ok((ckpt_dest, false));
}
match repo.get("best.ckpt") {
Ok(cached) => {
println!(" ↓ best.ckpt");
std::fs::copy(&cached, &ckpt_dest)
.context("copying best.ckpt")?;
println!(" → {}", ckpt_dest.display());
Ok((ckpt_dest, false))
}
Err(e) if is_not_found(&e) => {
bail!("neither model.safetensors nor best.ckpt found in repo '{}'", "?");
}
Err(e) => Err(e).context("fetching best.ckpt"),
}
}
fn is_not_found(e: &ApiError) -> bool {
let msg = e.to_string().to_lowercase();
msg.contains("404") || msg.contains("not found") || msg.contains("entry not found")
}