use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
include!("src/kernel_manifest_data.rs");
fn find_nvcc() -> PathBuf {
if let Ok(p) = env::var("NVCC_PATH") {
let path = PathBuf::from(p);
if path.exists() {
return path;
}
}
match Command::new("nvcc").arg("--version").output() {
Ok(output) if output.status.success() => PathBuf::from("nvcc"),
_ => {
panic!(
"nvcc not found — required to generate portable PTX/cubin build \
artifacts. Install the CUDA toolkit and ensure nvcc is on PATH."
);
}
}
}
fn find_kernel_sources(manifest_dir: &Path) -> PathBuf {
let packaged_dir = manifest_dir.join("kernels");
if packaged_dir.is_dir() {
return packaged_dir;
}
let workspace_dir = manifest_dir
.parent()
.expect("crate directory must have a parent (workspace crates/)")
.parent()
.expect("workspace crates/ directory must have a parent (workspace root)")
.join("kernels");
if workspace_dir.is_dir() {
return workspace_dir;
}
panic!(
"CUDA kernel sources not found. Expected either {} or {}",
packaged_dir.display(),
workspace_dir.display()
);
}
fn write_embedded_kernel_data(out_dir: &Path) {
let mut source = String::new();
source.push_str("// @generated by crates/xlog-cuda/build.rs\n");
source.push_str("// Portable PTX fallback for sidecar-free Cargo installs.\n\n");
source.push_str("pub struct EmbeddedKernelPtx {\n");
source.push_str(" pub name: &'static str,\n");
source.push_str(" pub ptx: &'static str,\n");
source.push_str("}\n\n");
source.push_str("pub const EMBEDDED_PORTABLE_PTX: &[EmbeddedKernelPtx] = &[\n");
for name in KERNEL_CU_NAMES {
let ptx_path = out_dir.join(format!("{name}.portable.ptx"));
source.push_str(&format!(
" EmbeddedKernelPtx {{ name: {:?}, ptx: include_str!({:?}) }},\n",
name,
ptx_path.display().to_string()
));
}
source.push_str("];\n\n");
source.push_str("pub fn portable_ptx(name: &str) -> Option<&'static str> {\n");
source.push_str(" EMBEDDED_PORTABLE_PTX\n");
source.push_str(" .iter()\n");
source.push_str(" .find(|artifact| artifact.name == name)\n");
source.push_str(" .map(|artifact| artifact.ptx)\n");
source.push_str("}\n");
fs::write(out_dir.join("embedded_kernel_data.rs"), source)
.expect("write embedded kernel metadata");
}
fn push_wcoj_register_cap(args: &mut Vec<String>, name: &str) {
if name == "wcoj" {
args.push("--maxrregcount=64".to_string());
}
}
fn main() {
let nvcc = find_nvcc();
let manifest_dir =
env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR must be set by cargo");
let manifest_dir = PathBuf::from(manifest_dir);
let kernels_dir = find_kernel_sources(&manifest_dir);
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR must be set by cargo"));
let no_cubin = env::var("XLOG_NO_CUBIN").map(|v| v == "1").unwrap_or(false);
let cubin_archs: Vec<String> = env::var("XLOG_CUBIN_ARCHS")
.unwrap_or_else(|_| "sm_120".to_string())
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
println!("cargo:rerun-if-env-changed=XLOG_NO_CUBIN");
println!("cargo:rerun-if-env-changed=XLOG_CUBIN_ARCHS");
println!("cargo:rerun-if-env-changed=NVCC_PATH");
println!("cargo:kernel-out-dir={}", out_dir.display());
for name in KERNEL_CU_NAMES {
let cu_path = kernels_dir.join(format!("{name}.cu"));
println!("cargo:rerun-if-changed={}", cu_path.display());
if !no_cubin {
for arch in &cubin_archs {
let cubin_path = out_dir.join(format!("{name}.{arch}.cubin"));
let mut args = vec![
"--cubin".to_string(),
format!("-arch={arch}"),
"-O3".to_string(),
"-o".to_string(),
cubin_path
.to_str()
.expect("cubin output path must be valid UTF-8")
.to_string(),
cu_path
.to_str()
.expect("kernel source path must be valid UTF-8")
.to_string(),
];
push_wcoj_register_cap(&mut args, name);
let status = Command::new(&nvcc)
.args(&args)
.status()
.unwrap_or_else(|e| panic!("failed to run nvcc for {name}.{arch}.cubin: {e}"));
if !status.success() {
panic!("nvcc failed to compile {name}.cu to cubin for {arch}");
}
}
}
let ptx_path = out_dir.join(format!("{name}.portable.ptx"));
let mut args = vec![
"--ptx".to_string(),
"-arch=sm_75".to_string(),
"-O3".to_string(),
"-o".to_string(),
ptx_path
.to_str()
.expect("ptx output path must be valid UTF-8")
.to_string(),
cu_path
.to_str()
.expect("kernel source path must be valid UTF-8")
.to_string(),
];
push_wcoj_register_cap(&mut args, name);
let status = Command::new(&nvcc)
.args(&args)
.status()
.unwrap_or_else(|e| panic!("failed to run nvcc for {name}.portable.ptx: {e}"));
if !status.success() {
panic!("nvcc failed to compile {name}.cu to portable PTX");
}
maybe_downgrade_ptx_version(&ptx_path);
}
write_embedded_kernel_data(&out_dir);
}
fn maybe_downgrade_ptx_version(ptx_path: &Path) {
println!("cargo:rerun-if-env-changed=XLOG_PTX_MAX_VERSION");
let target = match env::var("XLOG_PTX_MAX_VERSION") {
Ok(v) if !v.trim().is_empty() => v.trim().to_string(),
_ => return,
};
let text = fs::read_to_string(ptx_path)
.unwrap_or_else(|e| panic!("read PTX {}: {e}", ptx_path.display()));
let mut out = String::with_capacity(text.len());
let mut rewritten = false;
for line in text.lines() {
let trimmed = line.trim_start();
if !rewritten && trimmed.starts_with(".version ") {
let indent = &line[..line.len() - trimmed.len()];
out.push_str(&format!("{indent}.version {target}\n"));
rewritten = true;
} else {
out.push_str(line);
out.push('\n');
}
}
fs::write(ptx_path, out)
.unwrap_or_else(|e| panic!("write downgraded PTX {}: {e}", ptx_path.display()));
}