Skip to main content

apr_cli/commands/
pull.rs

1//! Pull command implementation using pacha ModelFetcher
2//!
3//! Downloads and caches models from HuggingFace with Ollama-like UX.
4//! Models are cached at `~/.cache/pacha/models/` for reuse.
5//!
6//! Usage:
7//!   apr pull qwen2.5-coder:7b           # Short alias
8//!   apr pull hf://bartowski/Qwen2.5-Coder-7B-Instruct-GGUF/qwen2.5-coder-7b-instruct-q4_k_m.gguf
9//!   apr pull TheBloke/Llama-2-7B-GGUF   # Will auto-detect quant
10
11use crate::error::{CliError, Result};
12use colored::Colorize;
13use pacha::fetcher::{FetchConfig, ModelFetcher};
14use pacha::format::ModelFormat;
15use serde::{Deserialize, Serialize};
16use std::collections::{HashMap, HashSet};
17use std::io::{self, Read, Write};
18use std::path::Path;
19
20/// Result of resolving a HuggingFace model reference.
21///
22/// Single-file models (small SafeTensors, GGUF) use the pacha fetcher.
23/// Sharded models (3B+ SafeTensors) are downloaded directly to `~/.apr/cache/hf/`.
24#[derive(Debug)]
25enum ResolvedModel {
26    /// Single file downloadable via pacha (existing behavior)
27    SingleFile(String),
28    /// Sharded SafeTensors model (multiple .safetensors files + index.json)
29    Sharded {
30        org: String,
31        repo: String,
32        shard_files: Vec<String>,
33    },
34}
35
36/// GH-213: Manifest recording checksums for each file in a sharded download.
37///
38/// Written to `.apr-manifest.json` in the cache directory after a successful download.
39/// Used by the pre-inference contract gate to verify shard integrity without re-hashing.
40#[derive(Debug, Serialize, Deserialize)]
41pub struct ShardManifest {
42    pub version: u32,
43    pub repo: String,
44    pub files: HashMap<String, FileChecksum>,
45}
46
47/// GH-213: Size and BLAKE3 hash of a downloaded file.
48#[derive(Debug, Serialize, Deserialize)]
49pub struct FileChecksum {
50    pub size: u64,
51    pub blake3: String,
52}
53
54/// Run the pull command
55pub fn run(model_ref: &str, force: bool) -> Result<()> {
56    println!("{}", "=== APR Pull ===".cyan().bold());
57    println!();
58
59    // GH-213: Resolve HuggingFace URI — detect single vs sharded models
60    let resolved = resolve_hf_model(model_ref)?;
61
62    match resolved {
63        ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
64        ResolvedModel::Sharded {
65            ref org,
66            ref repo,
67            ref shard_files,
68        } => run_sharded(org, repo, shard_files, force),
69    }
70}
71
72/// Pull a single-file model.
73///
74/// GH-352: For HuggingFace URIs, streams directly to disk instead of buffering
75/// the entire file in memory through pacha's resolver. For non-HF URIs (pacha
76/// aliases), falls back to the pacha fetcher.
77fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
78    println!("Model: {}", model_ref.cyan());
79
80    // GH-352: HuggingFace URIs bypass pacha to avoid O(model_size) memory buffering
81    if model_ref.starts_with("hf://") {
82        return run_single_file_streaming(model_ref, force);
83    }
84
85    let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
86        CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
87    })?;
88
89    if !force && fetcher.is_cached(model_ref) {
90        return handle_cached_model(&mut fetcher, model_ref);
91    }
92
93    let result = download_single_model(&mut fetcher, model_ref)?;
94    ensure_safetensors_companions(&result)?;
95    print_pull_usage(&result.path, true);
96    Ok(())
97}
98
99/// GH-352: Stream a single HuggingFace file directly to disk.
100///
101/// Uses O(64KB) memory instead of O(model_size). The pacha fetcher's
102/// `resolver.resolve()` buffers the entire response via `response.bytes()`,
103/// which consumed ~4.5 GB for a 7B GGUF. This function streams with a 64KB
104/// chunked read, computes BLAKE3 incrementally, and saves to the pacha cache.
105fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
106    let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
107    let parts: Vec<&str> = path.split('/').collect();
108    if parts.len() < 3 {
109        return Err(CliError::ValidationFailed(format!(
110            "HuggingFace URI must include a filename: {model_ref}"
111        )));
112    }
113
114    let filename = parts[2..].join("/");
115    let url = format!(
116        "https://huggingface.co/{}/{}/resolve/main/{}",
117        parts[0], parts[1], filename
118    );
119
120    // Determine cache path in pacha cache dir
121    let cache_dir = get_pacha_cache_dir()?;
122    std::fs::create_dir_all(&cache_dir)?;
123
124    // Check if already cached (by URI hash)
125    let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
126    let extension = std::path::Path::new(&filename)
127        .extension()
128        .and_then(|e| e.to_str())
129        .unwrap_or("bin");
130    let cache_filename = format!("{}.{}", &uri_hash[..16], extension);
131    let cache_path = cache_dir.join(&cache_filename);
132
133    if !force && cache_path.exists() {
134        let metadata = std::fs::metadata(&cache_path)?;
135        println!("{} Model already cached", "✓".green());
136        println!("  Path: {}", cache_path.display());
137        println!("  Size: {}", format_bytes(metadata.len()));
138        print_pull_usage(&cache_path, true);
139        return Ok(());
140    }
141
142    println!();
143    println!("{}", "Downloading (streaming)...".yellow());
144
145    let checksum = download_file_with_progress(&url, &cache_path)?;
146
147    println!();
148    println!("{} Downloaded successfully", "✓".green());
149    println!("  Path: {}", cache_path.display().to_string().green());
150    println!("  Size: {}", format_bytes(checksum.size).yellow());
151    println!("  Hash: {}", &checksum.blake3[..16]);
152
153    // Handle SafeTensors companions
154    if extension == "safetensors" {
155        fetch_safetensors_companions(&cache_path, model_ref)?;
156        convert_safetensors_formats(&cache_path)?;
157    }
158
159    print_pull_usage(&cache_path, true);
160    Ok(())
161}
162
163/// Get the pacha model cache directory.
164fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
165    if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
166        return Ok(std::path::PathBuf::from(cache_home)
167            .join("pacha")
168            .join("models"));
169    }
170    Ok(dirs::home_dir()
171        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
172        .join(".cache")
173        .join("pacha")
174        .join("models"))
175}
176
177/// Handle a model that is already cached in pacha.
178fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
179    println!("{} Model already cached", "✓".green());
180    let result = fetcher
181        .pull_quiet(model_ref)
182        .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
183
184    println!("  Path: {}", result.path.display());
185    println!("  Size: {}", result.size_human());
186    println!("  Format: {}", result.format.name());
187
188    ensure_safetensors_companions(&result)?;
189    print_pull_usage(&result.path, false);
190    Ok(())
191}
192
193/// Download a single model with progress bar.
194fn download_single_model(
195    fetcher: &mut ModelFetcher,
196    model_ref: &str,
197) -> Result<pacha::fetcher::FetchResult> {
198    println!();
199    println!("{}", "Downloading...".yellow());
200
201    let result = fetcher
202        .pull(model_ref, |progress| {
203            let pct = progress.percent();
204            print!(
205                "\r  [{:50}] {:5.1}% ({}/{})",
206                "=".repeat((pct / 2.0) as usize),
207                pct,
208                format_bytes(progress.downloaded_bytes),
209                format_bytes(progress.total_bytes)
210            );
211            io::stdout().flush().ok();
212        })
213        .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
214
215    println!();
216    println!();
217
218    if result.cache_hit {
219        println!("{} Model retrieved from cache", "✓".green());
220    } else {
221        println!("{} Downloaded successfully", "✓".green());
222    }
223
224    println!("  Path: {}", result.path.display().to_string().green());
225    println!("  Size: {}", result.size_human().yellow());
226    println!("  Format: {}", result.format.name());
227    println!("  Hash: {}", &result.hash[..16]);
228    Ok(result)
229}
230
231/// Ensure companion files exist for SafeTensors models (GH-198, GH-211).
232fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
233    if matches!(result.format, ModelFormat::SafeTensors(_)) {
234        fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
235        convert_safetensors_formats(&result.path)?;
236    }
237    Ok(())
238}
239
240/// Print usage instructions after a successful pull.
241fn print_pull_usage(path: &Path, show_serve: bool) {
242    println!();
243    println!("{}", "Usage:".cyan().bold());
244    println!("  apr run {}", path.display());
245    if show_serve {
246        println!("  apr serve {}", path.display());
247    }
248}
249
250/// GH-213: Pull a sharded SafeTensors model (3B+ models with multiple shard files)
251fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
252    println!(
253        "Model: {}/{} ({} shards)",
254        org.cyan(),
255        repo.cyan(),
256        shard_files.len().to_string().yellow()
257    );
258
259    let cache_dir = resolve_shard_cache_dir(org, repo)?;
260    std::fs::create_dir_all(&cache_dir)?;
261
262    let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
263    let index_path = cache_dir.join("model.safetensors.index.json");
264
265    download_index_if_needed(&base_url, &index_path, force)?;
266
267    let manifest_path = cache_dir.join(".apr-manifest.json");
268    let existing_manifest = load_existing_manifest(&manifest_path, force);
269
270    let file_checksums = download_all_shards(
271        &cache_dir,
272        &base_url,
273        shard_files,
274        force,
275        existing_manifest.as_ref(),
276    )?;
277
278    download_companion_files(&cache_dir, &base_url, force)?;
279    write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
280
281    println!();
282    println!("{} Downloaded successfully", "✓".green());
283    println!("  Path: {}", index_path.display().to_string().green());
284    println!("  Shards: {}", shard_files.len().to_string().yellow());
285
286    convert_safetensors_formats(&index_path)?;
287
288    println!();
289    println!("{}", "Usage:".cyan().bold());
290    println!("  apr run {}", index_path.display());
291    println!("  apr serve {}", index_path.display());
292    Ok(())
293}
294
295/// Resolve the cache directory for a sharded model.
296fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
297    Ok(dirs::home_dir()
298        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
299        .join(".apr")
300        .join("cache")
301        .join("hf")
302        .join(org)
303        .join(repo))
304}
305
306/// Download the SafeTensors index.json if not already cached.
307fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
308    if force || !index_path.exists() {
309        println!();
310        println!("  {} model.safetensors.index.json", "Downloading".yellow());
311        download_file(
312            &format!("{base_url}/model.safetensors.index.json"),
313            index_path,
314        )?;
315    } else {
316        println!("  {} model.safetensors.index.json (cached)", "✓".green());
317    }
318    Ok(())
319}
320
321/// Load existing shard manifest for cache-hit verification (GH-213).
322fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
323    if force || !manifest_path.exists() {
324        return None;
325    }
326    std::fs::read_to_string(manifest_path)
327        .ok()
328        .and_then(|s| serde_json::from_str(&s).ok())
329}
330
331/// Download all shards, collecting checksums for the manifest.
332fn download_all_shards(
333    cache_dir: &Path,
334    base_url: &str,
335    shard_files: &[String],
336    force: bool,
337    existing_manifest: Option<&ShardManifest>,
338) -> Result<HashMap<String, FileChecksum>> {
339    let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
340    let total = shard_files.len();
341    for (i, shard_file) in shard_files.iter().enumerate() {
342        download_or_verify_shard(
343            cache_dir,
344            base_url,
345            shard_file,
346            i,
347            total,
348            force,
349            existing_manifest,
350            &mut file_checksums,
351        )?;
352    }
353    Ok(file_checksums)
354}
355
356/// Download or verify a single shard file, updating the checksum map.
357fn download_or_verify_shard(
358    cache_dir: &Path,
359    base_url: &str,
360    shard_file: &str,
361    index: usize,
362    total: usize,
363    force: bool,
364    existing_manifest: Option<&ShardManifest>,
365    checksums: &mut HashMap<String, FileChecksum>,
366) -> Result<()> {
367    let shard_path = cache_dir.join(shard_file);
368
369    if !force && shard_path.exists() {
370        if let Some(manifest) = existing_manifest {
371            if let Some(expected) = manifest.files.get(shard_file) {
372                let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
373                if actual_size == expected.size {
374                    checksums.insert(
375                        shard_file.to_string(),
376                        FileChecksum {
377                            size: expected.size,
378                            blake3: expected.blake3.clone(),
379                        },
380                    );
381                    println!(
382                        "  {} [{}/{}] {} (cached, verified)",
383                        "✓".green(),
384                        index + 1,
385                        total,
386                        shard_file
387                    );
388                    return Ok(());
389                }
390                println!(
391                    "  {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
392                    "⚠".yellow(),
393                    index + 1,
394                    total,
395                    shard_file,
396                    actual_size,
397                    expected.size
398                );
399                // Fall through to re-download
400            }
401        } else {
402            println!(
403                "  {} [{}/{}] {} (cached)",
404                "✓".green(),
405                index + 1,
406                total,
407                shard_file
408            );
409            return Ok(());
410        }
411    }
412
413    let shard_url = format!("{base_url}/{shard_file}");
414    print!(
415        "  {} [{}/{}] {}...",
416        "↓".yellow(),
417        index + 1,
418        total,
419        shard_file
420    );
421    io::stdout().flush().ok();
422
423    let checksum = download_file_with_progress(&shard_url, &shard_path)?;
424    checksums.insert(shard_file.to_string(), checksum);
425    println!(" {}", "done".green());
426    Ok(())
427}
428
429/// Download companion files (tokenizer.json, config.json, tokenizer_config.json) for sharded models.
430fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
431    let companions = [
432        ("tokenizer.json", true),
433        ("config.json", true),
434        ("tokenizer_config.json", false),
435    ];
436    for (filename, required) in &companions {
437        let companion_path = cache_dir.join(filename);
438        if !force && companion_path.exists() {
439            println!("  {} {} (cached)", "✓".green(), filename);
440            continue;
441        }
442
443        let url = format!("{base_url}/{filename}");
444        match download_file(&url, &companion_path) {
445            Ok(()) => println!("  {} {}", "✓".green(), filename),
446            Err(e) if *required => {
447                return Err(CliError::ValidationFailed(format!(
448                    "{filename} is required for inference but download failed: {e}"
449                )));
450            }
451            Err(_) => println!("  {} {} (not available, optional)", "⚠".yellow(), filename),
452        }
453    }
454    Ok(())
455}
456
457/// Write shard manifest with BLAKE3 checksums for integrity verification.
458fn write_shard_manifest(
459    manifest_path: &Path,
460    org: &str,
461    repo: &str,
462    file_checksums: HashMap<String, FileChecksum>,
463) -> Result<()> {
464    if file_checksums.is_empty() {
465        return Ok(());
466    }
467    let manifest = ShardManifest {
468        version: 1,
469        repo: format!("{org}/{repo}"),
470        files: file_checksums,
471    };
472    let manifest_json = serde_json::to_string_pretty(&manifest)
473        .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
474    std::fs::write(manifest_path, manifest_json)?;
475    println!("  {} .apr-manifest.json (integrity checksums)", "✓".green());
476    Ok(())
477}
478
479/// List cached models
480// serde_json::json!() macro uses infallible unwrap internally
481#[allow(clippy::disallowed_methods)]
482pub fn list(json: bool) -> Result<()> {
483    let fetcher = ModelFetcher::new().map_err(|e| {
484        CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
485    })?;
486
487    let models = fetcher.list();
488
489    // GH-248: JSON output mode
490    if json {
491        let models_json: Vec<serde_json::Value> = models
492            .iter()
493            .map(|m| {
494                serde_json::json!({
495                    "name": m.name,
496                    "size_bytes": m.size_bytes,
497                    "format": m.format.name(),
498                    "path": m.path.display().to_string(),
499                })
500            })
501            .collect();
502        let stats = fetcher.stats();
503        let output = serde_json::json!({
504            "models": models_json,
505            "total": models.len(),
506            "total_size_bytes": stats.total_size_bytes,
507        });
508        println!(
509            "{}",
510            serde_json::to_string_pretty(&output).unwrap_or_default()
511        );
512        return Ok(());
513    }
514
515    println!("{}", "=== Cached Models ===".cyan().bold());
516    println!();
517
518    if models.is_empty() {
519        println!("{}", "No cached models found.".dimmed());
520        println!();
521        println!("Pull a model with:");
522        println!("  apr pull hf://Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf");
523        println!();
524        println!("Or run directly (auto-downloads):");
525        println!("  apr run hf://Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf");
526        return Ok(());
527    }
528
529    // Print header
530    println!(
531        "{:<40} {:<12} {:<12} {}",
532        "NAME".dimmed(),
533        "SIZE".dimmed(),
534        "FORMAT".dimmed(),
535        "PATH".dimmed()
536    );
537    println!("{}", "-".repeat(104).dimmed());
538
539    for model in &models {
540        let size = format_bytes(model.size_bytes);
541        let format = model.format.name();
542        let name = if model.name.len() > 38 {
543            format!("{}...", &model.name[..35])
544        } else {
545            model.name.clone()
546        };
547
548        println!(
549            "{:<40} {:<12} {:<12} {}",
550            name.cyan(),
551            size.yellow(),
552            format,
553            model.path.display().to_string().dimmed()
554        );
555    }
556
557    println!();
558
559    // Print stats
560    let stats = fetcher.stats();
561    println!(
562        "Total: {} models, {} used",
563        models.len(),
564        format_bytes(stats.total_size_bytes)
565    );
566
567    Ok(())
568}
569
570include!("pull_remove_resolve_model.rs");
571include!("pull_extract_shard.rs");
572include!("pull_04.rs");