zer-compute 1.0.0

Hardware-accelerated backend for zer, pairwise comparison and EM via CUDA, Vulkan, or AVX2
Documentation
//! Build script for zer-compute.
//!
//! When the `cuda` feature is enabled: compiles CUDA kernels (.cu → .ptx) via nvcc.
//! When the `vulkan` feature is enabled: compiles Slang shaders (.slang → .spv) via slangc.
//! Output is embedded into OUT_DIR so the Rust code can include it via `include_bytes!`.
//!
//! CUDA notes:
//!   - Requires CUDA Toolkit 13.1 or later (enforced below).
//!   - Targets SM 8.6 (Ampere) as the minimum compute capability.
//!   - Release: `-O3 --use_fast_math --restrict` for maximum throughput.
//!   - Debug (`debug-shaders` feature): `-g -G -O0` for cuda-gdb.
//!
//! Vulkan / Slang notes:
//!   - Requires `slangc` on PATH (https://github.com/shader-slang/slang/releases).
//!   - Compiles to SPIR-V 1.5 targeting Vulkan 1.3 compute.
//!   - Release: `-O3 -matrix-layout-column-major`.

use std::path::PathBuf;
use std::process::Command;

fn main() {
    let cuda_enabled          = std::env::var("CARGO_FEATURE_CUDA").is_ok();
    let vulkan_enabled        = std::env::var("CARGO_FEATURE_VULKAN").is_ok();
    let debug_shaders_enabled = std::env::var("CARGO_FEATURE_DEBUG_SHADERS").is_ok();

    if cuda_enabled {
        compile_cuda_kernels(debug_shaders_enabled);
    }

    if vulkan_enabled {
        compile_slang_shaders();
    }
}

// ── CUDA version gate ─────────────────────────────────────────────────────────

fn check_cuda_version() {
    let output = match Command::new("nvcc").arg("--version").output() {
        Ok(o)  => o,
        Err(e) => panic!(
            "nvcc not found ({e}). Install the CUDA toolkit to use the `cuda` feature."
        ),
    };

    let stdout = String::from_utf8_lossy(&output.stdout);

    let (major, minor) = stdout
        .lines()
        .find_map(|line| {
            let idx  = line.find("release ")?;
            let rest = &line[idx + 8..];
            let end  = rest.find(',')?;
            let ver  = &rest[..end];
            let mut it = ver.splitn(2, '.');
            let maj: u32 = it.next()?.parse().ok()?;
            let min: u32 = it.next()?.parse().ok()?;
            Some((maj, min))
        })
        .unwrap_or_else(|| panic!("Could not parse CUDA version from nvcc --version:\n{stdout}"));

    const REQ_MAJOR: u32 = 13;
    const REQ_MINOR: u32 = 1;

    if (major, minor) < (REQ_MAJOR, REQ_MINOR) {
        panic!(
            "CUDA Toolkit {major}.{minor} is below the required {REQ_MAJOR}.{REQ_MINOR}.\n\
             Update to CUDA Toolkit 13.1 or later: https://developer.nvidia.com/cuda-downloads"
        );
    }
}

// ── CUDA ─────────────────────────────────────────────────────────────────────

fn compile_cuda_kernels(debug: bool) {
    check_cuda_version();

    let out_dir    = PathBuf::from(std::env::var("OUT_DIR").unwrap());
    let kernel_dir = PathBuf::from("src/backend/cuda/kernels");
    let kernels    = ["em_reduce", "hello_backend"];
    let n          = kernels.len();

    let opt_flags: &[&str] = if debug {
        &["-g", "-G", "-O0"]
    } else {
        &["-O3", "--use_fast_math", "--restrict"]
    };

    if debug {
        println!("   [debug-shaders] CUDA kernels compiled with -g -G -O0");
    }

    for (i, name) in kernels.iter().enumerate() {
        let cu_path  = kernel_dir.join(format!("{name}.cu"));
        let ptx_path = out_dir.join(format!("{name}.ptx"));

        println!("cargo:rerun-if-changed={}", cu_path.display());

        println!("   Compiling CUDA [{}/{n}] {name}.cu", i + 1);

        let mut cmd = Command::new("nvcc");
        cmd.args(["-ptx", "-arch=sm_86"]);
        cmd.args(opt_flags);
        cmd.args([
            "-I", kernel_dir.to_str().unwrap(),
            "-o", ptx_path.to_str().unwrap(),
            cu_path.to_str().unwrap(),
        ]);

        let output = cmd.output();

        match output {
            Ok(o) if o.status.success() => {}
            Ok(o) => {
                let stderr = String::from_utf8_lossy(&o.stderr);
                panic!("nvcc exited with status {} while compiling {name}.cu\n{stderr}", o.status);
            }
            Err(e) => {
                panic!(
                    "nvcc not found ({e}). Install the CUDA toolkit to use the `cuda` feature."
                );
            }
        }
    }
}

// ── Vulkan / Slang ────────────────────────────────────────────────────────────

fn compile_slang_shaders() {
    // Verify slangc is on PATH with a helpful error.
    let version_check = Command::new("slangc").arg("-v").output();
    if version_check.is_err() || !version_check.unwrap().status.success() {
        panic!(
            "slangc not found on PATH. Install the Slang shader compiler to use the `vulkan` feature.\n\
             Download from: https://github.com/shader-slang/slang/releases\n\
             Then add the bin/ directory to PATH."
        );
    }

    let out_dir    = PathBuf::from(std::env::var("OUT_DIR").unwrap());
    let shader_dir = PathBuf::from("src/backend/vulkan/shaders");
    // Each tuple: (slang source stem, entry point name, output spv stem).
    // em_reduce has three entry points compiled to three separate SPIR-V modules.
    let shaders: &[(&str, &str, &str)] = &[
        ("hello_backend",   "hello_backend_main", "hello_backend"),
        ("em_reduce",       "em_estep",            "em_estep"),
        ("em_reduce",       "em_mstep_partial",    "em_mstep_partial"),
        ("em_reduce",       "em_mstep_final",      "em_mstep_final"),
    ];

    let n = shaders.len();
    for (i, (src_stem, entry, out_stem)) in shaders.iter().enumerate() {
        let slang_path = shader_dir.join(format!("{src_stem}.slang"));
        let spv_path   = out_dir.join(format!("{out_stem}.spv"));

        println!("cargo:rerun-if-changed={}", slang_path.display());

        println!("   Compiling Slang [{}/{n}] {src_stem}.slang [{entry}] → {out_stem}.spv", i + 1);

        let status = Command::new("slangc")
            .args([
                slang_path.to_str().unwrap(),
                "-target",  "spirv",
                "-profile", "spirv_1_5",
                "-entry",   entry,
                "-stage",   "compute",
                "-O3",
                "-matrix-layout-column-major",
                "-o", spv_path.to_str().unwrap(),
            ])
            .status();

        match status {
            Ok(s) if s.success() => {}
            Ok(s) => panic!(
                "slangc failed (exit {s}) compiling {src_stem}.slang entry={entry}"
            ),
            Err(e) => panic!("slangc not found ({e})"),
        }
    }
}