extern crate bindgen;
use std::env;
use std::path::PathBuf;
use bindgen::EnumVariation;
fn find_cuda() -> PathBuf {
let cuda_env = env::var("CUDA_LIBRARY_PATH").ok().unwrap_or(String::from(""));
let mut paths: Vec<PathBuf> = env::split_paths(&cuda_env).collect();
paths.push(PathBuf::from("/usr/local/cuda"));
paths.push(PathBuf::from("/opt/cuda"));
for path in paths {
if path.join("include/nvrtc.h").is_file() {
return path;
}
}
panic!("Cannot find CUDA NVRTC libraries");
}
pub fn read_env() -> Vec<PathBuf> {
if let Ok(path) = env::var("CUDA_LIBRARY_PATH") {
let split_char = if cfg!(target_os = "windows") {
";"
} else {
":"
};
path.split(split_char).map(|s| PathBuf::from(s)).collect()
} else {
vec![]
}
}
fn find_cuda_windows() -> PathBuf {
let paths = read_env();
if !paths.is_empty() {
return paths[0].clone();
}
if let Ok(path) = env::var("CUDA_PATH") {
let path = PathBuf::from(path);
let target = env::var("TARGET")
.expect("cargo did not set the TARGET environment variable as required.");
let target_components: Vec<_> = target.as_str().split("-").collect();
if target_components[2] != "windows" {
panic!(
"The CUDA_PATH variable is only used by cuda-sys on Windows. Your target is {}.",
target
);
}
debug_assert_eq!(
"pc", target_components[1],
"Expected a Windows target to have the second component be 'pc'. Target: {}",
target
);
if path.join("include/nvrtc.h").is_file() {
return path;
}
}
panic!("Cannot find CUDA NVRTC libraries");
}
fn main() {
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
let cuda_path;
if cfg!(target_os = "windows") {
cuda_path = find_cuda_windows()
} else {
cuda_path = find_cuda();
};
bindgen::builder()
.header("nvrtc.h")
.clang_arg(format!("-I{}/include", cuda_path.display()))
.allowlist_recursively(false)
.allowlist_type("^_?nvrtc.*")
.allowlist_var("^_?nvrtc.*")
.allowlist_function("^_?nvrtc.*")
.derive_copy(false)
.default_enum_style(EnumVariation::Rust { non_exhaustive: false })
.generate()
.expect("Unable to generate NVRTC bindings")
.write_to_file(out_path.join("nvrtc_bindings.rs"))
.expect("Unable to write NVRTC bindings");
if cfg!(target_os = "windows") {
println!(
"cargo:rustc-link-search=native={}\\lib\\x64",
cuda_path.display()
);
} else {
println!(
"cargo:rustc-link-search=native={}/lib64",
cuda_path.display()
);
}
#[cfg(feature = "static")] {
println!("cargo:rustc-link-lib=static=nvrtc_static");
println!("cargo:rustc-link-lib=static=nvrtc-builtins_static");
println!("cargo:rustc-link-lib=static=nvptxcompiler_static");
}
#[cfg(not(feature = "static"))]
println!("cargo:rustc-link-lib=dylib=nvrtc");
println!("cargo:rerun-if-changed=build.rs");
}