#[cfg(not(feature = "cuda_kv"))]
fn main() {}
#[cfg(feature = "cuda_kv")]
fn main() {
use std::{path::PathBuf, process::Command};
println!("cargo:rerun-if-changed=src/kernels/block_copy.cu");
let nvcc = Command::new("which").arg("nvcc").output().unwrap();
let cuda_lib = if nvcc.status.success() {
println!("cargo:info=nvcc found in path");
let nvcc_path = String::from_utf8_lossy(&nvcc.stdout).trim().to_string();
let path = PathBuf::from(nvcc_path);
if let Some(parent) = path.parent() {
if let Some(cuda_root) = parent.parent() {
cuda_root.to_string_lossy().to_string()
} else {
get_cuda_root_or_default()
}
} else {
get_cuda_root_or_default()
}
} else {
println!("cargo:warning=nvcc not found in path");
get_cuda_root_or_default()
};
println!("cargo:info=Using CUDA installation at: {}", cuda_lib);
let cuda_lib_path = PathBuf::from(&cuda_lib).join("lib64");
println!("cargo:info=Using CUDA libs: {}", cuda_lib_path.display());
println!("cargo:rustc-link-search=native={}", cuda_lib_path.display());
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}",
cuda_lib_path.display()
);
std::fs::create_dir_all("src/kernels").unwrap_or_else(|_| {
println!("Kernels directory already exists");
});
let output = Command::new("nvcc")
.arg("src/kernels/block_copy.cu")
.arg("-O3")
.arg("--compiler-options")
.arg("-fPIC")
.arg("-o")
.arg("src/kernels/libblock_copy.o")
.arg("-c")
.output()
.expect("Failed to compile CUDA code");
if !output.status.success() {
panic!(
"Failed to compile CUDA kernel: {}",
String::from_utf8_lossy(&output.stderr)
);
}
#[cfg(target_os = "windows")]
{
Command::new("lib")
.arg("/OUT:src/kernels/block_copy.lib")
.arg("src/kernels/libblock_copy.o")
.output()
.expect("Failed to create static library");
println!("cargo:rustc-link-search=native=src/kernels");
println!("cargo:rustc-link-lib=static=block_copy");
}
#[cfg(not(target_os = "windows"))]
{
Command::new("ar")
.arg("rcs")
.arg("src/kernels/libblock_copy.a")
.arg("src/kernels/libblock_copy.o")
.output()
.expect("Failed to create static library");
println!("cargo:rustc-link-search=native=src/kernels");
println!("cargo:rustc-link-lib=static=block_copy");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
}
}
#[cfg(feature = "cuda_kv")]
fn get_cuda_root_or_default() -> String {
match std::env::var("CUDA_ROOT") {
Ok(path) => path,
Err(_) => {
if cfg!(target_os = "windows") {
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8".to_string()
} else {
"/usr/local/cuda".to_string()
}
}
}
}