fn main() {
let cuda_path = std::env::var("CUDA_PATH")
.or_else(|_| std::env::var("CUDA_HOME"))
.unwrap_or_else(|_| {
if std::path::Path::new("/usr/local/cuda/lib64/libcudart.so").exists() {
"/usr/local/cuda".to_string()
} else {
let mut candidates: Vec<_> = std::fs::read_dir("/usr/local/")
.into_iter()
.flatten()
.flatten()
.filter(|e| {
e.file_name()
.to_str()
.map_or(false, |n| n.starts_with("cuda-"))
})
.collect();
candidates.sort_by(|a, b| b.file_name().cmp(&a.file_name()));
candidates
.first()
.map(|e| e.path().to_string_lossy().to_string())
.expect("CUDA Toolkit not found. Set CUDA_PATH or CUDA_HOME.")
}
});
let lib_path = format!("{}/lib64", cuda_path);
let nvcc_path = format!("{}/bin/nvcc", cuda_path);
let gpu_arch = detect_gpu_arch();
let nvcc_exists = std::path::Path::new(&nvcc_path).exists()
|| std::process::Command::new("nvcc")
.arg("--version")
.output()
.map(|o| o.status.success())
.unwrap_or(false);
if nvcc_exists {
std::env::set_var("NVCC", &nvcc_path);
let mut build = cc::Build::new();
build.cuda(true).cudart("shared");
let arch = gpu_arch.unwrap_or_else(|| "75".to_string());
let gencode = format!("arch=compute_{},code=sm_{}", arch, arch);
build.flag("-gencode").flag(&gencode);
let gencode_ptx = format!("arch=compute_{},code=compute_{}", arch, arch);
build.flag("-gencode").flag(&gencode_ptx);
build
.file("src/cuda_kernels/autograd.cu")
.compile("tl_cuda_autograd_kernels");
println!(
"cargo:warning=CUDA kernels compiled for sm_{} with {}",
arch, nvcc_path
);
} else {
println!("cargo:warning=nvcc not found at {} or in PATH.", nvcc_path);
}
println!("cargo:rustc-link-search=native={}", lib_path);
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rerun-if-env-changed=CUDA_PATH");
println!("cargo:rerun-if-env-changed=CUDA_HOME");
println!("cargo:rerun-if-changed=src/cuda_kernels/autograd.cu");
}
fn detect_gpu_arch() -> Option<String> {
let output = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv,noheader,nounits")
.output()
.ok()?;
if !output.status.success() {
return None;
}
let cap = String::from_utf8_lossy(&output.stdout);
let cap = cap.trim();
if cap.is_empty() {
return None;
}
let arch = cap.replace('.', "");
println!(
"cargo:warning=Detected GPU compute capability: {} (sm_{})",
cap, arch
);
Some(arch)
}