Skip to main content

apr_cli/commands/
pull.rs

1//! Pull command: download and cache models from HuggingFace (`~/.cache/pacha/models/`).
2
3use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12/// Result of resolving a HuggingFace model reference.
13///
14/// Single-file models (small SafeTensors, GGUF) use the pacha fetcher.
15/// Sharded models (3B+ SafeTensors) are downloaded directly to `~/.apr/cache/hf/`.
16#[derive(Debug)]
17enum ResolvedModel {
18    /// Single file downloadable via pacha (existing behavior)
19    SingleFile(String),
20    /// Sharded SafeTensors model (multiple .safetensors files + index.json)
21    Sharded {
22        org: String,
23        repo: String,
24        shard_files: Vec<String>,
25    },
26}
27
28/// GH-213: Manifest recording checksums for each file in a sharded download.
29///
30/// Written to `.apr-manifest.json` in the cache directory after a successful download.
31/// Used by the pre-inference contract gate to verify shard integrity without re-hashing.
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34    pub version: u32,
35    pub repo: String,
36    pub files: HashMap<String, FileChecksum>,
37}
38
39/// GH-213: Size and BLAKE3 hash of a downloaded file.
40#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42    pub size: u64,
43    pub blake3: String,
44}
45
46/// Run the pull command
47#[provable_contracts_macros::contract(
48    "apr-cli-operations-v1",
49    equation = "mutating_output_contract"
50)]
51pub fn run(model_ref: &str, force: bool) -> Result<()> {
52    contract_pre_pull_cache_integrity!();
53    println!("{}", "=== APR Pull ===".cyan().bold());
54    println!();
55
56    // GH-213: Resolve HuggingFace URI — detect single vs sharded models
57    let resolved = resolve_hf_model(model_ref)?;
58
59    let result = match resolved {
60        ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
61        ResolvedModel::Sharded {
62            ref org,
63            ref repo,
64            ref shard_files,
65        } => run_sharded(org, repo, shard_files, force),
66    };
67    if let Ok(ref r) = result {
68        contract_post_pull_cache_integrity!(r);
69    }
70    result
71}
72
73/// Pull a single-file model.
74///
75/// GH-352: For HuggingFace URIs, streams directly to disk instead of buffering
76/// the entire file in memory through pacha's resolver. For non-HF URIs (pacha
77/// aliases), falls back to the pacha fetcher.
78fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
79    println!("Model: {}", model_ref.cyan());
80
81    // GH-352: HuggingFace URIs bypass pacha to avoid O(model_size) memory buffering
82    if model_ref.starts_with("hf://") {
83        return run_single_file_streaming(model_ref, force);
84    }
85
86    let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
87        CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
88    })?;
89
90    if !force && fetcher.is_cached(model_ref) {
91        return handle_cached_model(&mut fetcher, model_ref);
92    }
93
94    let result = download_single_model(&mut fetcher, model_ref)?;
95    ensure_safetensors_companions(&result)?;
96    print_pull_usage(&result.path, true);
97    Ok(())
98}
99
100/// GH-352: Stream a single HuggingFace file directly to disk.
101///
102/// Uses O(64KB) memory instead of O(model_size). The pacha fetcher's
103/// `resolver.resolve()` buffers the entire response via `response.bytes()`,
104/// which consumed ~4.5 GB for a 7B GGUF. This function streams with a 64KB
105/// chunked read, computes BLAKE3 incrementally, and saves to the pacha cache.
106fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
107    let (org, repo, filename) = parse_hf_single_uri(model_ref)?;
108    let url = format!("https://huggingface.co/{org}/{repo}/resolve/main/{filename}");
109
110    let cache_dir = get_pacha_cache_dir()?;
111    std::fs::create_dir_all(&cache_dir)?;
112    let (extension, cache_path) = build_single_cache_path(&cache_dir, model_ref, &filename);
113
114    if !force && cache_path.exists() {
115        return report_cached_single(&cache_path);
116    }
117
118    stream_and_post_process(&url, &cache_path, model_ref, &extension)?;
119    print_pull_usage(&cache_path, true);
120    Ok(())
121}
122
123fn stream_and_post_process(
124    url: &str,
125    cache_path: &std::path::Path,
126    model_ref: &str,
127    extension: &str,
128) -> Result<()> {
129    println!();
130    println!("{}", "Downloading (streaming)...".yellow());
131    let checksum = download_file_with_progress(url, cache_path)?;
132    report_downloaded_single(cache_path, &checksum);
133
134    if extension == "safetensors" {
135        fetch_safetensors_companions(cache_path, model_ref)?;
136        convert_safetensors_formats(cache_path)?;
137    }
138    Ok(())
139}
140
141fn parse_hf_single_uri(model_ref: &str) -> Result<(String, String, String)> {
142    let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
143    let parts: Vec<&str> = path.split('/').collect();
144    if parts.len() < 3 {
145        return Err(CliError::ValidationFailed(format!(
146            "HuggingFace URI must include a filename: {model_ref}"
147        )));
148    }
149    Ok((
150        parts[0].to_string(),
151        parts[1].to_string(),
152        parts[2..].join("/"),
153    ))
154}
155
156fn build_single_cache_path(
157    cache_dir: &std::path::Path,
158    model_ref: &str,
159    filename: &str,
160) -> (String, std::path::PathBuf) {
161    let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
162    let extension = std::path::Path::new(filename)
163        .extension()
164        .and_then(|e| e.to_str())
165        .unwrap_or("bin")
166        .to_string();
167    let cache_filename = format!("{}.{extension}", &uri_hash[..16]);
168    let cache_path = cache_dir.join(&cache_filename);
169    (extension, cache_path)
170}
171
172fn report_cached_single(cache_path: &std::path::Path) -> Result<()> {
173    let metadata = std::fs::metadata(cache_path)?;
174    println!("{} Model already cached", "✓".green());
175    println!("  Path: {}", cache_path.display());
176    println!("  Size: {}", format_bytes(metadata.len()));
177    print_pull_usage(cache_path, true);
178    Ok(())
179}
180
181fn report_downloaded_single(cache_path: &std::path::Path, checksum: &FileChecksum) {
182    println!();
183    println!("{} Downloaded successfully", "✓".green());
184    println!("  Path: {}", cache_path.display().to_string().green());
185    println!("  Size: {}", format_bytes(checksum.size).yellow());
186    println!("  Hash: {}", &checksum.blake3[..16]);
187}
188
189/// Get the pacha model cache directory.
190fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
191    if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
192        return Ok(std::path::PathBuf::from(cache_home)
193            .join("pacha")
194            .join("models"));
195    }
196    Ok(dirs::home_dir()
197        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
198        .join(".cache")
199        .join("pacha")
200        .join("models"))
201}
202
203/// Handle a model that is already cached in pacha.
204fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
205    println!("{} Model already cached", "✓".green());
206    let result = fetcher
207        .pull_quiet(model_ref)
208        .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
209
210    println!("  Path: {}", result.path.display());
211    println!("  Size: {}", result.size_human());
212    println!("  Format: {}", result.format.name());
213
214    ensure_safetensors_companions(&result)?;
215    print_pull_usage(&result.path, false);
216    Ok(())
217}
218
219/// Download a single model with progress bar.
220fn download_single_model(
221    fetcher: &mut ModelFetcher,
222    model_ref: &str,
223) -> Result<pacha::fetcher::FetchResult> {
224    println!();
225    println!("{}", "Downloading...".yellow());
226
227    let result = fetcher
228        .pull(model_ref, |progress| {
229            let pct = progress.percent();
230            print!(
231                "\r  [{:50}] {:5.1}% ({}/{})",
232                "=".repeat((pct / 2.0) as usize),
233                pct,
234                format_bytes(progress.downloaded_bytes),
235                format_bytes(progress.total_bytes)
236            );
237            io::stdout().flush().ok();
238        })
239        .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
240
241    println!();
242    println!();
243
244    if result.cache_hit {
245        println!("{} Model retrieved from cache", "✓".green());
246    } else {
247        println!("{} Downloaded successfully", "✓".green());
248    }
249
250    println!("  Path: {}", result.path.display().to_string().green());
251    println!("  Size: {}", result.size_human().yellow());
252    println!("  Format: {}", result.format.name());
253    println!("  Hash: {}", &result.hash[..16]);
254    Ok(result)
255}
256
257/// Ensure companion files exist for SafeTensors models (GH-198, GH-211).
258fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
259    if matches!(result.format, ModelFormat::SafeTensors(_)) {
260        fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
261        convert_safetensors_formats(&result.path)?;
262    }
263    Ok(())
264}
265
266/// Print usage instructions after a successful pull.
267fn print_pull_usage(path: &Path, show_serve: bool) {
268    println!();
269    println!("{}", "Usage:".cyan().bold());
270    println!("  apr run {}", path.display());
271    if show_serve {
272        println!("  apr serve {}", path.display());
273    }
274}
275
276/// GH-213: Pull a sharded SafeTensors model (3B+ models with multiple shard files)
277fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
278    println!(
279        "Model: {}/{} ({} shards)",
280        org.cyan(),
281        repo.cyan(),
282        shard_files.len().to_string().yellow()
283    );
284
285    let cache_dir = resolve_shard_cache_dir(org, repo)?;
286    std::fs::create_dir_all(&cache_dir)?;
287
288    let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
289    let index_path = cache_dir.join("model.safetensors.index.json");
290
291    download_index_if_needed(&base_url, &index_path, force)?;
292
293    let manifest_path = cache_dir.join(".apr-manifest.json");
294    let existing_manifest = load_existing_manifest(&manifest_path, force);
295
296    let file_checksums = download_all_shards(
297        &cache_dir,
298        &base_url,
299        shard_files,
300        force,
301        existing_manifest.as_ref(),
302    )?;
303
304    download_companion_files(&cache_dir, &base_url, force)?;
305    write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
306
307    println!();
308    println!("{} Downloaded successfully", "✓".green());
309    println!("  Path: {}", index_path.display().to_string().green());
310    println!("  Shards: {}", shard_files.len().to_string().yellow());
311
312    convert_safetensors_formats(&index_path)?;
313
314    println!();
315    println!("{}", "Usage:".cyan().bold());
316    println!("  apr run {}", index_path.display());
317    println!("  apr serve {}", index_path.display());
318    Ok(())
319}
320
321/// Resolve the cache directory for a sharded model.
322fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
323    Ok(dirs::home_dir()
324        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
325        .join(".apr")
326        .join("cache")
327        .join("hf")
328        .join(org)
329        .join(repo))
330}
331
332/// Download the SafeTensors index.json if not already cached.
333fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
334    if force || !index_path.exists() {
335        println!();
336        println!("  {} model.safetensors.index.json", "Downloading".yellow());
337        download_file(
338            &format!("{base_url}/model.safetensors.index.json"),
339            index_path,
340        )?;
341    } else {
342        println!("  {} model.safetensors.index.json (cached)", "✓".green());
343    }
344    Ok(())
345}
346
347/// Load existing shard manifest for cache-hit verification (GH-213).
348fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
349    if force || !manifest_path.exists() {
350        return None;
351    }
352    std::fs::read_to_string(manifest_path)
353        .ok()
354        .and_then(|s| serde_json::from_str(&s).ok())
355}
356
357/// Download all shards, collecting checksums for the manifest.
358fn download_all_shards(
359    cache_dir: &Path,
360    base_url: &str,
361    shard_files: &[String],
362    force: bool,
363    existing_manifest: Option<&ShardManifest>,
364) -> Result<HashMap<String, FileChecksum>> {
365    let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
366    let total = shard_files.len();
367    for (i, shard_file) in shard_files.iter().enumerate() {
368        download_or_verify_shard(
369            cache_dir,
370            base_url,
371            shard_file,
372            i,
373            total,
374            force,
375            existing_manifest,
376            &mut file_checksums,
377        )?;
378    }
379    Ok(file_checksums)
380}
381
382/// Download or verify a single shard file, updating the checksum map.
383fn download_or_verify_shard(
384    cache_dir: &Path,
385    base_url: &str,
386    shard_file: &str,
387    index: usize,
388    total: usize,
389    force: bool,
390    existing_manifest: Option<&ShardManifest>,
391    checksums: &mut HashMap<String, FileChecksum>,
392) -> Result<()> {
393    let shard_path = cache_dir.join(shard_file);
394
395    if !force && shard_path.exists() {
396        if let Some(manifest) = existing_manifest {
397            if let Some(expected) = manifest.files.get(shard_file) {
398                let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
399                if actual_size == expected.size {
400                    checksums.insert(
401                        shard_file.to_string(),
402                        FileChecksum {
403                            size: expected.size,
404                            blake3: expected.blake3.clone(),
405                        },
406                    );
407                    println!(
408                        "  {} [{}/{}] {} (cached, verified)",
409                        "✓".green(),
410                        index + 1,
411                        total,
412                        shard_file
413                    );
414                    return Ok(());
415                }
416                println!(
417                    "  {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
418                    "⚠".yellow(),
419                    index + 1,
420                    total,
421                    shard_file,
422                    actual_size,
423                    expected.size
424                );
425                // Fall through to re-download
426            }
427        } else {
428            println!(
429                "  {} [{}/{}] {} (cached)",
430                "✓".green(),
431                index + 1,
432                total,
433                shard_file
434            );
435            return Ok(());
436        }
437    }
438
439    let shard_url = format!("{base_url}/{shard_file}");
440    print!(
441        "  {} [{}/{}] {}...",
442        "↓".yellow(),
443        index + 1,
444        total,
445        shard_file
446    );
447    io::stdout().flush().ok();
448
449    let checksum = download_file_with_progress(&shard_url, &shard_path)?;
450    checksums.insert(shard_file.to_string(), checksum);
451    println!(" {}", "done".green());
452    Ok(())
453}
454
455/// Download companion files (tokenizer, config) for sharded models.
456///
457/// GH-356: tokenizer.json is optional — some models only have tokenizer.model (SentencePiece)
458/// or tokenizer_config.json. We validate that at least ONE tokenizer file was obtained.
459fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
460    // (filename, is_required) — tokenizer files are individually optional but collectively required
461    let companions = [
462        ("tokenizer.json", false),
463        ("config.json", true),
464        ("tokenizer_config.json", false),
465        ("tokenizer.model", false),
466    ];
467    for (filename, required) in &companions {
468        let companion_path = cache_dir.join(filename);
469        if !force && companion_path.exists() {
470            println!("  {} {} (cached)", "✓".green(), filename);
471            continue;
472        }
473
474        let url = format!("{base_url}/{filename}");
475        match download_file(&url, &companion_path) {
476            Ok(()) => println!("  {} {}", "✓".green(), filename),
477            Err(CliError::HttpNotFound(_)) if *required => {
478                return Err(CliError::ValidationFailed(format!(
479                    "{filename} is required for inference but was not found (HTTP 404) at {url}"
480                )));
481            }
482            Err(CliError::HttpNotFound(_)) => {
483                println!("  {} {} (not found in repo)", "⚠".yellow(), filename);
484            }
485            Err(e) if *required => {
486                return Err(CliError::ValidationFailed(format!(
487                    "{filename} is required for inference but download failed: {e}"
488                )));
489            }
490            Err(_) => println!("  {} {} (not available, optional)", "⚠".yellow(), filename),
491        }
492    }
493
494    // GH-356: Validate at least one tokenizer file exists
495    let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
496    let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
497    if !has_tokenizer {
498        return Err(CliError::ValidationFailed(format!(
499            "No tokenizer found for this model. Tried: {}.\n\
500             The model may require a custom tokenizer not hosted in the repository.",
501            tokenizer_files.join(", ")
502        )));
503    }
504
505    Ok(())
506}
507
508/// Write shard manifest with BLAKE3 checksums for integrity verification.
509fn write_shard_manifest(
510    manifest_path: &Path,
511    org: &str,
512    repo: &str,
513    file_checksums: HashMap<String, FileChecksum>,
514) -> Result<()> {
515    if file_checksums.is_empty() {
516        return Ok(());
517    }
518    let manifest = ShardManifest {
519        version: 1,
520        repo: format!("{org}/{repo}"),
521        files: file_checksums,
522    };
523    let manifest_json = serde_json::to_string_pretty(&manifest)
524        .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
525    std::fs::write(manifest_path, manifest_json)?;
526    println!("  {} .apr-manifest.json (integrity checksums)", "✓".green());
527    Ok(())
528}
529
530include!("pull_list.rs");
531include!("pull_remove_resolve_model.rs");
532include!("pull_extract_shard.rs");
533include!("pull_04.rs");