rustvani 0.1.0

Voice AI framework for Rust — real-time speech pipelines with STT, LLM, TTS, and Dhara conversation flows
use std::env;
use std::fs;
use std::path::PathBuf;
use std::process::Command;

/// URL for the smart-turn weights file (auto-downloaded at build time).
const SMART_TURN_URL: &str =
    "https://smartturn-rustvani.s3.ap-south-1.amazonaws.com/smart_turn_weights+(1).bin.gz";

/// URL for the official Silero VAD ONNX model.
const SILERO_ONNX_URL: &str =
    "https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx";

/// URL for the custom-extracted Silero native weights (`silero_vad_16k.bin`).
/// This is a project-specific flat binary. Upload it to a public URL and set
/// this constant, or place the file manually in the cache directory.
const SILERO_NATIVE_URL: Option<&str> =
    Some("https://smartturn-rustvani.s3.ap-south-1.amazonaws.com/silero_vad_16k.bin");

fn main() {
    let cache = cache_dir();
    fs::create_dir_all(&cache).unwrap();

    // Make the cache path available to the library at compile time.
    println!("cargo:rustc-env=RUSTVANI_CACHE_DIR={}", cache.display());
    println!("cargo:rerun-if-changed=build.rs");
    println!("cargo:rerun-if-env-changed=RUSTVANI_CACHE_DIR");

    let manifest = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());

    // ── Smart-turn weights ───────────────────────────────────────────────
    let weights_cache = cache.join("smart_turn_weights.bin.gz");
    if !weights_cache.exists() {
        let weights_src = manifest.join("src/turn/smart_turn_weights.bin.gz");
        if weights_src.exists() {
            println!("cargo:warning=Copying local smart-turn weights to cache...");
            fs::copy(&weights_src, &weights_cache).unwrap();
        } else {
            download_or_warn(&weights_cache, SMART_TURN_URL, "smart-turn weights");
        }
    }

    // ── Silero ONNX model ────────────────────────────────────────────────
    let silero_onnx_cache = cache.join("silero.onnx");
    if !silero_onnx_cache.exists() {
        let silero_onnx_src = manifest.join("src/vad/data/silero.onnx");
        if silero_onnx_src.exists() {
            println!("cargo:warning=Copying local silero.onnx to cache...");
            fs::copy(&silero_onnx_src, &silero_onnx_cache).unwrap();
        } else {
            download_or_warn(&silero_onnx_cache, SILERO_ONNX_URL, "silero.onnx");
        }
    }

    // ── Silero native weights (custom flat binary) ───────────────────────
    let silero_bin_cache = cache.join("silero_vad_16k.bin");
    if !silero_bin_cache.exists() {
        let silero_bin_src = manifest.join("src/vad/data/silero_vad_16k.bin");
        if silero_bin_src.exists() {
            println!("cargo:warning=Copying local silero_vad_16k.bin to cache...");
            fs::copy(&silero_bin_src, &silero_bin_cache).unwrap();
        } else if let Some(url) = SILERO_NATIVE_URL {
            download_or_warn(&silero_bin_cache, url, "silero_vad_16k.bin");
        } else {
            println!("cargo:warning=silero_vad_16k.bin not found.");
            println!("cargo:warning=Place it manually at: {}", silero_bin_cache.display());
            println!("cargo:warning=Or set SILERO_NATIVE_URL in build.rs to auto-download.");
        }
    }
}

/// Return the cache directory for rustvani model files.
/// Respects `RUSTVANI_CACHE_DIR` env var, otherwise `~/.rustvani/cache`.
fn cache_dir() -> PathBuf {
    env::var("RUSTVANI_CACHE_DIR")
        .map(PathBuf::from)
        .unwrap_or_else(|_| {
            let home = env::var("HOME")
                .or_else(|_| env::var("USERPROFILE"))
                .expect(
                    "Cannot determine home directory. \
                     Set RUSTVANI_CACHE_DIR env var.",
                );
            PathBuf::from(home).join(".rustvani").join("cache")
        })
}

fn download_or_warn(path: &PathBuf, url: &str, desc: &str) {
    println!("cargo:warning=Downloading {}...", desc);
    if let Err(e) = download(url, path) {
        println!("cargo:warning=Failed to download {}: {}", desc, e);
        println!("cargo:warning=URL: {}", url);
        println!("cargo:warning=Expected path: {}", path.display());
    }
}

fn download(url: &str, dest: &PathBuf) -> Result<(), Box<dyn std::error::Error>> {
    fs::create_dir_all(dest.parent().unwrap())?;

    // 1. curl (most common)
    if let Ok(status) = Command::new("curl")
        .args(&["-fsSL", "--retry", "2", "-o", dest.to_str().unwrap(), url])
        .status()
    {
        if status.success() && dest.exists() && fs::metadata(dest)?.len() > 0 {
            return Ok(());
        }
    }

    // 2. wget
    if let Ok(status) = Command::new("wget")
        .args(&["-q", "-O", dest.to_str().unwrap(), url])
        .status()
    {
        if status.success() && dest.exists() && fs::metadata(dest)?.len() > 0 {
            return Ok(());
        }
    }

    // 3. PowerShell (Windows fallback)
    let ps_cmd = format!(
        "Invoke-WebRequest -Uri '{}' -OutFile '{}' -UseBasicParsing",
        url,
        dest.to_str().unwrap()
    );
    if let Ok(status) = Command::new("powershell")
        .args(&["-ExecutionPolicy", "Bypass", "-Command", &ps_cmd])
        .status()
    {
        if status.success() && dest.exists() && fs::metadata(dest)?.len() > 0 {
            return Ok(());
        }
    }

    // 4. pwsh (PowerShell Core)
    if let Ok(status) = Command::new("pwsh")
        .args(&["-Command", &ps_cmd])
        .status()
    {
        if status.success() && dest.exists() && fs::metadata(dest)?.len() > 0 {
            return Ok(());
        }
    }

    Err("No working download tool found (tried curl, wget, PowerShell, pwsh)".into())
}