ruvllm-cli 2.2.0

CLI for RuvLLM model management and inference on Apple Silicon
//! Model download command implementation
//!
//! Downloads models from HuggingFace Hub with progress indication,
//! supporting various quantization formats optimized for Apple Silicon.

use anyhow::{Context, Result};
use bytesize::ByteSize;
use colored::Colorize;
use console::style;
use hf_hub::api::tokio::Api;
use hf_hub::{Repo, RepoType};
use indicatif::{ProgressBar, ProgressStyle};
use std::path::{Path, PathBuf};

use crate::models::{get_model, resolve_model_id, QuantPreset};

/// Run the download command
pub async fn run(
    model: &str,
    quantization: &str,
    force: bool,
    revision: Option<&str>,
    cache_dir: &str,
) -> Result<()> {
    let model_id = resolve_model_id(model);
    let quant = QuantPreset::from_str(quantization)
        .ok_or_else(|| anyhow::anyhow!("Invalid quantization format: {}", quantization))?;

    println!();
    println!(
        "{} {} ({})",
        style("Downloading:").bold().cyan(),
        model_id,
        quant
    );
    println!();

    // Get model info if available
    if let Some(model_def) = get_model(model) {
        println!("  {} {}", "Name:".dimmed(), model_def.name);
        println!("  {} {}", "Architecture:".dimmed(), model_def.architecture);
        println!("  {} {}B", "Parameters:".dimmed(), model_def.params_b);
        println!(
            "  {} ~{:.1} GB",
            "Est. Memory:".dimmed(),
            quant.estimate_memory_gb(model_def.params_b)
        );
        println!();
    }

    // Initialize HuggingFace API
    let api = Api::new().context("Failed to initialize HuggingFace API")?;

    // Create repo reference
    let repo = if let Some(rev) = revision {
        api.repo(Repo::with_revision(
            model_id.clone(),
            RepoType::Model,
            rev.to_string(),
        ))
    } else {
        api.repo(Repo::new(model_id.clone(), RepoType::Model))
    };

    // Determine files to download
    let files_to_download = get_files_to_download(&model_id, quant);

    // Create cache directory
    let model_cache_dir = PathBuf::from(cache_dir).join("models").join(&model_id);
    tokio::fs::create_dir_all(&model_cache_dir)
        .await
        .context("Failed to create cache directory")?;

    // Download each file
    for file_name in &files_to_download {
        let target_path = model_cache_dir.join(file_name);

        // Check if file exists
        if target_path.exists() && !force {
            let size = tokio::fs::metadata(&target_path).await?.len();
            println!(
                "  {} {} ({})",
                style("Cached:").green(),
                file_name,
                ByteSize(size)
            );
            continue;
        }

        println!("  {} {}", style("Downloading:").yellow(), file_name);

        // Download with progress
        let downloaded_path = download_with_progress(&repo, file_name).await?;

        // Copy to cache directory
        tokio::fs::copy(&downloaded_path, &target_path)
            .await
            .context("Failed to copy file to cache")?;

        let size = tokio::fs::metadata(&target_path).await?.len();
        println!(
            "  {} {} ({})",
            style("Downloaded:").green(),
            file_name,
            ByteSize(size)
        );
    }

    println!();
    println!(
        "{} Model ready at: {}",
        style("Success!").green().bold(),
        model_cache_dir.display()
    );
    println!();

    // Print usage hint
    println!("{}", "Quick start:".bold());
    println!("  ruvllm chat {}", model);
    println!("  ruvllm serve {}", model);
    println!();

    Ok(())
}

/// Download a file with progress indication
async fn download_with_progress(
    repo: &hf_hub::api::tokio::ApiRepo,
    file_name: &str,
) -> Result<PathBuf> {
    // Create progress bar
    let pb = ProgressBar::new(100);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("    [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
            .unwrap()
            .progress_chars("#>-"),
    );

    // Download file
    let path = repo
        .get(file_name)
        .await
        .context(format!("Failed to download {}", file_name))?;

    pb.finish_and_clear();

    Ok(path)
}

/// Get list of files to download for a model and quantization
fn get_files_to_download(model_id: &str, quant: QuantPreset) -> Vec<String> {
    let mut files = vec![
        "tokenizer.json".to_string(),
        "tokenizer_config.json".to_string(),
        "config.json".to_string(),
    ];

    // Add model weights based on quantization
    if model_id.contains("GGUF") || quant != QuantPreset::None {
        // Look for GGUF files
        files.push(format!("*{}", quant.gguf_suffix()));
    } else {
        // SafeTensors format
        files.push("model.safetensors".to_string());
    }

    // Add special tokens and chat template if available
    files.push("special_tokens_map.json".to_string());
    files.push("generation_config.json".to_string());

    files
}

/// Check if a model is already downloaded
pub async fn is_model_downloaded(model: &str, cache_dir: &str) -> bool {
    let model_id = resolve_model_id(model);
    let model_cache_dir = PathBuf::from(cache_dir).join("models").join(&model_id);

    // Check for tokenizer and at least one model file
    let tokenizer_exists = model_cache_dir.join("tokenizer.json").exists();
    let has_weights = tokio::fs::read_dir(&model_cache_dir)
        .await
        .ok()
        .map(|mut dir| {
            use futures::StreamExt;
            // Simplified check - just see if directory exists and has files
            true
        })
        .unwrap_or(false);

    tokenizer_exists && has_weights
}

/// Get the path to a downloaded model
pub fn get_model_path(model: &str, cache_dir: &str) -> PathBuf {
    let model_id = resolve_model_id(model);
    PathBuf::from(cache_dir).join("models").join(&model_id)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_files_to_download() {
        let files = get_files_to_download("test/model", QuantPreset::Q4K);
        assert!(files.contains(&"tokenizer.json".to_string()));
        assert!(files.iter().any(|f| f.contains("Q4_K_M")));
    }
}