Skip to main content

apr_cli/commands/
pull.rs

1//! Pull command: download and cache models from HuggingFace (`~/.cache/pacha/models/`).
2
3use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12/// Result of resolving a HuggingFace model reference.
13///
14/// Single-file models (small SafeTensors, GGUF) use the pacha fetcher.
15/// Sharded models (3B+ SafeTensors) are downloaded directly to `~/.apr/cache/hf/`.
16#[derive(Debug)]
17enum ResolvedModel {
18    /// Single file downloadable via pacha (existing behavior)
19    SingleFile(String),
20    /// Sharded SafeTensors model (multiple .safetensors files + index.json)
21    Sharded {
22        org: String,
23        repo: String,
24        shard_files: Vec<String>,
25    },
26}
27
28/// GH-213: Manifest recording checksums for each file in a sharded download.
29///
30/// Written to `.apr-manifest.json` in the cache directory after a successful download.
31/// Used by the pre-inference contract gate to verify shard integrity without re-hashing.
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34    pub version: u32,
35    pub repo: String,
36    pub files: HashMap<String, FileChecksum>,
37}
38
39/// GH-213: Size and BLAKE3 hash of a downloaded file.
40#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42    pub size: u64,
43    pub blake3: String,
44}
45
46/// Run the pull command
47pub fn run(model_ref: &str, force: bool) -> Result<()> {
48    println!("{}", "=== APR Pull ===".cyan().bold());
49    println!();
50
51    // GH-213: Resolve HuggingFace URI — detect single vs sharded models
52    let resolved = resolve_hf_model(model_ref)?;
53
54    match resolved {
55        ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
56        ResolvedModel::Sharded {
57            ref org,
58            ref repo,
59            ref shard_files,
60        } => run_sharded(org, repo, shard_files, force),
61    }
62}
63
64/// Pull a single-file model.
65///
66/// GH-352: For HuggingFace URIs, streams directly to disk instead of buffering
67/// the entire file in memory through pacha's resolver. For non-HF URIs (pacha
68/// aliases), falls back to the pacha fetcher.
69fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
70    println!("Model: {}", model_ref.cyan());
71
72    // GH-352: HuggingFace URIs bypass pacha to avoid O(model_size) memory buffering
73    if model_ref.starts_with("hf://") {
74        return run_single_file_streaming(model_ref, force);
75    }
76
77    let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
78        CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
79    })?;
80
81    if !force && fetcher.is_cached(model_ref) {
82        return handle_cached_model(&mut fetcher, model_ref);
83    }
84
85    let result = download_single_model(&mut fetcher, model_ref)?;
86    ensure_safetensors_companions(&result)?;
87    print_pull_usage(&result.path, true);
88    Ok(())
89}
90
91/// GH-352: Stream a single HuggingFace file directly to disk.
92///
93/// Uses O(64KB) memory instead of O(model_size). The pacha fetcher's
94/// `resolver.resolve()` buffers the entire response via `response.bytes()`,
95/// which consumed ~4.5 GB for a 7B GGUF. This function streams with a 64KB
96/// chunked read, computes BLAKE3 incrementally, and saves to the pacha cache.
97fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
98    let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
99    let parts: Vec<&str> = path.split('/').collect();
100    if parts.len() < 3 {
101        return Err(CliError::ValidationFailed(format!(
102            "HuggingFace URI must include a filename: {model_ref}"
103        )));
104    }
105
106    let filename = parts[2..].join("/");
107    let url = format!(
108        "https://huggingface.co/{}/{}/resolve/main/{}",
109        parts[0], parts[1], filename
110    );
111
112    // Determine cache path in pacha cache dir
113    let cache_dir = get_pacha_cache_dir()?;
114    std::fs::create_dir_all(&cache_dir)?;
115
116    // Check if already cached (by URI hash)
117    let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
118    let extension = std::path::Path::new(&filename)
119        .extension()
120        .and_then(|e| e.to_str())
121        .unwrap_or("bin");
122    let cache_filename = format!("{}.{}", &uri_hash[..16], extension);
123    let cache_path = cache_dir.join(&cache_filename);
124
125    if !force && cache_path.exists() {
126        let metadata = std::fs::metadata(&cache_path)?;
127        println!("{} Model already cached", "✓".green());
128        println!("  Path: {}", cache_path.display());
129        println!("  Size: {}", format_bytes(metadata.len()));
130        print_pull_usage(&cache_path, true);
131        return Ok(());
132    }
133
134    println!();
135    println!("{}", "Downloading (streaming)...".yellow());
136
137    let checksum = download_file_with_progress(&url, &cache_path)?;
138
139    println!();
140    println!("{} Downloaded successfully", "✓".green());
141    println!("  Path: {}", cache_path.display().to_string().green());
142    println!("  Size: {}", format_bytes(checksum.size).yellow());
143    println!("  Hash: {}", &checksum.blake3[..16]);
144
145    // Handle SafeTensors companions
146    if extension == "safetensors" {
147        fetch_safetensors_companions(&cache_path, model_ref)?;
148        convert_safetensors_formats(&cache_path)?;
149    }
150
151    print_pull_usage(&cache_path, true);
152    Ok(())
153}
154
155/// Get the pacha model cache directory.
156fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
157    if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
158        return Ok(std::path::PathBuf::from(cache_home)
159            .join("pacha")
160            .join("models"));
161    }
162    Ok(dirs::home_dir()
163        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
164        .join(".cache")
165        .join("pacha")
166        .join("models"))
167}
168
169/// Handle a model that is already cached in pacha.
170fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
171    println!("{} Model already cached", "✓".green());
172    let result = fetcher
173        .pull_quiet(model_ref)
174        .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
175
176    println!("  Path: {}", result.path.display());
177    println!("  Size: {}", result.size_human());
178    println!("  Format: {}", result.format.name());
179
180    ensure_safetensors_companions(&result)?;
181    print_pull_usage(&result.path, false);
182    Ok(())
183}
184
185/// Download a single model with progress bar.
186fn download_single_model(
187    fetcher: &mut ModelFetcher,
188    model_ref: &str,
189) -> Result<pacha::fetcher::FetchResult> {
190    println!();
191    println!("{}", "Downloading...".yellow());
192
193    let result = fetcher
194        .pull(model_ref, |progress| {
195            let pct = progress.percent();
196            print!(
197                "\r  [{:50}] {:5.1}% ({}/{})",
198                "=".repeat((pct / 2.0) as usize),
199                pct,
200                format_bytes(progress.downloaded_bytes),
201                format_bytes(progress.total_bytes)
202            );
203            io::stdout().flush().ok();
204        })
205        .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
206
207    println!();
208    println!();
209
210    if result.cache_hit {
211        println!("{} Model retrieved from cache", "✓".green());
212    } else {
213        println!("{} Downloaded successfully", "✓".green());
214    }
215
216    println!("  Path: {}", result.path.display().to_string().green());
217    println!("  Size: {}", result.size_human().yellow());
218    println!("  Format: {}", result.format.name());
219    println!("  Hash: {}", &result.hash[..16]);
220    Ok(result)
221}
222
223/// Ensure companion files exist for SafeTensors models (GH-198, GH-211).
224fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
225    if matches!(result.format, ModelFormat::SafeTensors(_)) {
226        fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
227        convert_safetensors_formats(&result.path)?;
228    }
229    Ok(())
230}
231
232/// Print usage instructions after a successful pull.
233fn print_pull_usage(path: &Path, show_serve: bool) {
234    println!();
235    println!("{}", "Usage:".cyan().bold());
236    println!("  apr run {}", path.display());
237    if show_serve {
238        println!("  apr serve {}", path.display());
239    }
240}
241
242/// GH-213: Pull a sharded SafeTensors model (3B+ models with multiple shard files)
243fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
244    println!(
245        "Model: {}/{} ({} shards)",
246        org.cyan(),
247        repo.cyan(),
248        shard_files.len().to_string().yellow()
249    );
250
251    let cache_dir = resolve_shard_cache_dir(org, repo)?;
252    std::fs::create_dir_all(&cache_dir)?;
253
254    let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
255    let index_path = cache_dir.join("model.safetensors.index.json");
256
257    download_index_if_needed(&base_url, &index_path, force)?;
258
259    let manifest_path = cache_dir.join(".apr-manifest.json");
260    let existing_manifest = load_existing_manifest(&manifest_path, force);
261
262    let file_checksums = download_all_shards(
263        &cache_dir,
264        &base_url,
265        shard_files,
266        force,
267        existing_manifest.as_ref(),
268    )?;
269
270    download_companion_files(&cache_dir, &base_url, force)?;
271    write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
272
273    println!();
274    println!("{} Downloaded successfully", "✓".green());
275    println!("  Path: {}", index_path.display().to_string().green());
276    println!("  Shards: {}", shard_files.len().to_string().yellow());
277
278    convert_safetensors_formats(&index_path)?;
279
280    println!();
281    println!("{}", "Usage:".cyan().bold());
282    println!("  apr run {}", index_path.display());
283    println!("  apr serve {}", index_path.display());
284    Ok(())
285}
286
287/// Resolve the cache directory for a sharded model.
288fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
289    Ok(dirs::home_dir()
290        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
291        .join(".apr")
292        .join("cache")
293        .join("hf")
294        .join(org)
295        .join(repo))
296}
297
298/// Download the SafeTensors index.json if not already cached.
299fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
300    if force || !index_path.exists() {
301        println!();
302        println!("  {} model.safetensors.index.json", "Downloading".yellow());
303        download_file(
304            &format!("{base_url}/model.safetensors.index.json"),
305            index_path,
306        )?;
307    } else {
308        println!("  {} model.safetensors.index.json (cached)", "✓".green());
309    }
310    Ok(())
311}
312
313/// Load existing shard manifest for cache-hit verification (GH-213).
314fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
315    if force || !manifest_path.exists() {
316        return None;
317    }
318    std::fs::read_to_string(manifest_path)
319        .ok()
320        .and_then(|s| serde_json::from_str(&s).ok())
321}
322
323/// Download all shards, collecting checksums for the manifest.
324fn download_all_shards(
325    cache_dir: &Path,
326    base_url: &str,
327    shard_files: &[String],
328    force: bool,
329    existing_manifest: Option<&ShardManifest>,
330) -> Result<HashMap<String, FileChecksum>> {
331    let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
332    let total = shard_files.len();
333    for (i, shard_file) in shard_files.iter().enumerate() {
334        download_or_verify_shard(
335            cache_dir,
336            base_url,
337            shard_file,
338            i,
339            total,
340            force,
341            existing_manifest,
342            &mut file_checksums,
343        )?;
344    }
345    Ok(file_checksums)
346}
347
348/// Download or verify a single shard file, updating the checksum map.
349fn download_or_verify_shard(
350    cache_dir: &Path,
351    base_url: &str,
352    shard_file: &str,
353    index: usize,
354    total: usize,
355    force: bool,
356    existing_manifest: Option<&ShardManifest>,
357    checksums: &mut HashMap<String, FileChecksum>,
358) -> Result<()> {
359    let shard_path = cache_dir.join(shard_file);
360
361    if !force && shard_path.exists() {
362        if let Some(manifest) = existing_manifest {
363            if let Some(expected) = manifest.files.get(shard_file) {
364                let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
365                if actual_size == expected.size {
366                    checksums.insert(
367                        shard_file.to_string(),
368                        FileChecksum {
369                            size: expected.size,
370                            blake3: expected.blake3.clone(),
371                        },
372                    );
373                    println!(
374                        "  {} [{}/{}] {} (cached, verified)",
375                        "✓".green(),
376                        index + 1,
377                        total,
378                        shard_file
379                    );
380                    return Ok(());
381                }
382                println!(
383                    "  {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
384                    "⚠".yellow(),
385                    index + 1,
386                    total,
387                    shard_file,
388                    actual_size,
389                    expected.size
390                );
391                // Fall through to re-download
392            }
393        } else {
394            println!(
395                "  {} [{}/{}] {} (cached)",
396                "✓".green(),
397                index + 1,
398                total,
399                shard_file
400            );
401            return Ok(());
402        }
403    }
404
405    let shard_url = format!("{base_url}/{shard_file}");
406    print!(
407        "  {} [{}/{}] {}...",
408        "↓".yellow(),
409        index + 1,
410        total,
411        shard_file
412    );
413    io::stdout().flush().ok();
414
415    let checksum = download_file_with_progress(&shard_url, &shard_path)?;
416    checksums.insert(shard_file.to_string(), checksum);
417    println!(" {}", "done".green());
418    Ok(())
419}
420
421/// Download companion files (tokenizer, config) for sharded models.
422///
423/// GH-356: tokenizer.json is optional — some models only have tokenizer.model (SentencePiece)
424/// or tokenizer_config.json. We validate that at least ONE tokenizer file was obtained.
425fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
426    // (filename, is_required) — tokenizer files are individually optional but collectively required
427    let companions = [
428        ("tokenizer.json", false),
429        ("config.json", true),
430        ("tokenizer_config.json", false),
431        ("tokenizer.model", false),
432    ];
433    for (filename, required) in &companions {
434        let companion_path = cache_dir.join(filename);
435        if !force && companion_path.exists() {
436            println!("  {} {} (cached)", "✓".green(), filename);
437            continue;
438        }
439
440        let url = format!("{base_url}/{filename}");
441        match download_file(&url, &companion_path) {
442            Ok(()) => println!("  {} {}", "✓".green(), filename),
443            Err(CliError::HttpNotFound(_)) if *required => {
444                return Err(CliError::ValidationFailed(format!(
445                    "{filename} is required for inference but was not found (HTTP 404) at {url}"
446                )));
447            }
448            Err(CliError::HttpNotFound(_)) => {
449                println!("  {} {} (not found in repo)", "⚠".yellow(), filename);
450            }
451            Err(e) if *required => {
452                return Err(CliError::ValidationFailed(format!(
453                    "{filename} is required for inference but download failed: {e}"
454                )));
455            }
456            Err(_) => println!("  {} {} (not available, optional)", "⚠".yellow(), filename),
457        }
458    }
459
460    // GH-356: Validate at least one tokenizer file exists
461    let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
462    let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
463    if !has_tokenizer {
464        return Err(CliError::ValidationFailed(format!(
465            "No tokenizer found for this model. Tried: {}.\n\
466             The model may require a custom tokenizer not hosted in the repository.",
467            tokenizer_files.join(", ")
468        )));
469    }
470
471    Ok(())
472}
473
474/// Write shard manifest with BLAKE3 checksums for integrity verification.
475fn write_shard_manifest(
476    manifest_path: &Path,
477    org: &str,
478    repo: &str,
479    file_checksums: HashMap<String, FileChecksum>,
480) -> Result<()> {
481    if file_checksums.is_empty() {
482        return Ok(());
483    }
484    let manifest = ShardManifest {
485        version: 1,
486        repo: format!("{org}/{repo}"),
487        files: file_checksums,
488    };
489    let manifest_json = serde_json::to_string_pretty(&manifest)
490        .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
491    std::fs::write(manifest_path, manifest_json)?;
492    println!("  {} .apr-manifest.json (integrity checksums)", "✓".green());
493    Ok(())
494}
495
496include!("pull_list.rs");
497include!("pull_remove_resolve_model.rs");
498include!("pull_extract_shard.rs");
499include!("pull_04.rs");