1use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12#[derive(Debug)]
17pub(crate) enum ResolvedModel {
18 SingleFile(String),
20 Sharded {
22 org: String,
23 repo: String,
24 shard_files: Vec<String>,
25 },
26}
27
28#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34 pub version: u32,
35 pub repo: String,
36 pub files: HashMap<String, FileChecksum>,
37}
38
39#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42 pub size: u64,
43 pub blake3: String,
44}
45
46#[provable_contracts_macros::contract(
48 "apr-cli-operations-v1",
49 equation = "mutating_output_contract"
50)]
51pub fn run(
52 model_ref: &str,
53 force: bool,
54 dry_run: bool,
55 revision: Option<&str>,
56 offline: bool,
57) -> Result<()> {
58 contract_pre_pull_cache_integrity!();
59 println!("{}", "=== APR Pull ===".cyan().bold());
60 println!();
61
62 if dry_run {
69 return run_dry_run(model_ref, revision, offline);
70 }
71
72 let resolved = resolve_hf_model(model_ref)?;
74
75 let result = match resolved {
76 ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
77 ResolvedModel::Sharded {
78 ref org,
79 ref repo,
80 ref shard_files,
81 } => run_sharded(org, repo, shard_files, force),
82 };
83 if let Ok(ref r) = result {
84 contract_post_pull_cache_integrity!(r);
85 }
86 result
87}
88
89fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
95 println!("Model: {}", model_ref.cyan());
96
97 if model_ref.starts_with("hf://") {
99 return run_single_file_streaming(model_ref, force);
100 }
101
102 let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
103 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
104 })?;
105
106 if !force && fetcher.is_cached(model_ref) {
107 return handle_cached_model(&mut fetcher, model_ref);
108 }
109
110 let result = download_single_model(&mut fetcher, model_ref)?;
111 ensure_safetensors_companions(&result)?;
112 print_pull_usage(&result.path, true);
113 Ok(())
114}
115
116fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
123 let (org, repo, filename) = parse_hf_single_uri(model_ref)?;
124 let url = format!("https://huggingface.co/{org}/{repo}/resolve/main/{filename}");
125
126 let cache_dir = get_pacha_cache_dir()?;
127 std::fs::create_dir_all(&cache_dir)?;
128 let (extension, cache_path) = build_single_cache_path(&cache_dir, model_ref, &filename);
129
130 if !force && cache_path.exists() {
131 return report_cached_single(&cache_path);
132 }
133
134 stream_and_post_process(&url, &cache_path, model_ref, &extension)?;
135 print_pull_usage(&cache_path, true);
136 Ok(())
137}
138
139fn stream_and_post_process(
140 url: &str,
141 cache_path: &std::path::Path,
142 model_ref: &str,
143 extension: &str,
144) -> Result<()> {
145 println!();
146 println!("{}", "Downloading (streaming)...".yellow());
147 let checksum = download_file_with_progress(url, cache_path)?;
148 report_downloaded_single(cache_path, &checksum);
149
150 if extension == "safetensors" {
151 fetch_safetensors_companions(cache_path, model_ref)?;
152 convert_safetensors_formats(cache_path)?;
153 }
154 Ok(())
155}
156
157fn parse_hf_single_uri(model_ref: &str) -> Result<(String, String, String)> {
158 let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
159 let parts: Vec<&str> = path.split('/').collect();
160 if parts.len() < 3 {
161 return Err(CliError::ValidationFailed(format!(
162 "HuggingFace URI must include a filename: {model_ref}"
163 )));
164 }
165 Ok((
166 parts[0].to_string(),
167 parts[1].to_string(),
168 parts[2..].join("/"),
169 ))
170}
171
172pub(crate) fn build_single_cache_path(
173 cache_dir: &std::path::Path,
174 model_ref: &str,
175 filename: &str,
176) -> (String, std::path::PathBuf) {
177 let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
178 let extension = std::path::Path::new(filename)
179 .extension()
180 .and_then(|e| e.to_str())
181 .unwrap_or("bin")
182 .to_string();
183 let cache_filename = format!("{}.{extension}", &uri_hash[..16]);
184 let cache_path = cache_dir.join(&cache_filename);
185 (extension, cache_path)
186}
187
188fn report_cached_single(cache_path: &std::path::Path) -> Result<()> {
189 let metadata = std::fs::metadata(cache_path)?;
190 println!("{} Model already cached", "✓".green());
191 println!(" Path: {}", cache_path.display());
192 println!(" Size: {}", format_bytes(metadata.len()));
193 print_pull_usage(cache_path, true);
194 Ok(())
195}
196
197fn report_downloaded_single(cache_path: &std::path::Path, checksum: &FileChecksum) {
198 println!();
199 println!("{} Downloaded successfully", "✓".green());
200 println!(" Path: {}", cache_path.display().to_string().green());
201 println!(" Size: {}", format_bytes(checksum.size).yellow());
202 println!(" Hash: {}", &checksum.blake3[..16]);
203}
204
205pub(crate) fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
207 if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
208 return Ok(std::path::PathBuf::from(cache_home)
209 .join("pacha")
210 .join("models"));
211 }
212 Ok(dirs::home_dir()
213 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
214 .join(".cache")
215 .join("pacha")
216 .join("models"))
217}
218
219fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
221 println!("{} Model already cached", "✓".green());
222 let result = fetcher
223 .pull_quiet(model_ref)
224 .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
225
226 println!(" Path: {}", result.path.display());
227 println!(" Size: {}", result.size_human());
228 println!(" Format: {}", result.format.name());
229
230 ensure_safetensors_companions(&result)?;
231 print_pull_usage(&result.path, false);
232 Ok(())
233}
234
235fn download_single_model(
237 fetcher: &mut ModelFetcher,
238 model_ref: &str,
239) -> Result<pacha::fetcher::FetchResult> {
240 println!();
241 println!("{}", "Downloading...".yellow());
242
243 let result = fetcher
244 .pull(model_ref, |progress| {
245 let pct = progress.percent();
246 print!(
247 "\r [{:50}] {:5.1}% ({}/{})",
248 "=".repeat((pct / 2.0) as usize),
249 pct,
250 format_bytes(progress.downloaded_bytes),
251 format_bytes(progress.total_bytes)
252 );
253 io::stdout().flush().ok();
254 })
255 .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
256
257 println!();
258 println!();
259
260 if result.cache_hit {
261 println!("{} Model retrieved from cache", "✓".green());
262 } else {
263 println!("{} Downloaded successfully", "✓".green());
264 }
265
266 println!(" Path: {}", result.path.display().to_string().green());
267 println!(" Size: {}", result.size_human().yellow());
268 println!(" Format: {}", result.format.name());
269 println!(" Hash: {}", &result.hash[..16]);
270 Ok(result)
271}
272
273fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
275 if matches!(result.format, ModelFormat::SafeTensors(_)) {
276 fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
277 convert_safetensors_formats(&result.path)?;
278 }
279 Ok(())
280}
281
282fn print_pull_usage(path: &Path, show_serve: bool) {
284 println!();
285 println!("{}", "Usage:".cyan().bold());
286 println!(" apr run {}", path.display());
287 if show_serve {
288 println!(" apr serve {}", path.display());
289 }
290}
291
292fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
294 println!(
295 "Model: {}/{} ({} shards)",
296 org.cyan(),
297 repo.cyan(),
298 shard_files.len().to_string().yellow()
299 );
300
301 let cache_dir = resolve_shard_cache_dir(org, repo)?;
302 std::fs::create_dir_all(&cache_dir)?;
303
304 let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
305
306 if shard_files
310 .iter()
311 .all(|f| f.to_lowercase().ends_with(".gguf"))
312 {
313 return run_sharded_gguf(org, repo, &cache_dir, &base_url, shard_files, force);
314 }
315
316 let index_path = cache_dir.join("model.safetensors.index.json");
317
318 download_index_if_needed(&base_url, &index_path, force)?;
319
320 let manifest_path = cache_dir.join(".apr-manifest.json");
321 let existing_manifest = load_existing_manifest(&manifest_path, force);
322
323 let file_checksums = download_all_shards(
324 &cache_dir,
325 &base_url,
326 shard_files,
327 force,
328 existing_manifest.as_ref(),
329 )?;
330
331 download_companion_files(&cache_dir, &base_url, force)?;
332 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
333
334 println!();
335 println!("{} Downloaded successfully", "✓".green());
336 println!(" Path: {}", index_path.display().to_string().green());
337 println!(" Shards: {}", shard_files.len().to_string().yellow());
338
339 convert_safetensors_formats(&index_path)?;
340
341 println!();
342 println!("{}", "Usage:".cyan().bold());
343 println!(" apr run {}", index_path.display());
344 println!(" apr serve {}", index_path.display());
345 Ok(())
346}
347
348fn run_sharded_gguf(
353 org: &str,
354 repo: &str,
355 cache_dir: &Path,
356 base_url: &str,
357 shard_files: &[String],
358 force: bool,
359) -> Result<()> {
360 let manifest_path = cache_dir.join(".apr-manifest.json");
361 let existing_manifest = load_existing_manifest(&manifest_path, force);
362
363 let file_checksums = download_all_shards(
364 cache_dir,
365 base_url,
366 shard_files,
367 force,
368 existing_manifest.as_ref(),
369 )?;
370
371 download_companion_files(cache_dir, base_url, force)?;
372 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
373
374 let first_part = cache_dir.join(shard_files.first().map_or("", String::as_str));
375 println!();
376 println!(
377 "{} Downloaded {} GGUF shards",
378 "✓".green(),
379 shard_files.len().to_string().yellow()
380 );
381 println!(" Path: {}", first_part.display().to_string().green());
382 println!();
383 println!("{}", "Usage:".cyan().bold());
384 println!(" apr run {}", first_part.display());
385 println!(" apr serve {}", first_part.display());
386 Ok(())
387}
388
389fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
391 Ok(dirs::home_dir()
392 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
393 .join(".apr")
394 .join("cache")
395 .join("hf")
396 .join(org)
397 .join(repo))
398}
399
400fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
402 if force || !index_path.exists() {
403 println!();
404 println!(" {} model.safetensors.index.json", "Downloading".yellow());
405 download_file(
406 &format!("{base_url}/model.safetensors.index.json"),
407 index_path,
408 )?;
409 } else {
410 println!(" {} model.safetensors.index.json (cached)", "✓".green());
411 }
412 Ok(())
413}
414
415fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
417 if force || !manifest_path.exists() {
418 return None;
419 }
420 std::fs::read_to_string(manifest_path)
421 .ok()
422 .and_then(|s| serde_json::from_str(&s).ok())
423}
424
425fn download_all_shards(
427 cache_dir: &Path,
428 base_url: &str,
429 shard_files: &[String],
430 force: bool,
431 existing_manifest: Option<&ShardManifest>,
432) -> Result<HashMap<String, FileChecksum>> {
433 let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
434 let total = shard_files.len();
435 for (i, shard_file) in shard_files.iter().enumerate() {
436 download_or_verify_shard(
437 cache_dir,
438 base_url,
439 shard_file,
440 i,
441 total,
442 force,
443 existing_manifest,
444 &mut file_checksums,
445 )?;
446 }
447 Ok(file_checksums)
448}
449
450fn download_or_verify_shard(
452 cache_dir: &Path,
453 base_url: &str,
454 shard_file: &str,
455 index: usize,
456 total: usize,
457 force: bool,
458 existing_manifest: Option<&ShardManifest>,
459 checksums: &mut HashMap<String, FileChecksum>,
460) -> Result<()> {
461 let shard_path = cache_dir.join(shard_file);
462
463 if !force && shard_path.exists() {
464 if let Some(manifest) = existing_manifest {
465 if let Some(expected) = manifest.files.get(shard_file) {
466 let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
467 if actual_size == expected.size {
468 checksums.insert(
469 shard_file.to_string(),
470 FileChecksum {
471 size: expected.size,
472 blake3: expected.blake3.clone(),
473 },
474 );
475 println!(
476 " {} [{}/{}] {} (cached, verified)",
477 "✓".green(),
478 index + 1,
479 total,
480 shard_file
481 );
482 return Ok(());
483 }
484 println!(
485 " {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
486 "⚠".yellow(),
487 index + 1,
488 total,
489 shard_file,
490 actual_size,
491 expected.size
492 );
493 }
495 } else {
496 println!(
497 " {} [{}/{}] {} (cached)",
498 "✓".green(),
499 index + 1,
500 total,
501 shard_file
502 );
503 return Ok(());
504 }
505 }
506
507 let shard_url = format!("{base_url}/{shard_file}");
508 print!(
509 " {} [{}/{}] {}...",
510 "↓".yellow(),
511 index + 1,
512 total,
513 shard_file
514 );
515 io::stdout().flush().ok();
516
517 let checksum = download_file_with_progress(&shard_url, &shard_path)?;
518 checksums.insert(shard_file.to_string(), checksum);
519 println!(" {}", "done".green());
520 Ok(())
521}
522
523fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
528 let companions = [
530 ("tokenizer.json", false),
531 ("config.json", true),
532 ("tokenizer_config.json", false),
533 ("tokenizer.model", false),
534 ];
535 for (filename, required) in &companions {
536 let companion_path = cache_dir.join(filename);
537 if !force && companion_path.exists() {
538 println!(" {} {} (cached)", "✓".green(), filename);
539 continue;
540 }
541
542 let url = format!("{base_url}/{filename}");
543 match download_file(&url, &companion_path) {
544 Ok(()) => println!(" {} {}", "✓".green(), filename),
545 Err(CliError::HttpNotFound(_)) if *required => {
546 return Err(CliError::ValidationFailed(format!(
547 "{filename} is required for inference but was not found (HTTP 404) at {url}"
548 )));
549 }
550 Err(CliError::HttpNotFound(_)) => {
551 println!(" {} {} (not found in repo)", "⚠".yellow(), filename);
552 }
553 Err(e) if *required => {
554 return Err(CliError::ValidationFailed(format!(
555 "{filename} is required for inference but download failed: {e}"
556 )));
557 }
558 Err(_) => println!(" {} {} (not available, optional)", "⚠".yellow(), filename),
559 }
560 }
561
562 let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
564 let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
565 if !has_tokenizer {
566 return Err(CliError::ValidationFailed(format!(
567 "No tokenizer found for this model. Tried: {}.\n\
568 The model may require a custom tokenizer not hosted in the repository.",
569 tokenizer_files.join(", ")
570 )));
571 }
572
573 Ok(())
574}
575
576fn write_shard_manifest(
578 manifest_path: &Path,
579 org: &str,
580 repo: &str,
581 file_checksums: HashMap<String, FileChecksum>,
582) -> Result<()> {
583 if file_checksums.is_empty() {
584 return Ok(());
585 }
586 let manifest = ShardManifest {
587 version: 1,
588 repo: format!("{org}/{repo}"),
589 files: file_checksums,
590 };
591 let manifest_json = serde_json::to_string_pretty(&manifest)
592 .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
593 std::fs::write(manifest_path, manifest_json)?;
594 println!(" {} .apr-manifest.json (integrity checksums)", "✓".green());
595 Ok(())
596}
597
598fn run_dry_run(model_ref: &str, revision: Option<&str>, offline_flag: bool) -> Result<()> {
614 use super::aliases;
615 use super::offline;
616 use super::revision as rev;
617
618 let resolved = if let Some(url) = aliases::resolve_short_name(model_ref) {
619 url
620 } else if !model_ref.contains("://") && model_ref.contains('/') {
621 format!("hf://{model_ref}")
622 } else {
623 return Err(unknown_short_name_error(model_ref));
624 };
625
626 let rev_spec = revision.unwrap_or(rev::DEFAULT_REVISION);
627 let rev_kind = rev::classify_revision(rev_spec).map_err(|msg| {
628 CliError::ValidationFailed(format!("CRUX-A-03: invalid --revision {rev_spec:?}: {msg}"))
629 })?;
630
631 let env = offline::read_offline_env();
633 let env_borrowed: Vec<(&str, &str)> =
634 env.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
635 let is_offline = offline::is_offline(offline_flag, env_borrowed.iter().copied());
636
637 println!("Model: {}", model_ref.cyan());
638 println!("Resolved: {}", resolved.green());
639 println!("Revision: {} ({:?})", rev_spec.green(), rev_kind);
640 println!(
641 "Offline: {}",
642 if is_offline {
643 "true".green()
644 } else {
645 "false".yellow()
646 }
647 );
648 println!("Mode: {} (no network I/O)", "dry-run".yellow());
649 Ok(())
650}
651
652fn unknown_short_name_error(name: &str) -> CliError {
655 use super::aliases;
656
657 let suggestions = aliases::did_you_mean(name, 2);
658 let hint = if suggestions.is_empty() {
659 "Run `apr registry aliases --json` to list known short names.".to_string()
660 } else {
661 format!(
662 "did you mean {}? (run `apr registry aliases --json` for the full list)",
663 suggestions
664 .iter()
665 .map(|s| format!("`{s}`"))
666 .collect::<Vec<_>>()
667 .join(", ")
668 )
669 };
670 CliError::ValidationFailed(format!(
671 "CRUX-A-01: unknown short name '{name}' and not a fully-qualified URI. {hint}"
672 ))
673}
674
675include!("pull_list.rs");
676include!("pull_remove_resolve_model.rs");
677include!("pull_extract_shard.rs");
678include!("pull_04.rs");
679include!("pull_dataset.rs");