1use crate::error::{CliError, Result};
12use colored::Colorize;
13use pacha::fetcher::{FetchConfig, ModelFetcher};
14use pacha::format::ModelFormat;
15use serde::{Deserialize, Serialize};
16use std::collections::{HashMap, HashSet};
17use std::io::{self, Read, Write};
18use std::path::Path;
19
20#[derive(Debug)]
25enum ResolvedModel {
26 SingleFile(String),
28 Sharded {
30 org: String,
31 repo: String,
32 shard_files: Vec<String>,
33 },
34}
35
36#[derive(Debug, Serialize, Deserialize)]
41pub struct ShardManifest {
42 pub version: u32,
43 pub repo: String,
44 pub files: HashMap<String, FileChecksum>,
45}
46
47#[derive(Debug, Serialize, Deserialize)]
49pub struct FileChecksum {
50 pub size: u64,
51 pub blake3: String,
52}
53
54pub fn run(model_ref: &str, force: bool) -> Result<()> {
56 println!("{}", "=== APR Pull ===".cyan().bold());
57 println!();
58
59 let resolved = resolve_hf_model(model_ref)?;
61
62 match resolved {
63 ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
64 ResolvedModel::Sharded {
65 ref org,
66 ref repo,
67 ref shard_files,
68 } => run_sharded(org, repo, shard_files, force),
69 }
70}
71
72fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
78 println!("Model: {}", model_ref.cyan());
79
80 if model_ref.starts_with("hf://") {
82 return run_single_file_streaming(model_ref, force);
83 }
84
85 let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
86 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
87 })?;
88
89 if !force && fetcher.is_cached(model_ref) {
90 return handle_cached_model(&mut fetcher, model_ref);
91 }
92
93 let result = download_single_model(&mut fetcher, model_ref)?;
94 ensure_safetensors_companions(&result)?;
95 print_pull_usage(&result.path, true);
96 Ok(())
97}
98
99fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
106 let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
107 let parts: Vec<&str> = path.split('/').collect();
108 if parts.len() < 3 {
109 return Err(CliError::ValidationFailed(format!(
110 "HuggingFace URI must include a filename: {model_ref}"
111 )));
112 }
113
114 let filename = parts[2..].join("/");
115 let url = format!(
116 "https://huggingface.co/{}/{}/resolve/main/{}",
117 parts[0], parts[1], filename
118 );
119
120 let cache_dir = get_pacha_cache_dir()?;
122 std::fs::create_dir_all(&cache_dir)?;
123
124 let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
126 let extension = std::path::Path::new(&filename)
127 .extension()
128 .and_then(|e| e.to_str())
129 .unwrap_or("bin");
130 let cache_filename = format!("{}.{}", &uri_hash[..16], extension);
131 let cache_path = cache_dir.join(&cache_filename);
132
133 if !force && cache_path.exists() {
134 let metadata = std::fs::metadata(&cache_path)?;
135 println!("{} Model already cached", "✓".green());
136 println!(" Path: {}", cache_path.display());
137 println!(" Size: {}", format_bytes(metadata.len()));
138 print_pull_usage(&cache_path, true);
139 return Ok(());
140 }
141
142 println!();
143 println!("{}", "Downloading (streaming)...".yellow());
144
145 let checksum = download_file_with_progress(&url, &cache_path)?;
146
147 println!();
148 println!("{} Downloaded successfully", "✓".green());
149 println!(" Path: {}", cache_path.display().to_string().green());
150 println!(" Size: {}", format_bytes(checksum.size).yellow());
151 println!(" Hash: {}", &checksum.blake3[..16]);
152
153 if extension == "safetensors" {
155 fetch_safetensors_companions(&cache_path, model_ref)?;
156 convert_safetensors_formats(&cache_path)?;
157 }
158
159 print_pull_usage(&cache_path, true);
160 Ok(())
161}
162
163fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
165 if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
166 return Ok(std::path::PathBuf::from(cache_home)
167 .join("pacha")
168 .join("models"));
169 }
170 Ok(dirs::home_dir()
171 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
172 .join(".cache")
173 .join("pacha")
174 .join("models"))
175}
176
177fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
179 println!("{} Model already cached", "✓".green());
180 let result = fetcher
181 .pull_quiet(model_ref)
182 .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
183
184 println!(" Path: {}", result.path.display());
185 println!(" Size: {}", result.size_human());
186 println!(" Format: {}", result.format.name());
187
188 ensure_safetensors_companions(&result)?;
189 print_pull_usage(&result.path, false);
190 Ok(())
191}
192
193fn download_single_model(
195 fetcher: &mut ModelFetcher,
196 model_ref: &str,
197) -> Result<pacha::fetcher::FetchResult> {
198 println!();
199 println!("{}", "Downloading...".yellow());
200
201 let result = fetcher
202 .pull(model_ref, |progress| {
203 let pct = progress.percent();
204 print!(
205 "\r [{:50}] {:5.1}% ({}/{})",
206 "=".repeat((pct / 2.0) as usize),
207 pct,
208 format_bytes(progress.downloaded_bytes),
209 format_bytes(progress.total_bytes)
210 );
211 io::stdout().flush().ok();
212 })
213 .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
214
215 println!();
216 println!();
217
218 if result.cache_hit {
219 println!("{} Model retrieved from cache", "✓".green());
220 } else {
221 println!("{} Downloaded successfully", "✓".green());
222 }
223
224 println!(" Path: {}", result.path.display().to_string().green());
225 println!(" Size: {}", result.size_human().yellow());
226 println!(" Format: {}", result.format.name());
227 println!(" Hash: {}", &result.hash[..16]);
228 Ok(result)
229}
230
231fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
233 if matches!(result.format, ModelFormat::SafeTensors(_)) {
234 fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
235 convert_safetensors_formats(&result.path)?;
236 }
237 Ok(())
238}
239
240fn print_pull_usage(path: &Path, show_serve: bool) {
242 println!();
243 println!("{}", "Usage:".cyan().bold());
244 println!(" apr run {}", path.display());
245 if show_serve {
246 println!(" apr serve {}", path.display());
247 }
248}
249
250fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
252 println!(
253 "Model: {}/{} ({} shards)",
254 org.cyan(),
255 repo.cyan(),
256 shard_files.len().to_string().yellow()
257 );
258
259 let cache_dir = resolve_shard_cache_dir(org, repo)?;
260 std::fs::create_dir_all(&cache_dir)?;
261
262 let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
263 let index_path = cache_dir.join("model.safetensors.index.json");
264
265 download_index_if_needed(&base_url, &index_path, force)?;
266
267 let manifest_path = cache_dir.join(".apr-manifest.json");
268 let existing_manifest = load_existing_manifest(&manifest_path, force);
269
270 let file_checksums = download_all_shards(
271 &cache_dir,
272 &base_url,
273 shard_files,
274 force,
275 existing_manifest.as_ref(),
276 )?;
277
278 download_companion_files(&cache_dir, &base_url, force)?;
279 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
280
281 println!();
282 println!("{} Downloaded successfully", "✓".green());
283 println!(" Path: {}", index_path.display().to_string().green());
284 println!(" Shards: {}", shard_files.len().to_string().yellow());
285
286 convert_safetensors_formats(&index_path)?;
287
288 println!();
289 println!("{}", "Usage:".cyan().bold());
290 println!(" apr run {}", index_path.display());
291 println!(" apr serve {}", index_path.display());
292 Ok(())
293}
294
295fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
297 Ok(dirs::home_dir()
298 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
299 .join(".apr")
300 .join("cache")
301 .join("hf")
302 .join(org)
303 .join(repo))
304}
305
306fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
308 if force || !index_path.exists() {
309 println!();
310 println!(" {} model.safetensors.index.json", "Downloading".yellow());
311 download_file(
312 &format!("{base_url}/model.safetensors.index.json"),
313 index_path,
314 )?;
315 } else {
316 println!(" {} model.safetensors.index.json (cached)", "✓".green());
317 }
318 Ok(())
319}
320
321fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
323 if force || !manifest_path.exists() {
324 return None;
325 }
326 std::fs::read_to_string(manifest_path)
327 .ok()
328 .and_then(|s| serde_json::from_str(&s).ok())
329}
330
331fn download_all_shards(
333 cache_dir: &Path,
334 base_url: &str,
335 shard_files: &[String],
336 force: bool,
337 existing_manifest: Option<&ShardManifest>,
338) -> Result<HashMap<String, FileChecksum>> {
339 let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
340 let total = shard_files.len();
341 for (i, shard_file) in shard_files.iter().enumerate() {
342 download_or_verify_shard(
343 cache_dir,
344 base_url,
345 shard_file,
346 i,
347 total,
348 force,
349 existing_manifest,
350 &mut file_checksums,
351 )?;
352 }
353 Ok(file_checksums)
354}
355
356fn download_or_verify_shard(
358 cache_dir: &Path,
359 base_url: &str,
360 shard_file: &str,
361 index: usize,
362 total: usize,
363 force: bool,
364 existing_manifest: Option<&ShardManifest>,
365 checksums: &mut HashMap<String, FileChecksum>,
366) -> Result<()> {
367 let shard_path = cache_dir.join(shard_file);
368
369 if !force && shard_path.exists() {
370 if let Some(manifest) = existing_manifest {
371 if let Some(expected) = manifest.files.get(shard_file) {
372 let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
373 if actual_size == expected.size {
374 checksums.insert(
375 shard_file.to_string(),
376 FileChecksum {
377 size: expected.size,
378 blake3: expected.blake3.clone(),
379 },
380 );
381 println!(
382 " {} [{}/{}] {} (cached, verified)",
383 "✓".green(),
384 index + 1,
385 total,
386 shard_file
387 );
388 return Ok(());
389 }
390 println!(
391 " {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
392 "⚠".yellow(),
393 index + 1,
394 total,
395 shard_file,
396 actual_size,
397 expected.size
398 );
399 }
401 } else {
402 println!(
403 " {} [{}/{}] {} (cached)",
404 "✓".green(),
405 index + 1,
406 total,
407 shard_file
408 );
409 return Ok(());
410 }
411 }
412
413 let shard_url = format!("{base_url}/{shard_file}");
414 print!(
415 " {} [{}/{}] {}...",
416 "↓".yellow(),
417 index + 1,
418 total,
419 shard_file
420 );
421 io::stdout().flush().ok();
422
423 let checksum = download_file_with_progress(&shard_url, &shard_path)?;
424 checksums.insert(shard_file.to_string(), checksum);
425 println!(" {}", "done".green());
426 Ok(())
427}
428
429fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
431 let companions = [
432 ("tokenizer.json", true),
433 ("config.json", true),
434 ("tokenizer_config.json", false),
435 ];
436 for (filename, required) in &companions {
437 let companion_path = cache_dir.join(filename);
438 if !force && companion_path.exists() {
439 println!(" {} {} (cached)", "✓".green(), filename);
440 continue;
441 }
442
443 let url = format!("{base_url}/{filename}");
444 match download_file(&url, &companion_path) {
445 Ok(()) => println!(" {} {}", "✓".green(), filename),
446 Err(e) if *required => {
447 return Err(CliError::ValidationFailed(format!(
448 "{filename} is required for inference but download failed: {e}"
449 )));
450 }
451 Err(_) => println!(" {} {} (not available, optional)", "⚠".yellow(), filename),
452 }
453 }
454 Ok(())
455}
456
457fn write_shard_manifest(
459 manifest_path: &Path,
460 org: &str,
461 repo: &str,
462 file_checksums: HashMap<String, FileChecksum>,
463) -> Result<()> {
464 if file_checksums.is_empty() {
465 return Ok(());
466 }
467 let manifest = ShardManifest {
468 version: 1,
469 repo: format!("{org}/{repo}"),
470 files: file_checksums,
471 };
472 let manifest_json = serde_json::to_string_pretty(&manifest)
473 .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
474 std::fs::write(manifest_path, manifest_json)?;
475 println!(" {} .apr-manifest.json (integrity checksums)", "✓".green());
476 Ok(())
477}
478
479#[allow(clippy::disallowed_methods)]
482pub fn list(json: bool) -> Result<()> {
483 let fetcher = ModelFetcher::new().map_err(|e| {
484 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
485 })?;
486
487 let models = fetcher.list();
488
489 if json {
491 let models_json: Vec<serde_json::Value> = models
492 .iter()
493 .map(|m| {
494 serde_json::json!({
495 "name": m.name,
496 "size_bytes": m.size_bytes,
497 "format": m.format.name(),
498 "path": m.path.display().to_string(),
499 })
500 })
501 .collect();
502 let stats = fetcher.stats();
503 let output = serde_json::json!({
504 "models": models_json,
505 "total": models.len(),
506 "total_size_bytes": stats.total_size_bytes,
507 });
508 println!(
509 "{}",
510 serde_json::to_string_pretty(&output).unwrap_or_default()
511 );
512 return Ok(());
513 }
514
515 println!("{}", "=== Cached Models ===".cyan().bold());
516 println!();
517
518 if models.is_empty() {
519 println!("{}", "No cached models found.".dimmed());
520 println!();
521 println!("Pull a model with:");
522 println!(" apr pull hf://Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf");
523 println!();
524 println!("Or run directly (auto-downloads):");
525 println!(" apr run hf://Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf");
526 return Ok(());
527 }
528
529 println!(
531 "{:<40} {:<12} {:<12} {}",
532 "NAME".dimmed(),
533 "SIZE".dimmed(),
534 "FORMAT".dimmed(),
535 "PATH".dimmed()
536 );
537 println!("{}", "-".repeat(104).dimmed());
538
539 for model in &models {
540 let size = format_bytes(model.size_bytes);
541 let format = model.format.name();
542 let name = if model.name.len() > 38 {
543 format!("{}...", &model.name[..35])
544 } else {
545 model.name.clone()
546 };
547
548 println!(
549 "{:<40} {:<12} {:<12} {}",
550 name.cyan(),
551 size.yellow(),
552 format,
553 model.path.display().to_string().dimmed()
554 );
555 }
556
557 println!();
558
559 let stats = fetcher.stats();
561 println!(
562 "Total: {} models, {} used",
563 models.len(),
564 format_bytes(stats.total_size_bytes)
565 );
566
567 Ok(())
568}
569
570include!("pull_remove_resolve_model.rs");
571include!("pull_extract_shard.rs");
572include!("pull_04.rs");