1use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12#[derive(Debug)]
17enum ResolvedModel {
18 SingleFile(String),
20 Sharded {
22 org: String,
23 repo: String,
24 shard_files: Vec<String>,
25 },
26}
27
28#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34 pub version: u32,
35 pub repo: String,
36 pub files: HashMap<String, FileChecksum>,
37}
38
39#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42 pub size: u64,
43 pub blake3: String,
44}
45
46#[provable_contracts_macros::contract(
48 "apr-cli-operations-v1",
49 equation = "mutating_output_contract"
50)]
51pub fn run(model_ref: &str, force: bool) -> Result<()> {
52 contract_pre_pull_cache_integrity!();
53 println!("{}", "=== APR Pull ===".cyan().bold());
54 println!();
55
56 let resolved = resolve_hf_model(model_ref)?;
58
59 let result = match resolved {
60 ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
61 ResolvedModel::Sharded {
62 ref org,
63 ref repo,
64 ref shard_files,
65 } => run_sharded(org, repo, shard_files, force),
66 };
67 if let Ok(ref r) = result {
68 contract_post_pull_cache_integrity!(r);
69 }
70 result
71}
72
73fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
79 println!("Model: {}", model_ref.cyan());
80
81 if model_ref.starts_with("hf://") {
83 return run_single_file_streaming(model_ref, force);
84 }
85
86 let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
87 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
88 })?;
89
90 if !force && fetcher.is_cached(model_ref) {
91 return handle_cached_model(&mut fetcher, model_ref);
92 }
93
94 let result = download_single_model(&mut fetcher, model_ref)?;
95 ensure_safetensors_companions(&result)?;
96 print_pull_usage(&result.path, true);
97 Ok(())
98}
99
100fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
107 let (org, repo, filename) = parse_hf_single_uri(model_ref)?;
108 let url = format!("https://huggingface.co/{org}/{repo}/resolve/main/{filename}");
109
110 let cache_dir = get_pacha_cache_dir()?;
111 std::fs::create_dir_all(&cache_dir)?;
112 let (extension, cache_path) = build_single_cache_path(&cache_dir, model_ref, &filename);
113
114 if !force && cache_path.exists() {
115 return report_cached_single(&cache_path);
116 }
117
118 stream_and_post_process(&url, &cache_path, model_ref, &extension)?;
119 print_pull_usage(&cache_path, true);
120 Ok(())
121}
122
123fn stream_and_post_process(
124 url: &str,
125 cache_path: &std::path::Path,
126 model_ref: &str,
127 extension: &str,
128) -> Result<()> {
129 println!();
130 println!("{}", "Downloading (streaming)...".yellow());
131 let checksum = download_file_with_progress(url, cache_path)?;
132 report_downloaded_single(cache_path, &checksum);
133
134 if extension == "safetensors" {
135 fetch_safetensors_companions(cache_path, model_ref)?;
136 convert_safetensors_formats(cache_path)?;
137 }
138 Ok(())
139}
140
141fn parse_hf_single_uri(model_ref: &str) -> Result<(String, String, String)> {
142 let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
143 let parts: Vec<&str> = path.split('/').collect();
144 if parts.len() < 3 {
145 return Err(CliError::ValidationFailed(format!(
146 "HuggingFace URI must include a filename: {model_ref}"
147 )));
148 }
149 Ok((
150 parts[0].to_string(),
151 parts[1].to_string(),
152 parts[2..].join("/"),
153 ))
154}
155
156fn build_single_cache_path(
157 cache_dir: &std::path::Path,
158 model_ref: &str,
159 filename: &str,
160) -> (String, std::path::PathBuf) {
161 let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
162 let extension = std::path::Path::new(filename)
163 .extension()
164 .and_then(|e| e.to_str())
165 .unwrap_or("bin")
166 .to_string();
167 let cache_filename = format!("{}.{extension}", &uri_hash[..16]);
168 let cache_path = cache_dir.join(&cache_filename);
169 (extension, cache_path)
170}
171
172fn report_cached_single(cache_path: &std::path::Path) -> Result<()> {
173 let metadata = std::fs::metadata(cache_path)?;
174 println!("{} Model already cached", "✓".green());
175 println!(" Path: {}", cache_path.display());
176 println!(" Size: {}", format_bytes(metadata.len()));
177 print_pull_usage(cache_path, true);
178 Ok(())
179}
180
181fn report_downloaded_single(cache_path: &std::path::Path, checksum: &FileChecksum) {
182 println!();
183 println!("{} Downloaded successfully", "✓".green());
184 println!(" Path: {}", cache_path.display().to_string().green());
185 println!(" Size: {}", format_bytes(checksum.size).yellow());
186 println!(" Hash: {}", &checksum.blake3[..16]);
187}
188
189fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
191 if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
192 return Ok(std::path::PathBuf::from(cache_home)
193 .join("pacha")
194 .join("models"));
195 }
196 Ok(dirs::home_dir()
197 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
198 .join(".cache")
199 .join("pacha")
200 .join("models"))
201}
202
203fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
205 println!("{} Model already cached", "✓".green());
206 let result = fetcher
207 .pull_quiet(model_ref)
208 .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
209
210 println!(" Path: {}", result.path.display());
211 println!(" Size: {}", result.size_human());
212 println!(" Format: {}", result.format.name());
213
214 ensure_safetensors_companions(&result)?;
215 print_pull_usage(&result.path, false);
216 Ok(())
217}
218
219fn download_single_model(
221 fetcher: &mut ModelFetcher,
222 model_ref: &str,
223) -> Result<pacha::fetcher::FetchResult> {
224 println!();
225 println!("{}", "Downloading...".yellow());
226
227 let result = fetcher
228 .pull(model_ref, |progress| {
229 let pct = progress.percent();
230 print!(
231 "\r [{:50}] {:5.1}% ({}/{})",
232 "=".repeat((pct / 2.0) as usize),
233 pct,
234 format_bytes(progress.downloaded_bytes),
235 format_bytes(progress.total_bytes)
236 );
237 io::stdout().flush().ok();
238 })
239 .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
240
241 println!();
242 println!();
243
244 if result.cache_hit {
245 println!("{} Model retrieved from cache", "✓".green());
246 } else {
247 println!("{} Downloaded successfully", "✓".green());
248 }
249
250 println!(" Path: {}", result.path.display().to_string().green());
251 println!(" Size: {}", result.size_human().yellow());
252 println!(" Format: {}", result.format.name());
253 println!(" Hash: {}", &result.hash[..16]);
254 Ok(result)
255}
256
257fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
259 if matches!(result.format, ModelFormat::SafeTensors(_)) {
260 fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
261 convert_safetensors_formats(&result.path)?;
262 }
263 Ok(())
264}
265
266fn print_pull_usage(path: &Path, show_serve: bool) {
268 println!();
269 println!("{}", "Usage:".cyan().bold());
270 println!(" apr run {}", path.display());
271 if show_serve {
272 println!(" apr serve {}", path.display());
273 }
274}
275
276fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
278 println!(
279 "Model: {}/{} ({} shards)",
280 org.cyan(),
281 repo.cyan(),
282 shard_files.len().to_string().yellow()
283 );
284
285 let cache_dir = resolve_shard_cache_dir(org, repo)?;
286 std::fs::create_dir_all(&cache_dir)?;
287
288 let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
289 let index_path = cache_dir.join("model.safetensors.index.json");
290
291 download_index_if_needed(&base_url, &index_path, force)?;
292
293 let manifest_path = cache_dir.join(".apr-manifest.json");
294 let existing_manifest = load_existing_manifest(&manifest_path, force);
295
296 let file_checksums = download_all_shards(
297 &cache_dir,
298 &base_url,
299 shard_files,
300 force,
301 existing_manifest.as_ref(),
302 )?;
303
304 download_companion_files(&cache_dir, &base_url, force)?;
305 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
306
307 println!();
308 println!("{} Downloaded successfully", "✓".green());
309 println!(" Path: {}", index_path.display().to_string().green());
310 println!(" Shards: {}", shard_files.len().to_string().yellow());
311
312 convert_safetensors_formats(&index_path)?;
313
314 println!();
315 println!("{}", "Usage:".cyan().bold());
316 println!(" apr run {}", index_path.display());
317 println!(" apr serve {}", index_path.display());
318 Ok(())
319}
320
321fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
323 Ok(dirs::home_dir()
324 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
325 .join(".apr")
326 .join("cache")
327 .join("hf")
328 .join(org)
329 .join(repo))
330}
331
332fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
334 if force || !index_path.exists() {
335 println!();
336 println!(" {} model.safetensors.index.json", "Downloading".yellow());
337 download_file(
338 &format!("{base_url}/model.safetensors.index.json"),
339 index_path,
340 )?;
341 } else {
342 println!(" {} model.safetensors.index.json (cached)", "✓".green());
343 }
344 Ok(())
345}
346
347fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
349 if force || !manifest_path.exists() {
350 return None;
351 }
352 std::fs::read_to_string(manifest_path)
353 .ok()
354 .and_then(|s| serde_json::from_str(&s).ok())
355}
356
357fn download_all_shards(
359 cache_dir: &Path,
360 base_url: &str,
361 shard_files: &[String],
362 force: bool,
363 existing_manifest: Option<&ShardManifest>,
364) -> Result<HashMap<String, FileChecksum>> {
365 let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
366 let total = shard_files.len();
367 for (i, shard_file) in shard_files.iter().enumerate() {
368 download_or_verify_shard(
369 cache_dir,
370 base_url,
371 shard_file,
372 i,
373 total,
374 force,
375 existing_manifest,
376 &mut file_checksums,
377 )?;
378 }
379 Ok(file_checksums)
380}
381
382fn download_or_verify_shard(
384 cache_dir: &Path,
385 base_url: &str,
386 shard_file: &str,
387 index: usize,
388 total: usize,
389 force: bool,
390 existing_manifest: Option<&ShardManifest>,
391 checksums: &mut HashMap<String, FileChecksum>,
392) -> Result<()> {
393 let shard_path = cache_dir.join(shard_file);
394
395 if !force && shard_path.exists() {
396 if let Some(manifest) = existing_manifest {
397 if let Some(expected) = manifest.files.get(shard_file) {
398 let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
399 if actual_size == expected.size {
400 checksums.insert(
401 shard_file.to_string(),
402 FileChecksum {
403 size: expected.size,
404 blake3: expected.blake3.clone(),
405 },
406 );
407 println!(
408 " {} [{}/{}] {} (cached, verified)",
409 "✓".green(),
410 index + 1,
411 total,
412 shard_file
413 );
414 return Ok(());
415 }
416 println!(
417 " {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
418 "⚠".yellow(),
419 index + 1,
420 total,
421 shard_file,
422 actual_size,
423 expected.size
424 );
425 }
427 } else {
428 println!(
429 " {} [{}/{}] {} (cached)",
430 "✓".green(),
431 index + 1,
432 total,
433 shard_file
434 );
435 return Ok(());
436 }
437 }
438
439 let shard_url = format!("{base_url}/{shard_file}");
440 print!(
441 " {} [{}/{}] {}...",
442 "↓".yellow(),
443 index + 1,
444 total,
445 shard_file
446 );
447 io::stdout().flush().ok();
448
449 let checksum = download_file_with_progress(&shard_url, &shard_path)?;
450 checksums.insert(shard_file.to_string(), checksum);
451 println!(" {}", "done".green());
452 Ok(())
453}
454
455fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
460 let companions = [
462 ("tokenizer.json", false),
463 ("config.json", true),
464 ("tokenizer_config.json", false),
465 ("tokenizer.model", false),
466 ];
467 for (filename, required) in &companions {
468 let companion_path = cache_dir.join(filename);
469 if !force && companion_path.exists() {
470 println!(" {} {} (cached)", "✓".green(), filename);
471 continue;
472 }
473
474 let url = format!("{base_url}/{filename}");
475 match download_file(&url, &companion_path) {
476 Ok(()) => println!(" {} {}", "✓".green(), filename),
477 Err(CliError::HttpNotFound(_)) if *required => {
478 return Err(CliError::ValidationFailed(format!(
479 "{filename} is required for inference but was not found (HTTP 404) at {url}"
480 )));
481 }
482 Err(CliError::HttpNotFound(_)) => {
483 println!(" {} {} (not found in repo)", "⚠".yellow(), filename);
484 }
485 Err(e) if *required => {
486 return Err(CliError::ValidationFailed(format!(
487 "{filename} is required for inference but download failed: {e}"
488 )));
489 }
490 Err(_) => println!(" {} {} (not available, optional)", "⚠".yellow(), filename),
491 }
492 }
493
494 let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
496 let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
497 if !has_tokenizer {
498 return Err(CliError::ValidationFailed(format!(
499 "No tokenizer found for this model. Tried: {}.\n\
500 The model may require a custom tokenizer not hosted in the repository.",
501 tokenizer_files.join(", ")
502 )));
503 }
504
505 Ok(())
506}
507
508fn write_shard_manifest(
510 manifest_path: &Path,
511 org: &str,
512 repo: &str,
513 file_checksums: HashMap<String, FileChecksum>,
514) -> Result<()> {
515 if file_checksums.is_empty() {
516 return Ok(());
517 }
518 let manifest = ShardManifest {
519 version: 1,
520 repo: format!("{org}/{repo}"),
521 files: file_checksums,
522 };
523 let manifest_json = serde_json::to_string_pretty(&manifest)
524 .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
525 std::fs::write(manifest_path, manifest_json)?;
526 println!(" {} .apr-manifest.json (integrity checksums)", "✓".green());
527 Ok(())
528}
529
530include!("pull_list.rs");
531include!("pull_remove_resolve_model.rs");
532include!("pull_extract_shard.rs");
533include!("pull_04.rs");