1use crate::error::{CliError, Result};
4use colored::Colorize;
5use pacha::fetcher::{FetchConfig, ModelFetcher};
6use pacha::format::ModelFormat;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::io::{self, Read, Write};
10use std::path::Path;
11
12#[derive(Debug)]
17enum ResolvedModel {
18 SingleFile(String),
20 Sharded {
22 org: String,
23 repo: String,
24 shard_files: Vec<String>,
25 },
26}
27
28#[derive(Debug, Serialize, Deserialize)]
33pub struct ShardManifest {
34 pub version: u32,
35 pub repo: String,
36 pub files: HashMap<String, FileChecksum>,
37}
38
39#[derive(Debug, Serialize, Deserialize)]
41pub struct FileChecksum {
42 pub size: u64,
43 pub blake3: String,
44}
45
46pub fn run(model_ref: &str, force: bool) -> Result<()> {
48 println!("{}", "=== APR Pull ===".cyan().bold());
49 println!();
50
51 let resolved = resolve_hf_model(model_ref)?;
53
54 match resolved {
55 ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
56 ResolvedModel::Sharded {
57 ref org,
58 ref repo,
59 ref shard_files,
60 } => run_sharded(org, repo, shard_files, force),
61 }
62}
63
64fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
70 println!("Model: {}", model_ref.cyan());
71
72 if model_ref.starts_with("hf://") {
74 return run_single_file_streaming(model_ref, force);
75 }
76
77 let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
78 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
79 })?;
80
81 if !force && fetcher.is_cached(model_ref) {
82 return handle_cached_model(&mut fetcher, model_ref);
83 }
84
85 let result = download_single_model(&mut fetcher, model_ref)?;
86 ensure_safetensors_companions(&result)?;
87 print_pull_usage(&result.path, true);
88 Ok(())
89}
90
91fn run_single_file_streaming(model_ref: &str, force: bool) -> Result<()> {
98 let path = model_ref.strip_prefix("hf://").unwrap_or(model_ref);
99 let parts: Vec<&str> = path.split('/').collect();
100 if parts.len() < 3 {
101 return Err(CliError::ValidationFailed(format!(
102 "HuggingFace URI must include a filename: {model_ref}"
103 )));
104 }
105
106 let filename = parts[2..].join("/");
107 let url = format!(
108 "https://huggingface.co/{}/{}/resolve/main/{}",
109 parts[0], parts[1], filename
110 );
111
112 let cache_dir = get_pacha_cache_dir()?;
114 std::fs::create_dir_all(&cache_dir)?;
115
116 let uri_hash = blake3::hash(model_ref.as_bytes()).to_hex().to_string();
118 let extension = std::path::Path::new(&filename)
119 .extension()
120 .and_then(|e| e.to_str())
121 .unwrap_or("bin");
122 let cache_filename = format!("{}.{}", &uri_hash[..16], extension);
123 let cache_path = cache_dir.join(&cache_filename);
124
125 if !force && cache_path.exists() {
126 let metadata = std::fs::metadata(&cache_path)?;
127 println!("{} Model already cached", "✓".green());
128 println!(" Path: {}", cache_path.display());
129 println!(" Size: {}", format_bytes(metadata.len()));
130 print_pull_usage(&cache_path, true);
131 return Ok(());
132 }
133
134 println!();
135 println!("{}", "Downloading (streaming)...".yellow());
136
137 let checksum = download_file_with_progress(&url, &cache_path)?;
138
139 println!();
140 println!("{} Downloaded successfully", "✓".green());
141 println!(" Path: {}", cache_path.display().to_string().green());
142 println!(" Size: {}", format_bytes(checksum.size).yellow());
143 println!(" Hash: {}", &checksum.blake3[..16]);
144
145 if extension == "safetensors" {
147 fetch_safetensors_companions(&cache_path, model_ref)?;
148 convert_safetensors_formats(&cache_path)?;
149 }
150
151 print_pull_usage(&cache_path, true);
152 Ok(())
153}
154
155fn get_pacha_cache_dir() -> Result<std::path::PathBuf> {
157 if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
158 return Ok(std::path::PathBuf::from(cache_home)
159 .join("pacha")
160 .join("models"));
161 }
162 Ok(dirs::home_dir()
163 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
164 .join(".cache")
165 .join("pacha")
166 .join("models"))
167}
168
169fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
171 println!("{} Model already cached", "✓".green());
172 let result = fetcher
173 .pull_quiet(model_ref)
174 .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
175
176 println!(" Path: {}", result.path.display());
177 println!(" Size: {}", result.size_human());
178 println!(" Format: {}", result.format.name());
179
180 ensure_safetensors_companions(&result)?;
181 print_pull_usage(&result.path, false);
182 Ok(())
183}
184
185fn download_single_model(
187 fetcher: &mut ModelFetcher,
188 model_ref: &str,
189) -> Result<pacha::fetcher::FetchResult> {
190 println!();
191 println!("{}", "Downloading...".yellow());
192
193 let result = fetcher
194 .pull(model_ref, |progress| {
195 let pct = progress.percent();
196 print!(
197 "\r [{:50}] {:5.1}% ({}/{})",
198 "=".repeat((pct / 2.0) as usize),
199 pct,
200 format_bytes(progress.downloaded_bytes),
201 format_bytes(progress.total_bytes)
202 );
203 io::stdout().flush().ok();
204 })
205 .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
206
207 println!();
208 println!();
209
210 if result.cache_hit {
211 println!("{} Model retrieved from cache", "✓".green());
212 } else {
213 println!("{} Downloaded successfully", "✓".green());
214 }
215
216 println!(" Path: {}", result.path.display().to_string().green());
217 println!(" Size: {}", result.size_human().yellow());
218 println!(" Format: {}", result.format.name());
219 println!(" Hash: {}", &result.hash[..16]);
220 Ok(result)
221}
222
223fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
225 if matches!(result.format, ModelFormat::SafeTensors(_)) {
226 fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
227 convert_safetensors_formats(&result.path)?;
228 }
229 Ok(())
230}
231
232fn print_pull_usage(path: &Path, show_serve: bool) {
234 println!();
235 println!("{}", "Usage:".cyan().bold());
236 println!(" apr run {}", path.display());
237 if show_serve {
238 println!(" apr serve {}", path.display());
239 }
240}
241
242fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
244 println!(
245 "Model: {}/{} ({} shards)",
246 org.cyan(),
247 repo.cyan(),
248 shard_files.len().to_string().yellow()
249 );
250
251 let cache_dir = resolve_shard_cache_dir(org, repo)?;
252 std::fs::create_dir_all(&cache_dir)?;
253
254 let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
255 let index_path = cache_dir.join("model.safetensors.index.json");
256
257 download_index_if_needed(&base_url, &index_path, force)?;
258
259 let manifest_path = cache_dir.join(".apr-manifest.json");
260 let existing_manifest = load_existing_manifest(&manifest_path, force);
261
262 let file_checksums = download_all_shards(
263 &cache_dir,
264 &base_url,
265 shard_files,
266 force,
267 existing_manifest.as_ref(),
268 )?;
269
270 download_companion_files(&cache_dir, &base_url, force)?;
271 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
272
273 println!();
274 println!("{} Downloaded successfully", "✓".green());
275 println!(" Path: {}", index_path.display().to_string().green());
276 println!(" Shards: {}", shard_files.len().to_string().yellow());
277
278 convert_safetensors_formats(&index_path)?;
279
280 println!();
281 println!("{}", "Usage:".cyan().bold());
282 println!(" apr run {}", index_path.display());
283 println!(" apr serve {}", index_path.display());
284 Ok(())
285}
286
287fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
289 Ok(dirs::home_dir()
290 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
291 .join(".apr")
292 .join("cache")
293 .join("hf")
294 .join(org)
295 .join(repo))
296}
297
298fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
300 if force || !index_path.exists() {
301 println!();
302 println!(" {} model.safetensors.index.json", "Downloading".yellow());
303 download_file(
304 &format!("{base_url}/model.safetensors.index.json"),
305 index_path,
306 )?;
307 } else {
308 println!(" {} model.safetensors.index.json (cached)", "✓".green());
309 }
310 Ok(())
311}
312
313fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
315 if force || !manifest_path.exists() {
316 return None;
317 }
318 std::fs::read_to_string(manifest_path)
319 .ok()
320 .and_then(|s| serde_json::from_str(&s).ok())
321}
322
323fn download_all_shards(
325 cache_dir: &Path,
326 base_url: &str,
327 shard_files: &[String],
328 force: bool,
329 existing_manifest: Option<&ShardManifest>,
330) -> Result<HashMap<String, FileChecksum>> {
331 let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
332 let total = shard_files.len();
333 for (i, shard_file) in shard_files.iter().enumerate() {
334 download_or_verify_shard(
335 cache_dir,
336 base_url,
337 shard_file,
338 i,
339 total,
340 force,
341 existing_manifest,
342 &mut file_checksums,
343 )?;
344 }
345 Ok(file_checksums)
346}
347
348fn download_or_verify_shard(
350 cache_dir: &Path,
351 base_url: &str,
352 shard_file: &str,
353 index: usize,
354 total: usize,
355 force: bool,
356 existing_manifest: Option<&ShardManifest>,
357 checksums: &mut HashMap<String, FileChecksum>,
358) -> Result<()> {
359 let shard_path = cache_dir.join(shard_file);
360
361 if !force && shard_path.exists() {
362 if let Some(manifest) = existing_manifest {
363 if let Some(expected) = manifest.files.get(shard_file) {
364 let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
365 if actual_size == expected.size {
366 checksums.insert(
367 shard_file.to_string(),
368 FileChecksum {
369 size: expected.size,
370 blake3: expected.blake3.clone(),
371 },
372 );
373 println!(
374 " {} [{}/{}] {} (cached, verified)",
375 "✓".green(),
376 index + 1,
377 total,
378 shard_file
379 );
380 return Ok(());
381 }
382 println!(
383 " {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
384 "⚠".yellow(),
385 index + 1,
386 total,
387 shard_file,
388 actual_size,
389 expected.size
390 );
391 }
393 } else {
394 println!(
395 " {} [{}/{}] {} (cached)",
396 "✓".green(),
397 index + 1,
398 total,
399 shard_file
400 );
401 return Ok(());
402 }
403 }
404
405 let shard_url = format!("{base_url}/{shard_file}");
406 print!(
407 " {} [{}/{}] {}...",
408 "↓".yellow(),
409 index + 1,
410 total,
411 shard_file
412 );
413 io::stdout().flush().ok();
414
415 let checksum = download_file_with_progress(&shard_url, &shard_path)?;
416 checksums.insert(shard_file.to_string(), checksum);
417 println!(" {}", "done".green());
418 Ok(())
419}
420
421fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
426 let companions = [
428 ("tokenizer.json", false),
429 ("config.json", true),
430 ("tokenizer_config.json", false),
431 ("tokenizer.model", false),
432 ];
433 for (filename, required) in &companions {
434 let companion_path = cache_dir.join(filename);
435 if !force && companion_path.exists() {
436 println!(" {} {} (cached)", "✓".green(), filename);
437 continue;
438 }
439
440 let url = format!("{base_url}/{filename}");
441 match download_file(&url, &companion_path) {
442 Ok(()) => println!(" {} {}", "✓".green(), filename),
443 Err(CliError::HttpNotFound(_)) if *required => {
444 return Err(CliError::ValidationFailed(format!(
445 "{filename} is required for inference but was not found (HTTP 404) at {url}"
446 )));
447 }
448 Err(CliError::HttpNotFound(_)) => {
449 println!(" {} {} (not found in repo)", "⚠".yellow(), filename);
450 }
451 Err(e) if *required => {
452 return Err(CliError::ValidationFailed(format!(
453 "{filename} is required for inference but download failed: {e}"
454 )));
455 }
456 Err(_) => println!(" {} {} (not available, optional)", "⚠".yellow(), filename),
457 }
458 }
459
460 let tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
462 let has_tokenizer = tokenizer_files.iter().any(|f| cache_dir.join(f).exists());
463 if !has_tokenizer {
464 return Err(CliError::ValidationFailed(format!(
465 "No tokenizer found for this model. Tried: {}.\n\
466 The model may require a custom tokenizer not hosted in the repository.",
467 tokenizer_files.join(", ")
468 )));
469 }
470
471 Ok(())
472}
473
474fn write_shard_manifest(
476 manifest_path: &Path,
477 org: &str,
478 repo: &str,
479 file_checksums: HashMap<String, FileChecksum>,
480) -> Result<()> {
481 if file_checksums.is_empty() {
482 return Ok(());
483 }
484 let manifest = ShardManifest {
485 version: 1,
486 repo: format!("{org}/{repo}"),
487 files: file_checksums,
488 };
489 let manifest_json = serde_json::to_string_pretty(&manifest)
490 .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
491 std::fs::write(manifest_path, manifest_json)?;
492 println!(" {} .apr-manifest.json (integrity checksums)", "✓".green());
493 Ok(())
494}
495
496include!("pull_list.rs");
497include!("pull_remove_resolve_model.rs");
498include!("pull_extract_shard.rs");
499include!("pull_04.rs");