Skip to main content

apr_cli/commands/
pull.rs

1//! Pull command: download and cache models from HuggingFace (`~/.cache/pacha/models/`).
2
3use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12/// Result of resolving a HuggingFace model reference.
13///
14/// Single-file models (small SafeTensors, GGUF) use the pacha fetcher.
15/// Sharded models (3B+ SafeTensors) are downloaded directly to `~/.apr/cache/hf/`.
16#[derive(Debug)]
17enum ResolvedModel {
18    /// Single file downloadable via pacha (existing behavior)
19    SingleFile(String),
20    /// Sharded SafeTensors model (multiple .safetensors files + index.json)
21    Sharded {
22        org: String,
23        repo: String,
24        shard_files: Vec<String>,
25    },
26}
27
28/// GH-213: Manifest recording checksums for each file in a sharded download.
29///
30/// Written to `.apr-manifest.json` in the cache directory after a successful download.
31/// Used by the pre-inference contract gate to verify shard integrity without re-hashing.
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34    pub version: u32,
35    pub repo: String,
36    pub files: HashMap<String, FileChecksum>,
37}
38
39/// GH-213: Size and BLAKE3 hash of a downloaded file.
40#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42    pub size: u64,
43    pub blake3: String,
44}
45
46/// Run the pull command
47pub fn run(model_ref: &str, force: bool) -> Result<()> {
48    contract_pre_pull_cache_integrity!();
49    println!("{}", "=== APR Pull ===".cyan().bold());
50    println!();
51
52    // GH-213: Resolve HuggingFace URI — detect single vs sharded models
53    let resolved = resolve_hf_model(model_ref)?;
54
55    let result = match resolved {
56        ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
57        ResolvedModel::Sharded {
58            ref org,
59            ref repo,
60            ref shard_files,
61        } => run_sharded(org, repo, shard_files, force),
62    };
63    if let Ok(ref r) = result {
64        contract_post_pull_cache_integrity!(r);
65    }
66    result
67}
68
69/// Pull a single-file model.
70///
71/// GH-352: For HuggingFace URIs, streams directly to disk instead of buffering
72/// the entire file in memory through pacha's resolver. For non-HF URIs (pacha
73/// aliases), falls back to the pacha fetcher.
74fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
75    println!("Model: {}", model_ref.cyan());
76
77    // GH-352: HuggingFace URIs bypass pacha to avoid O(model_size) memory buffering
78    if model_ref.starts_with("hf://") {
79        return run_single_file_streaming(model_ref, force);
80    }
81
82    let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
83        CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
84    })?;
85
86    if !force && fetcher.is_cached(model_ref) {
87        return handle_cached_model(&mut fetcher, model_ref);
88    }
89
90    let result = download_single_model(&mut fetcher, model_ref)?;
91    ensure_safetensors_companions(&result)?;
92    print_pull_usage(&result.path, true);
93    Ok(())
94}
95
96/// GH-352: Stream a single HuggingFace file directly to disk.
97///
98/// Uses O(64KB) memory instead of O(model_size). The pacha fetcher's
99/// `resolver.resolve()` buffers the entire response via `response.bytes()`,
100/// which consumed ~4.5 GB for a 7B GGUF. This function streams with a 64KB
101/// chunked read, computes BLAKE3 incrementally, and saves to the pacha cache.
102fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
103    let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
104    let parts: Vec<&str> = path.split('/').collect();
105    if parts.len() < 3 {
106        return Err(CliError::ValidationFailed(format!(
107            "HuggingFace URI must include a filename: {model_ref}"
108        )));
109    }
110
111    let filename = parts[2..].join("/");
112    let url = format!(
113        "https://huggingface.co/{}/{}/resolve/main/{}",
114        parts[0], parts[1], filename
115    );
116
117    // Determine cache path in pacha cache dir
118    let cache_dir = get_pacha_cache_dir()?;
119    std::fs::create_dir_all(&cache_dir)?;
120
121    // Check if already cached (by URI hash)
122    let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
123    let extension = std::path::Path::new(&filename)
124        .extension()
125        .and_then(|e| e.to_str())
126        .unwrap_or("bin");
127    let cache_filename = format!("{}.{}", &uri_hash[..16], extension);
128    let cache_path = cache_dir.join(&cache_filename);
129
130    if !force && cache_path.exists() {
131        let metadata = std::fs::metadata(&cache_path)?;
132        println!("{} Model already cached", "✓".green());
133        println!("  Path: {}", cache_path.display());
134        println!("  Size: {}", format_bytes(metadata.len()));
135        print_pull_usage(&cache_path, true);
136        return Ok(());
137    }
138
139    println!();
140    println!("{}", "Downloading (streaming)...".yellow());
141
142    let checksum = download_file_with_progress(&url, &cache_path)?;
143
144    println!();
145    println!("{} Downloaded successfully", "✓".green());
146    println!("  Path: {}", cache_path.display().to_string().green());
147    println!("  Size: {}", format_bytes(checksum.size).yellow());
148    println!("  Hash: {}", &checksum.blake3[..16]);
149
150    // Handle SafeTensors companions
151    if extension == "safetensors" {
152        fetch_safetensors_companions(&cache_path, model_ref)?;
153        convert_safetensors_formats(&cache_path)?;
154    }
155
156    print_pull_usage(&cache_path, true);
157    Ok(())
158}
159
160/// Get the pacha model cache directory.
161fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
162    if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
163        return Ok(std::path::PathBuf::from(cache_home)
164            .join("pacha")
165            .join("models"));
166    }
167    Ok(dirs::home_dir()
168        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
169        .join(".cache")
170        .join("pacha")
171        .join("models"))
172}
173
174/// Handle a model that is already cached in pacha.
175fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
176    println!("{} Model already cached", "✓".green());
177    let result = fetcher
178        .pull_quiet(model_ref)
179        .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
180
181    println!("  Path: {}", result.path.display());
182    println!("  Size: {}", result.size_human());
183    println!("  Format: {}", result.format.name());
184
185    ensure_safetensors_companions(&result)?;
186    print_pull_usage(&result.path, false);
187    Ok(())
188}
189
190/// Download a single model with progress bar.
191fn download_single_model(
192    fetcher: &mut ModelFetcher,
193    model_ref: &str,
194) -> Result<pacha::fetcher::FetchResult> {
195    println!();
196    println!("{}", "Downloading...".yellow());
197
198    let result = fetcher
199        .pull(model_ref, |progress| {
200            let pct = progress.percent();
201            print!(
202                "\r  [{:50}] {:5.1}% ({}/{})",
203                "=".repeat((pct / 2.0) as usize),
204                pct,
205                format_bytes(progress.downloaded_bytes),
206                format_bytes(progress.total_bytes)
207            );
208            io::stdout().flush().ok();
209        })
210        .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
211
212    println!();
213    println!();
214
215    if result.cache_hit {
216        println!("{} Model retrieved from cache", "✓".green());
217    } else {
218        println!("{} Downloaded successfully", "✓".green());
219    }
220
221    println!("  Path: {}", result.path.display().to_string().green());
222    println!("  Size: {}", result.size_human().yellow());
223    println!("  Format: {}", result.format.name());
224    println!("  Hash: {}", &result.hash[..16]);
225    Ok(result)
226}
227
228/// Ensure companion files exist for SafeTensors models (GH-198, GH-211).
229fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
230    if matches!(result.format, ModelFormat::SafeTensors(_)) {
231        fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
232        convert_safetensors_formats(&result.path)?;
233    }
234    Ok(())
235}
236
237/// Print usage instructions after a successful pull.
238fn print_pull_usage(path: &Path, show_serve: bool) {
239    println!();
240    println!("{}", "Usage:".cyan().bold());
241    println!("  apr run {}", path.display());
242    if show_serve {
243        println!("  apr serve {}", path.display());
244    }
245}
246
247/// GH-213: Pull a sharded SafeTensors model (3B+ models with multiple shard files)
248fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
249    println!(
250        "Model: {}/{} ({} shards)",
251        org.cyan(),
252        repo.cyan(),
253        shard_files.len().to_string().yellow()
254    );
255
256    let cache_dir = resolve_shard_cache_dir(org, repo)?;
257    std::fs::create_dir_all(&cache_dir)?;
258
259    let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
260    let index_path = cache_dir.join("model.safetensors.index.json");
261
262    download_index_if_needed(&base_url, &index_path, force)?;
263
264    let manifest_path = cache_dir.join(".apr-manifest.json");
265    let existing_manifest = load_existing_manifest(&manifest_path, force);
266
267    let file_checksums = download_all_shards(
268        &cache_dir,
269        &base_url,
270        shard_files,
271        force,
272        existing_manifest.as_ref(),
273    )?;
274
275    download_companion_files(&cache_dir, &base_url, force)?;
276    write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
277
278    println!();
279    println!("{} Downloaded successfully", "✓".green());
280    println!("  Path: {}", index_path.display().to_string().green());
281    println!("  Shards: {}", shard_files.len().to_string().yellow());
282
283    convert_safetensors_formats(&index_path)?;
284
285    println!();
286    println!("{}", "Usage:".cyan().bold());
287    println!("  apr run {}", index_path.display());
288    println!("  apr serve {}", index_path.display());
289    Ok(())
290}
291
292/// Resolve the cache directory for a sharded model.
293fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
294    Ok(dirs::home_dir()
295        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
296        .join(".apr")
297        .join("cache")
298        .join("hf")
299        .join(org)
300        .join(repo))
301}
302
303/// Download the SafeTensors index.json if not already cached.
304fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
305    if force || !index_path.exists() {
306        println!();
307        println!("  {} model.safetensors.index.json", "Downloading".yellow());
308        download_file(
309            &format!("{base_url}/model.safetensors.index.json"),
310            index_path,
311        )?;
312    } else {
313        println!("  {} model.safetensors.index.json (cached)", "✓".green());
314    }
315    Ok(())
316}
317
318/// Load existing shard manifest for cache-hit verification (GH-213).
319fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
320    if force || !manifest_path.exists() {
321        return None;
322    }
323    std::fs::read_to_string(manifest_path)
324        .ok()
325        .and_then(|s| serde_json::from_str(&s).ok())
326}
327
328/// Download all shards, collecting checksums for the manifest.
329fn download_all_shards(
330    cache_dir: &Path,
331    base_url: &str,
332    shard_files: &[String],
333    force: bool,
334    existing_manifest: Option<&ShardManifest>,
335) -> Result<HashMap<String, FileChecksum>> {
336    let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
337    let total = shard_files.len();
338    for (i, shard_file) in shard_files.iter().enumerate() {
339        download_or_verify_shard(
340            cache_dir,
341            base_url,
342            shard_file,
343            i,
344            total,
345            force,
346            existing_manifest,
347            &mut file_checksums,
348        )?;
349    }
350    Ok(file_checksums)
351}
352
353/// Download or verify a single shard file, updating the checksum map.
354fn download_or_verify_shard(
355    cache_dir: &Path,
356    base_url: &str,
357    shard_file: &str,
358    index: usize,
359    total: usize,
360    force: bool,
361    existing_manifest: Option<&ShardManifest>,
362    checksums: &mut HashMap<String, FileChecksum>,
363) -> Result<()> {
364    let shard_path = cache_dir.join(shard_file);
365
366    if !force && shard_path.exists() {
367        if let Some(manifest) = existing_manifest {
368            if let Some(expected) = manifest.files.get(shard_file) {
369                let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
370                if actual_size == expected.size {
371                    checksums.insert(
372                        shard_file.to_string(),
373                        FileChecksum {
374                            size: expected.size,
375                            blake3: expected.blake3.clone(),
376                        },
377                    );
378                    println!(
379                        "  {} [{}/{}] {} (cached, verified)",
380                        "✓".green(),
381                        index + 1,
382                        total,
383                        shard_file
384                    );
385                    return Ok(());
386                }
387                println!(
388                    "  {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
389                    "⚠".yellow(),
390                    index + 1,
391                    total,
392                    shard_file,
393                    actual_size,
394                    expected.size
395                );
396                // Fall through to re-download
397            }
398        } else {
399            println!(
400                "  {} [{}/{}] {} (cached)",
401                "✓".green(),
402                index + 1,
403                total,
404                shard_file
405            );
406            return Ok(());
407        }
408    }
409
410    let shard_url = format!("{base_url}/{shard_file}");
411    print!(
412        "  {} [{}/{}] {}...",
413        "↓".yellow(),
414        index + 1,
415        total,
416        shard_file
417    );
418    io::stdout().flush().ok();
419
420    let checksum = download_file_with_progress(&shard_url, &shard_path)?;
421    checksums.insert(shard_file.to_string(), checksum);
422    println!(" {}", "done".green());
423    Ok(())
424}
425
426/// Download companion files (tokenizer, config) for sharded models.
427///
428/// GH-356: tokenizer.json is optional — some models only have tokenizer.model (SentencePiece)
429/// or tokenizer_config.json. We validate that at least ONE tokenizer file was obtained.
430fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
431    // (filename, is_required) — tokenizer files are individually optional but collectively required
432    let companions = [
433        ("tokenizer.json", false),
434        ("config.json", true),
435        ("tokenizer_config.json", false),
436        ("tokenizer.model", false),
437    ];
438    for (filename, required) in &companions {
439        let companion_path = cache_dir.join(filename);
440        if !force && companion_path.exists() {
441            println!("  {} {} (cached)", "✓".green(), filename);
442            continue;
443        }
444
445        let url = format!("{base_url}/{filename}");
446        match download_file(&url, &companion_path) {
447            Ok(()) => println!("  {} {}", "✓".green(), filename),
448            Err(CliError::HttpNotFound(_)) if *required => {
449                return Err(CliError::ValidationFailed(format!(
450                    "{filename} is required for inference but was not found (HTTP 404) at {url}"
451                )));
452            }
453            Err(CliError::HttpNotFound(_)) => {
454                println!("  {} {} (not found in repo)", "⚠".yellow(), filename);
455            }
456            Err(e) if *required => {
457                return Err(CliError::ValidationFailed(format!(
458                    "{filename} is required for inference but download failed: {e}"
459                )));
460            }
461            Err(_) => println!("  {} {} (not available, optional)", "⚠".yellow(), filename),
462        }
463    }
464
465    // GH-356: Validate at least one tokenizer file exists
466    let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
467    let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
468    if !has_tokenizer {
469        return Err(CliError::ValidationFailed(format!(
470            "No tokenizer found for this model. Tried: {}.\n\
471             The model may require a custom tokenizer not hosted in the repository.",
472            tokenizer_files.join(", ")
473        )));
474    }
475
476    Ok(())
477}
478
479/// Write shard manifest with BLAKE3 checksums for integrity verification.
480fn write_shard_manifest(
481    manifest_path: &Path,
482    org: &str,
483    repo: &str,
484    file_checksums: HashMap<String, FileChecksum>,
485) -> Result<()> {
486    if file_checksums.is_empty() {
487        return Ok(());
488    }
489    let manifest = ShardManifest {
490        version: 1,
491        repo: format!("{org}/{repo}"),
492        files: file_checksums,
493    };
494    let manifest_json = serde_json::to_string_pretty(&manifest)
495        .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
496    std::fs::write(manifest_path, manifest_json)?;
497    println!("  {} .apr-manifest.json (integrity checksums)", "✓".green());
498    Ok(())
499}
500
501include!("pull_list.rs");
502include!("pull_remove_resolve_model.rs");
503include!("pull_extract_shard.rs");
504include!("pull_04.rs");