use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
fn main() {
println!("cargo:rerun-if-changed=src/cuda/cooperative_kernels.cu");
println!("cargo:rerun-if-env-changed=CUDA_PATH");
println!("cargo:rerun-if-env-changed=CUDA_HOME");
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let cooperative_enabled = env::var("CARGO_FEATURE_COOPERATIVE").is_ok();
if !cooperative_enabled {
generate_stub(&out_dir, "Cooperative feature not enabled");
return;
}
match find_nvcc() {
Some(nvcc) => {
println!("cargo:warning=Found nvcc at: {:?}", nvcc);
match compile_cooperative_kernels(&nvcc, &out_dir) {
Ok(()) => {
println!("cargo:rustc-cfg=has_nvcc");
println!("cargo:warning=Cooperative groups kernels compiled successfully");
}
Err(e) => {
println!("cargo:warning=Failed to compile cooperative kernels: {}", e);
generate_stub(&out_dir, &format!("Compilation failed: {}", e));
}
}
}
None => {
println!("cargo:warning=nvcc not found - cooperative groups will use fallback");
generate_stub(&out_dir, "nvcc not found at build time");
}
}
}
fn find_nvcc() -> Option<PathBuf> {
if let Ok(cuda_path) = env::var("CUDA_PATH") {
let nvcc = PathBuf::from(&cuda_path).join("bin").join("nvcc");
if nvcc.exists() {
return Some(nvcc);
}
}
if let Ok(cuda_home) = env::var("CUDA_HOME") {
let nvcc = PathBuf::from(&cuda_home).join("bin").join("nvcc");
if nvcc.exists() {
return Some(nvcc);
}
}
for path in &[
"/usr/local/cuda/bin/nvcc",
"/opt/cuda/bin/nvcc",
"/usr/bin/nvcc",
] {
let p = PathBuf::from(path);
if p.exists() {
return Some(p);
}
}
if let Ok(output) = Command::new("which").arg("nvcc").output() {
if output.status.success() {
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !path.is_empty() {
return Some(PathBuf::from(path));
}
}
}
None
}
fn compile_cooperative_kernels(nvcc: &Path, out_dir: &Path) -> Result<(), String> {
let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());
let cuda_src_path = manifest_dir
.join("src")
.join("cuda")
.join("cooperative_kernels.cu");
if !cuda_src_path.exists() {
return Err(format!("CUDA source not found: {:?}", cuda_src_path));
}
let ptx_file = out_dir.join("cooperative_kernels.ptx");
let status = Command::new(nvcc)
.args([
"-ptx",
"-O3",
"--generate-line-info",
"-arch=sm_89", "-std=c++17",
"-w", "-o",
])
.arg(ptx_file.to_str().unwrap())
.arg(cuda_src_path.to_str().unwrap())
.status()
.map_err(|e| format!("Failed to execute nvcc: {}", e))?;
if !status.success() {
return Err(format!(
"nvcc compilation failed with exit code: {:?}",
status.code()
));
}
let ptx_content =
fs::read_to_string(&ptx_file).map_err(|e| format!("Failed to read PTX: {}", e))?;
let rust_file = out_dir.join("cooperative_kernels.rs");
write_rust_code(
&rust_file,
&ptx_content,
true,
"Cooperative groups compiled successfully with nvcc",
)
.map_err(|e| format!("Failed to write Rust bindings: {}", e))?;
Ok(())
}
fn generate_stub(out_dir: &Path, reason: &str) {
let rust_file = out_dir.join("cooperative_kernels.rs");
write_rust_code(&rust_file, "", false, reason).expect("Failed to write Rust stub");
}
fn write_rust_code(
path: &Path,
ptx: &str,
has_support: bool,
message: &str,
) -> std::io::Result<()> {
let mut code = String::new();
code.push_str("// Auto-generated cooperative kernel PTX.\n");
code.push_str("// Generated by build.rs at build time.\n\n");
code.push_str("/// Pre-compiled PTX for cooperative groups kernels.\n");
code.push_str("/// Contains:\n");
code.push_str("/// - coop_persistent_fdtd: Block-based FDTD with grid.sync()\n");
code.push_str("/// - coop_ring_kernel_entry: Generic cooperative ring kernel\n");
code.push_str("pub const COOPERATIVE_KERNEL_PTX: &str = r####\"");
code.push_str(ptx);
code.push_str("\"####;\n\n");
code.push_str("/// Check if cooperative groups support was compiled.\n");
code.push_str(&format!(
"pub const HAS_COOPERATIVE_SUPPORT: bool = {};\n\n",
has_support
));
code.push_str("/// Build-time message about cooperative support.\n");
let escaped_message = message.replace('\\', "\\\\").replace('"', "\\\"");
code.push_str(&format!(
"pub const COOPERATIVE_BUILD_MESSAGE: &str = \"{}\";\n",
escaped_message
));
fs::write(path, code)
}