1use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12#[derive(Debug)]
17pub(crate) enum ResolvedModel {
18 SingleFile(String),
20 Sharded {
22 org: String,
23 repo: String,
24 shard_files: Vec<String>,
25 },
26}
27
28#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34 pub version: u32,
35 pub repo: String,
36 pub files: HashMap<String, FileChecksum>,
37}
38
39#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42 pub size: u64,
43 pub blake3: String,
44}
45
46#[provable_contracts_macros::contract(
48 "apr-cli-operations-v1",
49 equation = "mutating_output_contract"
50)]
51pub fn run(
52 model_ref: &str,
53 force: bool,
54 dry_run: bool,
55 revision: Option<&str>,
56 offline: bool,
57) -> Result<()> {
58 contract_pre_pull_cache_integrity!();
59 println!("{}", "=== APR Pull ===".cyan().bold());
60 println!();
61
62 if dry_run {
69 return run_dry_run(model_ref, revision, offline);
70 }
71
72 let resolved = resolve_hf_model(model_ref)?;
74
75 let result = match resolved {
76 ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
77 ResolvedModel::Sharded {
78 ref org,
79 ref repo,
80 ref shard_files,
81 } => run_sharded(org, repo, shard_files, force),
82 };
83 if let Ok(ref r) = result {
84 contract_post_pull_cache_integrity!(r);
85 }
86 result
87}
88
89fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
95 println!("Model: {}", model_ref.cyan());
96
97 if model_ref.starts_with("hf://") {
99 return run_single_file_streaming(model_ref, force);
100 }
101
102 let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
103 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
104 })?;
105
106 if !force && fetcher.is_cached(model_ref) {
107 return handle_cached_model(&mut fetcher, model_ref);
108 }
109
110 let result = download_single_model(&mut fetcher, model_ref)?;
111 ensure_safetensors_companions(&result)?;
112 print_pull_usage(&result.path, true);
113 Ok(())
114}
115
116fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
123 let (org, repo, filename) = parse_hf_single_uri(model_ref)?;
124 let url = format!("https://huggingface.co/{org}/{repo}/resolve/main/{filename}");
125
126 let cache_dir = get_pacha_cache_dir()?;
127 std::fs::create_dir_all(&cache_dir)?;
128 let (extension, cache_path) = build_single_cache_path(&cache_dir, model_ref, &filename);
129
130 if !force && cache_path.exists() {
131 return report_cached_single(&cache_path);
132 }
133
134 stream_and_post_process(&url, &cache_path, model_ref, &extension)?;
135 print_pull_usage(&cache_path, true);
136 Ok(())
137}
138
139fn stream_and_post_process(
140 url: &str,
141 cache_path: &std::path::Path,
142 model_ref: &str,
143 extension: &str,
144) -> Result<()> {
145 println!();
146 println!("{}", "Downloading (streaming)...".yellow());
147 let checksum = download_file_with_progress(url, cache_path)?;
148 report_downloaded_single(cache_path, &checksum);
149
150 if extension == "safetensors" {
151 fetch_safetensors_companions(cache_path, model_ref)?;
152 convert_safetensors_formats(cache_path)?;
153 }
154 Ok(())
155}
156
157fn parse_hf_single_uri(model_ref: &str) -> Result<(String, String, String)> {
158 let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
159 let parts: Vec<&str> = path.split('/').collect();
160 if parts.len() < 3 {
161 return Err(CliError::ValidationFailed(format!(
162 "HuggingFace URI must include a filename: {model_ref}"
163 )));
164 }
165 Ok((
166 parts[0].to_string(),
167 parts[1].to_string(),
168 parts[2..].join("/"),
169 ))
170}
171
172pub(crate) fn build_single_cache_path(
173 cache_dir: &std::path::Path,
174 model_ref: &str,
175 filename: &str,
176) -> (String, std::path::PathBuf) {
177 let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
178 let extension = std::path::Path::new(filename)
179 .extension()
180 .and_then(|e| e.to_str())
181 .unwrap_or("bin")
182 .to_string();
183 let cache_filename = format!("{}.{extension}", &uri_hash[..16]);
184 let cache_path = cache_dir.join(&cache_filename);
185 (extension, cache_path)
186}
187
188fn report_cached_single(cache_path: &std::path::Path) -> Result<()> {
189 let metadata = std::fs::metadata(cache_path)?;
190 println!("{} Model already cached", "✓".green());
191 println!(" Path: {}", cache_path.display());
192 println!(" Size: {}", format_bytes(metadata.len()));
193 print_pull_usage(cache_path, true);
194 Ok(())
195}
196
197fn report_downloaded_single(cache_path: &std::path::Path, checksum: &FileChecksum) {
198 println!();
199 println!("{} Downloaded successfully", "✓".green());
200 println!(" Path: {}", cache_path.display().to_string().green());
201 println!(" Size: {}", format_bytes(checksum.size).yellow());
202 println!(" Hash: {}", &checksum.blake3[..16]);
203}
204
205pub(crate) fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
207 if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
208 return Ok(std::path::PathBuf::from(cache_home)
209 .join("pacha")
210 .join("models"));
211 }
212 Ok(dirs::home_dir()
213 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
214 .join(".cache")
215 .join("pacha")
216 .join("models"))
217}
218
219fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
221 println!("{} Model already cached", "✓".green());
222 let result = fetcher
223 .pull_quiet(model_ref)
224 .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
225
226 println!(" Path: {}", result.path.display());
227 println!(" Size: {}", result.size_human());
228 println!(" Format: {}", result.format.name());
229
230 ensure_safetensors_companions(&result)?;
231 print_pull_usage(&result.path, false);
232 Ok(())
233}
234
235fn download_single_model(
237 fetcher: &mut ModelFetcher,
238 model_ref: &str,
239) -> Result<pacha::fetcher::FetchResult> {
240 println!();
241 println!("{}", "Downloading...".yellow());
242
243 let result = fetcher
244 .pull(model_ref, |progress| {
245 let pct = progress.percent();
246 print!(
247 "\r [{:50}] {:5.1}% ({}/{})",
248 "=".repeat((pct / 2.0) as usize),
249 pct,
250 format_bytes(progress.downloaded_bytes),
251 format_bytes(progress.total_bytes)
252 );
253 io::stdout().flush().ok();
254 })
255 .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
256
257 println!();
258 println!();
259
260 if result.cache_hit {
261 println!("{} Model retrieved from cache", "✓".green());
262 } else {
263 println!("{} Downloaded successfully", "✓".green());
264 }
265
266 println!(" Path: {}", result.path.display().to_string().green());
267 println!(" Size: {}", result.size_human().yellow());
268 println!(" Format: {}", result.format.name());
269 println!(" Hash: {}", &result.hash[..16]);
270 Ok(result)
271}
272
273fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
275 if matches!(result.format, ModelFormat::SafeTensors(_)) {
276 fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
277 convert_safetensors_formats(&result.path)?;
278 }
279 Ok(())
280}
281
282fn print_pull_usage(path: &Path, show_serve: bool) {
284 println!();
285 println!("{}", "Usage:".cyan().bold());
286 println!(" apr run {}", path.display());
287 if show_serve {
288 println!(" apr serve {}", path.display());
289 }
290}
291
292fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
294 println!(
295 "Model: {}/{} ({} shards)",
296 org.cyan(),
297 repo.cyan(),
298 shard_files.len().to_string().yellow()
299 );
300
301 let cache_dir = resolve_shard_cache_dir(org, repo)?;
302 std::fs::create_dir_all(&cache_dir)?;
303
304 let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
305
306 if shard_files
310 .iter()
311 .all(|f| f.to_lowercase().ends_with(".gguf"))
312 {
313 return run_sharded_gguf(org, repo, &cache_dir, &base_url, shard_files, force);
314 }
315
316 let index_path = cache_dir.join("model.safetensors.index.json");
317
318 download_index_if_needed(&base_url, &index_path, force)?;
319
320 let manifest_path = cache_dir.join(".apr-manifest.json");
321 let existing_manifest = load_existing_manifest(&manifest_path, force);
322
323 let file_checksums = download_all_shards(
324 &cache_dir,
325 &base_url,
326 shard_files,
327 force,
328 existing_manifest.as_ref(),
329 )?;
330
331 download_companion_files(&cache_dir, &base_url, force)?;
332 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
333
334 println!();
335 println!("{} Downloaded successfully", "✓".green());
336 println!(" Path: {}", index_path.display().to_string().green());
337 println!(" Shards: {}", shard_files.len().to_string().yellow());
338
339 convert_safetensors_formats(&index_path)?;
340
341 println!();
342 println!("{}", "Usage:".cyan().bold());
343 println!(" apr run {}", index_path.display());
344 println!(" apr serve {}", index_path.display());
345 Ok(())
346}
347
348fn run_sharded_gguf(
353 org: &str,
354 repo: &str,
355 cache_dir: &Path,
356 base_url: &str,
357 shard_files: &[String],
358 force: bool,
359) -> Result<()> {
360 let manifest_path = cache_dir.join(".apr-manifest.json");
361 let existing_manifest = load_existing_manifest(&manifest_path, force);
362
363 let file_checksums = download_all_shards(
364 cache_dir,
365 base_url,
366 shard_files,
367 force,
368 existing_manifest.as_ref(),
369 )?;
370
371 download_companion_files(cache_dir, base_url, force)?;
372 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
373
374 println!();
375 println!(
376 "{} Downloaded {} GGUF shards",
377 "✓".green(),
378 shard_files.len().to_string().yellow()
379 );
380
381 let part_paths: Vec<std::path::PathBuf> =
385 shard_files.iter().map(|f| cache_dir.join(f)).collect();
386 let merged_path = cache_dir.join("model.gguf");
387 match aprender::format::gguf::merge_gguf_shards(&part_paths, &merged_path) {
388 Ok(()) => {
389 for part in &part_paths {
392 if let Err(e) = std::fs::remove_file(part) {
393 eprintln!(
394 " {} could not remove shard {} ({e})",
395 "!".yellow(),
396 part.display()
397 );
398 }
399 }
400 println!(
401 " {} merged {} parts → model.gguf",
402 "✓".green(),
403 shard_files.len().to_string().yellow()
404 );
405 println!(" Path: {}", merged_path.display().to_string().green());
406 println!();
407 println!("{}", "Usage:".cyan().bold());
408 println!(" apr run {}", merged_path.display());
409 println!(" apr serve {}", merged_path.display());
410 }
411 Err(e) => {
412 eprintln!(" {} could not assemble the sharded model: {e}", "✗".red());
415 eprintln!(
416 " The {} parts were downloaded to {} but cannot be run \
417 individually. Please file an issue (#1893) with the model name.",
418 shard_files.len(),
419 cache_dir.display()
420 );
421 return Err(CliError::ValidationFailed(format!(
422 "sharded GGUF merge failed: {e}"
423 )));
424 }
425 }
426 Ok(())
427}
428
429fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
431 Ok(dirs::home_dir()
432 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
433 .join(".apr")
434 .join("cache")
435 .join("hf")
436 .join(org)
437 .join(repo))
438}
439
440fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
442 if force || !index_path.exists() {
443 println!();
444 println!(" {} model.safetensors.index.json", "Downloading".yellow());
445 download_file(
446 &format!("{base_url}/model.safetensors.index.json"),
447 index_path,
448 )?;
449 } else {
450 println!(" {} model.safetensors.index.json (cached)", "✓".green());
451 }
452 Ok(())
453}
454
455fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
457 if force || !manifest_path.exists() {
458 return None;
459 }
460 std::fs::read_to_string(manifest_path)
461 .ok()
462 .and_then(|s| serde_json::from_str(&s).ok())
463}
464
465fn download_all_shards(
467 cache_dir: &Path,
468 base_url: &str,
469 shard_files: &[String],
470 force: bool,
471 existing_manifest: Option<&ShardManifest>,
472) -> Result<HashMap<String, FileChecksum>> {
473 let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
474 let total = shard_files.len();
475 for (i, shard_file) in shard_files.iter().enumerate() {
476 download_or_verify_shard(
477 cache_dir,
478 base_url,
479 shard_file,
480 i,
481 total,
482 force,
483 existing_manifest,
484 &mut file_checksums,
485 )?;
486 }
487 Ok(file_checksums)
488}
489
490fn download_or_verify_shard(
492 cache_dir: &Path,
493 base_url: &str,
494 shard_file: &str,
495 index: usize,
496 total: usize,
497 force: bool,
498 existing_manifest: Option<&ShardManifest>,
499 checksums: &mut HashMap<String, FileChecksum>,
500) -> Result<()> {
501 let shard_path = cache_dir.join(shard_file);
502
503 if !force && shard_path.exists() {
504 if let Some(manifest) = existing_manifest {
505 if let Some(expected) = manifest.files.get(shard_file) {
506 let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
507 if actual_size == expected.size {
508 checksums.insert(
509 shard_file.to_string(),
510 FileChecksum {
511 size: expected.size,
512 blake3: expected.blake3.clone(),
513 },
514 );
515 println!(
516 " {} [{}/{}] {} (cached, verified)",
517 "✓".green(),
518 index + 1,
519 total,
520 shard_file
521 );
522 return Ok(());
523 }
524 println!(
525 " {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
526 "⚠".yellow(),
527 index + 1,
528 total,
529 shard_file,
530 actual_size,
531 expected.size
532 );
533 }
535 } else {
536 println!(
537 " {} [{}/{}] {} (cached)",
538 "✓".green(),
539 index + 1,
540 total,
541 shard_file
542 );
543 return Ok(());
544 }
545 }
546
547 let shard_url = format!("{base_url}/{shard_file}");
548 print!(
549 " {} [{}/{}] {}...",
550 "↓".yellow(),
551 index + 1,
552 total,
553 shard_file
554 );
555 io::stdout().flush().ok();
556
557 let checksum = download_file_with_progress(&shard_url, &shard_path)?;
558 checksums.insert(shard_file.to_string(), checksum);
559 println!(" {}", "done".green());
560 Ok(())
561}
562
563fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
568 let companions = [
570 ("tokenizer.json", false),
571 ("config.json", true),
572 ("tokenizer_config.json", false),
573 ("tokenizer.model", false),
574 ];
575 for (filename, required) in &companions {
576 let companion_path = cache_dir.join(filename);
577 if !force && companion_path.exists() {
578 println!(" {} {} (cached)", "✓".green(), filename);
579 continue;
580 }
581
582 let url = format!("{base_url}/{filename}");
583 match download_file(&url, &companion_path) {
584 Ok(()) => println!(" {} {}", "✓".green(), filename),
585 Err(CliError::HttpNotFound(_)) if *required => {
586 return Err(CliError::ValidationFailed(format!(
587 "{filename} is required for inference but was not found (HTTP 404) at {url}"
588 )));
589 }
590 Err(CliError::HttpNotFound(_)) => {
591 println!(" {} {} (not found in repo)", "⚠".yellow(), filename);
592 }
593 Err(e) if *required => {
594 return Err(CliError::ValidationFailed(format!(
595 "{filename} is required for inference but download failed: {e}"
596 )));
597 }
598 Err(_) => println!(" {} {} (not available, optional)", "⚠".yellow(), filename),
599 }
600 }
601
602 let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
604 let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
605 if !has_tokenizer {
606 return Err(CliError::ValidationFailed(format!(
607 "No tokenizer found for this model. Tried: {}.\n\
608 The model may require a custom tokenizer not hosted in the repository.",
609 tokenizer_files.join(", ")
610 )));
611 }
612
613 Ok(())
614}
615
616fn write_shard_manifest(
618 manifest_path: &Path,
619 org: &str,
620 repo: &str,
621 file_checksums: HashMap<String, FileChecksum>,
622) -> Result<()> {
623 if file_checksums.is_empty() {
624 return Ok(());
625 }
626 let manifest = ShardManifest {
627 version: 1,
628 repo: format!("{org}/{repo}"),
629 files: file_checksums,
630 };
631 let manifest_json = serde_json::to_string_pretty(&manifest)
632 .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
633 std::fs::write(manifest_path, manifest_json)?;
634 println!(" {} .apr-manifest.json (integrity checksums)", "✓".green());
635 Ok(())
636}
637
638fn run_dry_run(model_ref: &str, revision: Option<&str>, offline_flag: bool) -> Result<()> {
654 use super::aliases;
655 use super::offline;
656 use super::revision as rev;
657
658 let resolved = if let Some(url) = aliases::resolve_short_name(model_ref) {
659 url
660 } else if !model_ref.contains("://") && model_ref.contains('/') {
661 format!("hf://{model_ref}")
662 } else {
663 return Err(unknown_short_name_error(model_ref));
664 };
665
666 let rev_spec = revision.unwrap_or(rev::DEFAULT_REVISION);
667 let rev_kind = rev::classify_revision(rev_spec).map_err(|msg| {
668 CliError::ValidationFailed(format!("CRUX-A-03: invalid --revision {rev_spec:?}: {msg}"))
669 })?;
670
671 let env = offline::read_offline_env();
673 let env_borrowed: Vec<(&str, &str)> =
674 env.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
675 let is_offline = offline::is_offline(offline_flag, env_borrowed.iter().copied());
676
677 println!("Model: {}", model_ref.cyan());
678 println!("Resolved: {}", resolved.green());
679 println!("Revision: {} ({:?})", rev_spec.green(), rev_kind);
680 println!(
681 "Offline: {}",
682 if is_offline {
683 "true".green()
684 } else {
685 "false".yellow()
686 }
687 );
688 println!("Mode: {} (no network I/O)", "dry-run".yellow());
689 Ok(())
690}
691
692fn unknown_short_name_error(name: &str) -> CliError {
695 use super::aliases;
696
697 let suggestions = aliases::did_you_mean(name, 2);
698 let hint = if suggestions.is_empty() {
699 "Run `apr registry aliases --json` to list known short names.".to_string()
700 } else {
701 format!(
702 "did you mean {}? (run `apr registry aliases --json` for the full list)",
703 suggestions
704 .iter()
705 .map(|s| format!("`{s}`"))
706 .collect::<Vec<_>>()
707 .join(", ")
708 )
709 };
710 CliError::ValidationFailed(format!(
711 "CRUX-A-01: unknown short name '{name}' and not a fully-qualified URI. {hint}"
712 ))
713}
714
715include!("pull_list.rs");
716include!("pull_remove_resolve_model.rs");
717include!("pull_extract_shard.rs");
718include!("pull_04.rs");
719include!("pull_dataset.rs");
720
721#[cfg(all(test, feature = "inference"))]
722mod sharded_gguf_interop_tests {
723 use aprender::format::gguf::{
724 export_tensors_to_gguf, merge_gguf_shards, GgmlType, GgufTensor, GgufValue,
725 };
726 use std::path::Path;
727
728 fn write_part(path: &Path, tensors: &[GgufTensor], meta: &[(String, GgufValue)]) {
729 let mut buf = Vec::new();
730 export_tensors_to_gguf(&mut buf, tensors, meta).expect("export part");
731 std::fs::write(path, &buf).expect("write part");
732 }
733
734 #[test]
740 fn merged_sharded_gguf_loads_in_realizar() {
741 let dir = std::env::temp_dir().join(format!("apr-merge-interop-{}", std::process::id()));
742 std::fs::create_dir_all(&dir).expect("mkdir");
743 let p0 = dir.join("model-00001-of-00002.gguf");
744 let p1 = dir.join("model-00002-of-00002.gguf");
745 let merged = dir.join("model.gguf");
746
747 let tensor = |name: &str, fill: u8| GgufTensor {
748 name: name.into(),
749 shape: vec![4],
750 dtype: GgmlType::F32,
751 data: vec![fill; 16],
752 };
753 write_part(
754 &p0,
755 &[tensor("blk.0.weight", 1)],
756 &[
757 (
758 "general.architecture".into(),
759 GgufValue::String("gemma".into()),
760 ),
761 ("gemma.embedding_length".into(), GgufValue::Uint32(2048)),
762 ("gemma.block_count".into(), GgufValue::Uint32(18)),
763 ("split.no".into(), GgufValue::Uint16(0)),
764 ("split.count".into(), GgufValue::Uint16(2)),
765 ],
766 );
767 write_part(
768 &p1,
769 &[tensor("blk.1.weight", 2)],
770 &[("split.no".into(), GgufValue::Uint16(1))],
771 );
772
773 merge_gguf_shards(&[p0, p1], &merged).expect("merge");
774 let bytes = std::fs::read(&merged).expect("read merged");
775
776 let parsed = realizar::gguf::GGUFModel::from_bytes(&bytes);
777 assert!(
778 parsed.is_ok(),
779 "realizar's GGUF loader must accept the merged sharded file: {:?}",
780 parsed.err()
781 );
782
783 std::fs::remove_dir_all(&dir).ok();
784 }
785}