muthr 0.1.7

Zero-trust orchestration for autonomous AI agents
use indicatif::{ProgressBar, ProgressStyle};
use std::path::PathBuf;
use tokio::fs;
use tokio::io::AsyncWriteExt;

pub async fn download(source: &str, file: Option<&str>) -> Result<(), color_eyre::Report> {
    let (repo, filename) = match (source, file) {
        (url, None) if url.starts_with("http") && url.contains("huggingface.co") => {
            parse_hf_url(url)?
        }
        (repo, Some(file)) => (repo.to_string(), file.to_string()),
        _ => {
            eprintln!("[ERR ] Usage: muthr download <hf-repo> <filename> | <hf-url>");
            eprintln!("  Example: muthr download unsloth/Qwen3.6-35B-A3B-GGUF Qwen3.6-35B-A3B-UD-Q4_K_M.gguf");
            return Ok(());
        }
    };

    if !filename.ends_with(".gguf") {
        eprintln!("[ERR ] Expected a .gguf file, got '{}'", filename);
        return Ok(());
    }

    let home = std::env::var("HOME")?;
    let model_dir =
        std::env::var("LLAMA_MODEL_DIR").unwrap_or_else(|_| format!("{}/opt/models", home));

    let model_subdir = PathBuf::from(&model_dir).join(&repo);
    let target_path = model_subdir.join(&filename);
    let url = format!("https://huggingface.co/{}/resolve/main/{}", repo, filename);

    fs::create_dir_all(&model_subdir).await?;

    if target_path.exists() {
        eprintln!(
            "[WARN] File exists at {:?}. Aborting to prevent overwrite.",
            target_path
        );
        return Ok(());
    }

    println!("[PROC] Fetching: {}", filename);
    println!("       From:    https://huggingface.co/{}", repo);

    let mut headers = reqwest::header::HeaderMap::new();
    if let Ok(token) = std::env::var("HF_TOKEN") {
        let auth_val = reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))?;
        headers.insert(reqwest::header::AUTHORIZATION, auth_val);
    }

    let client = reqwest::Client::builder()
        .default_headers(headers)
        .build()?;

    let mut response = client.get(&url).send().await?;

    if !response.status().is_success() {
        return Err(color_eyre::eyre::eyre!(
            "Download failed: {}",
            response.status()
        ));
    }

    let total_size = response.content_length().unwrap_or(0);
    let pb = ProgressBar::new(total_size);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})")?
            .progress_chars("#>-"),
    );

    let tmp_file = format!("{}.tmp", target_path.display());
    let mut file = fs::File::create(&tmp_file).await?;

    while let Some(chunk) = response.chunk().await? {
        file.write_all(&chunk).await?;
        pb.inc(chunk.len() as u64);
    }

    pb.finish_with_message("Downloaded");
    fs::rename(&tmp_file, &target_path).await?;

    println!("[ OK ] Download complete.");
    if let Ok(metadata) = fs::metadata(&target_path).await {
        println!("       Size: {}", human_size(metadata.len()));
    }

    Ok(())
}

fn parse_hf_url(url: &str) -> Result<(String, String), color_eyre::Report> {
    let tmp = url.trim_start_matches("https://");
    let tmp = tmp.trim_start_matches("http://");
    let tmp = tmp
        .strip_prefix("huggingface.co/")
        .ok_or_else(|| color_eyre::eyre::eyre!("Invalid HuggingFace URL format"))?;

    let repo_end = if tmp.contains("/blob/main/") {
        tmp.find("/blob/main/").unwrap()
    } else if tmp.contains("/raw/main/") {
        tmp.find("/raw/main/").unwrap()
    } else {
        return Err(color_eyre::eyre::eyre!("Invalid HuggingFace URL format"));
    };

    let repo = &tmp[..repo_end];
    let base = &tmp[repo_end + 11..]; // skip /blob/main/ or /raw/main/

    let filename = if !base.is_empty() {
        base.to_string()
    } else {
        url.rsplit('/').next().unwrap_or("").to_string()
    };

    Ok((repo.to_string(), filename))
}

fn human_size(bytes: u64) -> String {
    const UNITS: [&str; 5] = ["B", "KB", "MB", "GB", "TB"];
    let mut size = bytes as f64;
    let mut unit_idx = 0;

    while size >= 1024.0 && unit_idx < UNITS.len() - 1 {
        size /= 1024.0;
        unit_idx += 1;
    }

    format!("{:.1} {}", size, UNITS[unit_idx])
}