1use std::path::{Path, PathBuf};
8
9use anyhow::{Context, Result};
10
11use crate::progress;
12use crate::verbosity::Verbosity;
13
14#[allow(dead_code)]
20struct Asset {
21 name: &'static str,
23 url: &'static str,
25 filename: &'static str,
27 expected_bytes: u64,
30 sha256: &'static str,
33}
34
35pub struct HfModelSource {
42 pub repo_id: String,
44 pub filename: String,
46 pub revision: Option<String>,
48}
49
50impl HfModelSource {
51 pub fn parse(spec: &str) -> Result<Self> {
65 if spec.is_empty() {
66 anyhow::bail!("Model specification cannot be empty");
67 }
68
69 let (repo_part, revision) = if let Some(pos) = spec.find(':') {
71 let (repo, rev) = spec.split_at(pos);
72 (repo, Some(rev[1..].to_string()))
73 } else if let Some(pos) = spec.find('@') {
74 let (repo, rev) = spec.split_at(pos);
75 (repo, Some(rev[1..].to_string()))
76 } else {
77 (spec, None)
78 };
79
80 if !repo_part.contains('/') {
82 anyhow::bail!(
83 "Invalid repository format: '{}'. Expected format: 'organization/repository'",
84 repo_part
85 );
86 }
87
88 if let Some(ref rev) = revision {
90 if rev.is_empty() {
91 anyhow::bail!("Revision cannot be empty");
92 }
93 }
94
95 Ok(Self {
96 repo_id: repo_part.to_string(),
97 filename: "model.safetensors".to_string(),
98 revision,
99 })
100 }
101
102 pub fn with_filename(mut self, filename: String) -> Self {
104 self.filename = filename;
105 self
106 }
107
108 #[allow(dead_code)]
126 pub fn download(&self, token: Option<&str>) -> Result<PathBuf> {
127 use hf_hub::api::sync::ApiBuilder;
128
129 println!("📥 Downloading from HuggingFace Hub: {}", self.repo_id);
130 if let Some(ref rev) = self.revision {
131 println!(" Revision: {}", rev);
132 }
133 println!(" File: {}", self.filename);
134
135 let mut api_builder = ApiBuilder::new();
137
138 if let Some(token_str) = token {
139 api_builder = api_builder.with_token(Some(token_str.to_string()));
140 }
141
142 let api = api_builder
143 .build()
144 .context("Failed to initialize HuggingFace Hub API client")?;
145
146 use hf_hub::{Repo, RepoType};
148 let repo_obj = if let Some(ref rev) = self.revision {
149 Repo::with_revision(self.repo_id.clone(), RepoType::Model, rev.clone())
150 } else {
151 Repo::new(self.repo_id.clone(), RepoType::Model)
152 };
153
154 let repo = api.repo(repo_obj);
156
157 let file_path = repo.get(&self.filename).with_context(|| {
159 format!(
160 "Failed to download '{}' from repository '{}'{}",
161 self.filename,
162 self.repo_id,
163 self.revision
164 .as_ref()
165 .map(|r| format!(" (revision: {})", r))
166 .unwrap_or_default()
167 )
168 })?;
169
170 println!("✓ Downloaded to: {}", file_path.display());
171
172 Ok(file_path)
173 }
174}
175
176static ASSETS: &[Asset] = &[
181 Asset {
182 name: "FLAME 2023 Head Model",
183 url: "https://github.com/cool-japan/oxigaf/releases/download/v0.1.0/flame2023.tar.gz",
184 filename: "flame2023.tar.gz",
185 expected_bytes: 250_000_000,
186 sha256: "",
187 },
188 Asset {
189 name: "Multi-View Diffusion U-Net",
190 url: "https://github.com/cool-japan/oxigaf/releases/download/v0.1.0/diffusion_unet.safetensors",
191 filename: "diffusion_unet.safetensors",
192 expected_bytes: 1_700_000_000,
193 sha256: "",
194 },
195 Asset {
196 name: "VAE Decoder",
197 url: "https://github.com/cool-japan/oxigaf/releases/download/v0.1.0/vae_decoder.safetensors",
198 filename: "vae_decoder.safetensors",
199 expected_bytes: 200_000_000,
200 sha256: "",
201 },
202 Asset {
203 name: "CLIP Image Encoder",
204 url: "https://github.com/cool-japan/oxigaf/releases/download/v0.1.0/clip_image_encoder.safetensors",
205 filename: "clip_image_encoder.safetensors",
206 expected_bytes: 600_000_000,
207 sha256: "",
208 },
209];
210
211pub fn setup_cache(cache_dir: &Path, verbosity: Verbosity, json_mode: bool) -> Result<()> {
223 let cache_dir = ensure_cache_dir(cache_dir)?;
224
225 if !json_mode {
226 println!();
227 println!("📦 OxiGAF Model Setup");
228 println!(" Cache directory: {}", cache_dir.display());
229 println!();
230 }
231
232 let mut downloaded = 0usize;
233 let mut skipped = 0usize;
234 let mut downloaded_assets = Vec::new();
235
236 for asset in ASSETS {
237 let dest = cache_dir.join(asset.filename);
238
239 if is_cached(&dest, asset.expected_bytes) {
240 if !json_mode {
241 println!(" ✓ {} (cached)", asset.name);
242 }
243 skipped += 1;
244 continue;
245 }
246
247 if !json_mode {
248 println!(" ⬇ Downloading {} …", asset.name);
249 }
250 download_file(asset.url, &dest, asset.expected_bytes, verbosity)
251 .with_context(|| format!("Failed to download {}", asset.name))?;
252 downloaded += 1;
253 downloaded_assets.push(dest.clone());
254 }
255
256 if json_mode {
258 let mut output = crate::json_output::JsonOutput::success(
259 "setup",
260 serde_json::json!({
261 "cache_dir": cache_dir.display().to_string(),
262 "downloaded": downloaded,
263 "skipped": skipped,
264 "total_assets": ASSETS.len()
265 }),
266 );
267
268 for path in downloaded_assets {
270 if path.exists() {
271 output.add_artifact("model".to_string(), path);
272 }
273 }
274
275 output.print();
276 } else {
277 println!();
278 if downloaded > 0 {
279 println!(
280 "✅ Setup complete — downloaded {downloaded} asset(s), {skipped} already cached."
281 );
282 } else {
283 println!("✅ All assets already cached. Nothing to download.");
284 }
285 }
286
287 Ok(())
288}
289
290#[allow(dead_code)]
294pub fn expected_asset_paths(cache_dir: &Path) -> Vec<PathBuf> {
295 ASSETS.iter().map(|a| cache_dir.join(a.filename)).collect()
296}
297
298pub fn get_hf_token() -> Option<String> {
308 if let Ok(token) = std::env::var("HF_TOKEN") {
310 if !token.trim().is_empty() {
311 return Some(token.trim().to_string());
312 }
313 }
314
315 if let Some(home) = dirs::home_dir() {
317 let token_path = home.join(".huggingface").join("token");
318 if let Ok(token) = std::fs::read_to_string(token_path) {
319 let token = token.trim();
320 if !token.is_empty() {
321 return Some(token.to_string());
322 }
323 }
324 }
325
326 None
327}
328
329pub fn download_with_progress(
351 repo_id: &str,
352 filename: &str,
353 revision: Option<&str>,
354 token: Option<&str>,
355 verbosity: Verbosity,
356) -> Result<PathBuf> {
357 use hf_hub::api::sync::ApiBuilder;
358
359 if verbosity != Verbosity::Quiet {
360 println!("📥 Downloading from HuggingFace Hub: {}", repo_id);
361 if let Some(rev) = revision {
362 println!(" Revision: {}", rev);
363 }
364 println!(" File: {}", filename);
365 println!();
366 }
367
368 let mut api_builder = ApiBuilder::new();
370
371 if let Some(token_str) = token {
372 api_builder = api_builder.with_token(Some(token_str.to_string()));
373 }
374
375 if verbosity.show_progress() {
377 api_builder = api_builder.with_progress(true);
378 }
379
380 let api = api_builder
381 .build()
382 .context("Failed to initialize HuggingFace Hub API client")?;
383
384 use hf_hub::{Repo, RepoType};
386 let repo_obj = if let Some(rev) = revision {
387 Repo::with_revision(repo_id.to_string(), RepoType::Model, rev.to_string())
388 } else {
389 Repo::new(repo_id.to_string(), RepoType::Model)
390 };
391
392 let repo = api.repo(repo_obj);
394
395 let file_path = repo.get(filename).with_context(|| {
397 format!(
398 "Failed to download '{}' from repository '{}'{}",
399 filename,
400 repo_id,
401 revision
402 .map(|r| format!(" (revision: {})", r))
403 .unwrap_or_default()
404 )
405 })?;
406
407 if verbosity != Verbosity::Quiet {
408 println!();
409 println!("✓ Downloaded to: {}", file_path.display());
410 }
411
412 Ok(file_path)
413}
414
415fn ensure_cache_dir(cache_dir: &Path) -> Result<PathBuf> {
421 let expanded = crate::config::expand_tilde(cache_dir);
422 std::fs::create_dir_all(&expanded)
423 .with_context(|| format!("Failed to create cache directory: {}", expanded.display()))?;
424 Ok(expanded)
425}
426
427fn is_cached(path: &Path, expected_bytes: u64) -> bool {
430 match std::fs::metadata(path) {
431 Ok(meta) => {
432 let size = meta.len();
433 if expected_bytes == 0 {
434 size > 0
435 } else {
436 size >= expected_bytes * 9 / 10
438 }
439 }
440 Err(_) => false,
441 }
442}
443
444fn download_file(url: &str, dest: &Path, expected_bytes: u64, verbosity: Verbosity) -> Result<()> {
448 if let Some(parent) = dest.parent() {
450 std::fs::create_dir_all(parent)?;
451 }
452
453 let dest_str = dest.to_string_lossy().to_string();
454
455 let pb = if expected_bytes > 0 {
457 progress::download_progress(expected_bytes, verbosity)
458 } else {
459 progress::spinner("Downloading...", verbosity)
460 };
461
462 let curl_result = std::process::Command::new("curl")
464 .args([
465 "--fail",
466 "--location",
467 "--output",
468 &dest_str,
469 "--silent",
470 "--show-error",
471 url,
472 ])
473 .status();
474
475 let success = match curl_result {
476 Ok(status) if status.success() => true,
477 _ => {
478 tracing::debug!("curl not available or failed, trying wget");
480 let wget_result = std::process::Command::new("wget")
481 .args(["--quiet", "--output-document", &dest_str, url])
482 .status();
483
484 matches!(wget_result, Ok(status) if status.success())
485 }
486 };
487
488 if success {
489 if expected_bytes > 0 {
491 if let Ok(meta) = std::fs::metadata(dest) {
492 pb.set_position(meta.len());
493 }
494 }
495 pb.finish_and_clear();
496 println!(" ✓ Saved to {}", dest.display());
497 Ok(())
498 } else {
499 pb.abandon();
500 let _ = std::fs::remove_file(dest);
502 anyhow::bail!(
503 "Download failed. Please install `curl` or `wget`, or download manually:\n\
504 \n\
505 \x20 URL: {url}\n\
506 \x20 Save: {dest_str}\n"
507 )
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn asset_manifest_has_entries() {
517 assert!(!ASSETS.is_empty());
518 for asset in ASSETS {
519 assert!(!asset.name.is_empty());
520 assert!(!asset.url.is_empty());
521 assert!(!asset.filename.is_empty());
522 }
523 }
524
525 #[test]
526 fn expected_paths_match_manifest() {
527 let cache_path = std::env::temp_dir().join("oxigaf_cache_test");
528 let paths = expected_asset_paths(&cache_path);
529 assert_eq!(paths.len(), ASSETS.len());
530 assert!(paths[0].ends_with("flame2023.tar.gz"));
531 }
532}