use std::{env, process::Command};
#[derive(Copy, Clone)]
#[repr(u8)]
pub enum GpuArchitecture {
Rtx2 = 75,
Rtx3 = 86,
Rtx4 = 89,
Rtx5 = 100,
}
impl GpuArchitecture {
pub fn gencode_val(&self) -> String {
let v = (*self) as u8;
format!("arch=compute_{v},code=sm_{v}")
}
pub fn compute_val(&self) -> String {
let v = (*self) as u8;
format!("-arch=compute_{v}")
}
pub fn sm_val(&self) -> String {
let v = (*self) as u8;
format!("-arch=sm_{v}")
}
}
pub fn build_ptx(min_arch: GpuArchitecture, cuda_files: &[&str], filename: &str) {
if cuda_files.is_empty() {
return;
}
for kernel in cuda_files {
println!("cargo:rerun-if-changed={kernel}");
}
let compilation_result = Command::new("nvcc")
.args([
cuda_files[0],
&min_arch.compute_val(),
"-ptx",
"-O3", "-o",
&format!("{filename}.ptx"),
])
.output();
match compilation_result {
Ok(output) => {
if output.status.success() {
println!("Compiled the following PTX files: {cuda_files:?}");
} else {
panic!(
"CUDA PTX compilation problem:\nstatus: {}\nstdout: {}\nstderr: {}",
output.status,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
}
Err(e) => eprintln!("Unable to compile PTX: {e}"),
}
}
pub fn build_host(min_arch: GpuArchitecture, cuda_files: &[&str], filename: &str) {
for kernel in cuda_files {
println!("cargo:rerun-if-changed={kernel}");
}
let cuda_path = env::var("CUDA_PATH").unwrap_or_else(|_| "/usr/local/cuda".into());
let mut build = cc::Build::new();
build
.cuda(true)
.file(cuda_files[0])
.flag("-O3")
.flag("-std=c++20");
build.flag(min_arch.sm_val());
if cfg!(target_os = "linux") {
build.flag("-Xcompiler=-fPIC");
}
build.compile(filename);
#[cfg(target_os = "windows")]
{
println!("cargo:rustc-link-search=native={}\\lib\\x64", cuda_path);
println!("cargo:rustc-link-lib=cufft");
}
#[cfg(target_os = "linux")]
{
println!("cargo:rustc-link-search=native={}/lib64", cuda_path);
println!("cargo:rustc-link-lib=cufft");
}
}