Skip to main content

apr_cli/commands/
pull.rs

1//! Pull command: download and cache models from HuggingFace (`~/.cache/pacha/models/`).
2
3use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12/// Result of resolving a HuggingFace model reference.
13///
14/// Single-file models (small SafeTensors, GGUF) use the pacha fetcher.
15/// Sharded models (3B+ SafeTensors) are downloaded directly to `~/.apr/cache/hf/`.
16#[derive(Debug)]
17enum ResolvedModel {
18    /// Single file downloadable via pacha (existing behavior)
19    SingleFile(String),
20    /// Sharded SafeTensors model (multiple .safetensors files + index.json)
21    Sharded {
22        org: String,
23        repo: String,
24        shard_files: Vec<String>,
25    },
26}
27
28/// GH-213: Manifest recording checksums for each file in a sharded download.
29///
30/// Written to `.apr-manifest.json` in the cache directory after a successful download.
31/// Used by the pre-inference contract gate to verify shard integrity without re-hashing.
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34    pub version: u32,
35    pub repo: String,
36    pub files: HashMap<String, FileChecksum>,
37}
38
39/// GH-213: Size and BLAKE3 hash of a downloaded file.
40#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42    pub size: u64,
43    pub blake3: String,
44}
45
46/// Run the pull command
47#[provable_contracts_macros::contract(
48    "apr-cli-operations-v1",
49    equation = "mutating_output_contract"
50)]
51pub fn run(model_ref: &str, force: bool) -> Result<()> {
52    contract_pre_pull_cache_integrity!();
53    println!("{}", "=== APR Pull ===".cyan().bold());
54    println!();
55
56    // GH-213: Resolve HuggingFace URI — detect single vs sharded models
57    let resolved = resolve_hf_model(model_ref)?;
58
59    let result = match resolved {
60        ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
61        ResolvedModel::Sharded {
62            ref org,
63            ref repo,
64            ref shard_files,
65        } => run_sharded(org, repo, shard_files, force),
66    };
67    if let Ok(ref r) = result {
68        contract_post_pull_cache_integrity!(r);
69    }
70    result
71}
72
73/// Pull a single-file model.
74///
75/// GH-352: For HuggingFace URIs, streams directly to disk instead of buffering
76/// the entire file in memory through pacha's resolver. For non-HF URIs (pacha
77/// aliases), falls back to the pacha fetcher.
78fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
79    println!("Model: {}", model_ref.cyan());
80
81    // GH-352: HuggingFace URIs bypass pacha to avoid O(model_size) memory buffering
82    if model_ref.starts_with("hf://") {
83        return run_single_file_streaming(model_ref, force);
84    }
85
86    let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
87        CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
88    })?;
89
90    if !force && fetcher.is_cached(model_ref) {
91        return handle_cached_model(&mut fetcher, model_ref);
92    }
93
94    let result = download_single_model(&mut fetcher, model_ref)?;
95    ensure_safetensors_companions(&result)?;
96    print_pull_usage(&result.path, true);
97    Ok(())
98}
99
100/// GH-352: Stream a single HuggingFace file directly to disk.
101///
102/// Uses O(64KB) memory instead of O(model_size). The pacha fetcher's
103/// `resolver.resolve()` buffers the entire response via `response.bytes()`,
104/// which consumed ~4.5 GB for a 7B GGUF. This function streams with a 64KB
105/// chunked read, computes BLAKE3 incrementally, and saves to the pacha cache.
106fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
107    let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
108    let parts: Vec<&str> = path.split('/').collect();
109    if parts.len() < 3 {
110        return Err(CliError::ValidationFailed(format!(
111            "HuggingFace URI must include a filename: {model_ref}"
112        )));
113    }
114
115    let filename = parts[2..].join("/");
116    let url = format!(
117        "https://huggingface.co/{}/{}/resolve/main/{}",
118        parts[0], parts[1], filename
119    );
120
121    // Determine cache path in pacha cache dir
122    let cache_dir = get_pacha_cache_dir()?;
123    std::fs::create_dir_all(&cache_dir)?;
124
125    // Check if already cached (by URI hash)
126    let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
127    let extension = std::path::Path::new(&filename)
128        .extension()
129        .and_then(|e| e.to_str())
130        .unwrap_or("bin");
131    let cache_filename = format!("{}.{}", &uri_hash[..16], extension);
132    let cache_path = cache_dir.join(&cache_filename);
133
134    if !force && cache_path.exists() {
135        let metadata = std::fs::metadata(&cache_path)?;
136        println!("{} Model already cached", "✓".green());
137        println!("  Path: {}", cache_path.display());
138        println!("  Size: {}", format_bytes(metadata.len()));
139        print_pull_usage(&cache_path, true);
140        return Ok(());
141    }
142
143    println!();
144    println!("{}", "Downloading (streaming)...".yellow());
145
146    let checksum = download_file_with_progress(&url, &cache_path)?;
147
148    println!();
149    println!("{} Downloaded successfully", "✓".green());
150    println!("  Path: {}", cache_path.display().to_string().green());
151    println!("  Size: {}", format_bytes(checksum.size).yellow());
152    println!("  Hash: {}", &checksum.blake3[..16]);
153
154    // Handle SafeTensors companions
155    if extension == "safetensors" {
156        fetch_safetensors_companions(&cache_path, model_ref)?;
157        convert_safetensors_formats(&cache_path)?;
158    }
159
160    print_pull_usage(&cache_path, true);
161    Ok(())
162}
163
164/// Get the pacha model cache directory.
165fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
166    if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
167        return Ok(std::path::PathBuf::from(cache_home)
168            .join("pacha")
169            .join("models"));
170    }
171    Ok(dirs::home_dir()
172        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
173        .join(".cache")
174        .join("pacha")
175        .join("models"))
176}
177
178/// Handle a model that is already cached in pacha.
179fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
180    println!("{} Model already cached", "✓".green());
181    let result = fetcher
182        .pull_quiet(model_ref)
183        .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
184
185    println!("  Path: {}", result.path.display());
186    println!("  Size: {}", result.size_human());
187    println!("  Format: {}", result.format.name());
188
189    ensure_safetensors_companions(&result)?;
190    print_pull_usage(&result.path, false);
191    Ok(())
192}
193
194/// Download a single model with progress bar.
195fn download_single_model(
196    fetcher: &mut ModelFetcher,
197    model_ref: &str,
198) -> Result<pacha::fetcher::FetchResult> {
199    println!();
200    println!("{}", "Downloading...".yellow());
201
202    let result = fetcher
203        .pull(model_ref, |progress| {
204            let pct = progress.percent();
205            print!(
206                "\r  [{:50}] {:5.1}% ({}/{})",
207                "=".repeat((pct / 2.0) as usize),
208                pct,
209                format_bytes(progress.downloaded_bytes),
210                format_bytes(progress.total_bytes)
211            );
212            io::stdout().flush().ok();
213        })
214        .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
215
216    println!();
217    println!();
218
219    if result.cache_hit {
220        println!("{} Model retrieved from cache", "✓".green());
221    } else {
222        println!("{} Downloaded successfully", "✓".green());
223    }
224
225    println!("  Path: {}", result.path.display().to_string().green());
226    println!("  Size: {}", result.size_human().yellow());
227    println!("  Format: {}", result.format.name());
228    println!("  Hash: {}", &result.hash[..16]);
229    Ok(result)
230}
231
232/// Ensure companion files exist for SafeTensors models (GH-198, GH-211).
233fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
234    if matches!(result.format, ModelFormat::SafeTensors(_)) {
235        fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
236        convert_safetensors_formats(&result.path)?;
237    }
238    Ok(())
239}
240
241/// Print usage instructions after a successful pull.
242fn print_pull_usage(path: &Path, show_serve: bool) {
243    println!();
244    println!("{}", "Usage:".cyan().bold());
245    println!("  apr run {}", path.display());
246    if show_serve {
247        println!("  apr serve {}", path.display());
248    }
249}
250
251/// GH-213: Pull a sharded SafeTensors model (3B+ models with multiple shard files)
252fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
253    println!(
254        "Model: {}/{} ({} shards)",
255        org.cyan(),
256        repo.cyan(),
257        shard_files.len().to_string().yellow()
258    );
259
260    let cache_dir = resolve_shard_cache_dir(org, repo)?;
261    std::fs::create_dir_all(&cache_dir)?;
262
263    let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
264    let index_path = cache_dir.join("model.safetensors.index.json");
265
266    download_index_if_needed(&base_url, &index_path, force)?;
267
268    let manifest_path = cache_dir.join(".apr-manifest.json");
269    let existing_manifest = load_existing_manifest(&manifest_path, force);
270
271    let file_checksums = download_all_shards(
272        &cache_dir,
273        &base_url,
274        shard_files,
275        force,
276        existing_manifest.as_ref(),
277    )?;
278
279    download_companion_files(&cache_dir, &base_url, force)?;
280    write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
281
282    println!();
283    println!("{} Downloaded successfully", "✓".green());
284    println!("  Path: {}", index_path.display().to_string().green());
285    println!("  Shards: {}", shard_files.len().to_string().yellow());
286
287    convert_safetensors_formats(&index_path)?;
288
289    println!();
290    println!("{}", "Usage:".cyan().bold());
291    println!("  apr run {}", index_path.display());
292    println!("  apr serve {}", index_path.display());
293    Ok(())
294}
295
296/// Resolve the cache directory for a sharded model.
297fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
298    Ok(dirs::home_dir()
299        .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
300        .join(".apr")
301        .join("cache")
302        .join("hf")
303        .join(org)
304        .join(repo))
305}
306
307/// Download the SafeTensors index.json if not already cached.
308fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
309    if force || !index_path.exists() {
310        println!();
311        println!("  {} model.safetensors.index.json", "Downloading".yellow());
312        download_file(
313            &format!("{base_url}/model.safetensors.index.json"),
314            index_path,
315        )?;
316    } else {
317        println!("  {} model.safetensors.index.json (cached)", "✓".green());
318    }
319    Ok(())
320}
321
322/// Load existing shard manifest for cache-hit verification (GH-213).
323fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
324    if force || !manifest_path.exists() {
325        return None;
326    }
327    std::fs::read_to_string(manifest_path)
328        .ok()
329        .and_then(|s| serde_json::from_str(&s).ok())
330}
331
332/// Download all shards, collecting checksums for the manifest.
333fn download_all_shards(
334    cache_dir: &Path,
335    base_url: &str,
336    shard_files: &[String],
337    force: bool,
338    existing_manifest: Option<&ShardManifest>,
339) -> Result<HashMap<String, FileChecksum>> {
340    let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
341    let total = shard_files.len();
342    for (i, shard_file) in shard_files.iter().enumerate() {
343        download_or_verify_shard(
344            cache_dir,
345            base_url,
346            shard_file,
347            i,
348            total,
349            force,
350            existing_manifest,
351            &mut file_checksums,
352        )?;
353    }
354    Ok(file_checksums)
355}
356
357/// Download or verify a single shard file, updating the checksum map.
358fn download_or_verify_shard(
359    cache_dir: &Path,
360    base_url: &str,
361    shard_file: &str,
362    index: usize,
363    total: usize,
364    force: bool,
365    existing_manifest: Option<&ShardManifest>,
366    checksums: &mut HashMap<String, FileChecksum>,
367) -> Result<()> {
368    let shard_path = cache_dir.join(shard_file);
369
370    if !force && shard_path.exists() {
371        if let Some(manifest) = existing_manifest {
372            if let Some(expected) = manifest.files.get(shard_file) {
373                let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
374                if actual_size == expected.size {
375                    checksums.insert(
376                        shard_file.to_string(),
377                        FileChecksum {
378                            size: expected.size,
379                            blake3: expected.blake3.clone(),
380                        },
381                    );
382                    println!(
383                        "  {} [{}/{}] {} (cached, verified)",
384                        "✓".green(),
385                        index + 1,
386                        total,
387                        shard_file
388                    );
389                    return Ok(());
390                }
391                println!(
392                    "  {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
393                    "⚠".yellow(),
394                    index + 1,
395                    total,
396                    shard_file,
397                    actual_size,
398                    expected.size
399                );
400                // Fall through to re-download
401            }
402        } else {
403            println!(
404                "  {} [{}/{}] {} (cached)",
405                "✓".green(),
406                index + 1,
407                total,
408                shard_file
409            );
410            return Ok(());
411        }
412    }
413
414    let shard_url = format!("{base_url}/{shard_file}");
415    print!(
416        "  {} [{}/{}] {}...",
417        "↓".yellow(),
418        index + 1,
419        total,
420        shard_file
421    );
422    io::stdout().flush().ok();
423
424    let checksum = download_file_with_progress(&shard_url, &shard_path)?;
425    checksums.insert(shard_file.to_string(), checksum);
426    println!(" {}", "done".green());
427    Ok(())
428}
429
430/// Download companion files (tokenizer, config) for sharded models.
431///
432/// GH-356: tokenizer.json is optional — some models only have tokenizer.model (SentencePiece)
433/// or tokenizer_config.json. We validate that at least ONE tokenizer file was obtained.
434fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
435    // (filename, is_required) — tokenizer files are individually optional but collectively required
436    let companions = [
437        ("tokenizer.json", false),
438        ("config.json", true),
439        ("tokenizer_config.json", false),
440        ("tokenizer.model", false),
441    ];
442    for (filename, required) in &companions {
443        let companion_path = cache_dir.join(filename);
444        if !force && companion_path.exists() {
445            println!("  {} {} (cached)", "✓".green(), filename);
446            continue;
447        }
448
449        let url = format!("{base_url}/{filename}");
450        match download_file(&url, &companion_path) {
451            Ok(()) => println!("  {} {}", "✓".green(), filename),
452            Err(CliError::HttpNotFound(_)) if *required => {
453                return Err(CliError::ValidationFailed(format!(
454                    "{filename} is required for inference but was not found (HTTP 404) at {url}"
455                )));
456            }
457            Err(CliError::HttpNotFound(_)) => {
458                println!("  {} {} (not found in repo)", "⚠".yellow(), filename);
459            }
460            Err(e) if *required => {
461                return Err(CliError::ValidationFailed(format!(
462                    "{filename} is required for inference but download failed: {e}"
463                )));
464            }
465            Err(_) => println!("  {} {} (not available, optional)", "⚠".yellow(), filename),
466        }
467    }
468
469    // GH-356: Validate at least one tokenizer file exists
470    let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
471    let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
472    if !has_tokenizer {
473        return Err(CliError::ValidationFailed(format!(
474            "No tokenizer found for this model. Tried: {}.\n\
475             The model may require a custom tokenizer not hosted in the repository.",
476            tokenizer_files.join(", ")
477        )));
478    }
479
480    Ok(())
481}
482
483/// Write shard manifest with BLAKE3 checksums for integrity verification.
484fn write_shard_manifest(
485    manifest_path: &Path,
486    org: &str,
487    repo: &str,
488    file_checksums: HashMap<String, FileChecksum>,
489) -> Result<()> {
490    if file_checksums.is_empty() {
491        return Ok(());
492    }
493    let manifest = ShardManifest {
494        version: 1,
495        repo: format!("{org}/{repo}"),
496        files: file_checksums,
497    };
498    let manifest_json = serde_json::to_string_pretty(&manifest)
499        .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
500    std::fs::write(manifest_path, manifest_json)?;
501    println!("  {} .apr-manifest.json (integrity checksums)", "✓".green());
502    Ok(())
503}
504
505include!("pull_list.rs");
506include!("pull_remove_resolve_model.rs");
507include!("pull_extract_shard.rs");
508include!("pull_04.rs");