use crate::error::{CliError, Result};
use colored::Colorize;
use pacha::fetcher::{FetchConfig, ModelFetcher};
use pacha::format::ModelFormat;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::io::{self, Read, Write};
use std::path::Path;
#[derive(Debug)]
enum ResolvedModel {
SingleFile(String),
Sharded {
org: String,
repo: String,
shard_files: Vec<String>,
},
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ShardManifest {
pub version: u32,
pub repo: String,
pub files: HashMap<String, FileChecksum>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FileChecksum {
pub size: u64,
pub blake3: String,
}
#[provable_contracts_macros::contract(
"apr-cli-operations-v1",
equation = "mutating_output_contract"
)]
pub fn run(
model_ref: &str,
force: bool,
dry_run: bool,
revision: Option<&str>,
offline: bool,
) -> Result<()> {
contract_pre_pull_cache_integrity!();
println!("{}", "=== APR Pull ===".cyan().bold());
println!();
if dry_run {
return run_dry_run(model_ref, revision, offline);
}
let resolved = resolve_hf_model(model_ref)?;
let result = match resolved {
ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
ResolvedModel::Sharded {
ref org,
ref repo,
ref shard_files,
} => run_sharded(org, repo, shard_files, force),
};
if let Ok(ref r) = result {
contract_post_pull_cache_integrity!(r);
}
result
}
fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
println!("Model: {}", model_ref.cyan());
if model_ref.starts_with("hf://") {
return run_single_file_streaming(model_ref, force);
}
let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
})?;
if !force && fetcher.is_cached(model_ref) {
return handle_cached_model(&mut fetcher, model_ref);
}
let result = download_single_model(&mut fetcher, model_ref)?;
ensure_safetensors_companions(&result)?;
print_pull_usage(&result.path, true);
Ok(())
}
fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
let (org, repo, filename) = parse_hf_single_uri(model_ref)?;
let url = format!("https://huggingface.co/{org}/{repo}/resolve/main/{filename}");
let cache_dir = get_pacha_cache_dir()?;
std::fs::create_dir_all(&cache_dir)?;
let (extension, cache_path) = build_single_cache_path(&cache_dir, model_ref, &filename);
if !force && cache_path.exists() {
return report_cached_single(&cache_path);
}
stream_and_post_process(&url, &cache_path, model_ref, &extension)?;
print_pull_usage(&cache_path, true);
Ok(())
}
fn stream_and_post_process(
url: &str,
cache_path: &std::path::Path,
model_ref: &str,
extension: &str,
) -> Result<()> {
println!();
println!("{}", "Downloading (streaming)...".yellow());
let checksum = download_file_with_progress(url, cache_path)?;
report_downloaded_single(cache_path, &checksum);
if extension == "safetensors" {
fetch_safetensors_companions(cache_path, model_ref)?;
convert_safetensors_formats(cache_path)?;
}
Ok(())
}
fn parse_hf_single_uri(model_ref: &str) -> Result<(String, String, String)> {
let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
let parts: Vec<&str> = path.split('/').collect();
if parts.len() < 3 {
return Err(CliError::ValidationFailed(format!(
"HuggingFace URI must include a filename: {model_ref}"
)));
}
Ok((
parts[0].to_string(),
parts[1].to_string(),
parts[2..].join("/"),
))
}
fn build_single_cache_path(
cache_dir: &std::path::Path,
model_ref: &str,
filename: &str,
) -> (String, std::path::PathBuf) {
let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
let extension = std::path::Path::new(filename)
.extension()
.and_then(|e| e.to_str())
.unwrap_or("bin")
.to_string();
let cache_filename = format!("{}.{extension}", &uri_hash[..16]);
let cache_path = cache_dir.join(&cache_filename);
(extension, cache_path)
}
fn report_cached_single(cache_path: &std::path::Path) -> Result<()> {
let metadata = std::fs::metadata(cache_path)?;
println!("{} Model already cached", "✓".green());
println!(" Path: {}", cache_path.display());
println!(" Size: {}", format_bytes(metadata.len()));
print_pull_usage(cache_path, true);
Ok(())
}
fn report_downloaded_single(cache_path: &std::path::Path, checksum: &FileChecksum) {
println!();
println!("{} Downloaded successfully", "✓".green());
println!(" Path: {}", cache_path.display().to_string().green());
println!(" Size: {}", format_bytes(checksum.size).yellow());
println!(" Hash: {}", &checksum.blake3[..16]);
}
fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
return Ok(std::path::PathBuf::from(cache_home)
.join("pacha")
.join("models"));
}
Ok(dirs::home_dir()
.ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
.join(".cache")
.join("pacha")
.join("models"))
}
fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
println!("{} Model already cached", "✓".green());
let result = fetcher
.pull_quiet(model_ref)
.map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
println!(" Path: {}", result.path.display());
println!(" Size: {}", result.size_human());
println!(" Format: {}", result.format.name());
ensure_safetensors_companions(&result)?;
print_pull_usage(&result.path, false);
Ok(())
}
fn download_single_model(
fetcher: &mut ModelFetcher,
model_ref: &str,
) -> Result<pacha::fetcher::FetchResult> {
println!();
println!("{}", "Downloading...".yellow());
let result = fetcher
.pull(model_ref, |progress| {
let pct = progress.percent();
print!(
"\r [{:50}] {:5.1}% ({}/{})",
"=".repeat((pct / 2.0) as usize),
pct,
format_bytes(progress.downloaded_bytes),
format_bytes(progress.total_bytes)
);
io::stdout().flush().ok();
})
.map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
println!();
println!();
if result.cache_hit {
println!("{} Model retrieved from cache", "✓".green());
} else {
println!("{} Downloaded successfully", "✓".green());
}
println!(" Path: {}", result.path.display().to_string().green());
println!(" Size: {}", result.size_human().yellow());
println!(" Format: {}", result.format.name());
println!(" Hash: {}", &result.hash[..16]);
Ok(result)
}
fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
if matches!(result.format, ModelFormat::SafeTensors(_)) {
fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
convert_safetensors_formats(&result.path)?;
}
Ok(())
}
fn print_pull_usage(path: &Path, show_serve: bool) {
println!();
println!("{}", "Usage:".cyan().bold());
println!(" apr run {}", path.display());
if show_serve {
println!(" apr serve {}", path.display());
}
}
fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
println!(
"Model: {}/{} ({} shards)",
org.cyan(),
repo.cyan(),
shard_files.len().to_string().yellow()
);
let cache_dir = resolve_shard_cache_dir(org, repo)?;
std::fs::create_dir_all(&cache_dir)?;
let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
let index_path = cache_dir.join("model.safetensors.index.json");
download_index_if_needed(&base_url, &index_path, force)?;
let manifest_path = cache_dir.join(".apr-manifest.json");
let existing_manifest = load_existing_manifest(&manifest_path, force);
let file_checksums = download_all_shards(
&cache_dir,
&base_url,
shard_files,
force,
existing_manifest.as_ref(),
)?;
download_companion_files(&cache_dir, &base_url, force)?;
write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
println!();
println!("{} Downloaded successfully", "✓".green());
println!(" Path: {}", index_path.display().to_string().green());
println!(" Shards: {}", shard_files.len().to_string().yellow());
convert_safetensors_formats(&index_path)?;
println!();
println!("{}", "Usage:".cyan().bold());
println!(" apr run {}", index_path.display());
println!(" apr serve {}", index_path.display());
Ok(())
}
fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
Ok(dirs::home_dir()
.ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
.join(".apr")
.join("cache")
.join("hf")
.join(org)
.join(repo))
}
fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
if force || !index_path.exists() {
println!();
println!(" {} model.safetensors.index.json", "Downloading".yellow());
download_file(
&format!("{base_url}/model.safetensors.index.json"),
index_path,
)?;
} else {
println!(" {} model.safetensors.index.json (cached)", "✓".green());
}
Ok(())
}
fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
if force || !manifest_path.exists() {
return None;
}
std::fs::read_to_string(manifest_path)
.ok()
.and_then(|s| serde_json::from_str(&s).ok())
}
fn download_all_shards(
cache_dir: &Path,
base_url: &str,
shard_files: &[String],
force: bool,
existing_manifest: Option<&ShardManifest>,
) -> Result<HashMap<String, FileChecksum>> {
let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
let total = shard_files.len();
for (i, shard_file) in shard_files.iter().enumerate() {
download_or_verify_shard(
cache_dir,
base_url,
shard_file,
i,
total,
force,
existing_manifest,
&mut file_checksums,
)?;
}
Ok(file_checksums)
}
fn download_or_verify_shard(
cache_dir: &Path,
base_url: &str,
shard_file: &str,
index: usize,
total: usize,
force: bool,
existing_manifest: Option<&ShardManifest>,
checksums: &mut HashMap<String, FileChecksum>,
) -> Result<()> {
let shard_path = cache_dir.join(shard_file);
if !force && shard_path.exists() {
if let Some(manifest) = existing_manifest {
if let Some(expected) = manifest.files.get(shard_file) {
let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
if actual_size == expected.size {
checksums.insert(
shard_file.to_string(),
FileChecksum {
size: expected.size,
blake3: expected.blake3.clone(),
},
);
println!(
" {} [{}/{}] {} (cached, verified)",
"✓".green(),
index + 1,
total,
shard_file
);
return Ok(());
}
println!(
" {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
"⚠".yellow(),
index + 1,
total,
shard_file,
actual_size,
expected.size
);
}
} else {
println!(
" {} [{}/{}] {} (cached)",
"✓".green(),
index + 1,
total,
shard_file
);
return Ok(());
}
}
let shard_url = format!("{base_url}/{shard_file}");
print!(
" {} [{}/{}] {}...",
"↓".yellow(),
index + 1,
total,
shard_file
);
io::stdout().flush().ok();
let checksum = download_file_with_progress(&shard_url, &shard_path)?;
checksums.insert(shard_file.to_string(), checksum);
println!(" {}", "done".green());
Ok(())
}
fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
let companions = [
("tokenizer.json", false),
("config.json", true),
("tokenizer_config.json", false),
("tokenizer.model", false),
];
for (filename, required) in &companions {
let companion_path = cache_dir.join(filename);
if !force && companion_path.exists() {
println!(" {} {} (cached)", "✓".green(), filename);
continue;
}
let url = format!("{base_url}/{filename}");
match download_file(&url, &companion_path) {
Ok(()) => println!(" {} {}", "✓".green(), filename),
Err(CliError::HttpNotFound(_)) if *required => {
return Err(CliError::ValidationFailed(format!(
"{filename} is required for inference but was not found (HTTP 404) at {url}"
)));
}
Err(CliError::HttpNotFound(_)) => {
println!(" {} {} (not found in repo)", "⚠".yellow(), filename);
}
Err(e) if *required => {
return Err(CliError::ValidationFailed(format!(
"{filename} is required for inference but download failed: {e}"
)));
}
Err(_) => println!(" {} {} (not available, optional)", "⚠".yellow(), filename),
}
}
let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
if !has_tokenizer {
return Err(CliError::ValidationFailed(format!(
"No tokenizer found for this model. Tried: {}.\n\
The model may require a custom tokenizer not hosted in the repository.",
tokenizer_files.join(", ")
)));
}
Ok(())
}
fn write_shard_manifest(
manifest_path: &Path,
org: &str,
repo: &str,
file_checksums: HashMap<String, FileChecksum>,
) -> Result<()> {
if file_checksums.is_empty() {
return Ok(());
}
let manifest = ShardManifest {
version: 1,
repo: format!("{org}/{repo}"),
files: file_checksums,
};
let manifest_json = serde_json::to_string_pretty(&manifest)
.map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
std::fs::write(manifest_path, manifest_json)?;
println!(" {} .apr-manifest.json (integrity checksums)", "✓".green());
Ok(())
}
fn run_dry_run(model_ref: &str, revision: Option<&str>, offline_flag: bool) -> Result<()> {
use super::aliases;
use super::offline;
use super::revision as rev;
let resolved = if let Some(url) = aliases::resolve_short_name(model_ref) {
url
} else if !model_ref.contains("://") && model_ref.contains('/') {
format!("hf://{model_ref}")
} else {
return Err(unknown_short_name_error(model_ref));
};
let rev_spec = revision.unwrap_or(rev::DEFAULT_REVISION);
let rev_kind = rev::classify_revision(rev_spec).map_err(|msg| {
CliError::ValidationFailed(format!("CRUX-A-03: invalid --revision {rev_spec:?}: {msg}"))
})?;
let env = offline::read_offline_env();
let env_borrowed: Vec<(&str, &str)> =
env.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
let is_offline = offline::is_offline(offline_flag, env_borrowed.iter().copied());
println!("Model: {}", model_ref.cyan());
println!("Resolved: {}", resolved.green());
println!("Revision: {} ({:?})", rev_spec.green(), rev_kind);
println!(
"Offline: {}",
if is_offline {
"true".green()
} else {
"false".yellow()
}
);
println!("Mode: {} (no network I/O)", "dry-run".yellow());
Ok(())
}
fn unknown_short_name_error(name: &str) -> CliError {
use super::aliases;
let suggestions = aliases::did_you_mean(name, 2);
let hint = if suggestions.is_empty() {
"Run `apr registry aliases --json` to list known short names.".to_string()
} else {
format!(
"did you mean {}? (run `apr registry aliases --json` for the full list)",
suggestions
.iter()
.map(|s| format!("`{s}`"))
.collect::<Vec<_>>()
.join(", ")
)
};
CliError::ValidationFailed(format!(
"CRUX-A-01: unknown short name '{name}' and not a fully-qualified URI. {hint}"
))
}
include!("pull_list.rs");
include!("pull_remove_resolve_model.rs");
include!("pull_extract_shard.rs");
include!("pull_04.rs");
include!("pull_dataset.rs");