Skip to main content

apr_cli/commands/
pull.rs

1//! Pull command: download and cache models from HuggingFace (`~/.cache/pacha/models/`).
2
3use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12/// Result of resolving a HuggingFace model reference.
13///
14/// Single-file models (small SafeTensors, GGUF) use the pacha fetcher.
15/// Sharded models (3B+ SafeTensors) are downloaded directly to `~/.apr/cache/hf/`.
16#[derive(Debug)]
17enum ResolvedModel {
18    /// Single file downloadable via pacha (existing behavior)
19    SingleFile(String),
20    /// Sharded SafeTensors model (multiple .safetensors files + index.json)
21    Sharded {
22        org: String,
23        repo: String,
24        shard_files: Vec<String>,
25    },
26}
27
28/// GH-213: Manifest recording checksums for each file in a sharded download.
29///
30/// Written to `.apr-manifest.json` in the cache directory after a successful download.
31/// Used by the pre-inference contract gate to verify shard integrity without re-hashing.
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34    pub version: u32,
35    pub repo: String,
36    pub files: HashMap<String, FileChecksum>,
37}
38
39/// GH-213: Size and BLAKE3 hash of a downloaded file.
40#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42    pub size: u64,
43    pub blake3: String,
44}
45
46/// Run the pull command
47pub fn run(model_ref: &str, force: bool) -> Result<()> {
48    contract_pre_pull_cache_integrity!();
49    println!("{}", "=== APR Pull ===".cyan().bold());
50    println!();
51
52    // GH-213: Resolve HuggingFace URI — detect single vs sharded models
53    let resolved = resolve_hf_model(model_ref)?;
54
55    match resolved {
56        ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
57        ResolvedModel::Sharded {
58            ref org,
59            ref repo,
60            ref shard_files,
61        } => run_sharded(org, repo, shard_files, force),
62    }
63}
64
65/// Pull a single-file model.
66///
67/// GH-352: For HuggingFace URIs, streams directly to disk instead of buffering
68/// the entire file in memory through pacha's resolver. For non-HF URIs (pacha
69/// aliases), falls back to the pacha fetcher.
70fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
71    println!("Model: {}", model_ref.cyan());
72
73    // GH-352: HuggingFace URIs bypass pacha to avoid O(model_size) memory buffering
74    if model_ref.starts_with("hf://") {
75        return run_single_file_streaming(model_ref, force);
76    }
77
78    let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
79        CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
80    })?;
81
82    if !force && fetcher.is_cached(model_ref) {
83        return handle_cached_model(&mut fetcher, model_ref);
84    }
85
86    let result = download_single_model(&mut fetcher, model_ref)?;
87    ensure_safetensors_companions(&result)?;
88    print_pull_usage(&result.path, true);
89    Ok(())
90}
91
92/// GH-352: Stream a single HuggingFace file directly to disk.
93///
94/// Uses O(64KB) memory instead of O(model_size). The pacha fetcher's
95/// `resolver.resolve()` buffers the entire response via `response.bytes()`,
96/// which consumed ~4.5 GB for a 7B GGUF. This function streams with a 64KB
97/// chunked read, computes BLAKE3 incrementally, and saves to the pacha cache.
98fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
99    let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
100    let parts: Vec<&str> = path.split('/').collect();
101    if parts.len() < 3 {
102        return Err(CliError::ValidationFailed(format!(
103            "HuggingFace URI must include a filename: {model_ref}"
104        )));
105    }
106
107    let filename = parts[2..].join("/");
108    let url = format!(
109        "https://huggingface.co/{}/{}/resolve/main/{}",
110        parts[0], parts[1], filename
111    );
112
113    // Determine cache path in pacha cache dir
114    let cache_dir = get_pacha_cache_dir()?;
115    std::fs::create_dir_all(&cache_dir)?;
116
117    // Check if already cached (by URI hash)
118    let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
119    let extension = std::path::Path::new(&filename)
120        .extension()
121        .and_then(|e| e.to_str())
122        .unwrap_or("bin");
123    let cache_filename = format!("{}.{}", &uri_hash[..16], extension);
124    let cache_path = cache_dir.join(&cache_filename);
125
126    if !force && cache_path.exists() {
127        let metadata = std::fs::metadata(&cache_path)?;
128        println!("{} Model already cached", "✓".green());
129        println!("  Path: {}", cache_path.display());
130        println!("  Size: {}", format_bytes(metadata.len()));
131        print_pull_usage(&cache_path, true);
132        return Ok(());
133    }
134
135    println!();
136    println!("{}", "Downloading (streaming)...".yellow());
137
138    let checksum = download_file_with_progress(&url, &cache_path)?;
139
140    println!();
141    println!("{} Downloaded successfully", "✓".green());
142    println!("  Path: {}", cache_path.display().to_string().green());
143    println!("  Size: {}", format_bytes(checksum.size).yellow());
144    println!("  Hash: {}", &checksum.blake3[..16]);
145
146    // Handle SafeTensors companions
147    if extension == "safetensors" {
148        fetch_safetensors_companions(&cache_path, model_ref)?;
149        convert_safetensors_formats(&cache_path)?;
150    }
151
152    print_pull_usage(&cache_path, true);
153    Ok(())
154}
155
156/// Get the pacha model cache directory.
157fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
158    if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
159        return Ok(std::path::PathBuf::from(cache_home)
160            .join("pacha")
161            .join("models"));
162    }
163    Ok(dirs::home_dir()
164        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
165        .join(".cache")
166        .join("pacha")
167        .join("models"))
168}
169
170/// Handle a model that is already cached in pacha.
171fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
172    println!("{} Model already cached", "✓".green());
173    let result = fetcher
174        .pull_quiet(model_ref)
175        .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
176
177    println!("  Path: {}", result.path.display());
178    println!("  Size: {}", result.size_human());
179    println!("  Format: {}", result.format.name());
180
181    ensure_safetensors_companions(&result)?;
182    print_pull_usage(&result.path, false);
183    Ok(())
184}
185
186/// Download a single model with progress bar.
187fn download_single_model(
188    fetcher: &mut ModelFetcher,
189    model_ref: &str,
190) -> Result<pacha::fetcher::FetchResult> {
191    println!();
192    println!("{}", "Downloading...".yellow());
193
194    let result = fetcher
195        .pull(model_ref, |progress| {
196            let pct = progress.percent();
197            print!(
198                "\r  [{:50}] {:5.1}% ({}/{})",
199                "=".repeat((pct / 2.0) as usize),
200                pct,
201                format_bytes(progress.downloaded_bytes),
202                format_bytes(progress.total_bytes)
203            );
204            io::stdout().flush().ok();
205        })
206        .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
207
208    println!();
209    println!();
210
211    if result.cache_hit {
212        println!("{} Model retrieved from cache", "✓".green());
213    } else {
214        println!("{} Downloaded successfully", "✓".green());
215    }
216
217    println!("  Path: {}", result.path.display().to_string().green());
218    println!("  Size: {}", result.size_human().yellow());
219    println!("  Format: {}", result.format.name());
220    println!("  Hash: {}", &result.hash[..16]);
221    Ok(result)
222}
223
224/// Ensure companion files exist for SafeTensors models (GH-198, GH-211).
225fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
226    if matches!(result.format, ModelFormat::SafeTensors(_)) {
227        fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
228        convert_safetensors_formats(&result.path)?;
229    }
230    Ok(())
231}
232
233/// Print usage instructions after a successful pull.
234fn print_pull_usage(path: &Path, show_serve: bool) {
235    println!();
236    println!("{}", "Usage:".cyan().bold());
237    println!("  apr run {}", path.display());
238    if show_serve {
239        println!("  apr serve {}", path.display());
240    }
241}
242
243/// GH-213: Pull a sharded SafeTensors model (3B+ models with multiple shard files)
244fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
245    println!(
246        "Model: {}/{} ({} shards)",
247        org.cyan(),
248        repo.cyan(),
249        shard_files.len().to_string().yellow()
250    );
251
252    let cache_dir = resolve_shard_cache_dir(org, repo)?;
253    std::fs::create_dir_all(&cache_dir)?;
254
255    let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
256    let index_path = cache_dir.join("model.safetensors.index.json");
257
258    download_index_if_needed(&base_url, &index_path, force)?;
259
260    let manifest_path = cache_dir.join(".apr-manifest.json");
261    let existing_manifest = load_existing_manifest(&manifest_path, force);
262
263    let file_checksums = download_all_shards(
264        &cache_dir,
265        &base_url,
266        shard_files,
267        force,
268        existing_manifest.as_ref(),
269    )?;
270
271    download_companion_files(&cache_dir, &base_url, force)?;
272    write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
273
274    println!();
275    println!("{} Downloaded successfully", "✓".green());
276    println!("  Path: {}", index_path.display().to_string().green());
277    println!("  Shards: {}", shard_files.len().to_string().yellow());
278
279    convert_safetensors_formats(&index_path)?;
280
281    println!();
282    println!("{}", "Usage:".cyan().bold());
283    println!("  apr run {}", index_path.display());
284    println!("  apr serve {}", index_path.display());
285    Ok(())
286}
287
288/// Resolve the cache directory for a sharded model.
289fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
290    Ok(dirs::home_dir()
291        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
292        .join(".apr")
293        .join("cache")
294        .join("hf")
295        .join(org)
296        .join(repo))
297}
298
299/// Download the SafeTensors index.json if not already cached.
300fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
301    if force || !index_path.exists() {
302        println!();
303        println!("  {} model.safetensors.index.json", "Downloading".yellow());
304        download_file(
305            &format!("{base_url}/model.safetensors.index.json"),
306            index_path,
307        )?;
308    } else {
309        println!("  {} model.safetensors.index.json (cached)", "✓".green());
310    }
311    Ok(())
312}
313
314/// Load existing shard manifest for cache-hit verification (GH-213).
315fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
316    if force || !manifest_path.exists() {
317        return None;
318    }
319    std::fs::read_to_string(manifest_path)
320        .ok()
321        .and_then(|s| serde_json::from_str(&s).ok())
322}
323
324/// Download all shards, collecting checksums for the manifest.
325fn download_all_shards(
326    cache_dir: &Path,
327    base_url: &str,
328    shard_files: &[String],
329    force: bool,
330    existing_manifest: Option<&ShardManifest>,
331) -> Result<HashMap<String, FileChecksum>> {
332    let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
333    let total = shard_files.len();
334    for (i, shard_file) in shard_files.iter().enumerate() {
335        download_or_verify_shard(
336            cache_dir,
337            base_url,
338            shard_file,
339            i,
340            total,
341            force,
342            existing_manifest,
343            &mut file_checksums,
344        )?;
345    }
346    Ok(file_checksums)
347}
348
349/// Download or verify a single shard file, updating the checksum map.
350fn download_or_verify_shard(
351    cache_dir: &Path,
352    base_url: &str,
353    shard_file: &str,
354    index: usize,
355    total: usize,
356    force: bool,
357    existing_manifest: Option<&ShardManifest>,
358    checksums: &mut HashMap<String, FileChecksum>,
359) -> Result<()> {
360    let shard_path = cache_dir.join(shard_file);
361
362    if !force && shard_path.exists() {
363        if let Some(manifest) = existing_manifest {
364            if let Some(expected) = manifest.files.get(shard_file) {
365                let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
366                if actual_size == expected.size {
367                    checksums.insert(
368                        shard_file.to_string(),
369                        FileChecksum {
370                            size: expected.size,
371                            blake3: expected.blake3.clone(),
372                        },
373                    );
374                    println!(
375                        "  {} [{}/{}] {} (cached, verified)",
376                        "✓".green(),
377                        index + 1,
378                        total,
379                        shard_file
380                    );
381                    return Ok(());
382                }
383                println!(
384                    "  {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
385                    "⚠".yellow(),
386                    index + 1,
387                    total,
388                    shard_file,
389                    actual_size,
390                    expected.size
391                );
392                // Fall through to re-download
393            }
394        } else {
395            println!(
396                "  {} [{}/{}] {} (cached)",
397                "✓".green(),
398                index + 1,
399                total,
400                shard_file
401            );
402            return Ok(());
403        }
404    }
405
406    let shard_url = format!("{base_url}/{shard_file}");
407    print!(
408        "  {} [{}/{}] {}...",
409        "↓".yellow(),
410        index + 1,
411        total,
412        shard_file
413    );
414    io::stdout().flush().ok();
415
416    let checksum = download_file_with_progress(&shard_url, &shard_path)?;
417    checksums.insert(shard_file.to_string(), checksum);
418    println!(" {}", "done".green());
419    Ok(())
420}
421
422/// Download companion files (tokenizer, config) for sharded models.
423///
424/// GH-356: tokenizer.json is optional — some models only have tokenizer.model (SentencePiece)
425/// or tokenizer_config.json. We validate that at least ONE tokenizer file was obtained.
426fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
427    // (filename, is_required) — tokenizer files are individually optional but collectively required
428    let companions = [
429        ("tokenizer.json", false),
430        ("config.json", true),
431        ("tokenizer_config.json", false),
432        ("tokenizer.model", false),
433    ];
434    for (filename, required) in &companions {
435        let companion_path = cache_dir.join(filename);
436        if !force && companion_path.exists() {
437            println!("  {} {} (cached)", "✓".green(), filename);
438            continue;
439        }
440
441        let url = format!("{base_url}/{filename}");
442        match download_file(&url, &companion_path) {
443            Ok(()) => println!("  {} {}", "✓".green(), filename),
444            Err(CliError::HttpNotFound(_)) if *required => {
445                return Err(CliError::ValidationFailed(format!(
446                    "{filename} is required for inference but was not found (HTTP 404) at {url}"
447                )));
448            }
449            Err(CliError::HttpNotFound(_)) => {
450                println!("  {} {} (not found in repo)", "⚠".yellow(), filename);
451            }
452            Err(e) if *required => {
453                return Err(CliError::ValidationFailed(format!(
454                    "{filename} is required for inference but download failed: {e}"
455                )));
456            }
457            Err(_) => println!("  {} {} (not available, optional)", "⚠".yellow(), filename),
458        }
459    }
460
461    // GH-356: Validate at least one tokenizer file exists
462    let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
463    let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
464    if !has_tokenizer {
465        return Err(CliError::ValidationFailed(format!(
466            "No tokenizer found for this model. Tried: {}.\n\
467             The model may require a custom tokenizer not hosted in the repository.",
468            tokenizer_files.join(", ")
469        )));
470    }
471
472    Ok(())
473}
474
475/// Write shard manifest with BLAKE3 checksums for integrity verification.
476fn write_shard_manifest(
477    manifest_path: &Path,
478    org: &str,
479    repo: &str,
480    file_checksums: HashMap<String, FileChecksum>,
481) -> Result<()> {
482    if file_checksums.is_empty() {
483        return Ok(());
484    }
485    let manifest = ShardManifest {
486        version: 1,
487        repo: format!("{org}/{repo}"),
488        files: file_checksums,
489    };
490    let manifest_json = serde_json::to_string_pretty(&manifest)
491        .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
492    std::fs::write(manifest_path, manifest_json)?;
493    println!("  {} .apr-manifest.json (integrity checksums)", "✓".green());
494    Ok(())
495}
496
497include!("pull_list.rs");
498include!("pull_remove_resolve_model.rs");
499include!("pull_extract_shard.rs");
500include!("pull_04.rs");