Skip to main content

oxigaf_cli/
assets.rs

1//! Model download and cache management.
2//!
3//! Provides the `setup` command implementation: downloads required model weights
4//! (FLAME, diffusion U-Net, VAE, CLIP) into a local cache directory and
5//! verifies file integrity by size.
6
7use std::path::{Path, PathBuf};
8
9use anyhow::{Context, Result};
10
11use crate::progress;
12use crate::verbosity::Verbosity;
13
14// ---------------------------------------------------------------------------
15// Asset manifest
16// ---------------------------------------------------------------------------
17
18/// Metadata for a single downloadable model asset.
19#[allow(dead_code)]
20struct Asset {
21    /// Human-readable name shown in progress output.
22    name: &'static str,
23    /// Download URL.
24    url: &'static str,
25    /// Filename within the cache directory.
26    filename: &'static str,
27    /// Expected file size in bytes (used for progress reporting and simple
28    /// integrity checks). Set to 0 to skip the size check.
29    expected_bytes: u64,
30    /// Optional SHA-256 hex digest. Left empty for now — a future release will
31    /// add checksums for all official model bundles.
32    sha256: &'static str,
33}
34
35/// HuggingFace Hub model source specification.
36///
37/// Supports parsing model identifiers like:
38/// - "cool-japan/oxigaf-flame-2023" (default revision)
39/// - "cool-japan/oxigaf-flame-2023:main" (branch/tag)
40/// - "cool-japan/oxigaf-flame-2023@v1.0" (specific revision)
41pub struct HfModelSource {
42    /// Repository identifier (e.g., "cool-japan/oxigaf-flame")
43    pub repo_id: String,
44    /// Model filename within the repository
45    pub filename: String,
46    /// Optional revision (branch, tag, or commit SHA)
47    pub revision: Option<String>,
48}
49
50impl HfModelSource {
51    /// Parse a HuggingFace model specification string.
52    ///
53    /// # Examples
54    ///
55    /// ```no_run
56    /// # use oxigaf_cli::assets::HfModelSource;
57    /// let source = HfModelSource::parse("cool-japan/oxigaf-flame").unwrap();
58    /// assert_eq!(source.repo_id, "cool-japan/oxigaf-flame");
59    /// assert!(source.revision.is_none());
60    ///
61    /// let source = HfModelSource::parse("cool-japan/oxigaf-flame:main").unwrap();
62    /// assert_eq!(source.revision, Some("main".to_string()));
63    /// ```
64    pub fn parse(spec: &str) -> Result<Self> {
65        if spec.is_empty() {
66            anyhow::bail!("Model specification cannot be empty");
67        }
68
69        // Split on ':' or '@' to extract revision
70        let (repo_part, revision) = if let Some(pos) = spec.find(':') {
71            let (repo, rev) = spec.split_at(pos);
72            (repo, Some(rev[1..].to_string()))
73        } else if let Some(pos) = spec.find('@') {
74            let (repo, rev) = spec.split_at(pos);
75            (repo, Some(rev[1..].to_string()))
76        } else {
77            (spec, None)
78        };
79
80        // Validate repository format (should be "org/repo")
81        if !repo_part.contains('/') {
82            anyhow::bail!(
83                "Invalid repository format: '{}'. Expected format: 'organization/repository'",
84                repo_part
85            );
86        }
87
88        // Validate revision is not empty if specified
89        if let Some(ref rev) = revision {
90            if rev.is_empty() {
91                anyhow::bail!("Revision cannot be empty");
92            }
93        }
94
95        Ok(Self {
96            repo_id: repo_part.to_string(),
97            filename: "model.safetensors".to_string(),
98            revision,
99        })
100    }
101
102    /// Set a custom filename instead of the default "model.safetensors".
103    pub fn with_filename(mut self, filename: String) -> Self {
104        self.filename = filename;
105        self
106    }
107
108    /// Download the model from HuggingFace Hub.
109    ///
110    /// # Arguments
111    ///
112    /// * `token` - Optional HuggingFace authentication token for private models
113    ///
114    /// # Returns
115    ///
116    /// The path to the downloaded model file in the HuggingFace cache directory.
117    ///
118    /// # Errors
119    ///
120    /// Returns an error if:
121    /// - The API client cannot be initialized
122    /// - The repository or file is not found
123    /// - Network errors occur during download
124    /// - Authentication fails for private models
125    #[allow(dead_code)]
126    pub fn download(&self, token: Option<&str>) -> Result<PathBuf> {
127        use hf_hub::api::sync::ApiBuilder;
128
129        println!("📥 Downloading from HuggingFace Hub: {}", self.repo_id);
130        if let Some(ref rev) = self.revision {
131            println!("   Revision: {}", rev);
132        }
133        println!("   File: {}", self.filename);
134
135        // Build API client with optional token
136        let mut api_builder = ApiBuilder::new();
137
138        if let Some(token_str) = token {
139            api_builder = api_builder.with_token(Some(token_str.to_string()));
140        }
141
142        let api = api_builder
143            .build()
144            .context("Failed to initialize HuggingFace Hub API client")?;
145
146        // Create Repo with revision
147        use hf_hub::{Repo, RepoType};
148        let repo_obj = if let Some(ref rev) = self.revision {
149            Repo::with_revision(self.repo_id.clone(), RepoType::Model, rev.clone())
150        } else {
151            Repo::new(self.repo_id.clone(), RepoType::Model)
152        };
153
154        // Get the repository handle from the API
155        let repo = api.repo(repo_obj);
156
157        // Download the file (hf-hub handles caching and resumable downloads)
158        let file_path = repo.get(&self.filename).with_context(|| {
159            format!(
160                "Failed to download '{}' from repository '{}'{}",
161                self.filename,
162                self.repo_id,
163                self.revision
164                    .as_ref()
165                    .map(|r| format!(" (revision: {})", r))
166                    .unwrap_or_default()
167            )
168        })?;
169
170        println!("✓ Downloaded to: {}", file_path.display());
171
172        Ok(file_path)
173    }
174}
175
176/// Placeholder asset manifest.
177///
178/// The URLs below point at the project's GitHub releases page. Replace them
179/// with real artifact URLs once the model weights are published.
180static ASSETS: &[Asset] = &[
181    Asset {
182        name: "FLAME 2023 Head Model",
183        url: "https://github.com/cool-japan/oxigaf/releases/download/v0.1.0/flame2023.tar.gz",
184        filename: "flame2023.tar.gz",
185        expected_bytes: 250_000_000,
186        sha256: "",
187    },
188    Asset {
189        name: "Multi-View Diffusion U-Net",
190        url: "https://github.com/cool-japan/oxigaf/releases/download/v0.1.0/diffusion_unet.safetensors",
191        filename: "diffusion_unet.safetensors",
192        expected_bytes: 1_700_000_000,
193        sha256: "",
194    },
195    Asset {
196        name: "VAE Decoder",
197        url: "https://github.com/cool-japan/oxigaf/releases/download/v0.1.0/vae_decoder.safetensors",
198        filename: "vae_decoder.safetensors",
199        expected_bytes: 200_000_000,
200        sha256: "",
201    },
202    Asset {
203        name: "CLIP Image Encoder",
204        url: "https://github.com/cool-japan/oxigaf/releases/download/v0.1.0/clip_image_encoder.safetensors",
205        filename: "clip_image_encoder.safetensors",
206        expected_bytes: 600_000_000,
207        sha256: "",
208    },
209];
210
211// ---------------------------------------------------------------------------
212// Public API
213// ---------------------------------------------------------------------------
214
215/// Ensure all required model assets are present in `cache_dir`.
216///
217/// For each asset:
218/// 1. If the file already exists (and has a reasonable size), skip it.
219/// 2. Otherwise, attempt to download it via `curl` (or `wget` as a fallback).
220///
221/// This function prints user-facing progress directly to stdout.
222pub fn setup_cache(cache_dir: &Path, verbosity: Verbosity, json_mode: bool) -> Result<()> {
223    let cache_dir = ensure_cache_dir(cache_dir)?;
224
225    if !json_mode {
226        println!();
227        println!("📦  OxiGAF Model Setup");
228        println!("    Cache directory: {}", cache_dir.display());
229        println!();
230    }
231
232    let mut downloaded = 0usize;
233    let mut skipped = 0usize;
234    let mut downloaded_assets = Vec::new();
235
236    for asset in ASSETS {
237        let dest = cache_dir.join(asset.filename);
238
239        if is_cached(&dest, asset.expected_bytes) {
240            if !json_mode {
241                println!("  ✓  {} (cached)", asset.name);
242            }
243            skipped += 1;
244            continue;
245        }
246
247        if !json_mode {
248            println!("  ⬇  Downloading {} …", asset.name);
249        }
250        download_file(asset.url, &dest, asset.expected_bytes, verbosity)
251            .with_context(|| format!("Failed to download {}", asset.name))?;
252        downloaded += 1;
253        downloaded_assets.push(dest.clone());
254    }
255
256    // Output based on mode
257    if json_mode {
258        let mut output = crate::json_output::JsonOutput::success(
259            "setup",
260            serde_json::json!({
261                "cache_dir": cache_dir.display().to_string(),
262                "downloaded": downloaded,
263                "skipped": skipped,
264                "total_assets": ASSETS.len()
265            }),
266        );
267
268        // Add downloaded files as artifacts
269        for path in downloaded_assets {
270            if path.exists() {
271                output.add_artifact("model".to_string(), path);
272            }
273        }
274
275        output.print();
276    } else {
277        println!();
278        if downloaded > 0 {
279            println!(
280                "✅  Setup complete — downloaded {downloaded} asset(s), {skipped} already cached."
281            );
282        } else {
283            println!("✅  All assets already cached. Nothing to download.");
284        }
285    }
286
287    Ok(())
288}
289
290/// Return the list of expected asset file paths within the cache directory.
291///
292/// Useful for tooling that needs to verify the cache without downloading.
293#[allow(dead_code)]
294pub fn expected_asset_paths(cache_dir: &Path) -> Vec<PathBuf> {
295    ASSETS.iter().map(|a| cache_dir.join(a.filename)).collect()
296}
297
298/// Get HuggingFace authentication token from environment or config file.
299///
300/// Checks sources in the following order:
301/// 1. HF_TOKEN environment variable
302/// 2. ~/.huggingface/token file
303///
304/// # Returns
305///
306/// The authentication token if found, `None` otherwise.
307pub fn get_hf_token() -> Option<String> {
308    // Check environment variable first
309    if let Ok(token) = std::env::var("HF_TOKEN") {
310        if !token.trim().is_empty() {
311            return Some(token.trim().to_string());
312        }
313    }
314
315    // Fall back to reading from ~/.huggingface/token
316    if let Some(home) = dirs::home_dir() {
317        let token_path = home.join(".huggingface").join("token");
318        if let Ok(token) = std::fs::read_to_string(token_path) {
319            let token = token.trim();
320            if !token.is_empty() {
321                return Some(token.to_string());
322            }
323        }
324    }
325
326    None
327}
328
329/// Download a model from HuggingFace Hub with progress tracking.
330///
331/// # Arguments
332///
333/// * `repo_id` - HuggingFace repository identifier (e.g., "cool-japan/oxigaf-flame")
334/// * `filename` - Filename within the repository
335/// * `revision` - Optional revision (branch, tag, or commit SHA)
336/// * `token` - Optional authentication token
337/// * `verbosity` - Controls progress display
338///
339/// # Returns
340///
341/// The path to the downloaded model file in the HuggingFace cache directory.
342///
343/// # Errors
344///
345/// Returns an error if:
346/// - The API client cannot be initialized
347/// - The repository or file is not found
348/// - Network errors occur during download
349/// - Authentication fails for private models
350pub fn download_with_progress(
351    repo_id: &str,
352    filename: &str,
353    revision: Option<&str>,
354    token: Option<&str>,
355    verbosity: Verbosity,
356) -> Result<PathBuf> {
357    use hf_hub::api::sync::ApiBuilder;
358
359    if verbosity != Verbosity::Quiet {
360        println!("📥 Downloading from HuggingFace Hub: {}", repo_id);
361        if let Some(rev) = revision {
362            println!("   Revision: {}", rev);
363        }
364        println!("   File: {}", filename);
365        println!();
366    }
367
368    // Build API client with optional token and progress tracking
369    let mut api_builder = ApiBuilder::new();
370
371    if let Some(token_str) = token {
372        api_builder = api_builder.with_token(Some(token_str.to_string()));
373    }
374
375    // Enable progress based on verbosity
376    if verbosity.show_progress() {
377        api_builder = api_builder.with_progress(true);
378    }
379
380    let api = api_builder
381        .build()
382        .context("Failed to initialize HuggingFace Hub API client")?;
383
384    // Create Repo with revision
385    use hf_hub::{Repo, RepoType};
386    let repo_obj = if let Some(rev) = revision {
387        Repo::with_revision(repo_id.to_string(), RepoType::Model, rev.to_string())
388    } else {
389        Repo::new(repo_id.to_string(), RepoType::Model)
390    };
391
392    // Get the repository handle from the API
393    let repo = api.repo(repo_obj);
394
395    // Download the file (hf-hub handles caching and resumable downloads)
396    let file_path = repo.get(filename).with_context(|| {
397        format!(
398            "Failed to download '{}' from repository '{}'{}",
399            filename,
400            repo_id,
401            revision
402                .map(|r| format!(" (revision: {})", r))
403                .unwrap_or_default()
404        )
405    })?;
406
407    if verbosity != Verbosity::Quiet {
408        println!();
409        println!("✓ Downloaded to: {}", file_path.display());
410    }
411
412    Ok(file_path)
413}
414
415// ---------------------------------------------------------------------------
416// Internal helpers
417// ---------------------------------------------------------------------------
418
419/// Create the cache directory (expanding `~` if needed).
420fn ensure_cache_dir(cache_dir: &Path) -> Result<PathBuf> {
421    let expanded = crate::config::expand_tilde(cache_dir);
422    std::fs::create_dir_all(&expanded)
423        .with_context(|| format!("Failed to create cache directory: {}", expanded.display()))?;
424    Ok(expanded)
425}
426
427/// Check whether a file exists and looks complete (non-zero, at least 90% of
428/// the expected size).
429fn is_cached(path: &Path, expected_bytes: u64) -> bool {
430    match std::fs::metadata(path) {
431        Ok(meta) => {
432            let size = meta.len();
433            if expected_bytes == 0 {
434                size > 0
435            } else {
436                // Accept if within 90% — compressed archives may vary slightly.
437                size >= expected_bytes * 9 / 10
438            }
439        }
440        Err(_) => false,
441    }
442}
443
444/// Download a file from `url` to `dest` using `curl` or `wget`.
445///
446/// A progress bar is shown via `indicatif` while the download runs.
447fn download_file(url: &str, dest: &Path, expected_bytes: u64, verbosity: Verbosity) -> Result<()> {
448    // Ensure parent directory exists.
449    if let Some(parent) = dest.parent() {
450        std::fs::create_dir_all(parent)?;
451    }
452
453    let dest_str = dest.to_string_lossy().to_string();
454
455    // Progress bar (indeterminate or sized).
456    let pb = if expected_bytes > 0 {
457        progress::download_progress(expected_bytes, verbosity)
458    } else {
459        progress::spinner("Downloading...", verbosity)
460    };
461
462    // Try curl first (most common on Linux / macOS).
463    let curl_result = std::process::Command::new("curl")
464        .args([
465            "--fail",
466            "--location",
467            "--output",
468            &dest_str,
469            "--silent",
470            "--show-error",
471            url,
472        ])
473        .status();
474
475    let success = match curl_result {
476        Ok(status) if status.success() => true,
477        _ => {
478            // Fall back to wget.
479            tracing::debug!("curl not available or failed, trying wget");
480            let wget_result = std::process::Command::new("wget")
481                .args(["--quiet", "--output-document", &dest_str, url])
482                .status();
483
484            matches!(wget_result, Ok(status) if status.success())
485        }
486    };
487
488    if success {
489        // Update progress bar to completion.
490        if expected_bytes > 0 {
491            if let Ok(meta) = std::fs::metadata(dest) {
492                pb.set_position(meta.len());
493            }
494        }
495        pb.finish_and_clear();
496        println!("     ✓  Saved to {}", dest.display());
497        Ok(())
498    } else {
499        pb.abandon();
500        // Clean up partial download.
501        let _ = std::fs::remove_file(dest);
502        anyhow::bail!(
503            "Download failed. Please install `curl` or `wget`, or download manually:\n\
504             \n\
505             \x20  URL:  {url}\n\
506             \x20  Save: {dest_str}\n"
507        )
508    }
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    fn asset_manifest_has_entries() {
517        assert!(!ASSETS.is_empty());
518        for asset in ASSETS {
519            assert!(!asset.name.is_empty());
520            assert!(!asset.url.is_empty());
521            assert!(!asset.filename.is_empty());
522        }
523    }
524
525    #[test]
526    fn expected_paths_match_manifest() {
527        let cache_path = std::env::temp_dir().join("oxigaf_cache_test");
528        let paths = expected_asset_paths(&cache_path);
529        assert_eq!(paths.len(), ASSETS.len());
530        assert!(paths[0].ends_with("flame2023.tar.gz"));
531    }
532}