1use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12#[derive(Debug)]
17enum ResolvedModel {
18 SingleFile(String),
20 Sharded {
22 org: String,
23 repo: String,
24 shard_files: Vec<String>,
25 },
26}
27
28#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34 pub version: u32,
35 pub repo: String,
36 pub files: HashMap<String, FileChecksum>,
37}
38
39#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42 pub size: u64,
43 pub blake3: String,
44}
45
46#[provable_contracts_macros::contract(
48 "apr-cli-operations-v1",
49 equation = "mutating_output_contract"
50)]
51pub fn run(model_ref: &str, force: bool) -> Result<()> {
52 contract_pre_pull_cache_integrity!();
53 println!("{}", "=== APR Pull ===".cyan().bold());
54 println!();
55
56 let resolved = resolve_hf_model(model_ref)?;
58
59 let result = match resolved {
60 ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
61 ResolvedModel::Sharded {
62 ref org,
63 ref repo,
64 ref shard_files,
65 } => run_sharded(org, repo, shard_files, force),
66 };
67 if let Ok(ref r) = result {
68 contract_post_pull_cache_integrity!(r);
69 }
70 result
71}
72
73fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
79 println!("Model: {}", model_ref.cyan());
80
81 if model_ref.starts_with("hf://") {
83 return run_single_file_streaming(model_ref, force);
84 }
85
86 let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
87 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
88 })?;
89
90 if !force && fetcher.is_cached(model_ref) {
91 return handle_cached_model(&mut fetcher, model_ref);
92 }
93
94 let result = download_single_model(&mut fetcher, model_ref)?;
95 ensure_safetensors_companions(&result)?;
96 print_pull_usage(&result.path, true);
97 Ok(())
98}
99
100fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
107 let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
108 let parts: Vec<&str> = path.split('/').collect();
109 if parts.len() < 3 {
110 return Err(CliError::ValidationFailed(format!(
111 "HuggingFace URI must include a filename: {model_ref}"
112 )));
113 }
114
115 let filename = parts[2..].join("/");
116 let url = format!(
117 "https://huggingface.co/{}/{}/resolve/main/{}",
118 parts[0], parts[1], filename
119 );
120
121 let cache_dir = get_pacha_cache_dir()?;
123 std::fs::create_dir_all(&cache_dir)?;
124
125 let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
127 let extension = std::path::Path::new(&filename)
128 .extension()
129 .and_then(|e| e.to_str())
130 .unwrap_or("bin");
131 let cache_filename = format!("{}.{}", &uri_hash[..16], extension);
132 let cache_path = cache_dir.join(&cache_filename);
133
134 if !force && cache_path.exists() {
135 let metadata = std::fs::metadata(&cache_path)?;
136 println!("{} Model already cached", "✓".green());
137 println!(" Path: {}", cache_path.display());
138 println!(" Size: {}", format_bytes(metadata.len()));
139 print_pull_usage(&cache_path, true);
140 return Ok(());
141 }
142
143 println!();
144 println!("{}", "Downloading (streaming)...".yellow());
145
146 let checksum = download_file_with_progress(&url, &cache_path)?;
147
148 println!();
149 println!("{} Downloaded successfully", "✓".green());
150 println!(" Path: {}", cache_path.display().to_string().green());
151 println!(" Size: {}", format_bytes(checksum.size).yellow());
152 println!(" Hash: {}", &checksum.blake3[..16]);
153
154 if extension == "safetensors" {
156 fetch_safetensors_companions(&cache_path, model_ref)?;
157 convert_safetensors_formats(&cache_path)?;
158 }
159
160 print_pull_usage(&cache_path, true);
161 Ok(())
162}
163
164fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
166 if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
167 return Ok(std::path::PathBuf::from(cache_home)
168 .join("pacha")
169 .join("models"));
170 }
171 Ok(dirs::home_dir()
172 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
173 .join(".cache")
174 .join("pacha")
175 .join("models"))
176}
177
178fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
180 println!("{} Model already cached", "✓".green());
181 let result = fetcher
182 .pull_quiet(model_ref)
183 .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
184
185 println!(" Path: {}", result.path.display());
186 println!(" Size: {}", result.size_human());
187 println!(" Format: {}", result.format.name());
188
189 ensure_safetensors_companions(&result)?;
190 print_pull_usage(&result.path, false);
191 Ok(())
192}
193
194fn download_single_model(
196 fetcher: &mut ModelFetcher,
197 model_ref: &str,
198) -> Result<pacha::fetcher::FetchResult> {
199 println!();
200 println!("{}", "Downloading...".yellow());
201
202 let result = fetcher
203 .pull(model_ref, |progress| {
204 let pct = progress.percent();
205 print!(
206 "\r [{:50}] {:5.1}% ({}/{})",
207 "=".repeat((pct / 2.0) as usize),
208 pct,
209 format_bytes(progress.downloaded_bytes),
210 format_bytes(progress.total_bytes)
211 );
212 io::stdout().flush().ok();
213 })
214 .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
215
216 println!();
217 println!();
218
219 if result.cache_hit {
220 println!("{} Model retrieved from cache", "✓".green());
221 } else {
222 println!("{} Downloaded successfully", "✓".green());
223 }
224
225 println!(" Path: {}", result.path.display().to_string().green());
226 println!(" Size: {}", result.size_human().yellow());
227 println!(" Format: {}", result.format.name());
228 println!(" Hash: {}", &result.hash[..16]);
229 Ok(result)
230}
231
232fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
234 if matches!(result.format, ModelFormat::SafeTensors(_)) {
235 fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
236 convert_safetensors_formats(&result.path)?;
237 }
238 Ok(())
239}
240
241fn print_pull_usage(path: &Path, show_serve: bool) {
243 println!();
244 println!("{}", "Usage:".cyan().bold());
245 println!(" apr run {}", path.display());
246 if show_serve {
247 println!(" apr serve {}", path.display());
248 }
249}
250
251fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
253 println!(
254 "Model: {}/{} ({} shards)",
255 org.cyan(),
256 repo.cyan(),
257 shard_files.len().to_string().yellow()
258 );
259
260 let cache_dir = resolve_shard_cache_dir(org, repo)?;
261 std::fs::create_dir_all(&cache_dir)?;
262
263 let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
264 let index_path = cache_dir.join("model.safetensors.index.json");
265
266 download_index_if_needed(&base_url, &index_path, force)?;
267
268 let manifest_path = cache_dir.join(".apr-manifest.json");
269 let existing_manifest = load_existing_manifest(&manifest_path, force);
270
271 let file_checksums = download_all_shards(
272 &cache_dir,
273 &base_url,
274 shard_files,
275 force,
276 existing_manifest.as_ref(),
277 )?;
278
279 download_companion_files(&cache_dir, &base_url, force)?;
280 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
281
282 println!();
283 println!("{} Downloaded successfully", "✓".green());
284 println!(" Path: {}", index_path.display().to_string().green());
285 println!(" Shards: {}", shard_files.len().to_string().yellow());
286
287 convert_safetensors_formats(&index_path)?;
288
289 println!();
290 println!("{}", "Usage:".cyan().bold());
291 println!(" apr run {}", index_path.display());
292 println!(" apr serve {}", index_path.display());
293 Ok(())
294}
295
296fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
298 Ok(dirs::home_dir()
299 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
300 .join(".apr")
301 .join("cache")
302 .join("hf")
303 .join(org)
304 .join(repo))
305}
306
307fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
309 if force || !index_path.exists() {
310 println!();
311 println!(" {} model.safetensors.index.json", "Downloading".yellow());
312 download_file(
313 &format!("{base_url}/model.safetensors.index.json"),
314 index_path,
315 )?;
316 } else {
317 println!(" {} model.safetensors.index.json (cached)", "✓".green());
318 }
319 Ok(())
320}
321
322fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
324 if force || !manifest_path.exists() {
325 return None;
326 }
327 std::fs::read_to_string(manifest_path)
328 .ok()
329 .and_then(|s| serde_json::from_str(&s).ok())
330}
331
332fn download_all_shards(
334 cache_dir: &Path,
335 base_url: &str,
336 shard_files: &[String],
337 force: bool,
338 existing_manifest: Option<&ShardManifest>,
339) -> Result<HashMap<String, FileChecksum>> {
340 let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
341 let total = shard_files.len();
342 for (i, shard_file) in shard_files.iter().enumerate() {
343 download_or_verify_shard(
344 cache_dir,
345 base_url,
346 shard_file,
347 i,
348 total,
349 force,
350 existing_manifest,
351 &mut file_checksums,
352 )?;
353 }
354 Ok(file_checksums)
355}
356
357fn download_or_verify_shard(
359 cache_dir: &Path,
360 base_url: &str,
361 shard_file: &str,
362 index: usize,
363 total: usize,
364 force: bool,
365 existing_manifest: Option<&ShardManifest>,
366 checksums: &mut HashMap<String, FileChecksum>,
367) -> Result<()> {
368 let shard_path = cache_dir.join(shard_file);
369
370 if !force && shard_path.exists() {
371 if let Some(manifest) = existing_manifest {
372 if let Some(expected) = manifest.files.get(shard_file) {
373 let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
374 if actual_size == expected.size {
375 checksums.insert(
376 shard_file.to_string(),
377 FileChecksum {
378 size: expected.size,
379 blake3: expected.blake3.clone(),
380 },
381 );
382 println!(
383 " {} [{}/{}] {} (cached, verified)",
384 "✓".green(),
385 index + 1,
386 total,
387 shard_file
388 );
389 return Ok(());
390 }
391 println!(
392 " {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
393 "⚠".yellow(),
394 index + 1,
395 total,
396 shard_file,
397 actual_size,
398 expected.size
399 );
400 }
402 } else {
403 println!(
404 " {} [{}/{}] {} (cached)",
405 "✓".green(),
406 index + 1,
407 total,
408 shard_file
409 );
410 return Ok(());
411 }
412 }
413
414 let shard_url = format!("{base_url}/{shard_file}");
415 print!(
416 " {} [{}/{}] {}...",
417 "↓".yellow(),
418 index + 1,
419 total,
420 shard_file
421 );
422 io::stdout().flush().ok();
423
424 let checksum = download_file_with_progress(&shard_url, &shard_path)?;
425 checksums.insert(shard_file.to_string(), checksum);
426 println!(" {}", "done".green());
427 Ok(())
428}
429
430fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
435 let companions = [
437 ("tokenizer.json", false),
438 ("config.json", true),
439 ("tokenizer_config.json", false),
440 ("tokenizer.model", false),
441 ];
442 for (filename, required) in &companions {
443 let companion_path = cache_dir.join(filename);
444 if !force && companion_path.exists() {
445 println!(" {} {} (cached)", "✓".green(), filename);
446 continue;
447 }
448
449 let url = format!("{base_url}/{filename}");
450 match download_file(&url, &companion_path) {
451 Ok(()) => println!(" {} {}", "✓".green(), filename),
452 Err(CliError::HttpNotFound(_)) if *required => {
453 return Err(CliError::ValidationFailed(format!(
454 "{filename} is required for inference but was not found (HTTP 404) at {url}"
455 )));
456 }
457 Err(CliError::HttpNotFound(_)) => {
458 println!(" {} {} (not found in repo)", "⚠".yellow(), filename);
459 }
460 Err(e) if *required => {
461 return Err(CliError::ValidationFailed(format!(
462 "{filename} is required for inference but download failed: {e}"
463 )));
464 }
465 Err(_) => println!(" {} {} (not available, optional)", "⚠".yellow(), filename),
466 }
467 }
468
469 let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
471 let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
472 if !has_tokenizer {
473 return Err(CliError::ValidationFailed(format!(
474 "No tokenizer found for this model. Tried: {}.\n\
475 The model may require a custom tokenizer not hosted in the repository.",
476 tokenizer_files.join(", ")
477 )));
478 }
479
480 Ok(())
481}
482
483fn write_shard_manifest(
485 manifest_path: &Path,
486 org: &str,
487 repo: &str,
488 file_checksums: HashMap<String, FileChecksum>,
489) -> Result<()> {
490 if file_checksums.is_empty() {
491 return Ok(());
492 }
493 let manifest = ShardManifest {
494 version: 1,
495 repo: format!("{org}/{repo}"),
496 files: file_checksums,
497 };
498 let manifest_json = serde_json::to_string_pretty(&manifest)
499 .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
500 std::fs::write(manifest_path, manifest_json)?;
501 println!(" {} .apr-manifest.json (integrity checksums)", "✓".green());
502 Ok(())
503}
504
505include!("pull_list.rs");
506include!("pull_remove_resolve_model.rs");
507include!("pull_extract_shard.rs");
508include!("pull_04.rs");