1use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12#[derive(Debug)]
17enum ResolvedModel {
18 SingleFile(String),
20 Sharded {
22 org: String,
23 repo: String,
24 shard_files: Vec<String>,
25 },
26}
27
28#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34 pub version: u32,
35 pub repo: String,
36 pub files: HashMap<String, FileChecksum>,
37}
38
39#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42 pub size: u64,
43 pub blake3: String,
44}
45
46pub fn run(model_ref: &str, force: bool) -> Result<()> {
48 contract_pre_pull_cache_integrity!();
49 println!("{}", "=== APR Pull ===".cyan().bold());
50 println!();
51
52 let resolved = resolve_hf_model(model_ref)?;
54
55 match resolved {
56 ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
57 ResolvedModel::Sharded {
58 ref org,
59 ref repo,
60 ref shard_files,
61 } => run_sharded(org, repo, shard_files, force),
62 }
63}
64
65fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
71 println!("Model: {}", model_ref.cyan());
72
73 if model_ref.starts_with("hf://") {
75 return run_single_file_streaming(model_ref, force);
76 }
77
78 let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
79 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
80 })?;
81
82 if !force && fetcher.is_cached(model_ref) {
83 return handle_cached_model(&mut fetcher, model_ref);
84 }
85
86 let result = download_single_model(&mut fetcher, model_ref)?;
87 ensure_safetensors_companions(&result)?;
88 print_pull_usage(&result.path, true);
89 Ok(())
90}
91
92fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
99 let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
100 let parts: Vec<&str> = path.split('/').collect();
101 if parts.len() < 3 {
102 return Err(CliError::ValidationFailed(format!(
103 "HuggingFace URI must include a filename: {model_ref}"
104 )));
105 }
106
107 let filename = parts[2..].join("/");
108 let url = format!(
109 "https://huggingface.co/{}/{}/resolve/main/{}",
110 parts[0], parts[1], filename
111 );
112
113 let cache_dir = get_pacha_cache_dir()?;
115 std::fs::create_dir_all(&cache_dir)?;
116
117 let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
119 let extension = std::path::Path::new(&filename)
120 .extension()
121 .and_then(|e| e.to_str())
122 .unwrap_or("bin");
123 let cache_filename = format!("{}.{}", &uri_hash[..16], extension);
124 let cache_path = cache_dir.join(&cache_filename);
125
126 if !force && cache_path.exists() {
127 let metadata = std::fs::metadata(&cache_path)?;
128 println!("{} Model already cached", "✓".green());
129 println!(" Path: {}", cache_path.display());
130 println!(" Size: {}", format_bytes(metadata.len()));
131 print_pull_usage(&cache_path, true);
132 return Ok(());
133 }
134
135 println!();
136 println!("{}", "Downloading (streaming)...".yellow());
137
138 let checksum = download_file_with_progress(&url, &cache_path)?;
139
140 println!();
141 println!("{} Downloaded successfully", "✓".green());
142 println!(" Path: {}", cache_path.display().to_string().green());
143 println!(" Size: {}", format_bytes(checksum.size).yellow());
144 println!(" Hash: {}", &checksum.blake3[..16]);
145
146 if extension == "safetensors" {
148 fetch_safetensors_companions(&cache_path, model_ref)?;
149 convert_safetensors_formats(&cache_path)?;
150 }
151
152 print_pull_usage(&cache_path, true);
153 Ok(())
154}
155
156fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
158 if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
159 return Ok(std::path::PathBuf::from(cache_home)
160 .join("pacha")
161 .join("models"));
162 }
163 Ok(dirs::home_dir()
164 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
165 .join(".cache")
166 .join("pacha")
167 .join("models"))
168}
169
170fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
172 println!("{} Model already cached", "✓".green());
173 let result = fetcher
174 .pull_quiet(model_ref)
175 .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
176
177 println!(" Path: {}", result.path.display());
178 println!(" Size: {}", result.size_human());
179 println!(" Format: {}", result.format.name());
180
181 ensure_safetensors_companions(&result)?;
182 print_pull_usage(&result.path, false);
183 Ok(())
184}
185
186fn download_single_model(
188 fetcher: &mut ModelFetcher,
189 model_ref: &str,
190) -> Result<pacha::fetcher::FetchResult> {
191 println!();
192 println!("{}", "Downloading...".yellow());
193
194 let result = fetcher
195 .pull(model_ref, |progress| {
196 let pct = progress.percent();
197 print!(
198 "\r [{:50}] {:5.1}% ({}/{})",
199 "=".repeat((pct / 2.0) as usize),
200 pct,
201 format_bytes(progress.downloaded_bytes),
202 format_bytes(progress.total_bytes)
203 );
204 io::stdout().flush().ok();
205 })
206 .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
207
208 println!();
209 println!();
210
211 if result.cache_hit {
212 println!("{} Model retrieved from cache", "✓".green());
213 } else {
214 println!("{} Downloaded successfully", "✓".green());
215 }
216
217 println!(" Path: {}", result.path.display().to_string().green());
218 println!(" Size: {}", result.size_human().yellow());
219 println!(" Format: {}", result.format.name());
220 println!(" Hash: {}", &result.hash[..16]);
221 Ok(result)
222}
223
224fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
226 if matches!(result.format, ModelFormat::SafeTensors(_)) {
227 fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
228 convert_safetensors_formats(&result.path)?;
229 }
230 Ok(())
231}
232
233fn print_pull_usage(path: &Path, show_serve: bool) {
235 println!();
236 println!("{}", "Usage:".cyan().bold());
237 println!(" apr run {}", path.display());
238 if show_serve {
239 println!(" apr serve {}", path.display());
240 }
241}
242
243fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
245 println!(
246 "Model: {}/{} ({} shards)",
247 org.cyan(),
248 repo.cyan(),
249 shard_files.len().to_string().yellow()
250 );
251
252 let cache_dir = resolve_shard_cache_dir(org, repo)?;
253 std::fs::create_dir_all(&cache_dir)?;
254
255 let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
256 let index_path = cache_dir.join("model.safetensors.index.json");
257
258 download_index_if_needed(&base_url, &index_path, force)?;
259
260 let manifest_path = cache_dir.join(".apr-manifest.json");
261 let existing_manifest = load_existing_manifest(&manifest_path, force);
262
263 let file_checksums = download_all_shards(
264 &cache_dir,
265 &base_url,
266 shard_files,
267 force,
268 existing_manifest.as_ref(),
269 )?;
270
271 download_companion_files(&cache_dir, &base_url, force)?;
272 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
273
274 println!();
275 println!("{} Downloaded successfully", "✓".green());
276 println!(" Path: {}", index_path.display().to_string().green());
277 println!(" Shards: {}", shard_files.len().to_string().yellow());
278
279 convert_safetensors_formats(&index_path)?;
280
281 println!();
282 println!("{}", "Usage:".cyan().bold());
283 println!(" apr run {}", index_path.display());
284 println!(" apr serve {}", index_path.display());
285 Ok(())
286}
287
288fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
290 Ok(dirs::home_dir()
291 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
292 .join(".apr")
293 .join("cache")
294 .join("hf")
295 .join(org)
296 .join(repo))
297}
298
299fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
301 if force || !index_path.exists() {
302 println!();
303 println!(" {} model.safetensors.index.json", "Downloading".yellow());
304 download_file(
305 &format!("{base_url}/model.safetensors.index.json"),
306 index_path,
307 )?;
308 } else {
309 println!(" {} model.safetensors.index.json (cached)", "✓".green());
310 }
311 Ok(())
312}
313
314fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
316 if force || !manifest_path.exists() {
317 return None;
318 }
319 std::fs::read_to_string(manifest_path)
320 .ok()
321 .and_then(|s| serde_json::from_str(&s).ok())
322}
323
324fn download_all_shards(
326 cache_dir: &Path,
327 base_url: &str,
328 shard_files: &[String],
329 force: bool,
330 existing_manifest: Option<&ShardManifest>,
331) -> Result<HashMap<String, FileChecksum>> {
332 let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
333 let total = shard_files.len();
334 for (i, shard_file) in shard_files.iter().enumerate() {
335 download_or_verify_shard(
336 cache_dir,
337 base_url,
338 shard_file,
339 i,
340 total,
341 force,
342 existing_manifest,
343 &mut file_checksums,
344 )?;
345 }
346 Ok(file_checksums)
347}
348
349fn download_or_verify_shard(
351 cache_dir: &Path,
352 base_url: &str,
353 shard_file: &str,
354 index: usize,
355 total: usize,
356 force: bool,
357 existing_manifest: Option<&ShardManifest>,
358 checksums: &mut HashMap<String, FileChecksum>,
359) -> Result<()> {
360 let shard_path = cache_dir.join(shard_file);
361
362 if !force && shard_path.exists() {
363 if let Some(manifest) = existing_manifest {
364 if let Some(expected) = manifest.files.get(shard_file) {
365 let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
366 if actual_size == expected.size {
367 checksums.insert(
368 shard_file.to_string(),
369 FileChecksum {
370 size: expected.size,
371 blake3: expected.blake3.clone(),
372 },
373 );
374 println!(
375 " {} [{}/{}] {} (cached, verified)",
376 "✓".green(),
377 index + 1,
378 total,
379 shard_file
380 );
381 return Ok(());
382 }
383 println!(
384 " {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
385 "⚠".yellow(),
386 index + 1,
387 total,
388 shard_file,
389 actual_size,
390 expected.size
391 );
392 }
394 } else {
395 println!(
396 " {} [{}/{}] {} (cached)",
397 "✓".green(),
398 index + 1,
399 total,
400 shard_file
401 );
402 return Ok(());
403 }
404 }
405
406 let shard_url = format!("{base_url}/{shard_file}");
407 print!(
408 " {} [{}/{}] {}...",
409 "↓".yellow(),
410 index + 1,
411 total,
412 shard_file
413 );
414 io::stdout().flush().ok();
415
416 let checksum = download_file_with_progress(&shard_url, &shard_path)?;
417 checksums.insert(shard_file.to_string(), checksum);
418 println!(" {}", "done".green());
419 Ok(())
420}
421
422fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
427 let companions = [
429 ("tokenizer.json", false),
430 ("config.json", true),
431 ("tokenizer_config.json", false),
432 ("tokenizer.model", false),
433 ];
434 for (filename, required) in &companions {
435 let companion_path = cache_dir.join(filename);
436 if !force && companion_path.exists() {
437 println!(" {} {} (cached)", "✓".green(), filename);
438 continue;
439 }
440
441 let url = format!("{base_url}/{filename}");
442 match download_file(&url, &companion_path) {
443 Ok(()) => println!(" {} {}", "✓".green(), filename),
444 Err(CliError::HttpNotFound(_)) if *required => {
445 return Err(CliError::ValidationFailed(format!(
446 "{filename} is required for inference but was not found (HTTP 404) at {url}"
447 )));
448 }
449 Err(CliError::HttpNotFound(_)) => {
450 println!(" {} {} (not found in repo)", "⚠".yellow(), filename);
451 }
452 Err(e) if *required => {
453 return Err(CliError::ValidationFailed(format!(
454 "{filename} is required for inference but download failed: {e}"
455 )));
456 }
457 Err(_) => println!(" {} {} (not available, optional)", "⚠".yellow(), filename),
458 }
459 }
460
461 let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
463 let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
464 if !has_tokenizer {
465 return Err(CliError::ValidationFailed(format!(
466 "No tokenizer found for this model. Tried: {}.\n\
467 The model may require a custom tokenizer not hosted in the repository.",
468 tokenizer_files.join(", ")
469 )));
470 }
471
472 Ok(())
473}
474
475fn write_shard_manifest(
477 manifest_path: &Path,
478 org: &str,
479 repo: &str,
480 file_checksums: HashMap<String, FileChecksum>,
481) -> Result<()> {
482 if file_checksums.is_empty() {
483 return Ok(());
484 }
485 let manifest = ShardManifest {
486 version: 1,
487 repo: format!("{org}/{repo}"),
488 files: file_checksums,
489 };
490 let manifest_json = serde_json::to_string_pretty(&manifest)
491 .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
492 std::fs::write(manifest_path, manifest_json)?;
493 println!(" {} .apr-manifest.json (integrity checksums)", "✓".green());
494 Ok(())
495}
496
497include!("pull_list.rs");
498include!("pull_remove_resolve_model.rs");
499include!("pull_extract_shard.rs");
500include!("pull_04.rs");