1use crate::error::{CliError, Result};
12use aprender::format::{
13 apr_export, apr_import, ExportFormat, ExportOptions, ImportOptions, ValidationConfig,
14};
15use colored::Colorize;
16use pacha::fetcher::{FetchConfig, ModelFetcher};
17use pacha::format::ModelFormat;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet};
20use std::io::{self, Read, Write};
21use std::path::Path;
22
23#[derive(Debug)]
28enum ResolvedModel {
29 SingleFile(String),
31 Sharded {
33 org: String,
34 repo: String,
35 shard_files: Vec<String>,
36 },
37}
38
39#[derive(Debug, Serialize, Deserialize)]
44pub struct ShardManifest {
45 pub version: u32,
46 pub repo: String,
47 pub files: HashMap<String, FileChecksum>,
48}
49
50#[derive(Debug, Serialize, Deserialize)]
52pub struct FileChecksum {
53 pub size: u64,
54 pub blake3: String,
55}
56
57pub fn run(model_ref: &str, force: bool) -> Result<()> {
59 println!("{}", "=== APR Pull ===".cyan().bold());
60 println!();
61
62 let resolved = resolve_hf_model(model_ref)?;
64
65 match resolved {
66 ResolvedModel::SingleFile(ref uri) => run_single_file(uri, force),
67 ResolvedModel::Sharded {
68 ref org,
69 ref repo,
70 ref shard_files,
71 } => run_sharded(org, repo, shard_files, force),
72 }
73}
74
75fn run_single_file(model_ref: &str, force: bool) -> Result<()> {
77 println!("Model: {}", model_ref.cyan());
78
79 let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
80 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
81 })?;
82
83 if !force && fetcher.is_cached(model_ref) {
84 return handle_cached_model(&mut fetcher, model_ref);
85 }
86
87 let result = download_single_model(&mut fetcher, model_ref)?;
88 ensure_safetensors_companions(&result)?;
89 print_pull_usage(&result.path, true);
90 Ok(())
91}
92
93fn handle_cached_model(fetcher: &mut ModelFetcher, model_ref: &str) -> Result<()> {
95 println!("{} Model already cached", "✓".green());
96 let result = fetcher
97 .pull_quiet(model_ref)
98 .map_err(|e| CliError::ValidationFailed(format!("Failed to get cached model: {e}")))?;
99
100 println!(" Path: {}", result.path.display());
101 println!(" Size: {}", result.size_human());
102 println!(" Format: {}", result.format.name());
103
104 ensure_safetensors_companions(&result)?;
105 print_pull_usage(&result.path, false);
106 Ok(())
107}
108
109fn download_single_model(
111 fetcher: &mut ModelFetcher,
112 model_ref: &str,
113) -> Result<pacha::fetcher::FetchResult> {
114 println!();
115 println!("{}", "Downloading...".yellow());
116
117 let result = fetcher
118 .pull(model_ref, |progress| {
119 let pct = progress.percent();
120 print!(
121 "\r [{:50}] {:5.1}% ({}/{})",
122 "=".repeat((pct / 2.0) as usize),
123 pct,
124 format_bytes(progress.downloaded_bytes),
125 format_bytes(progress.total_bytes)
126 );
127 io::stdout().flush().ok();
128 })
129 .map_err(|e| CliError::NetworkError(format!("Download failed: {e}")))?;
130
131 println!();
132 println!();
133
134 if result.cache_hit {
135 println!("{} Model retrieved from cache", "✓".green());
136 } else {
137 println!("{} Downloaded successfully", "✓".green());
138 }
139
140 println!(" Path: {}", result.path.display().to_string().green());
141 println!(" Size: {}", result.size_human().yellow());
142 println!(" Format: {}", result.format.name());
143 println!(" Hash: {}", &result.hash[..16]);
144 Ok(result)
145}
146
147fn ensure_safetensors_companions(result: &pacha::fetcher::FetchResult) -> Result<()> {
149 if matches!(result.format, ModelFormat::SafeTensors(_)) {
150 fetch_safetensors_companions(&result.path, &result.resolved_uri)?;
151 convert_safetensors_formats(&result.path)?;
152 }
153 Ok(())
154}
155
156fn print_pull_usage(path: &Path, show_serve: bool) {
158 println!();
159 println!("{}", "Usage:".cyan().bold());
160 println!(" apr run {}", path.display());
161 if show_serve {
162 println!(" apr serve {}", path.display());
163 }
164}
165
166fn run_sharded(org: &str, repo: &str, shard_files: &[String], force: bool) -> Result<()> {
168 println!(
169 "Model: {}/{} ({} shards)",
170 org.cyan(),
171 repo.cyan(),
172 shard_files.len().to_string().yellow()
173 );
174
175 let cache_dir = resolve_shard_cache_dir(org, repo)?;
176 std::fs::create_dir_all(&cache_dir)?;
177
178 let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
179 let index_path = cache_dir.join("model.safetensors.index.json");
180
181 download_index_if_needed(&base_url, &index_path, force)?;
182
183 let manifest_path = cache_dir.join(".apr-manifest.json");
184 let existing_manifest = load_existing_manifest(&manifest_path, force);
185
186 let file_checksums = download_all_shards(
187 &cache_dir,
188 &base_url,
189 shard_files,
190 force,
191 existing_manifest.as_ref(),
192 )?;
193
194 download_companion_files(&cache_dir, &base_url, force)?;
195 write_shard_manifest(&manifest_path, org, repo, file_checksums)?;
196
197 println!();
198 println!("{} Downloaded successfully", "✓".green());
199 println!(" Path: {}", index_path.display().to_string().green());
200 println!(" Shards: {}", shard_files.len().to_string().yellow());
201
202 convert_safetensors_formats(&index_path)?;
203
204 println!();
205 println!("{}", "Usage:".cyan().bold());
206 println!(" apr run {}", index_path.display());
207 println!(" apr serve {}", index_path.display());
208 Ok(())
209}
210
211fn resolve_shard_cache_dir(org: &str, repo: &str) -> Result<std::path::PathBuf> {
213 Ok(dirs::home_dir()
214 .ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
215 .join(".apr")
216 .join("cache")
217 .join("hf")
218 .join(org)
219 .join(repo))
220}
221
222fn download_index_if_needed(base_url: &str, index_path: &Path, force: bool) -> Result<()> {
224 if force || !index_path.exists() {
225 println!();
226 println!(" {} model.safetensors.index.json", "Downloading".yellow());
227 download_file(
228 &format!("{base_url}/model.safetensors.index.json"),
229 index_path,
230 )?;
231 } else {
232 println!(" {} model.safetensors.index.json (cached)", "✓".green());
233 }
234 Ok(())
235}
236
237fn load_existing_manifest(manifest_path: &Path, force: bool) -> Option<ShardManifest> {
239 if force || !manifest_path.exists() {
240 return None;
241 }
242 std::fs::read_to_string(manifest_path)
243 .ok()
244 .and_then(|s| serde_json::from_str(&s).ok())
245}
246
247fn download_all_shards(
249 cache_dir: &Path,
250 base_url: &str,
251 shard_files: &[String],
252 force: bool,
253 existing_manifest: Option<&ShardManifest>,
254) -> Result<HashMap<String, FileChecksum>> {
255 let mut file_checksums: HashMap<String, FileChecksum> = HashMap::new();
256 let total = shard_files.len();
257 for (i, shard_file) in shard_files.iter().enumerate() {
258 download_or_verify_shard(
259 cache_dir,
260 base_url,
261 shard_file,
262 i,
263 total,
264 force,
265 existing_manifest,
266 &mut file_checksums,
267 )?;
268 }
269 Ok(file_checksums)
270}
271
272fn download_or_verify_shard(
274 cache_dir: &Path,
275 base_url: &str,
276 shard_file: &str,
277 index: usize,
278 total: usize,
279 force: bool,
280 existing_manifest: Option<&ShardManifest>,
281 checksums: &mut HashMap<String, FileChecksum>,
282) -> Result<()> {
283 let shard_path = cache_dir.join(shard_file);
284
285 if !force && shard_path.exists() {
286 if let Some(manifest) = existing_manifest {
287 if let Some(expected) = manifest.files.get(shard_file) {
288 let actual_size = std::fs::metadata(&shard_path).map(|m| m.len()).unwrap_or(0);
289 if actual_size == expected.size {
290 checksums.insert(
291 shard_file.to_string(),
292 FileChecksum {
293 size: expected.size,
294 blake3: expected.blake3.clone(),
295 },
296 );
297 println!(
298 " {} [{}/{}] {} (cached, verified)",
299 "✓".green(),
300 index + 1,
301 total,
302 shard_file
303 );
304 return Ok(());
305 }
306 println!(
307 " {} [{}/{}] {} (size mismatch: {} vs {} bytes, re-downloading)",
308 "⚠".yellow(),
309 index + 1,
310 total,
311 shard_file,
312 actual_size,
313 expected.size
314 );
315 }
317 } else {
318 println!(
319 " {} [{}/{}] {} (cached)",
320 "✓".green(),
321 index + 1,
322 total,
323 shard_file
324 );
325 return Ok(());
326 }
327 }
328
329 let shard_url = format!("{base_url}/{shard_file}");
330 print!(
331 " {} [{}/{}] {}...",
332 "↓".yellow(),
333 index + 1,
334 total,
335 shard_file
336 );
337 io::stdout().flush().ok();
338
339 let checksum = download_file_with_progress(&shard_url, &shard_path)?;
340 checksums.insert(shard_file.to_string(), checksum);
341 println!(" {}", "done".green());
342 Ok(())
343}
344
345fn download_companion_files(cache_dir: &Path, base_url: &str, force: bool) -> Result<()> {
347 let companions = [
348 ("tokenizer.json", true),
349 ("config.json", true),
350 ("tokenizer_config.json", false),
351 ];
352 for (filename, required) in &companions {
353 let companion_path = cache_dir.join(filename);
354 if !force && companion_path.exists() {
355 println!(" {} {} (cached)", "✓".green(), filename);
356 continue;
357 }
358
359 let url = format!("{base_url}/{filename}");
360 match download_file(&url, &companion_path) {
361 Ok(()) => println!(" {} {}", "✓".green(), filename),
362 Err(e) if *required => {
363 return Err(CliError::ValidationFailed(format!(
364 "{filename} is required for inference but download failed: {e}"
365 )));
366 }
367 Err(_) => println!(" {} {} (not available, optional)", "⚠".yellow(), filename),
368 }
369 }
370 Ok(())
371}
372
373fn write_shard_manifest(
375 manifest_path: &Path,
376 org: &str,
377 repo: &str,
378 file_checksums: HashMap<String, FileChecksum>,
379) -> Result<()> {
380 if file_checksums.is_empty() {
381 return Ok(());
382 }
383 let manifest = ShardManifest {
384 version: 1,
385 repo: format!("{org}/{repo}"),
386 files: file_checksums,
387 };
388 let manifest_json = serde_json::to_string_pretty(&manifest)
389 .map_err(|e| CliError::ValidationFailed(format!("Failed to serialize manifest: {e}")))?;
390 std::fs::write(manifest_path, manifest_json)?;
391 println!(" {} .apr-manifest.json (integrity checksums)", "✓".green());
392 Ok(())
393}
394
395#[allow(clippy::disallowed_methods)]
398pub fn list(json: bool) -> Result<()> {
399 let fetcher = ModelFetcher::new().map_err(|e| {
400 CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
401 })?;
402
403 let models = fetcher.list();
404
405 if json {
407 let models_json: Vec<serde_json::Value> = models
408 .iter()
409 .map(|m| {
410 serde_json::json!({
411 "name": m.name,
412 "size_bytes": m.size_bytes,
413 "format": m.format.name(),
414 "path": m.path.display().to_string(),
415 })
416 })
417 .collect();
418 let stats = fetcher.stats();
419 let output = serde_json::json!({
420 "models": models_json,
421 "total": models.len(),
422 "total_size_bytes": stats.total_size_bytes,
423 });
424 println!(
425 "{}",
426 serde_json::to_string_pretty(&output).unwrap_or_default()
427 );
428 return Ok(());
429 }
430
431 println!("{}", "=== Cached Models ===".cyan().bold());
432 println!();
433
434 if models.is_empty() {
435 println!("{}", "No cached models found.".dimmed());
436 println!();
437 println!("Pull a model with:");
438 println!(" apr pull hf://Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf");
439 println!();
440 println!("Or run directly (auto-downloads):");
441 println!(" apr run hf://Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf");
442 return Ok(());
443 }
444
445 println!(
447 "{:<40} {:<12} {:<12} {}",
448 "NAME".dimmed(),
449 "SIZE".dimmed(),
450 "FORMAT".dimmed(),
451 "PATH".dimmed()
452 );
453 println!("{}", "-".repeat(104).dimmed());
454
455 for model in &models {
456 let size = format_bytes(model.size_bytes);
457 let format = model.format.name();
458 let name = if model.name.len() > 38 {
459 format!("{}...", &model.name[..35])
460 } else {
461 model.name.clone()
462 };
463
464 println!(
465 "{:<40} {:<12} {:<12} {}",
466 name.cyan(),
467 size.yellow(),
468 format,
469 model.path.display().to_string().dimmed()
470 );
471 }
472
473 println!();
474
475 let stats = fetcher.stats();
477 println!(
478 "Total: {} models, {} used",
479 models.len(),
480 format_bytes(stats.total_size_bytes)
481 );
482
483 Ok(())
484}
485
486include!("pull_remove_resolve_model.rs");
487include!("pull_extract_shard.rs");
488include!("pull_04.rs");