xlog-cuda 0.9.2

CUDA kernel provider, buffers, and interop for XLOG
//! Build script for xlog-cuda: compiles CUDA kernels to cubin + portable PTX.

use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;

// Import the canonical kernel list from the shared manifest.
include!("src/kernel_manifest_data.rs");

fn find_nvcc() -> PathBuf {
    // Check NVCC_PATH env var first, then fall back to PATH lookup.
    if let Ok(p) = env::var("NVCC_PATH") {
        let path = PathBuf::from(p);
        if path.exists() {
            return path;
        }
    }
    // Verify nvcc is on PATH.
    match Command::new("nvcc").arg("--version").output() {
        Ok(output) if output.status.success() => PathBuf::from("nvcc"),
        _ => {
            panic!(
                "nvcc not found — required to generate portable PTX/cubin build \
                 artifacts. Install the CUDA toolkit and ensure nvcc is on PATH."
            );
        }
    }
}

fn find_kernel_sources(manifest_dir: &Path) -> PathBuf {
    let packaged_dir = manifest_dir.join("kernels");
    if packaged_dir.is_dir() {
        return packaged_dir;
    }

    let workspace_dir = manifest_dir
        .parent()
        .expect("crate directory must have a parent (workspace crates/)")
        .parent()
        .expect("workspace crates/ directory must have a parent (workspace root)")
        .join("kernels");
    if workspace_dir.is_dir() {
        return workspace_dir;
    }

    panic!(
        "CUDA kernel sources not found. Expected either {} or {}",
        packaged_dir.display(),
        workspace_dir.display()
    );
}

fn write_embedded_kernel_data(out_dir: &Path) {
    let mut source = String::new();
    source.push_str("// @generated by crates/xlog-cuda/build.rs\n");
    source.push_str("// Portable PTX fallback for sidecar-free Cargo installs.\n\n");
    source.push_str("pub struct EmbeddedKernelPtx {\n");
    source.push_str("    pub name: &'static str,\n");
    source.push_str("    pub ptx: &'static str,\n");
    source.push_str("}\n\n");
    source.push_str("pub const EMBEDDED_PORTABLE_PTX: &[EmbeddedKernelPtx] = &[\n");

    for name in KERNEL_CU_NAMES {
        let ptx_path = out_dir.join(format!("{name}.portable.ptx"));
        source.push_str(&format!(
            "    EmbeddedKernelPtx {{ name: {:?}, ptx: include_str!({:?}) }},\n",
            name,
            ptx_path.display().to_string()
        ));
    }

    source.push_str("];\n\n");
    source.push_str("pub fn portable_ptx(name: &str) -> Option<&'static str> {\n");
    source.push_str("    EMBEDDED_PORTABLE_PTX\n");
    source.push_str("        .iter()\n");
    source.push_str("        .find(|artifact| artifact.name == name)\n");
    source.push_str("        .map(|artifact| artifact.ptx)\n");
    source.push_str("}\n");

    fs::write(out_dir.join("embedded_kernel_data.rs"), source)
        .expect("write embedded kernel metadata");
}

fn push_wcoj_register_cap(args: &mut Vec<String>, name: &str) {
    if name == "wcoj" {
        args.push("--maxrregcount=64".to_string());
    }
}

fn main() {
    let nvcc = find_nvcc();

    let manifest_dir =
        env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR must be set by cargo");
    let manifest_dir = PathBuf::from(manifest_dir);
    let kernels_dir = find_kernel_sources(&manifest_dir);
    let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR must be set by cargo"));

    // Environment knobs.
    let no_cubin = env::var("XLOG_NO_CUBIN").map(|v| v == "1").unwrap_or(false);
    let cubin_archs: Vec<String> = env::var("XLOG_CUBIN_ARCHS")
        .unwrap_or_else(|_| "sm_120".to_string())
        .split(',')
        .map(|s| s.trim().to_string())
        .filter(|s| !s.is_empty())
        .collect();

    // Rerun triggers.
    println!("cargo:rerun-if-env-changed=XLOG_NO_CUBIN");
    println!("cargo:rerun-if-env-changed=XLOG_CUBIN_ARCHS");
    println!("cargo:rerun-if-env-changed=NVCC_PATH");

    // Emit the canonical kernel artifact root for packaging/staging helpers.
    // The staging script consumes this OUT_DIR and copies the generated
    // cubins/PTX into the installed layout (<exe_dir>/kernels/).
    // Note: this is NOT propagated as DEP_* to downstream crates (no `links`
    // key in Cargo.toml). Packaging tools should read `cargo:kernel-out-dir=`
    // from `cargo build -vv` output or use OUT_DIR directly.
    println!("cargo:kernel-out-dir={}", out_dir.display());

    for name in KERNEL_CU_NAMES {
        let cu_path = kernels_dir.join(format!("{name}.cu"));
        println!("cargo:rerun-if-changed={}", cu_path.display());

        // (a) Generate cubin per arch (unless XLOG_NO_CUBIN=1).
        if !no_cubin {
            for arch in &cubin_archs {
                let cubin_path = out_dir.join(format!("{name}.{arch}.cubin"));
                let mut args = vec![
                    "--cubin".to_string(),
                    format!("-arch={arch}"),
                    "-O3".to_string(),
                    "-o".to_string(),
                    cubin_path
                        .to_str()
                        .expect("cubin output path must be valid UTF-8")
                        .to_string(),
                    cu_path
                        .to_str()
                        .expect("kernel source path must be valid UTF-8")
                        .to_string(),
                ];
                push_wcoj_register_cap(&mut args, name);
                let status = Command::new(&nvcc)
                    .args(&args)
                    .status()
                    .unwrap_or_else(|e| panic!("failed to run nvcc for {name}.{arch}.cubin: {e}"));

                if !status.success() {
                    panic!("nvcc failed to compile {name}.cu to cubin for {arch}");
                }
            }
        }

        // (b) Always generate portable PTX (sm_75 baseline — lowest arch in CUDA 13+).
        let ptx_path = out_dir.join(format!("{name}.portable.ptx"));
        let mut args = vec![
            "--ptx".to_string(),
            "-arch=sm_75".to_string(),
            "-O3".to_string(),
            "-o".to_string(),
            ptx_path
                .to_str()
                .expect("ptx output path must be valid UTF-8")
                .to_string(),
            cu_path
                .to_str()
                .expect("kernel source path must be valid UTF-8")
                .to_string(),
        ];
        push_wcoj_register_cap(&mut args, name);
        let status = Command::new(&nvcc)
            .args(&args)
            .status()
            .unwrap_or_else(|e| panic!("failed to run nvcc for {name}.portable.ptx: {e}"));

        if !status.success() {
            panic!("nvcc failed to compile {name}.cu to portable PTX");
        }

        // Portability downgrade: the embedded portable PTX is JIT-compiled by the
        // target machine's driver. A newer toolkit (e.g. CUDA 13.x -> PTX ISA 9.x)
        // stamps a `.version` that older drivers reject with
        // CUDA_ERROR_UNSUPPORTED_PTX_VERSION. When XLOG_PTX_MAX_VERSION is set (e.g.
        // "8.4" for CUDA 12.4 drivers), rewrite the `.version` directive down to that
        // ISA. The sm_75 baseline kernels use no ISA-9-only constructs, so this is a
        // sound downgrade (verified offline with the matching ptxas).
        maybe_downgrade_ptx_version(&ptx_path);
    }

    write_embedded_kernel_data(&out_dir);
}

/// Rewrite the `.version` directive of a portable PTX file to the ISA named by
/// XLOG_PTX_MAX_VERSION (no-op when unset or already lower).
fn maybe_downgrade_ptx_version(ptx_path: &Path) {
    println!("cargo:rerun-if-env-changed=XLOG_PTX_MAX_VERSION");
    let target = match env::var("XLOG_PTX_MAX_VERSION") {
        Ok(v) if !v.trim().is_empty() => v.trim().to_string(),
        _ => return,
    };
    let text = fs::read_to_string(ptx_path)
        .unwrap_or_else(|e| panic!("read PTX {}: {e}", ptx_path.display()));
    let mut out = String::with_capacity(text.len());
    let mut rewritten = false;
    for line in text.lines() {
        let trimmed = line.trim_start();
        if !rewritten && trimmed.starts_with(".version ") {
            let indent = &line[..line.len() - trimmed.len()];
            out.push_str(&format!("{indent}.version {target}\n"));
            rewritten = true;
        } else {
            out.push_str(line);
            out.push('\n');
        }
    }
    fs::write(ptx_path, out)
        .unwrap_or_else(|e| panic!("write downgraded PTX {}: {e}", ptx_path.display()));
}