1use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12#[derive(Debug)]
17enum ResolvedModel {
18 SingleFile(String),
20 Sharded {
22 org: String,
23 repo: String,
24 shard_files: Vec<String>,
25 },
26}
27
28#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34 pub version: u32,
35 pub repo: String,
36 pub files: HashMap<String, FileChecksum>,
37}
38
39#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42 pub size: u64,
43 pub blake3: String,
44}
45
46#[provable_contracts_macros::contract(
48 "apr-cli-operations-v1",
49 equation = "mutating_output_contract"
50)]
51pub fn run(
52 model_ref: &str,
53 force: bool,
54 dry_run: bool,
55 revision: Option<&str>,
56 offline: bool,
57) -> Result<()> {
58 contract_pre_pull_cache_integrity!();
59 println!("{}", "=== APR Pull ===".cyan().bold());
60 println!();
61
62 if dry_run {
69 return run_dry_run(model_ref, revision, offline);
70 }
71
72 let resolved = resolve_hf_model(model_ref)?;
74
75 let result = match resolved {
76 ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
77 ResolvedModel::Sharded {
78 ref org,
79 ref repo,
80 ref shard_files,
81 } => run_sharded(org, repo, shard_files, force),
82 };
83 if let Ok(ref r) = result {
84 contract_post_pull_cache_integrity!(r);
85 }
86 result
87}
88
89fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
95 println!("Model: {}", model_ref.cyan());
96
97 if model_ref.starts_with("hf://") {
99 return run_single_file_streaming(model_ref, force);
100 }
101
102 let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
103 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
104 })?;
105
106 if !force && fetcher.is_cached(model_ref) {
107 return handle_cached_model(&mut fetcher, model_ref);
108 }
109
110 let result = download_single_model(&mut fetcher, model_ref)?;
111 ensure_safetensors_companions(&result)?;
112 print_pull_usage(&result.path, true);
113 Ok(())
114}
115
116fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
123 let (org, repo, filename) = parse_hf_single_uri(model_ref)?;
124 let url = format!("https://huggingface.co/{org}/{repo}/resolve/main/{filename}");
125
126 let cache_dir = get_pacha_cache_dir()?;
127 std::fs::create_dir_all(&cache_dir)?;
128 let (extension, cache_path) = build_single_cache_path(&cache_dir, model_ref, &filename);
129
130 if !force && cache_path.exists() {
131 return report_cached_single(&cache_path);
132 }
133
134 stream_and_post_process(&url, &cache_path, model_ref, &extension)?;
135 print_pull_usage(&cache_path, true);
136 Ok(())
137}
138
139fn stream_and_post_process(
140 url: &str,
141 cache_path: &std::path::Path,
142 model_ref: &str,
143 extension: &str,
144) -> Result<()> {
145 println!();
146 println!("{}", "Downloading (streaming)...".yellow());
147 let checksum = download_file_with_progress(url, cache_path)?;
148 report_downloaded_single(cache_path, &checksum);
149
150 if extension == "safetensors" {
151 fetch_safetensors_companions(cache_path, model_ref)?;
152 convert_safetensors_formats(cache_path)?;
153 }
154 Ok(())
155}
156
157fn parse_hf_single_uri(model_ref: &str) -> Result<(String, String, String)> {
158 let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
159 let parts: Vec<&str> = path.split('/').collect();
160 if parts.len() < 3 {
161 return Err(CliError::ValidationFailed(format!(
162 "HuggingFace URI must include a filename: {model_ref}"
163 )));
164 }
165 Ok((
166 parts[0].to_string(),
167 parts[1].to_string(),
168 parts[2..].join("/"),
169 ))
170}
171
172fn build_single_cache_path(
173 cache_dir: &std::path::Path,
174 model_ref: &str,
175 filename: &str,
176) -> (String, std::path::PathBuf) {
177 let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
178 let extension = std::path::Path::new(filename)
179 .extension()
180 .and_then(|e| e.to_str())
181 .unwrap_or("bin")
182 .to_string();
183 let cache_filename = format!("{}.{extension}", &uri_hash[..16]);
184 let cache_path = cache_dir.join(&cache_filename);
185 (extension, cache_path)
186}
187
188fn report_cached_single(cache_path: &std::path::Path) -> Result<()> {
189 let metadata = std::fs::metadata(cache_path)?;
190 println!("{} Model already cached", "✓".green());
191 println!(" Path: {}", cache_path.display());
192 println!(" Size: {}", format_bytes(metadata.len()));
193 print_pull_usage(cache_path, true);
194 Ok(())
195}
196
197fn report_downloaded_single(cache_path: &std::path::Path, checksum: &FileChecksum) {
198 println!();
199 println!("{} Downloaded successfully", "✓".green());
200 println!(" Path: {}", cache_path.display().to_string().green());
201 println!(" Size: {}", format_bytes(checksum.size).yellow());
202 println!(" Hash: {}", &checksum.blake3[..16]);
203}
204
205fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
207 if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
208 return Ok(std::path::PathBuf::from(cache_home)
209 .join("pacha")
210 .join("models"));
211 }
212 Ok(dirs::home_dir()
213 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
214 .join(".cache")
215 .join("pacha")
216 .join("models"))
217}
218
219fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
221 println!("{} Model already cached", "✓".green());
222 let result = fetcher
223 .pull_quiet(model_ref)
224 .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
225
226 println!(" Path: {}", result.path.display());
227 println!(" Size: {}", result.size_human());
228 println!(" Format: {}", result.format.name());
229
230 ensure_safetensors_companions(&result)?;
231 print_pull_usage(&result.path, false);
232 Ok(())
233}
234
235fn download_single_model(
237 fetcher: &mut ModelFetcher,
238 model_ref: &str,
239) -> Result<pacha::fetcher::FetchResult> {
240 println!();
241 println!("{}", "Downloading...".yellow());
242
243 let result = fetcher
244 .pull(model_ref, |progress| {
245 let pct = progress.percent();
246 print!(
247 "\r [{:50}] {:5.1}% ({}/{})",
248 "=".repeat((pct / 2.0) as usize),
249 pct,
250 format_bytes(progress.downloaded_bytes),
251 format_bytes(progress.total_bytes)
252 );
253 io::stdout().flush().ok();
254 })
255 .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
256
257 println!();
258 println!();
259
260 if result.cache_hit {
261 println!("{} Model retrieved from cache", "✓".green());
262 } else {
263 println!("{} Downloaded successfully", "✓".green());
264 }
265
266 println!(" Path: {}", result.path.display().to_string().green());
267 println!(" Size: {}", result.size_human().yellow());
268 println!(" Format: {}", result.format.name());
269 println!(" Hash: {}", &result.hash[..16]);
270 Ok(result)
271}
272
273fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
275 if matches!(result.format, ModelFormat::SafeTensors(_)) {
276 fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
277 convert_safetensors_formats(&result.path)?;
278 }
279 Ok(())
280}
281
282fn print_pull_usage(path: &Path, show_serve: bool) {
284 println!();
285 println!("{}", "Usage:".cyan().bold());
286 println!(" apr run {}", path.display());
287 if show_serve {
288 println!(" apr serve {}", path.display());
289 }
290}
291
292fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
294 println!(
295 "Model: {}/{} ({} shards)",
296 org.cyan(),
297 repo.cyan(),
298 shard_files.len().to_string().yellow()
299 );
300
301 let cache_dir = resolve_shard_cache_dir(org, repo)?;
302 std::fs::create_dir_all(&cache_dir)?;
303
304 let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
305 let index_path = cache_dir.join("model.safetensors.index.json");
306
307 download_index_if_needed(&base_url, &index_path, force)?;
308
309 let manifest_path = cache_dir.join(".apr-manifest.json");
310 let existing_manifest = load_existing_manifest(&manifest_path, force);
311
312 let file_checksums = download_all_shards(
313 &cache_dir,
314 &base_url,
315 shard_files,
316 force,
317 existing_manifest.as_ref(),
318 )?;
319
320 download_companion_files(&cache_dir, &base_url, force)?;
321 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
322
323 println!();
324 println!("{} Downloaded successfully", "✓".green());
325 println!(" Path: {}", index_path.display().to_string().green());
326 println!(" Shards: {}", shard_files.len().to_string().yellow());
327
328 convert_safetensors_formats(&index_path)?;
329
330 println!();
331 println!("{}", "Usage:".cyan().bold());
332 println!(" apr run {}", index_path.display());
333 println!(" apr serve {}", index_path.display());
334 Ok(())
335}
336
337fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
339 Ok(dirs::home_dir()
340 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
341 .join(".apr")
342 .join("cache")
343 .join("hf")
344 .join(org)
345 .join(repo))
346}
347
348fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
350 if force || !index_path.exists() {
351 println!();
352 println!(" {} model.safetensors.index.json", "Downloading".yellow());
353 download_file(
354 &format!("{base_url}/model.safetensors.index.json"),
355 index_path,
356 )?;
357 } else {
358 println!(" {} model.safetensors.index.json (cached)", "✓".green());
359 }
360 Ok(())
361}
362
363fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
365 if force || !manifest_path.exists() {
366 return None;
367 }
368 std::fs::read_to_string(manifest_path)
369 .ok()
370 .and_then(|s| serde_json::from_str(&s).ok())
371}
372
373fn download_all_shards(
375 cache_dir: &Path,
376 base_url: &str,
377 shard_files: &[String],
378 force: bool,
379 existing_manifest: Option<&ShardManifest>,
380) -> Result<HashMap<String, FileChecksum>> {
381 let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
382 let total = shard_files.len();
383 for (i, shard_file) in shard_files.iter().enumerate() {
384 download_or_verify_shard(
385 cache_dir,
386 base_url,
387 shard_file,
388 i,
389 total,
390 force,
391 existing_manifest,
392 &mut file_checksums,
393 )?;
394 }
395 Ok(file_checksums)
396}
397
398fn download_or_verify_shard(
400 cache_dir: &Path,
401 base_url: &str,
402 shard_file: &str,
403 index: usize,
404 total: usize,
405 force: bool,
406 existing_manifest: Option<&ShardManifest>,
407 checksums: &mut HashMap<String, FileChecksum>,
408) -> Result<()> {
409 let shard_path = cache_dir.join(shard_file);
410
411 if !force && shard_path.exists() {
412 if let Some(manifest) = existing_manifest {
413 if let Some(expected) = manifest.files.get(shard_file) {
414 let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
415 if actual_size == expected.size {
416 checksums.insert(
417 shard_file.to_string(),
418 FileChecksum {
419 size: expected.size,
420 blake3: expected.blake3.clone(),
421 },
422 );
423 println!(
424 " {} [{}/{}] {} (cached, verified)",
425 "✓".green(),
426 index + 1,
427 total,
428 shard_file
429 );
430 return Ok(());
431 }
432 println!(
433 " {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
434 "⚠".yellow(),
435 index + 1,
436 total,
437 shard_file,
438 actual_size,
439 expected.size
440 );
441 }
443 } else {
444 println!(
445 " {} [{}/{}] {} (cached)",
446 "✓".green(),
447 index + 1,
448 total,
449 shard_file
450 );
451 return Ok(());
452 }
453 }
454
455 let shard_url = format!("{base_url}/{shard_file}");
456 print!(
457 " {} [{}/{}] {}...",
458 "↓".yellow(),
459 index + 1,
460 total,
461 shard_file
462 );
463 io::stdout().flush().ok();
464
465 let checksum = download_file_with_progress(&shard_url, &shard_path)?;
466 checksums.insert(shard_file.to_string(), checksum);
467 println!(" {}", "done".green());
468 Ok(())
469}
470
471fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
476 let companions = [
478 ("tokenizer.json", false),
479 ("config.json", true),
480 ("tokenizer_config.json", false),
481 ("tokenizer.model", false),
482 ];
483 for (filename, required) in &companions {
484 let companion_path = cache_dir.join(filename);
485 if !force && companion_path.exists() {
486 println!(" {} {} (cached)", "✓".green(), filename);
487 continue;
488 }
489
490 let url = format!("{base_url}/{filename}");
491 match download_file(&url, &companion_path) {
492 Ok(()) => println!(" {} {}", "✓".green(), filename),
493 Err(CliError::HttpNotFound(_)) if *required => {
494 return Err(CliError::ValidationFailed(format!(
495 "{filename} is required for inference but was not found (HTTP 404) at {url}"
496 )));
497 }
498 Err(CliError::HttpNotFound(_)) => {
499 println!(" {} {} (not found in repo)", "⚠".yellow(), filename);
500 }
501 Err(e) if *required => {
502 return Err(CliError::ValidationFailed(format!(
503 "{filename} is required for inference but download failed: {e}"
504 )));
505 }
506 Err(_) => println!(" {} {} (not available, optional)", "⚠".yellow(), filename),
507 }
508 }
509
510 let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
512 let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
513 if !has_tokenizer {
514 return Err(CliError::ValidationFailed(format!(
515 "No tokenizer found for this model. Tried: {}.\n\
516 The model may require a custom tokenizer not hosted in the repository.",
517 tokenizer_files.join(", ")
518 )));
519 }
520
521 Ok(())
522}
523
524fn write_shard_manifest(
526 manifest_path: &Path,
527 org: &str,
528 repo: &str,
529 file_checksums: HashMap<String, FileChecksum>,
530) -> Result<()> {
531 if file_checksums.is_empty() {
532 return Ok(());
533 }
534 let manifest = ShardManifest {
535 version: 1,
536 repo: format!("{org}/{repo}"),
537 files: file_checksums,
538 };
539 let manifest_json = serde_json::to_string_pretty(&manifest)
540 .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
541 std::fs::write(manifest_path, manifest_json)?;
542 println!(" {} .apr-manifest.json (integrity checksums)", "✓".green());
543 Ok(())
544}
545
546fn run_dry_run(model_ref: &str, revision: Option<&str>, offline_flag: bool) -> Result<()> {
562 use super::aliases;
563 use super::offline;
564 use super::revision as rev;
565
566 let resolved = if let Some(url) = aliases::resolve_short_name(model_ref) {
567 url
568 } else if !model_ref.contains("://") && model_ref.contains('/') {
569 format!("hf://{model_ref}")
570 } else {
571 return Err(unknown_short_name_error(model_ref));
572 };
573
574 let rev_spec = revision.unwrap_or(rev::DEFAULT_REVISION);
575 let rev_kind = rev::classify_revision(rev_spec).map_err(|msg| {
576 CliError::ValidationFailed(format!("CRUX-A-03: invalid --revision {rev_spec:?}: {msg}"))
577 })?;
578
579 let env = offline::read_offline_env();
581 let env_borrowed: Vec<(&str, &str)> =
582 env.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
583 let is_offline = offline::is_offline(offline_flag, env_borrowed.iter().copied());
584
585 println!("Model: {}", model_ref.cyan());
586 println!("Resolved: {}", resolved.green());
587 println!("Revision: {} ({:?})", rev_spec.green(), rev_kind);
588 println!(
589 "Offline: {}",
590 if is_offline {
591 "true".green()
592 } else {
593 "false".yellow()
594 }
595 );
596 println!("Mode: {} (no network I/O)", "dry-run".yellow());
597 Ok(())
598}
599
600fn unknown_short_name_error(name: &str) -> CliError {
603 use super::aliases;
604
605 let suggestions = aliases::did_you_mean(name, 2);
606 let hint = if suggestions.is_empty() {
607 "Run `apr registry aliases --json` to list known short names.".to_string()
608 } else {
609 format!(
610 "did you mean {}? (run `apr registry aliases --json` for the full list)",
611 suggestions
612 .iter()
613 .map(|s| format!("`{s}`"))
614 .collect::<Vec<_>>()
615 .join(", ")
616 )
617 };
618 CliError::ValidationFailed(format!(
619 "CRUX-A-01: unknown short name '{name}' and not a fully-qualified URI. {hint}"
620 ))
621}
622
623include!("pull_list.rs");
624include!("pull_remove_resolve_model.rs");
625include!("pull_extract_shard.rs");
626include!("pull_04.rs");
627include!("pull_dataset.rs");