use std::collections::{HashMap, HashSet};
use std::io::IsTerminal;
use std::path::{Path, PathBuf};
use std::process::ExitCode;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use clap::{Args, Parser, Subcommand, ValueEnum};
use tracing_subscriber::EnvFilter;
use hf_fetch_model::cache;
use hf_fetch_model::discover;
use hf_fetch_model::inspect;
use hf_fetch_model::progress::IndicatifProgress;
use hf_fetch_model::repo;
use hf_fetch_model::{
compile_glob_patterns, file_matches, has_glob_chars, FetchConfig, FetchError, Filter,
};
#[derive(Parser)]
#[command(
name = "hf-fetch-model",
version,
about,
before_help = concat!("hf-fetch-model v", env!("CARGO_PKG_VERSION"))
)]
#[command(args_conflicts_with_subcommands = true)]
struct Cli {
#[command(subcommand)]
command: Option<Commands>,
#[command(flatten)]
download: DownloadArgs,
}
#[derive(Args)]
struct DownloadArgs {
#[arg(short, long)]
verbose: bool,
#[arg(value_name = "REPO_ID")]
repo_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
token: Option<String>,
#[arg(long, action = clap::ArgAction::Append)]
filter: Vec<String>,
#[arg(long, action = clap::ArgAction::Append)]
exclude: Vec<String>,
#[arg(long, value_enum)]
preset: Option<Preset>,
#[arg(long)]
output_dir: Option<PathBuf>,
#[arg(long)]
concurrency: Option<usize>,
#[arg(long)]
chunk_threshold_mib: Option<u64>,
#[arg(long)]
connections_per_file: Option<usize>,
#[arg(long)]
dry_run: bool,
#[arg(long)]
flat: bool,
}
#[derive(Subcommand)]
enum Commands {
ListFamilies,
Discover {
#[arg(long, default_value = "500")]
limit: usize,
},
Search {
query: String,
#[arg(long, default_value = "20")]
limit: usize,
#[arg(long)]
exact: bool,
#[arg(long)]
library: Option<String>,
#[arg(long)]
pipeline: Option<String>,
},
Info {
repo_id: String,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
token: Option<String>,
#[arg(long)]
json: bool,
#[arg(long, default_value = "40")]
lines: usize,
},
DownloadFile {
#[arg(short, long)]
verbose: bool,
repo_id: String,
filename: String,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
token: Option<String>,
#[arg(long)]
output_dir: Option<PathBuf>,
#[arg(long)]
chunk_threshold_mib: Option<u64>,
#[arg(long)]
connections_per_file: Option<usize>,
#[arg(long)]
flat: bool,
},
Status {
repo_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
token: Option<String>,
},
Diff {
repo_a: String,
repo_b: String,
#[arg(long)]
revision_a: Option<String>,
#[arg(long)]
revision_b: Option<String>,
#[arg(long)]
token: Option<String>,
#[arg(long)]
cached: bool,
#[arg(long)]
filter: Option<String>,
#[arg(long)]
summary: bool,
#[arg(long)]
json: bool,
},
Du {
repo_id: Option<String>,
},
Inspect {
repo_id: String,
filename: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
token: Option<String>,
#[arg(long)]
cached: bool,
#[arg(long)]
no_metadata: bool,
#[arg(long)]
json: bool,
#[arg(long)]
filter: Option<String>,
},
ListFiles {
repo_id: String,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
token: Option<String>,
#[arg(long, action = clap::ArgAction::Append)]
filter: Vec<String>,
#[arg(long, action = clap::ArgAction::Append)]
exclude: Vec<String>,
#[arg(long, value_enum)]
preset: Option<Preset>,
#[arg(long)]
no_checksum: bool,
#[arg(long)]
show_cached: bool,
},
}
#[derive(Clone, ValueEnum)]
enum Preset {
Safetensors,
Gguf,
Pth,
ConfigOnly,
}
fn main() -> ExitCode {
let cli = Cli::parse();
let verbose = match &cli.command {
Some(Commands::DownloadFile { verbose, .. }) => *verbose,
None => cli.download.verbose,
Some(
Commands::ListFamilies
| Commands::Discover { .. }
| Commands::Search { .. }
| Commands::Info { .. }
| Commands::Status { .. }
| Commands::Diff { .. }
| Commands::Du { .. }
| Commands::Inspect { .. }
| Commands::ListFiles { .. },
) => false,
};
if verbose {
let filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("hf_fetch_model=debug"));
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(false)
.with_writer(std::io::stderr)
.init();
}
match run(cli) {
Ok(()) => ExitCode::SUCCESS,
Err(FetchError::PartialDownload { path, failures }) => {
eprintln!();
eprintln!("error: {} file(s) failed to download:", failures.len());
for f in &failures {
eprintln!(" - {}: {}", f.filename, f.reason);
}
if let Some(p) = path {
eprintln!();
eprintln!("Partial download at: {}", p.display());
}
let any_retryable = failures.iter().any(|f| f.retryable);
if any_retryable {
eprintln!();
eprintln!(
"hint: re-run the same command to retry failed files \
(already-downloaded files will be skipped)"
);
}
ExitCode::FAILURE
}
Err(FetchError::RepoNotFound { ref repo_id }) => {
eprintln!(
"error: {e}",
e = FetchError::RepoNotFound {
repo_id: repo_id.clone()
}
);
let search_term = repo_id.split('/').nth(1).unwrap_or(repo_id.as_str());
eprintln!("hint: try `hf-fm search {search_term}` to find matching models");
ExitCode::FAILURE
}
Err(e) => {
eprintln!("error: {e}");
ExitCode::FAILURE
}
}
}
fn run(cli: Cli) -> Result<(), FetchError> {
match cli.command {
Some(Commands::ListFamilies) => run_list_families(),
Some(Commands::Discover { limit }) => run_discover(limit),
Some(Commands::Search {
query,
limit,
exact,
library,
pipeline,
}) => run_search(
query.as_str(),
limit,
exact,
library.as_deref(),
pipeline.as_deref(),
),
Some(Commands::Info {
repo_id,
revision,
token,
json,
lines,
}) => run_info(
repo_id.as_str(),
revision.as_deref(),
token.as_deref(),
json,
lines,
),
Some(Commands::DownloadFile {
verbose: _,
repo_id,
filename,
revision,
token,
output_dir,
chunk_threshold_mib,
connections_per_file,
flat,
}) => run_download_file(DownloadFileParams {
repo_id: repo_id.as_str(),
filename: filename.as_str(),
revision: revision.as_deref(),
token: token.as_deref(),
output_dir,
chunk_threshold_mib,
connections_per_file,
flat,
}),
Some(Commands::Status {
repo_id: Some(repo_id),
revision,
token,
}) => run_status(repo_id.as_str(), revision.as_deref(), token.as_deref()),
Some(Commands::Status { repo_id: None, .. }) => run_status_all(),
Some(Commands::Diff {
repo_a,
repo_b,
revision_a,
revision_b,
token,
cached,
filter,
summary,
json,
}) => run_diff(
repo_a.as_str(),
repo_b.as_str(),
revision_a.as_deref(),
revision_b.as_deref(),
token.as_deref(),
cached,
filter.as_deref(),
summary,
json,
),
Some(Commands::Du {
repo_id: Some(repo_id),
}) => run_du_repo(repo_id.as_str()),
Some(Commands::Du { repo_id: None }) => run_du(),
Some(Commands::Inspect {
repo_id,
filename,
revision,
token,
cached,
no_metadata,
json,
filter,
}) => run_inspect(
repo_id.as_str(),
filename.as_deref(),
revision.as_deref(),
token.as_deref(),
cached,
no_metadata,
json,
filter.as_deref(),
),
Some(Commands::ListFiles {
repo_id,
revision,
token,
filter,
exclude,
preset,
no_checksum,
show_cached,
}) => run_list_files(
repo_id.as_str(),
revision.as_deref(),
token.as_deref(),
&filter,
&exclude,
preset.as_ref(),
no_checksum,
show_cached,
),
None => run_download(cli.download),
}
}
struct NonTtyProgress {
last_report: Mutex<Instant>,
last_bucket: Mutex<HashMap<String, u64>>,
}
impl NonTtyProgress {
fn new() -> Self {
Self {
last_report: Mutex::new(Instant::now()),
last_bucket: Mutex::new(HashMap::new()),
}
}
fn handle(&self, event: &hf_fetch_model::progress::ProgressEvent) {
if event.percent >= 100.0 {
return;
}
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::as_conversions
)]
let bucket = (event.percent / 10.0) as u64;
let elapsed_ok = self
.last_report
.lock()
.is_ok_and(|guard| guard.elapsed().as_secs() >= 5);
let bucket_crossed = self.last_bucket.lock().is_ok_and(|mut map| {
let prev = map.entry(event.filename.clone()).or_insert(0);
if bucket > *prev {
*prev = bucket;
true
} else {
false
}
});
if elapsed_ok || bucket_crossed {
if let Ok(mut ts) = self.last_report.lock() {
*ts = Instant::now();
}
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::as_conversions
)]
let pct = event.percent as u64;
eprintln!(
"[hf-fm] {}: {}/{} ({pct}%)",
event.filename,
format_size(event.bytes_downloaded),
format_size(event.bytes_total)
);
}
}
}
fn run_download(args: DownloadArgs) -> Result<(), FetchError> {
let dry_run = args.dry_run;
let repo_id = args.repo_id.as_deref().ok_or_else(|| {
FetchError::InvalidArgument(
"REPO_ID is required for download. Usage: hf-fm <REPO_ID>".to_owned(),
)
})?;
if !repo_id.contains('/') {
return Err(FetchError::InvalidArgument(format!(
"invalid REPO_ID \"{repo_id}\": expected \"org/model\" format (e.g., \"EleutherAI/pythia-1.4b\")"
)));
}
if dry_run {
return run_dry_run(repo_id, &args);
}
let repo_id = repo_id.to_owned();
let flat = args.flat;
let flat_target = if flat { args.output_dir.clone() } else { None };
let mut builder = match args.preset {
Some(Preset::Safetensors) => Filter::safetensors(),
Some(Preset::Gguf) => Filter::gguf(),
Some(Preset::Pth) => Filter::pth(),
Some(Preset::ConfigOnly) => Filter::config_only(),
None => FetchConfig::builder(),
};
if let Some(ref preset) = args.preset {
warn_redundant_filters(preset, &args.filter);
}
if let Some(rev) = args.revision.as_deref() {
builder = builder.revision(rev);
}
if let Some(tok) = args.token.as_deref() {
builder = builder.token(tok);
} else {
builder = builder.token_from_env();
}
for pattern in &args.filter {
builder = builder.filter(pattern.as_str());
}
for pattern in &args.exclude {
builder = builder.exclude(pattern.as_str());
}
if let Some(c) = args.concurrency {
builder = builder.concurrency(c);
}
if let Some(ct) = args.chunk_threshold_mib {
builder = builder.chunk_threshold(ct.saturating_mul(1024 * 1024));
}
if let Some(cpf) = args.connections_per_file {
builder = builder.connections_per_file(cpf);
}
if !flat {
if let Some(dir) = args.output_dir {
builder = builder.output_dir(dir);
}
}
let is_tty = std::io::stderr().is_terminal();
let indicatif = if is_tty {
let p = Arc::new(IndicatifProgress::new());
let handle = Arc::clone(&p);
builder = builder.on_progress(move |e| handle.handle(e));
Some(p)
} else {
let p = Arc::new(NonTtyProgress::new());
let handle = Arc::clone(&p);
builder = builder.on_progress(move |e| handle.handle(e));
None
};
let config = builder.build()?;
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
let start = Instant::now();
if flat {
let outcome = rt.block_on(hf_fetch_model::download_files_with_config(repo_id, &config))?;
let elapsed = start.elapsed();
if let Some(ref p) = indicatif {
p.finish();
}
let file_map = outcome.inner();
let target_dir = resolve_flat_target(flat_target.as_deref())?;
let flat_paths = flatten_files(file_map, &target_dir)?;
println!(
"{} file(s) copied to {}:",
flat_paths.len(),
target_dir.display()
);
for p in &flat_paths {
println!(" {}", p.display());
}
print_download_summary(&target_dir, elapsed);
} else {
let outcome = rt.block_on(hf_fetch_model::download_with_config(repo_id, &config))?;
let elapsed = start.elapsed();
if let Some(ref p) = indicatif {
p.finish();
}
if outcome.is_cached() {
println!("Cached at: {}", outcome.inner().display());
} else {
println!("Downloaded to: {}", outcome.inner().display());
print_download_summary(outcome.inner(), elapsed);
}
}
Ok(())
}
fn run_dry_run(repo_id: &str, args: &DownloadArgs) -> Result<(), FetchError> {
let mut builder = match args.preset {
Some(Preset::Safetensors) => Filter::safetensors(),
Some(Preset::Gguf) => Filter::gguf(),
Some(Preset::Pth) => Filter::pth(),
Some(Preset::ConfigOnly) => Filter::config_only(),
None => FetchConfig::builder(),
};
if let Some(ref preset) = args.preset {
warn_redundant_filters(preset, &args.filter);
}
if let Some(rev) = args.revision.as_deref() {
builder = builder.revision(rev);
}
if let Some(tok) = args.token.as_deref() {
builder = builder.token(tok);
} else {
builder = builder.token_from_env();
}
for pattern in &args.filter {
builder = builder.filter(pattern.as_str());
}
for pattern in &args.exclude {
builder = builder.exclude(pattern.as_str());
}
if let Some(ref dir) = args.output_dir {
builder = builder.output_dir(dir.clone());
}
let config = builder.build()?;
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
let plan = rt.block_on(hf_fetch_model::download_plan(repo_id, &config))?;
println!(" Repo: {}", plan.repo_id);
println!(" Revision: {}", plan.revision);
if args.preset.is_some() || !args.filter.is_empty() {
println!(" Filter: active (preset or --filter)");
}
if args.flat {
let target = resolve_flat_target(args.output_dir.as_deref())?;
println!(
" Flat: {} (files will be copied here)",
target.display()
);
}
println!();
println!(" {:<48} {:>10} Status", "File", "Size");
println!(
" {:\u{2500}<48} {:\u{2500}<10} {:\u{2500}<12}",
"", "", ""
);
for fp in &plan.files {
let status = if fp.cached {
"cached \u{2713}"
} else {
"to download"
};
println!(
" {:<48} {:>10} {status}",
fp.filename,
format_size(fp.size)
);
}
println!("{:\u{2500}<74}", " ");
let cached_count = plan.files.len() - plan.files_to_download();
let to_dl = plan.files_to_download();
println!(
" Total: {} ({} files, {} cached, {} to download)",
format_size(plan.total_bytes),
plan.files.len(),
cached_count,
to_dl
);
println!(" Download: {}", format_size(plan.download_bytes));
if !plan.fully_cached() {
let rec = plan.recommended_config()?;
println!();
println!(" Recommended config:");
println!(" concurrency: {}", rec.concurrency());
println!(" connections/file: {}", rec.connections_per_file());
if rec.chunk_threshold() == u64::MAX {
println!(" chunk threshold: disabled (single-connection per file)");
} else {
println!(
" chunk threshold: {} MiB",
rec.chunk_threshold() / 1_048_576
);
}
}
Ok(())
}
struct DownloadFileParams<'a> {
repo_id: &'a str,
filename: &'a str,
revision: Option<&'a str>,
token: Option<&'a str>,
output_dir: Option<PathBuf>,
chunk_threshold_mib: Option<u64>,
connections_per_file: Option<usize>,
flat: bool,
}
fn run_download_file(params: DownloadFileParams<'_>) -> Result<(), FetchError> {
let DownloadFileParams {
repo_id,
filename,
revision,
token,
output_dir,
chunk_threshold_mib,
connections_per_file,
flat,
} = params;
if !repo_id.contains('/') {
return Err(FetchError::InvalidArgument(format!(
"invalid REPO_ID \"{repo_id}\": expected \"org/model\" format (e.g., \"mntss/clt-gemma-2-2b-426k\")"
)));
}
if has_glob_chars(filename) {
return run_download_file_glob(DownloadFileParams {
repo_id,
filename,
revision,
token,
output_dir,
chunk_threshold_mib,
connections_per_file,
flat,
});
}
let flat_target = if flat { output_dir.clone() } else { None };
let mut builder = FetchConfig::builder();
if let Some(rev) = revision {
builder = builder.revision(rev);
}
if let Some(tok) = token {
builder = builder.token(tok);
} else {
builder = builder.token_from_env();
}
if let Some(ct) = chunk_threshold_mib {
builder = builder.chunk_threshold(ct.saturating_mul(1024 * 1024));
}
if let Some(cpf) = connections_per_file {
builder = builder.connections_per_file(cpf);
}
if !flat {
if let Some(dir) = output_dir {
builder = builder.output_dir(dir);
}
}
let is_tty = std::io::stderr().is_terminal();
let indicatif = if is_tty {
let p = Arc::new(IndicatifProgress::new());
let handle = Arc::clone(&p);
builder = builder.on_progress(move |e| handle.handle(e));
Some(p)
} else {
let p = Arc::new(NonTtyProgress::new());
let handle = Arc::clone(&p);
builder = builder.on_progress(move |e| handle.handle(e));
None
};
let config = builder.build()?;
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
let start = Instant::now();
let outcome = rt.block_on(hf_fetch_model::download_file(
repo_id.to_owned(),
filename,
&config,
))?;
let elapsed = start.elapsed();
if let Some(ref p) = indicatif {
p.finish();
}
if flat {
let target_dir = resolve_flat_target(flat_target.as_deref())?;
let flat_path = flatten_single_file(outcome.inner(), &target_dir)?;
println!("Copied to: {}", flat_path.display());
} else if outcome.is_cached() {
println!("Cached at: {}", outcome.inner().display());
} else {
println!("Downloaded to: {}", outcome.inner().display());
print_download_summary(outcome.inner(), elapsed);
}
Ok(())
}
fn run_download_file_glob(params: DownloadFileParams<'_>) -> Result<(), FetchError> {
let DownloadFileParams {
repo_id,
filename: pattern,
revision,
token,
output_dir,
chunk_threshold_mib,
connections_per_file,
flat,
} = params;
let flat_target = if flat { output_dir.clone() } else { None };
let mut builder = FetchConfig::builder().filter(pattern);
if let Some(rev) = revision {
builder = builder.revision(rev);
}
if let Some(tok) = token {
builder = builder.token(tok);
} else {
builder = builder.token_from_env();
}
if let Some(ct) = chunk_threshold_mib {
builder = builder.chunk_threshold(ct.saturating_mul(1024 * 1024));
}
if let Some(cpf) = connections_per_file {
builder = builder.connections_per_file(cpf);
}
if !flat {
if let Some(dir) = output_dir {
builder = builder.output_dir(dir);
}
}
let is_tty = std::io::stderr().is_terminal();
let indicatif = if is_tty {
let p = Arc::new(IndicatifProgress::new());
let handle = Arc::clone(&p);
builder = builder.on_progress(move |e| handle.handle(e));
Some(p)
} else {
let p = Arc::new(NonTtyProgress::new());
let handle = Arc::clone(&p);
builder = builder.on_progress(move |e| handle.handle(e));
None
};
let config = builder.build()?;
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
let start = Instant::now();
let outcome = rt.block_on(hf_fetch_model::download_files_with_config(
repo_id.to_owned(),
&config,
))?;
let elapsed = start.elapsed();
if let Some(ref p) = indicatif {
p.finish();
}
let file_map = outcome.inner();
if file_map.is_empty() {
println!("No files matched pattern \"{pattern}\" in {repo_id}");
return Ok(());
}
if flat {
let target_dir = resolve_flat_target(flat_target.as_deref())?;
let flat_paths = flatten_files(file_map, &target_dir)?;
println!(
"{} file(s) copied to {}:",
flat_paths.len(),
target_dir.display()
);
for p in &flat_paths {
println!(" {}", p.display());
}
} else {
println!("{} file(s) matched pattern \"{pattern}\":", file_map.len());
for (name, path) in file_map {
println!(" {name}: {}", path.display());
}
}
let elapsed_secs = elapsed.as_secs_f64();
if elapsed_secs > 0.0 {
println!(" completed in {elapsed_secs:.1}s");
}
Ok(())
}
fn run_list_families() -> Result<(), FetchError> {
let families = cache::list_cached_families()?;
if families.is_empty() {
println!("No model families found in local cache.");
return Ok(());
}
println!("{:<16}Models", "Family");
println!("{:-<16}{:-<64}", "", "");
for (model_type, repos) in &families {
let repos_str = repos.join(", ");
println!("{model_type:<16}{repos_str}");
}
Ok(())
}
fn run_discover(limit: usize) -> Result<(), FetchError> {
let families = cache::list_cached_families()?;
let local_types: HashSet<String> = families.into_keys().collect();
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
let discovered = rt.block_on(discover::discover_new_families(&local_types, limit))?;
if discovered.is_empty() {
println!("No new model families found.");
return Ok(());
}
println!("New families not in local cache (top models by downloads):\n");
println!("{:<16}Top Model", "Family");
println!("{:-<16}{:-<64}", "", "");
for family in &discovered {
println!("{:<16}{}", family.model_type, family.top_model);
}
Ok(())
}
fn run_search(
query: &str,
limit: usize,
exact: bool,
library: Option<&str>,
pipeline: Option<&str>,
) -> Result<(), FetchError> {
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
let has_commas = query.contains(',');
let normalized = if has_commas {
query.replace('/', ",")
} else {
query.replace('/', " ")
};
let terms: Vec<&str> = normalized
.split(',')
.map(str::trim)
.filter(|t| !t.is_empty())
.collect();
let api_query = terms.first().copied().unwrap_or(normalized.as_str());
let filter_terms: Vec<String> = terms.iter().map(|t| t.to_lowercase()).collect();
let has_client_filter = filter_terms.len() > 1 || library.is_some() || pipeline.is_some();
let api_limit = if has_client_filter {
limit.saturating_mul(5)
} else {
limit
};
let results = rt.block_on(discover::search_models(
api_query, api_limit, library, pipeline,
))?;
let has_multi_term = filter_terms.len() > 1;
let filtered: Vec<&discover::SearchResult> = results
.iter()
.filter(|r| {
if !has_multi_term {
return true;
}
let id_normalized = r.model_id.replace('/', " ").to_lowercase();
filter_terms
.iter()
.all(|term| id_normalized.contains(term.as_str())) })
.take(limit)
.collect();
if exact {
let exact_match = filtered
.iter()
.find(|r| r.model_id.eq_ignore_ascii_case(query));
if let Some(matched) = exact_match {
println!("Exact match:\n");
print_search_result(matched);
match rt.block_on(discover::fetch_model_card(
matched.model_id.as_str(), )) {
Ok(card) => print_model_card(&card),
Err(e) => eprintln!("\n (could not fetch model card: {e})"),
}
} else {
println!("No exact match for \"{query}\".");
if !filtered.is_empty() {
println!("\nDid you mean:\n");
for result in &filtered {
print_search_result(result);
}
}
}
} else {
if filtered.is_empty() {
println!("No models found matching \"{query}\".");
} else {
println!("Models matching \"{query}\" (by downloads):\n");
for result in &filtered {
print_search_result(result);
}
}
}
Ok(())
}
fn print_search_result(result: &discover::SearchResult) {
let suffix = match (&result.library_name, &result.pipeline_tag) {
(Some(lib), Some(pipe)) => format!(" [{lib}, {pipe}]"),
(Some(lib), None) => format!(" [{lib}]"),
(None, Some(pipe)) => format!(" [{pipe}]"),
(None, None) => String::new(),
};
println!(
" hf-fm {:<48} ({} downloads){suffix}",
result.model_id,
format_downloads(result.downloads)
);
}
fn print_model_card(card: &discover::ModelCardMetadata) {
println!();
if let Some(ref license) = card.license {
println!(" License: {license}");
}
if card.gated.is_gated() {
println!(
" Gated: {} (requires accepting terms on HF)",
card.gated
);
}
if let Some(ref pipeline) = card.pipeline_tag {
println!(" Pipeline: {pipeline}");
}
if let Some(ref library) = card.library_name {
println!(" Library: {library}");
}
if !card.tags.is_empty() {
println!(" Tags: {}", card.tags.join(", "));
}
if !card.languages.is_empty() {
println!(" Languages: {}", card.languages.join(", "));
}
}
fn run_info(
repo_id: &str,
revision: Option<&str>,
token: Option<&str>,
json: bool,
max_lines: usize,
) -> Result<(), FetchError> {
if !repo_id.contains('/') {
return Err(FetchError::InvalidArgument(format!(
"invalid REPO_ID \"{repo_id}\": expected \"owner/model\" format \
(e.g., \"mistralai/Ministral-3-3B-Instruct-2512\")"
)));
}
let token_owned = token
.map(String::from)
.or_else(|| std::env::var("HF_TOKEN").ok());
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
let card = rt.block_on(discover::fetch_model_card(repo_id))?;
let readme = rt.block_on(discover::fetch_readme(
repo_id,
revision,
token_owned.as_deref(),
))?;
if json {
return print_info_json(repo_id, &card, readme.as_deref());
}
println!(" Repo: {repo_id}");
print_model_card(&card);
if let Some(ref text) = readme {
let body = strip_yaml_front_matter(text);
println!();
println!(" README:");
println!(" {}", "\u{2500}".repeat(70));
let lines: Vec<&str> = body.lines().collect();
let display_count = if max_lines == 0 {
lines.len()
} else {
lines.len().min(max_lines)
};
#[allow(clippy::indexing_slicing)]
for line in &lines[..display_count] {
println!(" {line}");
}
if display_count < lines.len() {
println!(
" ... ({} more lines, use --lines 0 for full output)",
lines.len().saturating_sub(display_count)
);
}
} else {
println!();
println!(" (no README.md found)");
}
Ok(())
}
#[derive(serde::Serialize)]
struct InfoResult {
repo_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
license: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pipeline_tag: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
library_name: Option<String>,
tags: Vec<String>,
languages: Vec<String>,
gated: String,
#[serde(skip_serializing_if = "Option::is_none")]
readme: Option<String>,
}
fn print_info_json(
repo_id: &str,
card: &discover::ModelCardMetadata,
readme: Option<&str>,
) -> Result<(), FetchError> {
let result = InfoResult {
repo_id: repo_id.to_owned(),
license: card.license.clone(),
pipeline_tag: card.pipeline_tag.clone(),
library_name: card.library_name.clone(),
tags: card.tags.clone(),
languages: card.languages.clone(),
gated: card.gated.to_string(),
readme: readme.map(str::to_owned),
};
let output = serde_json::to_string_pretty(&result)
.map_err(|e| FetchError::Http(format!("failed to serialize JSON: {e}")))?;
println!("{output}");
Ok(())
}
#[must_use]
fn strip_yaml_front_matter(text: &str) -> &str {
let trimmed = text.trim_start();
if !trimmed.starts_with("---") {
return text;
}
#[allow(clippy::indexing_slicing)]
let after_open = &trimmed[3..];
if let Some(close_pos) = after_open.find("\n---") {
let body_start = close_pos + 4; #[allow(clippy::indexing_slicing)]
let body = after_open[body_start..].trim_start_matches('\n');
return body.trim_start_matches('\r');
}
text
}
fn run_status_all() -> Result<(), FetchError> {
let cache_dir = cache::hf_cache_dir()?;
let summaries = cache::cache_summary()?;
if summaries.is_empty() {
println!("No models found in local cache.");
return Ok(());
}
println!("Cache: {}\n", cache_dir.display());
println!(
" {:<48} {:>5} {:>10} Status",
"Repository", "Files", "Size"
);
println!(" {:-<48} {:-<5} {:-<10} {:-<8}", "", "", "", "");
for s in &summaries {
let status_label = if s.has_partial { "PARTIAL" } else { "ok" };
println!(
" {:<48} {:>5} {:>10} {}",
s.repo_id,
s.file_count,
format_size(s.total_size),
status_label
);
}
println!("\n{} model(s) cached", summaries.len());
Ok(())
}
fn run_du() -> Result<(), FetchError> {
let mut summaries = cache::cache_summary()?;
if summaries.is_empty() {
println!("No models found in local cache.");
return Ok(());
}
summaries.sort_by(|a, b| b.total_size.cmp(&a.total_size));
let mut total_size: u64 = 0;
let mut total_files: usize = 0;
for s in &summaries {
total_size = total_size.saturating_add(s.total_size);
total_files = total_files.saturating_add(s.file_count);
let partial_marker = if s.has_partial { " PARTIAL" } else { "" };
println!(
" {:>10} {:<48} ({} files){}",
format_size(s.total_size),
s.repo_id,
s.file_count,
partial_marker,
);
}
println!(" {}", "\u{2500}".repeat(50),);
println!(
" {:>10} total ({} repos, {} files)",
format_size(total_size),
summaries.len(),
total_files,
);
Ok(())
}
fn run_du_repo(repo_id: &str) -> Result<(), FetchError> {
let files = cache::cache_repo_usage(repo_id)?;
if files.is_empty() {
println!("No cached files found for {repo_id}.");
return Ok(());
}
let mut total_size: u64 = 0;
for f in &files {
total_size = total_size.saturating_add(f.size);
println!(" {:>10} {}", format_size(f.size), f.filename);
}
println!(" {}", "\u{2500}".repeat(50),);
println!(
" {:>10} total ({} files)",
format_size(total_size),
files.len(),
);
Ok(())
}
fn collect_repo_tensors(
repo_id: &str,
revision: Option<&str>,
token: Option<&str>,
cached: bool,
) -> Result<HashMap<String, inspect::TensorInfo>, FetchError> {
let results: Vec<(String, inspect::SafetensorsHeaderInfo)> = if cached {
inspect::inspect_repo_safetensors_cached(repo_id, revision)?
} else {
let token = token
.map(String::from)
.or_else(|| std::env::var("HF_TOKEN").ok());
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
let remote_results = rt.block_on(inspect::inspect_repo_safetensors(
repo_id,
token.as_deref(),
revision,
))?;
remote_results
.into_iter()
.map(|(name, info, _source)| (name, info))
.collect()
};
let mut tensors = HashMap::new();
for (_filename, info) in results {
for t in info.tensors {
tensors.insert(t.name.clone(), t);
}
}
Ok(tensors)
}
#[allow(clippy::too_many_arguments, clippy::fn_params_excessive_bools)]
fn run_diff(
repo_a: &str,
repo_b: &str,
revision_a: Option<&str>,
revision_b: Option<&str>,
token: Option<&str>,
cached: bool,
filter: Option<&str>,
summary: bool,
json: bool,
) -> Result<(), FetchError> {
let tensors_a = collect_repo_tensors(repo_a, revision_a, token, cached)?;
let tensors_b = collect_repo_tensors(repo_b, revision_b, token, cached)?;
if tensors_a.is_empty() {
println!("No .safetensors files found in {repo_a}.");
println!("Hint: use `hf-fm list-files {repo_a}` to see available file types");
return Ok(());
}
if tensors_b.is_empty() {
println!("No .safetensors files found in {repo_b}.");
println!("Hint: use `hf-fm list-files {repo_b}` to see available file types");
return Ok(());
}
let mut all_names: Vec<&str> = tensors_a
.keys()
.chain(tensors_b.keys())
.map(String::as_str)
.collect::<HashSet<&str>>()
.into_iter()
.collect();
all_names.sort_unstable();
if let Some(pattern) = filter {
all_names.retain(|name| name.contains(pattern));
}
let mut only_a: Vec<&str> = Vec::new();
let mut only_b: Vec<&str> = Vec::new();
let mut differ: Vec<&str> = Vec::new();
let mut matching: Vec<&str> = Vec::new();
for name in &all_names {
match (tensors_a.get(*name), tensors_b.get(*name)) {
(Some(_), None) => only_a.push(name),
(None, Some(_)) => only_b.push(name),
(Some(a), Some(b)) => {
if a.dtype == b.dtype && a.shape == b.shape {
matching.push(name);
} else {
differ.push(name);
}
}
(None, None) => {} }
}
let total_a = if filter.is_some() {
only_a.len() + differ.len() + matching.len()
} else {
tensors_a.len()
};
let total_b = if filter.is_some() {
only_b.len() + differ.len() + matching.len()
} else {
tensors_b.len()
};
if json {
return print_diff_json(
repo_a, repo_b, &tensors_a, &tensors_b, &only_a, &only_b, &differ, &matching, filter,
);
}
println!(" A: {repo_a}");
println!(" B: {repo_b}");
if !summary {
println!();
if !only_a.is_empty() {
let label = if only_a.len() == 1 {
"tensor"
} else {
"tensors"
};
println!(" Only in A ({} {label}):", only_a.len());
for name in &only_a {
if let Some(t) = tensors_a.get(*name) {
let shape_str = format!("{:?}", t.shape);
println!(" {name:<50} {:<8} {shape_str}", t.dtype);
}
}
println!();
}
if !only_b.is_empty() {
let label = if only_b.len() == 1 {
"tensor"
} else {
"tensors"
};
println!(" Only in B ({} {label}):", only_b.len());
for name in &only_b {
if let Some(t) = tensors_b.get(*name) {
let shape_str = format!("{:?}", t.shape);
println!(" {name:<50} {:<8} {shape_str}", t.dtype);
}
}
println!();
}
if !differ.is_empty() {
let label = if differ.len() == 1 {
"tensor"
} else {
"tensors"
};
println!(" Dtype/shape differences ({} {label}):", differ.len());
for name in &differ {
if let Some((a, b)) = tensors_a.get(*name).zip(tensors_b.get(*name)) {
let shape_a = format!("{:?}", a.shape);
let shape_b = format!("{:?}", b.shape);
println!(" {name}");
println!(" A: {:<8} {shape_a}", a.dtype);
println!(" B: {:<8} {shape_b}", b.dtype);
}
}
println!();
}
let match_label = if matching.len() == 1 {
"tensor"
} else {
"tensors"
};
println!(" Matching: {} {match_label} identical", matching.len());
}
println!(" {}", "\u{2500}".repeat(70));
print!(
" A: {} tensors | B: {} tensors | only-A: {} | only-B: {} | differ: {} | match: {}",
total_a,
total_b,
only_a.len(),
only_b.len(),
differ.len(),
matching.len(),
);
if let Some(pattern) = filter {
println!(" (filter: {pattern:?})");
} else {
println!();
}
Ok(())
}
#[derive(serde::Serialize)]
struct DiffTensorEntry {
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
a: Option<DiffTensorSide>,
#[serde(skip_serializing_if = "Option::is_none")]
b: Option<DiffTensorSide>,
}
#[derive(serde::Serialize)]
struct DiffTensorSide {
dtype: String,
shape: Vec<usize>,
}
#[derive(serde::Serialize)]
struct DiffResult {
repo_a: String,
repo_b: String,
only_a: Vec<DiffTensorEntry>,
only_b: Vec<DiffTensorEntry>,
differ: Vec<DiffTensorEntry>,
matching_count: usize,
#[serde(skip_serializing_if = "Option::is_none")]
filter: Option<String>,
}
#[allow(clippy::too_many_arguments)]
fn print_diff_json(
repo_a: &str,
repo_b: &str,
tensors_a: &HashMap<String, inspect::TensorInfo>,
tensors_b: &HashMap<String, inspect::TensorInfo>,
only_a: &[&str],
only_b: &[&str],
differ: &[&str],
matching: &[&str],
filter: Option<&str>,
) -> Result<(), FetchError> {
let make_entry = |name: &str,
a: Option<&inspect::TensorInfo>,
b: Option<&inspect::TensorInfo>|
-> DiffTensorEntry {
DiffTensorEntry {
name: name.to_owned(),
a: a.map(|t| DiffTensorSide {
dtype: t.dtype.clone(),
shape: t.shape.clone(),
}),
b: b.map(|t| DiffTensorSide {
dtype: t.dtype.clone(),
shape: t.shape.clone(),
}),
}
};
let result = DiffResult {
repo_a: repo_a.to_owned(),
repo_b: repo_b.to_owned(),
only_a: only_a
.iter()
.map(|n| make_entry(n, tensors_a.get(*n), None))
.collect(),
only_b: only_b
.iter()
.map(|n| make_entry(n, None, tensors_b.get(*n)))
.collect(),
differ: differ
.iter()
.map(|n| make_entry(n, tensors_a.get(*n), tensors_b.get(*n)))
.collect(),
matching_count: matching.len(),
filter: filter.map(str::to_owned),
};
let output = serde_json::to_string_pretty(&result)
.map_err(|e| FetchError::Http(format!("failed to serialize JSON: {e}")))?;
println!("{output}");
Ok(())
}
#[allow(clippy::fn_params_excessive_bools, clippy::too_many_arguments)]
fn run_inspect(
repo_id: &str,
filename: Option<&str>,
revision: Option<&str>,
token: Option<&str>,
cached: bool,
no_metadata: bool,
json: bool,
filter: Option<&str>,
) -> Result<(), FetchError> {
match filename {
Some(f) => run_inspect_single(
repo_id,
f,
revision,
token,
cached,
no_metadata,
json,
filter,
),
None => run_inspect_repo(repo_id, revision, token, cached, json, filter),
}
}
#[allow(clippy::fn_params_excessive_bools, clippy::too_many_arguments)]
fn run_inspect_single(
repo_id: &str,
filename: &str,
revision: Option<&str>,
token: Option<&str>,
cached: bool,
no_metadata: bool,
json: bool,
filter: Option<&str>,
) -> Result<(), FetchError> {
let (mut info, source) = if cached {
let info = inspect::inspect_safetensors_cached(repo_id, filename, revision)?;
(info, inspect::InspectSource::Cached)
} else {
let token = token
.map(String::from)
.or_else(|| std::env::var("HF_TOKEN").ok());
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
rt.block_on(inspect::inspect_safetensors(
repo_id,
filename,
token.as_deref(),
revision,
))?
};
let total_tensor_count = info.tensors.len();
let total_params = info.total_params();
if let Some(pattern) = filter {
info.tensors.retain(|t| t.name.as_str().contains(pattern));
}
if json {
let output = serde_json::to_string_pretty(&info)
.map_err(|e| FetchError::Http(format!("failed to serialize JSON: {e}")))?;
println!("{output}");
return Ok(());
}
let source_label = match source {
inspect::InspectSource::Cached => "cached",
inspect::InspectSource::Remote => "remote (2 HTTP requests)",
};
println!(" Repo: {repo_id}");
println!(" File: {filename}");
println!(" Source: {source_label}");
let header_display = format_size(info.header_size);
if let Some(fs) = info.file_size {
println!(
" Header: {header_display} (JSON), {} total",
format_size(fs)
);
} else {
println!(" Header: {header_display} (JSON)");
}
if !no_metadata {
if let Some(ref meta) = info.metadata {
let entries: Vec<String> = meta.iter().map(|(k, v)| format!("{k}={v}")).collect();
println!(" Metadata: {}", entries.join(", "));
}
}
println!();
println!(
" {:<50} {:<8} {:<16} {:>10} {:>10}",
"Tensor", "Dtype", "Shape", "Size", "Params"
);
for t in &info.tensors {
let shape_str = format!("{:?}", t.shape);
let size_str = format_size(t.byte_len());
let params_str = inspect::format_params(t.num_elements());
println!(
" {:<50} {:<8} {:<16} {:>10} {:>10}",
t.name, t.dtype, shape_str, size_str, params_str
);
}
println!(" {}", "\u{2500}".repeat(96));
let filtered_count = info.tensors.len();
let filtered_params = info.total_params();
let tensor_label = if filtered_count == 1 {
"tensor"
} else {
"tensors"
};
if filter.is_some() {
println!(
" {filtered_count}/{total_tensor_count} {tensor_label}, {}/{} params (filter: {:?})",
inspect::format_params(filtered_params),
inspect::format_params(total_params),
filter.unwrap_or_default(),
);
} else {
println!(
" {filtered_count} {tensor_label}, {} params",
inspect::format_params(filtered_params)
);
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn run_inspect_repo(
repo_id: &str,
revision: Option<&str>,
token: Option<&str>,
cached: bool,
json: bool,
filter: Option<&str>,
) -> Result<(), FetchError> {
if cached {
if let Some(index) = inspect::fetch_shard_index_cached(repo_id, revision)? {
print_shard_index_summary(repo_id, &index, filter);
print_adapter_config_if_present(repo_id, revision, None, true, json);
return Ok(());
}
let results = inspect::inspect_repo_safetensors_cached(repo_id, revision)?;
if results.is_empty() {
println!("No cached .safetensors files found for {repo_id}.");
println!("Hint: use `hf-fm list-files {repo_id}` to see available file types");
return Ok(());
}
if json {
print_multi_file_json(&results, filter)?;
print_adapter_config_if_present(repo_id, revision, None, true, true);
return Ok(());
}
print_multi_file_summary(repo_id, "cached", &results, filter);
print_adapter_config_if_present(repo_id, revision, None, true, false);
return Ok(());
}
let token = token
.map(String::from)
.or_else(|| std::env::var("HF_TOKEN").ok());
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
let shard_index = rt.block_on(inspect::fetch_shard_index(
repo_id,
token.as_deref(),
revision,
))?;
if let Some(index) = shard_index {
print_shard_index_summary(repo_id, &index, filter);
print_adapter_config_if_present(repo_id, revision, token.as_deref(), false, json);
return Ok(());
}
let results = rt.block_on(inspect::inspect_repo_safetensors(
repo_id,
token.as_deref(),
revision,
))?;
if results.is_empty() {
println!("No .safetensors files found in {repo_id}.");
println!("Hint: use `hf-fm list-files {repo_id}` to see available file types");
return Ok(());
}
if json {
let mapped: Vec<(String, inspect::SafetensorsHeaderInfo)> = results
.into_iter()
.map(|(name, info, _source)| (name, info))
.collect();
print_multi_file_json(&mapped, filter)?;
print_adapter_config_if_present(repo_id, revision, token.as_deref(), false, true);
return Ok(());
}
let mapped: Vec<(String, inspect::SafetensorsHeaderInfo)> = results
.into_iter()
.map(|(name, info, _source)| (name, info))
.collect();
print_multi_file_summary(repo_id, "mixed", &mapped, filter);
print_adapter_config_if_present(repo_id, revision, token.as_deref(), false, false);
Ok(())
}
fn print_adapter_config_if_present(
repo_id: &str,
revision: Option<&str>,
token: Option<&str>,
cached: bool,
json: bool,
) {
let result = if cached {
inspect::fetch_adapter_config_cached(repo_id, revision)
} else {
let Ok(rt) = tokio::runtime::Runtime::new() else {
return;
};
rt.block_on(inspect::fetch_adapter_config(repo_id, token, revision))
};
let Ok(Some(config)) = result else { return };
if json {
if let Ok(output) = serde_json::to_string_pretty(&config) {
println!("{output}");
}
return;
}
println!();
println!(" Adapter config:");
if let Some(ref peft_type) = config.peft_type {
println!(" PEFT type: {peft_type}");
}
if let Some(ref base) = config.base_model_name_or_path {
println!(" Base model: {base}");
}
if let Some(r) = config.r {
println!(" Rank (r): {r}");
}
if let Some(alpha) = config.lora_alpha {
println!(" LoRA alpha: {alpha}");
}
if let Some(ref task) = config.task_type {
println!(" Task type: {task}");
}
if !config.target_modules.is_empty() {
println!(" Target modules: {}", config.target_modules.join(", "));
}
}
fn print_shard_index_summary(repo_id: &str, index: &inspect::ShardedIndex, filter: Option<&str>) {
println!(" Repo: {repo_id}");
println!(" Source: shard index (model.safetensors.index.json)");
println!();
let total_tensors = index.weight_map.len();
let mut by_shard: HashMap<String, usize> = HashMap::new();
let mut filtered_total: usize = 0;
for (tensor_name, shard_name) in &index.weight_map {
if let Some(pattern) = filter {
if !tensor_name.contains(pattern) {
continue;
}
}
*by_shard.entry(shard_name.clone()).or_default() += 1;
filtered_total += 1;
}
println!(" {:<48} {:>8}", "File", "Tensors");
for shard in &index.shards {
let count = by_shard.get(shard).copied().unwrap_or(0);
if filter.is_some() && count == 0 {
continue;
}
println!(" {shard:<48} {count:>8}");
}
println!(" {}", "\u{2500}".repeat(58));
let displayed_shards = if filter.is_some() {
by_shard.len()
} else {
index.shards.len()
};
let shard_label = if displayed_shards == 1 {
"shard"
} else {
"shards"
};
let tensor_label = if filtered_total == 1 {
"tensor"
} else {
"tensors"
};
if filter.is_some() {
println!(
" {displayed_shards} {shard_label}, {filtered_total}/{total_tensors} {tensor_label} (filter: {:?})",
filter.unwrap_or_default(),
);
} else {
println!(" {displayed_shards} {shard_label}, {filtered_total} {tensor_label}",);
}
println!(" Hint: use `hf-fm inspect {repo_id} <filename>` for per-tensor detail");
}
fn print_multi_file_json(
results: &[(String, inspect::SafetensorsHeaderInfo)],
filter: Option<&str>,
) -> Result<(), FetchError> {
if let Some(pattern) = filter {
let filtered: Vec<(String, inspect::SafetensorsHeaderInfo)> = results
.iter()
.map(|(name, info)| {
let mut filtered_info = info.clone();
filtered_info
.tensors
.retain(|t| t.name.as_str().contains(pattern));
(name.clone(), filtered_info)
})
.filter(|(_, info)| !info.tensors.is_empty())
.collect();
let output = serde_json::to_string_pretty(&filtered)
.map_err(|e| FetchError::Http(format!("failed to serialize JSON: {e}")))?;
println!("{output}");
} else {
let output = serde_json::to_string_pretty(results)
.map_err(|e| FetchError::Http(format!("failed to serialize JSON: {e}")))?;
println!("{output}");
}
Ok(())
}
fn print_multi_file_summary(
repo_id: &str,
source: &str,
results: &[(String, inspect::SafetensorsHeaderInfo)],
filter: Option<&str>,
) {
println!(" Repo: {repo_id}");
println!(" Source: {source}");
println!();
println!(" {:<48} {:>8} {:>12}", "File", "Tensors", "Params");
let mut total_tensors_unfiltered: usize = 0;
let mut total_params_unfiltered: u64 = 0;
let mut total_tensors_filtered: usize = 0;
let mut total_params_filtered: u64 = 0;
let mut files_with_matches: usize = 0;
for (name, info) in results {
total_tensors_unfiltered = total_tensors_unfiltered.saturating_add(info.tensors.len());
total_params_unfiltered = total_params_unfiltered.saturating_add(info.total_params());
let (tensor_count, params) = if let Some(pattern) = filter {
let matching: Vec<&inspect::TensorInfo> = info
.tensors
.iter()
.filter(|t| t.name.as_str().contains(pattern))
.collect();
let p: u64 = matching.iter().map(|t| t.num_elements()).sum();
(matching.len(), p)
} else {
(info.tensors.len(), info.total_params())
};
if filter.is_some() && tensor_count == 0 {
continue;
}
files_with_matches += 1;
total_tensors_filtered = total_tensors_filtered.saturating_add(tensor_count);
total_params_filtered = total_params_filtered.saturating_add(params);
println!(
" {name:<48} {tensor_count:>8} {:>12}",
inspect::format_params(params)
);
}
println!(" {}", "\u{2500}".repeat(70));
let file_label = if files_with_matches == 1 {
"file"
} else {
"files"
};
let tensor_label = if total_tensors_filtered == 1 {
"tensor"
} else {
"tensors"
};
if filter.is_some() {
println!(
" {} {file_label}, {total_tensors_filtered}/{total_tensors_unfiltered} {tensor_label}, {}/{} params (filter: {:?})",
files_with_matches,
inspect::format_params(total_params_filtered),
inspect::format_params(total_params_unfiltered),
filter.unwrap_or_default(),
);
} else {
println!(
" {} {file_label}, {total_tensors_filtered} {tensor_label}, {} params",
files_with_matches,
inspect::format_params(total_params_filtered)
);
}
}
fn run_status(
repo_id: &str,
revision: Option<&str>,
token: Option<&str>,
) -> Result<(), FetchError> {
let token = token
.map(String::from)
.or_else(|| std::env::var("HF_TOKEN").ok());
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
let status = rt.block_on(cache::repo_status(repo_id, token.as_deref(), revision))?;
let rev_display = revision.unwrap_or("main");
match &status.commit_hash {
Some(hash) => println!("{repo_id} ({rev_display} @ {hash})"),
None => println!("{repo_id} ({rev_display}, not yet cached)"),
}
println!("Cache: {}\n", status.cache_path.display());
if status.files.is_empty() {
println!(" (no files found in remote repository)");
return Ok(());
}
for (filename, file_status) in &status.files {
match file_status {
cache::FileStatus::Complete { local_size } => {
println!(
" {:<48} {:>10} complete",
filename,
format_size(*local_size)
);
}
cache::FileStatus::Partial {
local_size,
expected_size,
} => {
println!(
" {:<48} {:>10} / {:<10} PARTIAL",
filename,
format_size(*local_size),
format_size(*expected_size)
);
}
cache::FileStatus::Missing { expected_size } => {
if *expected_size > 0 {
println!(
" {:<48} {:>10} MISSING",
filename,
format_size(*expected_size)
);
} else {
println!(" {filename:<48} — MISSING");
}
}
_ => {
println!(" {filename:<48} UNKNOWN");
}
}
}
let total = status.files.len();
let complete = status.complete_count();
let partial = status.partial_count();
let missing = status.missing_count();
println!();
println!("{complete}/{total} complete, {partial} partial, {missing} missing");
Ok(())
}
#[allow(clippy::too_many_arguments, clippy::fn_params_excessive_bools)]
fn run_list_files(
repo_id: &str,
revision: Option<&str>,
token: Option<&str>,
filter_patterns: &[String],
exclude_patterns: &[String],
preset: Option<&Preset>,
no_checksum: bool,
show_cached: bool,
) -> Result<(), FetchError> {
if !repo_id.contains('/') {
return Err(FetchError::InvalidArgument(format!(
"invalid REPO_ID \"{repo_id}\": expected \"org/model\" format \
(e.g., \"google/gemma-2-2b-it\")"
)));
}
let mut include_patterns: Vec<String> = match preset {
Some(&Preset::Safetensors) => vec![
"*.safetensors".to_owned(),
"*.json".to_owned(),
"*.txt".to_owned(),
],
Some(&Preset::Gguf) => vec!["*.gguf".to_owned(), "*.json".to_owned(), "*.txt".to_owned()],
Some(&Preset::Pth) => vec![
"pytorch_model*.bin".to_owned(),
"*.json".to_owned(),
"*.txt".to_owned(),
],
Some(&Preset::ConfigOnly) => {
vec!["*.json".to_owned(), "*.txt".to_owned(), "*.md".to_owned()]
}
None => Vec::new(),
};
for p in filter_patterns {
include_patterns.push(p.clone());
}
let include = compile_glob_patterns(&include_patterns)?;
let exclude = compile_glob_patterns(exclude_patterns)?;
let resolved_token = token
.map(ToOwned::to_owned)
.or_else(|| std::env::var("HF_TOKEN").ok());
let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
path: PathBuf::from("<runtime>"),
source: e,
})?;
let files = rt.block_on(repo::list_repo_files_with_metadata(
repo_id,
resolved_token.as_deref(),
revision,
))?;
let filtered: Vec<_> = files
.into_iter()
.filter(|f| {
file_matches(f.filename.as_str(), include.as_ref(), exclude.as_ref())
})
.collect();
let cache_marks: Vec<String> = if show_cached {
let cache_dir = cache::hf_cache_dir()?;
let repo_folder = format!("models--{}", repo_id.replace('/', "--"));
let repo_dir = cache_dir.join(&repo_folder);
let revision_str = revision.unwrap_or("main");
let commit_hash = cache::read_ref(&repo_dir, revision_str);
let snapshot_dir = commit_hash.map(|h| repo_dir.join("snapshots").join(h));
filtered
.iter()
.map(|f| {
let local_path = snapshot_dir
.as_ref()
.map(|dir| dir.join(f.filename.as_str()));
match local_path {
Some(ref path) if path.exists() => {
let local_size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
let expected = f.size.unwrap_or(0);
if expected > 0 && local_size < expected {
"partial".to_owned()
} else {
"\u{2713}".to_owned()
}
}
_ => "\u{2717}".to_owned(),
}
})
.collect()
} else {
Vec::new()
};
if no_checksum {
if show_cached {
println!(" {:<48} {:>10} Cached", "File", "Size");
println!(" {:<48} {:>10} {:-<6}", "", "", "");
} else {
println!(" {:<48} {:>10}", "File", "Size");
println!(" {:<48} {:>10}", "", "");
}
} else if show_cached {
println!(" {:<48} {:>10} {:<12} Cached", "File", "Size", "SHA256");
println!(" {:<48} {:>10} {:<12} {:-<6}", "", "", "", "");
} else {
println!(" {:<48} {:>10} {:<12}", "File", "Size", "SHA256");
println!(" {:<48} {:>10} {:<12}", "", "", "");
}
let mut total_bytes: u64 = 0;
let mut cached_count: usize = 0;
for (i, f) in filtered.iter().enumerate() {
let size = f.size.unwrap_or(0);
total_bytes = total_bytes.saturating_add(size);
let size_str = format_size(size);
let sha_str = if no_checksum {
String::new()
} else {
f.sha256
.as_deref()
.and_then(|s| s.get(..12))
.unwrap_or("\u{2014}")
.to_owned()
};
if show_cached {
let mark = cache_marks.get(i).map_or("\u{2717}", String::as_str);
if mark == "\u{2713}" {
cached_count += 1;
}
if no_checksum {
println!(" {:<48} {:>10} {mark}", f.filename, size_str);
} else {
println!(
" {:<48} {:>10} {:<12} {mark}",
f.filename, size_str, sha_str
);
}
} else if no_checksum {
println!(" {:<48} {:>10}", f.filename, size_str);
} else {
println!(" {:<48} {:>10} {sha_str}", f.filename, size_str);
}
}
let count = filtered.len();
println!(" {:\u{2500}<72}", "");
if show_cached {
println!(
" {count} files, {} total ({cached_count} cached)",
format_size(total_bytes)
);
} else {
println!(" {count} files, {} total", format_size(total_bytes));
}
Ok(())
}
fn format_size(bytes: u64) -> String {
const KIB: u64 = 1024;
const MIB: u64 = 1024 * 1024;
const GIB: u64 = 1024 * 1024 * 1024;
if bytes >= GIB {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let val = bytes as f64 / GIB as f64;
format!("{val:.2} GiB")
} else if bytes >= MIB {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let val = bytes as f64 / MIB as f64;
format!("{val:.2} MiB")
} else if bytes >= KIB {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let val = bytes as f64 / KIB as f64;
format!("{val:.1} KiB")
} else {
format!("{bytes} B")
}
}
fn resolve_flat_target(output_dir: Option<&Path>) -> Result<PathBuf, FetchError> {
match output_dir {
Some(dir) => Ok(dir.to_path_buf()),
None => std::env::current_dir().map_err(|e| FetchError::Io {
path: PathBuf::from("."),
source: e,
}),
}
}
fn flatten_files(
file_map: &HashMap<String, PathBuf>,
target_dir: &Path,
) -> Result<Vec<PathBuf>, FetchError> {
std::fs::create_dir_all(target_dir).map_err(|e| FetchError::Io {
path: target_dir.to_path_buf(),
source: e,
})?;
let mut flat_paths = Vec::with_capacity(file_map.len());
for (filename, cache_path) in file_map {
let basename = Path::new(filename)
.file_name()
.unwrap_or(std::ffi::OsStr::new(filename.as_str()));
let flat_path = target_dir.join(basename);
std::fs::copy(cache_path, &flat_path).map_err(|e| FetchError::Io {
path: flat_path.clone(),
source: e,
})?;
flat_paths.push(flat_path);
}
Ok(flat_paths)
}
fn flatten_single_file(cache_path: &Path, target_dir: &Path) -> Result<PathBuf, FetchError> {
std::fs::create_dir_all(target_dir).map_err(|e| FetchError::Io {
path: target_dir.to_path_buf(),
source: e,
})?;
let basename = cache_path
.file_name()
.unwrap_or(std::ffi::OsStr::new("file"));
let flat_path = target_dir.join(basename);
std::fs::copy(cache_path, &flat_path).map_err(|e| FetchError::Io {
path: flat_path.clone(),
source: e,
})?;
Ok(flat_path)
}
fn warn_redundant_filters(preset: &Preset, filters: &[String]) {
let (preset_globs, preset_name): (&[&str], &str) = match preset {
Preset::Safetensors => (&["*.safetensors", "*.json", "*.txt"], "safetensors"),
Preset::Gguf => (&["*.gguf", "*.json", "*.txt"], "gguf"),
Preset::Pth => {
let globs: &[&str] = &["pytorch_model*.bin", "*.json", "*.txt"];
(globs, "pth")
}
Preset::ConfigOnly => (&["*.json", "*.txt", "*.md"], "config-only"),
};
for filter in filters {
if preset_globs.contains(&filter.as_str()) {
eprintln!("warning: --filter \"{filter}\" is redundant with --preset {preset_name}");
}
}
}
fn walk_dir_size(dir: &Path) -> u64 {
let Ok(entries) = std::fs::read_dir(dir) else {
return 0;
};
let mut total: u64 = 0;
for entry in entries.flatten() {
let Ok(meta) = entry.metadata() else {
continue;
};
if meta.is_dir() {
total = total.saturating_add(walk_dir_size(&entry.path()));
} else {
total = total.saturating_add(meta.len());
}
}
total
}
fn print_download_summary(path: &Path, elapsed: Duration) {
let total_bytes = if path.is_dir() {
walk_dir_size(path)
} else {
std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
};
let elapsed_secs = elapsed.as_secs_f64();
if total_bytes > 0 && elapsed_secs > 0.0 {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let throughput = total_bytes as f64 / elapsed_secs / (1024.0 * 1024.0);
println!(
" {} in {:.1}s ({:.1} MiB/s)",
format_size(total_bytes),
elapsed_secs,
throughput
);
}
}
fn format_downloads(n: u64) -> String {
let s = n.to_string();
let mut result = String::with_capacity(s.len() + s.len() / 3);
for (i, ch) in s.chars().enumerate() {
if i > 0 && (s.len() - i).is_multiple_of(3) {
result.push(',');
}
result.push(ch);
}
result
}