1use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12#[derive(Debug)]
17enum ResolvedModel {
18 SingleFile(String),
20 Sharded {
22 org: String,
23 repo: String,
24 shard_files: Vec<String>,
25 },
26}
27
28#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34 pub version: u32,
35 pub repo: String,
36 pub files: HashMap<String, FileChecksum>,
37}
38
39#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42 pub size: u64,
43 pub blake3: String,
44}
45
46pub fn run(model_ref: &str, force: bool) -> Result<()> {
48 contract_pre_pull_cache_integrity!();
49 println!("{}", "=== APR Pull ===".cyan().bold());
50 println!();
51
52 let resolved = resolve_hf_model(model_ref)?;
54
55 let result = match resolved {
56 ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
57 ResolvedModel::Sharded {
58 ref org,
59 ref repo,
60 ref shard_files,
61 } => run_sharded(org, repo, shard_files, force),
62 };
63 if let Ok(ref r) = result {
64 contract_post_pull_cache_integrity!(r);
65 }
66 result
67}
68
69fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
75 println!("Model: {}", model_ref.cyan());
76
77 if model_ref.starts_with("hf://") {
79 return run_single_file_streaming(model_ref, force);
80 }
81
82 let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
83 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
84 })?;
85
86 if !force && fetcher.is_cached(model_ref) {
87 return handle_cached_model(&mut fetcher, model_ref);
88 }
89
90 let result = download_single_model(&mut fetcher, model_ref)?;
91 ensure_safetensors_companions(&result)?;
92 print_pull_usage(&result.path, true);
93 Ok(())
94}
95
96fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
103 let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
104 let parts: Vec<&str> = path.split('/').collect();
105 if parts.len() < 3 {
106 return Err(CliError::ValidationFailed(format!(
107 "HuggingFace URI must include a filename: {model_ref}"
108 )));
109 }
110
111 let filename = parts[2..].join("/");
112 let url = format!(
113 "https://huggingface.co/{}/{}/resolve/main/{}",
114 parts[0], parts[1], filename
115 );
116
117 let cache_dir = get_pacha_cache_dir()?;
119 std::fs::create_dir_all(&cache_dir)?;
120
121 let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
123 let extension = std::path::Path::new(&filename)
124 .extension()
125 .and_then(|e| e.to_str())
126 .unwrap_or("bin");
127 let cache_filename = format!("{}.{}", &uri_hash[..16], extension);
128 let cache_path = cache_dir.join(&cache_filename);
129
130 if !force && cache_path.exists() {
131 let metadata = std::fs::metadata(&cache_path)?;
132 println!("{} Model already cached", "✓".green());
133 println!(" Path: {}", cache_path.display());
134 println!(" Size: {}", format_bytes(metadata.len()));
135 print_pull_usage(&cache_path, true);
136 return Ok(());
137 }
138
139 println!();
140 println!("{}", "Downloading (streaming)...".yellow());
141
142 let checksum = download_file_with_progress(&url, &cache_path)?;
143
144 println!();
145 println!("{} Downloaded successfully", "✓".green());
146 println!(" Path: {}", cache_path.display().to_string().green());
147 println!(" Size: {}", format_bytes(checksum.size).yellow());
148 println!(" Hash: {}", &checksum.blake3[..16]);
149
150 if extension == "safetensors" {
152 fetch_safetensors_companions(&cache_path, model_ref)?;
153 convert_safetensors_formats(&cache_path)?;
154 }
155
156 print_pull_usage(&cache_path, true);
157 Ok(())
158}
159
160fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
162 if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
163 return Ok(std::path::PathBuf::from(cache_home)
164 .join("pacha")
165 .join("models"));
166 }
167 Ok(dirs::home_dir()
168 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
169 .join(".cache")
170 .join("pacha")
171 .join("models"))
172}
173
174fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
176 println!("{} Model already cached", "✓".green());
177 let result = fetcher
178 .pull_quiet(model_ref)
179 .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
180
181 println!(" Path: {}", result.path.display());
182 println!(" Size: {}", result.size_human());
183 println!(" Format: {}", result.format.name());
184
185 ensure_safetensors_companions(&result)?;
186 print_pull_usage(&result.path, false);
187 Ok(())
188}
189
190fn download_single_model(
192 fetcher: &mut ModelFetcher,
193 model_ref: &str,
194) -> Result<pacha::fetcher::FetchResult> {
195 println!();
196 println!("{}", "Downloading...".yellow());
197
198 let result = fetcher
199 .pull(model_ref, |progress| {
200 let pct = progress.percent();
201 print!(
202 "\r [{:50}] {:5.1}% ({}/{})",
203 "=".repeat((pct / 2.0) as usize),
204 pct,
205 format_bytes(progress.downloaded_bytes),
206 format_bytes(progress.total_bytes)
207 );
208 io::stdout().flush().ok();
209 })
210 .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
211
212 println!();
213 println!();
214
215 if result.cache_hit {
216 println!("{} Model retrieved from cache", "✓".green());
217 } else {
218 println!("{} Downloaded successfully", "✓".green());
219 }
220
221 println!(" Path: {}", result.path.display().to_string().green());
222 println!(" Size: {}", result.size_human().yellow());
223 println!(" Format: {}", result.format.name());
224 println!(" Hash: {}", &result.hash[..16]);
225 Ok(result)
226}
227
228fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
230 if matches!(result.format, ModelFormat::SafeTensors(_)) {
231 fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
232 convert_safetensors_formats(&result.path)?;
233 }
234 Ok(())
235}
236
237fn print_pull_usage(path: &Path, show_serve: bool) {
239 println!();
240 println!("{}", "Usage:".cyan().bold());
241 println!(" apr run {}", path.display());
242 if show_serve {
243 println!(" apr serve {}", path.display());
244 }
245}
246
247fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
249 println!(
250 "Model: {}/{} ({} shards)",
251 org.cyan(),
252 repo.cyan(),
253 shard_files.len().to_string().yellow()
254 );
255
256 let cache_dir = resolve_shard_cache_dir(org, repo)?;
257 std::fs::create_dir_all(&cache_dir)?;
258
259 let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
260 let index_path = cache_dir.join("model.safetensors.index.json");
261
262 download_index_if_needed(&base_url, &index_path, force)?;
263
264 let manifest_path = cache_dir.join(".apr-manifest.json");
265 let existing_manifest = load_existing_manifest(&manifest_path, force);
266
267 let file_checksums = download_all_shards(
268 &cache_dir,
269 &base_url,
270 shard_files,
271 force,
272 existing_manifest.as_ref(),
273 )?;
274
275 download_companion_files(&cache_dir, &base_url, force)?;
276 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
277
278 println!();
279 println!("{} Downloaded successfully", "✓".green());
280 println!(" Path: {}", index_path.display().to_string().green());
281 println!(" Shards: {}", shard_files.len().to_string().yellow());
282
283 convert_safetensors_formats(&index_path)?;
284
285 println!();
286 println!("{}", "Usage:".cyan().bold());
287 println!(" apr run {}", index_path.display());
288 println!(" apr serve {}", index_path.display());
289 Ok(())
290}
291
292fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
294 Ok(dirs::home_dir()
295 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
296 .join(".apr")
297 .join("cache")
298 .join("hf")
299 .join(org)
300 .join(repo))
301}
302
303fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
305 if force || !index_path.exists() {
306 println!();
307 println!(" {} model.safetensors.index.json", "Downloading".yellow());
308 download_file(
309 &format!("{base_url}/model.safetensors.index.json"),
310 index_path,
311 )?;
312 } else {
313 println!(" {} model.safetensors.index.json (cached)", "✓".green());
314 }
315 Ok(())
316}
317
318fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
320 if force || !manifest_path.exists() {
321 return None;
322 }
323 std::fs::read_to_string(manifest_path)
324 .ok()
325 .and_then(|s| serde_json::from_str(&s).ok())
326}
327
328fn download_all_shards(
330 cache_dir: &Path,
331 base_url: &str,
332 shard_files: &[String],
333 force: bool,
334 existing_manifest: Option<&ShardManifest>,
335) -> Result<HashMap<String, FileChecksum>> {
336 let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
337 let total = shard_files.len();
338 for (i, shard_file) in shard_files.iter().enumerate() {
339 download_or_verify_shard(
340 cache_dir,
341 base_url,
342 shard_file,
343 i,
344 total,
345 force,
346 existing_manifest,
347 &mut file_checksums,
348 )?;
349 }
350 Ok(file_checksums)
351}
352
353fn download_or_verify_shard(
355 cache_dir: &Path,
356 base_url: &str,
357 shard_file: &str,
358 index: usize,
359 total: usize,
360 force: bool,
361 existing_manifest: Option<&ShardManifest>,
362 checksums: &mut HashMap<String, FileChecksum>,
363) -> Result<()> {
364 let shard_path = cache_dir.join(shard_file);
365
366 if !force && shard_path.exists() {
367 if let Some(manifest) = existing_manifest {
368 if let Some(expected) = manifest.files.get(shard_file) {
369 let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
370 if actual_size == expected.size {
371 checksums.insert(
372 shard_file.to_string(),
373 FileChecksum {
374 size: expected.size,
375 blake3: expected.blake3.clone(),
376 },
377 );
378 println!(
379 " {} [{}/{}] {} (cached, verified)",
380 "✓".green(),
381 index + 1,
382 total,
383 shard_file
384 );
385 return Ok(());
386 }
387 println!(
388 " {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
389 "⚠".yellow(),
390 index + 1,
391 total,
392 shard_file,
393 actual_size,
394 expected.size
395 );
396 }
398 } else {
399 println!(
400 " {} [{}/{}] {} (cached)",
401 "✓".green(),
402 index + 1,
403 total,
404 shard_file
405 );
406 return Ok(());
407 }
408 }
409
410 let shard_url = format!("{base_url}/{shard_file}");
411 print!(
412 " {} [{}/{}] {}...",
413 "↓".yellow(),
414 index + 1,
415 total,
416 shard_file
417 );
418 io::stdout().flush().ok();
419
420 let checksum = download_file_with_progress(&shard_url, &shard_path)?;
421 checksums.insert(shard_file.to_string(), checksum);
422 println!(" {}", "done".green());
423 Ok(())
424}
425
426fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
431 let companions = [
433 ("tokenizer.json", false),
434 ("config.json", true),
435 ("tokenizer_config.json", false),
436 ("tokenizer.model", false),
437 ];
438 for (filename, required) in &companions {
439 let companion_path = cache_dir.join(filename);
440 if !force && companion_path.exists() {
441 println!(" {} {} (cached)", "✓".green(), filename);
442 continue;
443 }
444
445 let url = format!("{base_url}/{filename}");
446 match download_file(&url, &companion_path) {
447 Ok(()) => println!(" {} {}", "✓".green(), filename),
448 Err(CliError::HttpNotFound(_)) if *required => {
449 return Err(CliError::ValidationFailed(format!(
450 "{filename} is required for inference but was not found (HTTP 404) at {url}"
451 )));
452 }
453 Err(CliError::HttpNotFound(_)) => {
454 println!(" {} {} (not found in repo)", "⚠".yellow(), filename);
455 }
456 Err(e) if *required => {
457 return Err(CliError::ValidationFailed(format!(
458 "{filename} is required for inference but download failed: {e}"
459 )));
460 }
461 Err(_) => println!(" {} {} (not available, optional)", "⚠".yellow(), filename),
462 }
463 }
464
465 let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
467 let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
468 if !has_tokenizer {
469 return Err(CliError::ValidationFailed(format!(
470 "No tokenizer found for this model. Tried: {}.\n\
471 The model may require a custom tokenizer not hosted in the repository.",
472 tokenizer_files.join(", ")
473 )));
474 }
475
476 Ok(())
477}
478
479fn write_shard_manifest(
481 manifest_path: &Path,
482 org: &str,
483 repo: &str,
484 file_checksums: HashMap<String, FileChecksum>,
485) -> Result<()> {
486 if file_checksums.is_empty() {
487 return Ok(());
488 }
489 let manifest = ShardManifest {
490 version: 1,
491 repo: format!("{org}/{repo}"),
492 files: file_checksums,
493 };
494 let manifest_json = serde_json::to_string_pretty(&manifest)
495 .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
496 std::fs::write(manifest_path, manifest_json)?;
497 println!(" {} .apr-manifest.json (integrity checksums)", "✓".green());
498 Ok(())
499}
500
501include!("pull_list.rs");
502include!("pull_remove_resolve_model.rs");
503include!("pull_extract_shard.rs");
504include!("pull_04.rs");