Skip to main content

apr_cli/commands/
pull.rs

1//! Pull command: download and cache models from HuggingFace (`~/.cache/pacha/models/`).
2
3use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12/// Result of resolving a HuggingFace model reference.
13///
14/// Single-file models (small SafeTensors, GGUF) use the pacha fetcher.
15/// Sharded models (3B+ SafeTensors) are downloaded directly to `~/.apr/cache/hf/`.
16#[derive(Debug)]
17enum ResolvedModel {
18    /// Single file downloadable via pacha (existing behavior)
19    SingleFile(String),
20    /// Sharded SafeTensors model (multiple .safetensors files + index.json)
21    Sharded {
22        org: String,
23        repo: String,
24        shard_files: Vec<String>,
25    },
26}
27
28/// GH-213: Manifest recording checksums for each file in a sharded download.
29///
30/// Written to `.apr-manifest.json` in the cache directory after a successful download.
31/// Used by the pre-inference contract gate to verify shard integrity without re-hashing.
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34    pub version: u32,
35    pub repo: String,
36    pub files: HashMap<String, FileChecksum>,
37}
38
39/// GH-213: Size and BLAKE3 hash of a downloaded file.
40#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42    pub size: u64,
43    pub blake3: String,
44}
45
46/// Run the pull command
47#[provable_contracts_macros::contract(
48    "apr-cli-operations-v1",
49    equation = "mutating_output_contract"
50)]
51pub fn run(
52    model_ref: &str,
53    force: bool,
54    dry_run: bool,
55    revision: Option<&str>,
56    offline: bool,
57) -> Result<()> {
58    contract_pre_pull_cache_integrity!();
59    println!("{}", "=== APR Pull ===".cyan().bold());
60    println!();
61
62    // CRUX-A-01 FALSIFY-CRUX-A-01-001: --dry-run resolves short name to
63    // canonical URL and exits with zero network I/O.
64    // CRUX-A-03 ALGO-001..003: --dry-run echoes the revision spec the user
65    // supplied (or the default "main") and validates its local form.
66    // CRUX-A-20 ALGO-001..005: --dry-run also echoes the resolved offline
67    // mode so callers can assert CLI-flag / env-var equivalence offline.
68    if dry_run {
69        return run_dry_run(model_ref, revision, offline);
70    }
71
72    // GH-213: Resolve HuggingFace URI — detect single vs sharded models
73    let resolved = resolve_hf_model(model_ref)?;
74
75    let result = match resolved {
76        ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
77        ResolvedModel::Sharded {
78            ref org,
79            ref repo,
80            ref shard_files,
81        } => run_sharded(org, repo, shard_files, force),
82    };
83    if let Ok(ref r) = result {
84        contract_post_pull_cache_integrity!(r);
85    }
86    result
87}
88
89/// Pull a single-file model.
90///
91/// GH-352: For HuggingFace URIs, streams directly to disk instead of buffering
92/// the entire file in memory through pacha's resolver. For non-HF URIs (pacha
93/// aliases), falls back to the pacha fetcher.
94fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
95    println!("Model: {}", model_ref.cyan());
96
97    // GH-352: HuggingFace URIs bypass pacha to avoid O(model_size) memory buffering
98    if model_ref.starts_with("hf://") {
99        return run_single_file_streaming(model_ref, force);
100    }
101
102    let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
103        CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
104    })?;
105
106    if !force && fetcher.is_cached(model_ref) {
107        return handle_cached_model(&mut fetcher, model_ref);
108    }
109
110    let result = download_single_model(&mut fetcher, model_ref)?;
111    ensure_safetensors_companions(&result)?;
112    print_pull_usage(&result.path, true);
113    Ok(())
114}
115
116/// GH-352: Stream a single HuggingFace file directly to disk.
117///
118/// Uses O(64KB) memory instead of O(model_size). The pacha fetcher's
119/// `resolver.resolve()` buffers the entire response via `response.bytes()`,
120/// which consumed ~4.5 GB for a 7B GGUF. This function streams with a 64KB
121/// chunked read, computes BLAKE3 incrementally, and saves to the pacha cache.
122fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
123    let (org, repo, filename) = parse_hf_single_uri(model_ref)?;
124    let url = format!("https://huggingface.co/{org}/{repo}/resolve/main/{filename}");
125
126    let cache_dir = get_pacha_cache_dir()?;
127    std::fs::create_dir_all(&cache_dir)?;
128    let (extension, cache_path) = build_single_cache_path(&cache_dir, model_ref, &filename);
129
130    if !force && cache_path.exists() {
131        return report_cached_single(&cache_path);
132    }
133
134    stream_and_post_process(&url, &cache_path, model_ref, &extension)?;
135    print_pull_usage(&cache_path, true);
136    Ok(())
137}
138
139fn stream_and_post_process(
140    url: &str,
141    cache_path: &std::path::Path,
142    model_ref: &str,
143    extension: &str,
144) -> Result<()> {
145    println!();
146    println!("{}", "Downloading (streaming)...".yellow());
147    let checksum = download_file_with_progress(url, cache_path)?;
148    report_downloaded_single(cache_path, &checksum);
149
150    if extension == "safetensors" {
151        fetch_safetensors_companions(cache_path, model_ref)?;
152        convert_safetensors_formats(cache_path)?;
153    }
154    Ok(())
155}
156
157fn parse_hf_single_uri(model_ref: &str) -> Result<(String, String, String)> {
158    let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
159    let parts: Vec<&str> = path.split('/').collect();
160    if parts.len() < 3 {
161        return Err(CliError::ValidationFailed(format!(
162            "HuggingFace URI must include a filename: {model_ref}"
163        )));
164    }
165    Ok((
166        parts[0].to_string(),
167        parts[1].to_string(),
168        parts[2..].join("/"),
169    ))
170}
171
172fn build_single_cache_path(
173    cache_dir: &std::path::Path,
174    model_ref: &str,
175    filename: &str,
176) -> (String, std::path::PathBuf) {
177    let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
178    let extension = std::path::Path::new(filename)
179        .extension()
180        .and_then(|e| e.to_str())
181        .unwrap_or("bin")
182        .to_string();
183    let cache_filename = format!("{}.{extension}", &uri_hash[..16]);
184    let cache_path = cache_dir.join(&cache_filename);
185    (extension, cache_path)
186}
187
188fn report_cached_single(cache_path: &std::path::Path) -> Result<()> {
189    let metadata = std::fs::metadata(cache_path)?;
190    println!("{} Model already cached", "✓".green());
191    println!("  Path: {}", cache_path.display());
192    println!("  Size: {}", format_bytes(metadata.len()));
193    print_pull_usage(cache_path, true);
194    Ok(())
195}
196
197fn report_downloaded_single(cache_path: &std::path::Path, checksum: &FileChecksum) {
198    println!();
199    println!("{} Downloaded successfully", "✓".green());
200    println!("  Path: {}", cache_path.display().to_string().green());
201    println!("  Size: {}", format_bytes(checksum.size).yellow());
202    println!("  Hash: {}", &checksum.blake3[..16]);
203}
204
205/// Get the pacha model cache directory.
206fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
207    if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
208        return Ok(std::path::PathBuf::from(cache_home)
209            .join("pacha")
210            .join("models"));
211    }
212    Ok(dirs::home_dir()
213        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
214        .join(".cache")
215        .join("pacha")
216        .join("models"))
217}
218
219/// Handle a model that is already cached in pacha.
220fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
221    println!("{} Model already cached", "✓".green());
222    let result = fetcher
223        .pull_quiet(model_ref)
224        .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
225
226    println!("  Path: {}", result.path.display());
227    println!("  Size: {}", result.size_human());
228    println!("  Format: {}", result.format.name());
229
230    ensure_safetensors_companions(&result)?;
231    print_pull_usage(&result.path, false);
232    Ok(())
233}
234
235/// Download a single model with progress bar.
236fn download_single_model(
237    fetcher: &mut ModelFetcher,
238    model_ref: &str,
239) -> Result<pacha::fetcher::FetchResult> {
240    println!();
241    println!("{}", "Downloading...".yellow());
242
243    let result = fetcher
244        .pull(model_ref, |progress| {
245            let pct = progress.percent();
246            print!(
247                "\r  [{:50}] {:5.1}% ({}/{})",
248                "=".repeat((pct / 2.0) as usize),
249                pct,
250                format_bytes(progress.downloaded_bytes),
251                format_bytes(progress.total_bytes)
252            );
253            io::stdout().flush().ok();
254        })
255        .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
256
257    println!();
258    println!();
259
260    if result.cache_hit {
261        println!("{} Model retrieved from cache", "✓".green());
262    } else {
263        println!("{} Downloaded successfully", "✓".green());
264    }
265
266    println!("  Path: {}", result.path.display().to_string().green());
267    println!("  Size: {}", result.size_human().yellow());
268    println!("  Format: {}", result.format.name());
269    println!("  Hash: {}", &result.hash[..16]);
270    Ok(result)
271}
272
273/// Ensure companion files exist for SafeTensors models (GH-198, GH-211).
274fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
275    if matches!(result.format, ModelFormat::SafeTensors(_)) {
276        fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
277        convert_safetensors_formats(&result.path)?;
278    }
279    Ok(())
280}
281
282/// Print usage instructions after a successful pull.
283fn print_pull_usage(path: &Path, show_serve: bool) {
284    println!();
285    println!("{}", "Usage:".cyan().bold());
286    println!("  apr run {}", path.display());
287    if show_serve {
288        println!("  apr serve {}", path.display());
289    }
290}
291
292/// GH-213: Pull a sharded SafeTensors model (3B+ models with multiple shard files)
293fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
294    println!(
295        "Model: {}/{} ({} shards)",
296        org.cyan(),
297        repo.cyan(),
298        shard_files.len().to_string().yellow()
299    );
300
301    let cache_dir = resolve_shard_cache_dir(org, repo)?;
302    std::fs::create_dir_all(&cache_dir)?;
303
304    let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
305    let index_path = cache_dir.join("model.safetensors.index.json");
306
307    download_index_if_needed(&base_url, &index_path, force)?;
308
309    let manifest_path = cache_dir.join(".apr-manifest.json");
310    let existing_manifest = load_existing_manifest(&manifest_path, force);
311
312    let file_checksums = download_all_shards(
313        &cache_dir,
314        &base_url,
315        shard_files,
316        force,
317        existing_manifest.as_ref(),
318    )?;
319
320    download_companion_files(&cache_dir, &base_url, force)?;
321    write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
322
323    println!();
324    println!("{} Downloaded successfully", "✓".green());
325    println!("  Path: {}", index_path.display().to_string().green());
326    println!("  Shards: {}", shard_files.len().to_string().yellow());
327
328    convert_safetensors_formats(&index_path)?;
329
330    println!();
331    println!("{}", "Usage:".cyan().bold());
332    println!("  apr run {}", index_path.display());
333    println!("  apr serve {}", index_path.display());
334    Ok(())
335}
336
337/// Resolve the cache directory for a sharded model.
338fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
339    Ok(dirs::home_dir()
340        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
341        .join(".apr")
342        .join("cache")
343        .join("hf")
344        .join(org)
345        .join(repo))
346}
347
348/// Download the SafeTensors index.json if not already cached.
349fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
350    if force || !index_path.exists() {
351        println!();
352        println!("  {} model.safetensors.index.json", "Downloading".yellow());
353        download_file(
354            &format!("{base_url}/model.safetensors.index.json"),
355            index_path,
356        )?;
357    } else {
358        println!("  {} model.safetensors.index.json (cached)", "✓".green());
359    }
360    Ok(())
361}
362
363/// Load existing shard manifest for cache-hit verification (GH-213).
364fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
365    if force || !manifest_path.exists() {
366        return None;
367    }
368    std::fs::read_to_string(manifest_path)
369        .ok()
370        .and_then(|s| serde_json::from_str(&s).ok())
371}
372
373/// Download all shards, collecting checksums for the manifest.
374fn download_all_shards(
375    cache_dir: &Path,
376    base_url: &str,
377    shard_files: &[String],
378    force: bool,
379    existing_manifest: Option<&ShardManifest>,
380) -> Result<HashMap<String, FileChecksum>> {
381    let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
382    let total = shard_files.len();
383    for (i, shard_file) in shard_files.iter().enumerate() {
384        download_or_verify_shard(
385            cache_dir,
386            base_url,
387            shard_file,
388            i,
389            total,
390            force,
391            existing_manifest,
392            &mut file_checksums,
393        )?;
394    }
395    Ok(file_checksums)
396}
397
398/// Download or verify a single shard file, updating the checksum map.
399fn download_or_verify_shard(
400    cache_dir: &Path,
401    base_url: &str,
402    shard_file: &str,
403    index: usize,
404    total: usize,
405    force: bool,
406    existing_manifest: Option<&ShardManifest>,
407    checksums: &mut HashMap<String, FileChecksum>,
408) -> Result<()> {
409    let shard_path = cache_dir.join(shard_file);
410
411    if !force && shard_path.exists() {
412        if let Some(manifest) = existing_manifest {
413            if let Some(expected) = manifest.files.get(shard_file) {
414                let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
415                if actual_size == expected.size {
416                    checksums.insert(
417                        shard_file.to_string(),
418                        FileChecksum {
419                            size: expected.size,
420                            blake3: expected.blake3.clone(),
421                        },
422                    );
423                    println!(
424                        "  {} [{}/{}] {} (cached, verified)",
425                        "✓".green(),
426                        index + 1,
427                        total,
428                        shard_file
429                    );
430                    return Ok(());
431                }
432                println!(
433                    "  {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
434                    "⚠".yellow(),
435                    index + 1,
436                    total,
437                    shard_file,
438                    actual_size,
439                    expected.size
440                );
441                // Fall through to re-download
442            }
443        } else {
444            println!(
445                "  {} [{}/{}] {} (cached)",
446                "✓".green(),
447                index + 1,
448                total,
449                shard_file
450            );
451            return Ok(());
452        }
453    }
454
455    let shard_url = format!("{base_url}/{shard_file}");
456    print!(
457        "  {} [{}/{}] {}...",
458        "↓".yellow(),
459        index + 1,
460        total,
461        shard_file
462    );
463    io::stdout().flush().ok();
464
465    let checksum = download_file_with_progress(&shard_url, &shard_path)?;
466    checksums.insert(shard_file.to_string(), checksum);
467    println!(" {}", "done".green());
468    Ok(())
469}
470
471/// Download companion files (tokenizer, config) for sharded models.
472///
473/// GH-356: tokenizer.json is optional — some models only have tokenizer.model (SentencePiece)
474/// or tokenizer_config.json. We validate that at least ONE tokenizer file was obtained.
475fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
476    // (filename, is_required) — tokenizer files are individually optional but collectively required
477    let companions = [
478        ("tokenizer.json", false),
479        ("config.json", true),
480        ("tokenizer_config.json", false),
481        ("tokenizer.model", false),
482    ];
483    for (filename, required) in &companions {
484        let companion_path = cache_dir.join(filename);
485        if !force && companion_path.exists() {
486            println!("  {} {} (cached)", "✓".green(), filename);
487            continue;
488        }
489
490        let url = format!("{base_url}/{filename}");
491        match download_file(&url, &companion_path) {
492            Ok(()) => println!("  {} {}", "✓".green(), filename),
493            Err(CliError::HttpNotFound(_)) if *required => {
494                return Err(CliError::ValidationFailed(format!(
495                    "{filename} is required for inference but was not found (HTTP 404) at {url}"
496                )));
497            }
498            Err(CliError::HttpNotFound(_)) => {
499                println!("  {} {} (not found in repo)", "⚠".yellow(), filename);
500            }
501            Err(e) if *required => {
502                return Err(CliError::ValidationFailed(format!(
503                    "{filename} is required for inference but download failed: {e}"
504                )));
505            }
506            Err(_) => println!("  {} {} (not available, optional)", "⚠".yellow(), filename),
507        }
508    }
509
510    // GH-356: Validate at least one tokenizer file exists
511    let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
512    let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
513    if !has_tokenizer {
514        return Err(CliError::ValidationFailed(format!(
515            "No tokenizer found for this model. Tried: {}.\n\
516             The model may require a custom tokenizer not hosted in the repository.",
517            tokenizer_files.join(", ")
518        )));
519    }
520
521    Ok(())
522}
523
524/// Write shard manifest with BLAKE3 checksums for integrity verification.
525fn write_shard_manifest(
526    manifest_path: &Path,
527    org: &str,
528    repo: &str,
529    file_checksums: HashMap<String, FileChecksum>,
530) -> Result<()> {
531    if file_checksums.is_empty() {
532        return Ok(());
533    }
534    let manifest = ShardManifest {
535        version: 1,
536        repo: format!("{org}/{repo}"),
537        files: file_checksums,
538    };
539    let manifest_json = serde_json::to_string_pretty(&manifest)
540        .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
541    std::fs::write(manifest_path, manifest_json)?;
542    println!("  {} .apr-manifest.json (integrity checksums)", "✓".green());
543    Ok(())
544}
545
546/// CRUX-A-01 FALSIFY-CRUX-A-01-001: `--dry-run` resolver.
547///
548/// Emits the resolved canonical URL on stdout and returns `Ok(())` with zero
549/// network I/O. Short names are resolved via the embedded alias map
550/// (`configs/aliases.yaml`); scheme-qualified inputs (`hf://…`,
551/// `https://…`) and bare `org/repo` inputs echo as their canonical forms.
552///
553/// CRUX-A-01 FALSIFY-CRUX-A-01-003: unknown short names (no scheme, no `/`)
554/// return an error that includes a Levenshtein ≤ 2 "did you mean …" hint.
555/// CRUX-A-03 ALGO-001..003: `--revision` is classified locally and echoed
556/// in the dry-run output. Malformed revisions (empty, whitespace, URL)
557/// fail fast without touching the network.
558///
559/// CRUX-A-20 ALGO-001..005: the effective offline signal (CLI flag OR
560/// `APR_OFFLINE` OR `HF_HUB_OFFLINE` truthy) is echoed too.
561fn run_dry_run(model_ref: &str, revision: Option<&str>, offline_flag: bool) -> Result<()> {
562    use super::aliases;
563    use super::offline;
564    use super::revision as rev;
565
566    let resolved = if let Some(url) = aliases::resolve_short_name(model_ref) {
567        url
568    } else if !model_ref.contains("://") && model_ref.contains('/') {
569        format!("hf://{model_ref}")
570    } else {
571        return Err(unknown_short_name_error(model_ref));
572    };
573
574    let rev_spec = revision.unwrap_or(rev::DEFAULT_REVISION);
575    let rev_kind = rev::classify_revision(rev_spec).map_err(|msg| {
576        CliError::ValidationFailed(format!("CRUX-A-03: invalid --revision {rev_spec:?}: {msg}"))
577    })?;
578
579    // CRUX-A-20: resolve offline signal from CLI flag + env vars.
580    let env = offline::read_offline_env();
581    let env_borrowed: Vec<(&str, &str)> =
582        env.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
583    let is_offline = offline::is_offline(offline_flag, env_borrowed.iter().copied());
584
585    println!("Model:    {}", model_ref.cyan());
586    println!("Resolved: {}", resolved.green());
587    println!("Revision: {} ({:?})", rev_spec.green(), rev_kind);
588    println!(
589        "Offline:  {}",
590        if is_offline {
591            "true".green()
592        } else {
593            "false".yellow()
594        }
595    );
596    println!("Mode:     {} (no network I/O)", "dry-run".yellow());
597    Ok(())
598}
599
600/// CRUX-A-01 FALSIFY-CRUX-A-01-003: build an error carrying a did-you-mean
601/// hint derived from Levenshtein ≤ 2 matches against the alias map.
602fn unknown_short_name_error(name: &str) -> CliError {
603    use super::aliases;
604
605    let suggestions = aliases::did_you_mean(name, 2);
606    let hint = if suggestions.is_empty() {
607        "Run `apr registry aliases --json` to list known short names.".to_string()
608    } else {
609        format!(
610            "did you mean {}? (run `apr registry aliases --json` for the full list)",
611            suggestions
612                .iter()
613                .map(|s| format!("`{s}`"))
614                .collect::<Vec<_>>()
615                .join(", ")
616        )
617    };
618    CliError::ValidationFailed(format!(
619        "CRUX-A-01: unknown short name '{name}' and not a fully-qualified URI. {hint}"
620    ))
621}
622
623include!("pull_list.rs");
624include!("pull_remove_resolve_model.rs");
625include!("pull_extract_shard.rs");
626include!("pull_04.rs");
627include!("pull_dataset.rs");