Skip to main content

apr_cli/commands/
pull.rs

1//! Pull command implementation using pacha ModelFetcher
2//!
3//! Downloads and caches models from HuggingFace with Ollama-like UX.
4//! Models are cached at `~/.cache/pacha/models/` for reuse.
5//!
6//! Usage:
7//!   apr pull qwen2.5-coder:7b           # Short alias
8//!   apr pull hf://bartowski/Qwen2.5-Coder-7B-Instruct-GGUF/qwen2.5-coder-7b-instruct-q4_k_m.gguf
9//!   apr pull TheBloke/Llama-2-7B-GGUF   # Will auto-detect quant
10
11use crate::error::{CliError, Result};
12use aprender::format::{
13    apr_export, apr_import, ExportFormat, ExportOptions, ImportOptions, ValidationConfig,
14};
15use colored::Colorize;
16use pacha::fetcher::{FetchConfig, ModelFetcher};
17use pacha::format::ModelFormat;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet};
20use std::io::{self, Read, Write};
21use std::path::Path;
22
23/// Result of resolving a HuggingFace model reference.
24///
25/// Single-file models (small SafeTensors, GGUF) use the pacha fetcher.
26/// Sharded models (3B+ SafeTensors) are downloaded directly to `~/.apr/cache/hf/`.
27#[derive(Debug)]
28enum ResolvedModel {
29    /// Single file downloadable via pacha (existing behavior)
30    SingleFile(String),
31    /// Sharded SafeTensors model (multiple .safetensors files + index.json)
32    Sharded {
33        org: String,
34        repo: String,
35        shard_files: Vec<String>,
36    },
37}
38
39/// GH-213: Manifest recording checksums for each file in a sharded download.
40///
41/// Written to `.apr-manifest.json` in the cache directory after a successful download.
42/// Used by the pre-inference contract gate to verify shard integrity without re-hashing.
43#[derive(Debug, Serialize, Deserialize)]
44pub struct ShardManifest {
45    pub version: u32,
46    pub repo: String,
47    pub files: HashMap<String, FileChecksum>,
48}
49
50/// GH-213: Size and BLAKE3 hash of a downloaded file.
51#[derive(Debug, Serialize, Deserialize)]
52pub struct FileChecksum {
53    pub size: u64,
54    pub blake3: String,
55}
56
57/// Run the pull command
58pub fn run(model_ref: &str, force: bool) -> Result<()> {
59    println!("{}", "=== APR Pull ===".cyan().bold());
60    println!();
61
62    // GH-213: Resolve HuggingFace URI — detect single vs sharded models
63    let resolved = resolve_hf_model(model_ref)?;
64
65    match resolved {
66        ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
67        ResolvedModel::Sharded {
68            ref org,
69            ref repo,
70            ref shard_files,
71        } => run_sharded(org, repo, shard_files, force),
72    }
73}
74
75/// Pull a single-file model via pacha (existing behavior)
76fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
77    println!("Model: {}", model_ref.cyan());
78
79    let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
80        CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
81    })?;
82
83    if !force && fetcher.is_cached(model_ref) {
84        return handle_cached_model(&mut fetcher, model_ref);
85    }
86
87    let result = download_single_model(&mut fetcher, model_ref)?;
88    ensure_safetensors_companions(&result)?;
89    print_pull_usage(&result.path, true);
90    Ok(())
91}
92
93/// Handle a model that is already cached in pacha.
94fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
95    println!("{} Model already cached", "✓".green());
96    let result = fetcher
97        .pull_quiet(model_ref)
98        .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
99
100    println!("  Path: {}", result.path.display());
101    println!("  Size: {}", result.size_human());
102    println!("  Format: {:?}", result.format);
103
104    ensure_safetensors_companions(&result)?;
105    print_pull_usage(&result.path, false);
106    Ok(())
107}
108
109/// Download a single model with progress bar.
110fn download_single_model(
111    fetcher: &mut ModelFetcher,
112    model_ref: &str,
113) -> Result<pacha::fetcher::FetchResult> {
114    println!();
115    println!("{}", "Downloading...".yellow());
116
117    let result = fetcher
118        .pull(model_ref, |progress| {
119            let pct = progress.percent();
120            print!(
121                "\r  [{:50}] {:5.1}% ({}/{})",
122                "=".repeat((pct / 2.0) as usize),
123                pct,
124                format_bytes(progress.downloaded_bytes),
125                format_bytes(progress.total_bytes)
126            );
127            io::stdout().flush().ok();
128        })
129        .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
130
131    println!();
132    println!();
133
134    if result.cache_hit {
135        println!("{} Model retrieved from cache", "✓".green());
136    } else {
137        println!("{} Downloaded successfully", "✓".green());
138    }
139
140    println!("  Path: {}", result.path.display().to_string().green());
141    println!("  Size: {}", result.size_human().yellow());
142    println!("  Format: {:?}", result.format);
143    println!("  Hash: {}", &result.hash[..16]);
144    Ok(result)
145}
146
147/// Ensure companion files exist for SafeTensors models (GH-198, GH-211).
148fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
149    if matches!(result.format, ModelFormat::SafeTensors(_)) {
150        fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
151        convert_safetensors_formats(&result.path)?;
152    }
153    Ok(())
154}
155
156/// Print usage instructions after a successful pull.
157fn print_pull_usage(path: &Path, show_serve: bool) {
158    println!();
159    println!("{}", "Usage:".cyan().bold());
160    println!("  apr run {}", path.display());
161    if show_serve {
162        println!("  apr serve {}", path.display());
163    }
164}
165
166/// GH-213: Pull a sharded SafeTensors model (3B+ models with multiple shard files)
167fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
168    println!(
169        "Model: {}/{} ({} shards)",
170        org.cyan(),
171        repo.cyan(),
172        shard_files.len().to_string().yellow()
173    );
174
175    let cache_dir = resolve_shard_cache_dir(org, repo)?;
176    std::fs::create_dir_all(&cache_dir)?;
177
178    let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
179    let index_path = cache_dir.join("model.safetensors.index.json");
180
181    download_index_if_needed(&base_url, &index_path, force)?;
182
183    let manifest_path = cache_dir.join(".apr-manifest.json");
184    let existing_manifest = load_existing_manifest(&manifest_path, force);
185
186    let file_checksums = download_all_shards(
187        &cache_dir,
188        &base_url,
189        shard_files,
190        force,
191        existing_manifest.as_ref(),
192    )?;
193
194    download_companion_files(&cache_dir, &base_url, force)?;
195    write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
196
197    println!();
198    println!("{} Downloaded successfully", "✓".green());
199    println!("  Path: {}", index_path.display().to_string().green());
200    println!("  Shards: {}", shard_files.len().to_string().yellow());
201
202    convert_safetensors_formats(&index_path)?;
203
204    println!();
205    println!("{}", "Usage:".cyan().bold());
206    println!("  apr run {}", index_path.display());
207    println!("  apr serve {}", index_path.display());
208    Ok(())
209}
210
211/// Resolve the cache directory for a sharded model.
212fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
213    Ok(dirs::home_dir()
214        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
215        .join(".apr")
216        .join("cache")
217        .join("hf")
218        .join(org)
219        .join(repo))
220}
221
222/// Download the SafeTensors index.json if not already cached.
223fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
224    if force || !index_path.exists() {
225        println!();
226        println!("  {} model.safetensors.index.json", "Downloading".yellow());
227        download_file(
228            &format!("{base_url}/model.safetensors.index.json"),
229            index_path,
230        )?;
231    } else {
232        println!("  {} model.safetensors.index.json (cached)", "✓".green());
233    }
234    Ok(())
235}
236
237/// Load existing shard manifest for cache-hit verification (GH-213).
238fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
239    if force || !manifest_path.exists() {
240        return None;
241    }
242    std::fs::read_to_string(manifest_path)
243        .ok()
244        .and_then(|s| serde_json::from_str(&s).ok())
245}
246
247/// Download all shards, collecting checksums for the manifest.
248fn download_all_shards(
249    cache_dir: &Path,
250    base_url: &str,
251    shard_files: &[String],
252    force: bool,
253    existing_manifest: Option<&ShardManifest>,
254) -> Result<HashMap<String, FileChecksum>> {
255    let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
256    let total = shard_files.len();
257    for (i, shard_file) in shard_files.iter().enumerate() {
258        download_or_verify_shard(
259            cache_dir,
260            base_url,
261            shard_file,
262            i,
263            total,
264            force,
265            existing_manifest,
266            &mut file_checksums,
267        )?;
268    }
269    Ok(file_checksums)
270}
271
272/// Download or verify a single shard file, updating the checksum map.
273fn download_or_verify_shard(
274    cache_dir: &Path,
275    base_url: &str,
276    shard_file: &str,
277    index: usize,
278    total: usize,
279    force: bool,
280    existing_manifest: Option<&ShardManifest>,
281    checksums: &mut HashMap<String, FileChecksum>,
282) -> Result<()> {
283    let shard_path = cache_dir.join(shard_file);
284
285    if !force && shard_path.exists() {
286        if let Some(manifest) = existing_manifest {
287            if let Some(expected) = manifest.files.get(shard_file) {
288                let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
289                if actual_size == expected.size {
290                    checksums.insert(
291                        shard_file.to_string(),
292                        FileChecksum {
293                            size: expected.size,
294                            blake3: expected.blake3.clone(),
295                        },
296                    );
297                    println!(
298                        "  {} [{}/{}] {} (cached, verified)",
299                        "✓".green(),
300                        index + 1,
301                        total,
302                        shard_file
303                    );
304                    return Ok(());
305                }
306                println!(
307                    "  {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
308                    "⚠".yellow(),
309                    index + 1,
310                    total,
311                    shard_file,
312                    actual_size,
313                    expected.size
314                );
315                // Fall through to re-download
316            }
317        } else {
318            println!(
319                "  {} [{}/{}] {} (cached)",
320                "✓".green(),
321                index + 1,
322                total,
323                shard_file
324            );
325            return Ok(());
326        }
327    }
328
329    let shard_url = format!("{base_url}/{shard_file}");
330    print!(
331        "  {} [{}/{}] {}...",
332        "↓".yellow(),
333        index + 1,
334        total,
335        shard_file
336    );
337    io::stdout().flush().ok();
338
339    let checksum = download_file_with_progress(&shard_url, &shard_path)?;
340    checksums.insert(shard_file.to_string(), checksum);
341    println!(" {}", "done".green());
342    Ok(())
343}
344
345/// Download companion files (tokenizer.json, config.json, tokenizer_config.json) for sharded models.
346fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
347    let companions = [
348        ("tokenizer.json", true),
349        ("config.json", true),
350        ("tokenizer_config.json", false),
351    ];
352    for (filename, required) in &companions {
353        let companion_path = cache_dir.join(filename);
354        if !force && companion_path.exists() {
355            println!("  {} {} (cached)", "✓".green(), filename);
356            continue;
357        }
358
359        let url = format!("{base_url}/{filename}");
360        match download_file(&url, &companion_path) {
361            Ok(()) => println!("  {} {}", "✓".green(), filename),
362            Err(e) if *required => {
363                return Err(CliError::ValidationFailed(format!(
364                    "{filename} is required for inference but download failed: {e}"
365                )));
366            }
367            Err(_) => println!("  {} {} (not available, optional)", "⚠".yellow(), filename),
368        }
369    }
370    Ok(())
371}
372
373/// Write shard manifest with BLAKE3 checksums for integrity verification.
374fn write_shard_manifest(
375    manifest_path: &Path,
376    org: &str,
377    repo: &str,
378    file_checksums: HashMap<String, FileChecksum>,
379) -> Result<()> {
380    if file_checksums.is_empty() {
381        return Ok(());
382    }
383    let manifest = ShardManifest {
384        version: 1,
385        repo: format!("{org}/{repo}"),
386        files: file_checksums,
387    };
388    let manifest_json = serde_json::to_string_pretty(&manifest)
389        .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
390    std::fs::write(manifest_path, manifest_json)?;
391    println!("  {} .apr-manifest.json (integrity checksums)", "✓".green());
392    Ok(())
393}
394
395/// List cached models
396// serde_json::json!() macro uses infallible unwrap internally
397#[allow(clippy::disallowed_methods)]
398pub fn list(json: bool) -> Result<()> {
399    let fetcher = ModelFetcher::new().map_err(|e| {
400        CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
401    })?;
402
403    let models = fetcher.list();
404
405    // GH-248: JSON output mode
406    if json {
407        let models_json: Vec<serde_json::Value> = models
408            .iter()
409            .map(|m| {
410                serde_json::json!({
411                    "name": m.name,
412                    "size_bytes": m.size_bytes,
413                    "format": format!("{:?}", m.format),
414                    "path": m.path.display().to_string(),
415                })
416            })
417            .collect();
418        let stats = fetcher.stats();
419        let output = serde_json::json!({
420            "models": models_json,
421            "total": models.len(),
422            "total_size_bytes": stats.total_size_bytes,
423        });
424        println!(
425            "{}",
426            serde_json::to_string_pretty(&output).unwrap_or_default()
427        );
428        return Ok(());
429    }
430
431    println!("{}", "=== Cached Models ===".cyan().bold());
432    println!();
433
434    if models.is_empty() {
435        println!("{}", "No cached models found.".dimmed());
436        println!();
437        println!("Pull a model with:");
438        println!("  apr pull hf://Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf");
439        println!();
440        println!("Or run directly (auto-downloads):");
441        println!("  apr run hf://Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf");
442        return Ok(());
443    }
444
445    // Print header
446    println!(
447        "{:<40} {:<12} {:<10} {}",
448        "NAME".dimmed(),
449        "SIZE".dimmed(),
450        "FORMAT".dimmed(),
451        "PATH".dimmed()
452    );
453    println!("{}", "-".repeat(100).dimmed());
454
455    for model in &models {
456        let size = format_bytes(model.size_bytes);
457        let format = format!("{:?}", model.format);
458        let name = if model.name.len() > 38 {
459            format!("{}...", &model.name[..35])
460        } else {
461            model.name.clone()
462        };
463
464        println!(
465            "{:<40} {:<12} {:<10} {}",
466            name.cyan(),
467            size.yellow(),
468            format,
469            model.path.display().to_string().dimmed()
470        );
471    }
472
473    println!();
474
475    // Print stats
476    let stats = fetcher.stats();
477    println!(
478        "Total: {} models, {} used",
479        models.len(),
480        format_bytes(stats.total_size_bytes)
481    );
482
483    Ok(())
484}
485
486include!("pull_remove_resolve_model.rs");
487include!("pull_extract_shard.rs");
488include!("pull_04.rs");