use std::{env, error::Error, path::Path, process::exit, str};
use std::path::PathBuf;
pub fn cuda_toolkit_dir() -> String {
env::var("CUDA_TOOLKIT_PATH").expect("CUDA_TOOLKIT_PATH is required but not set")
}
fn main() {
if let Err(error) = run() {
eprintln!("{}", error);
exit(1);
}
}
fn run() -> Result<(), Box<dyn Error>> {
println!("cargo:rerun-if-changed=wrapper.h");
let toolkit_paths = collect_paths(&cuda_toolkit_dir());
for path in toolkit_paths {
println!("cargo:rustc-link-search=native={}", path.display());
}
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=curand");
bindgen::builder()
.header("wrapper.h")
.clang_arg(format!("-I{}/include", cuda_toolkit_dir()))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate()
.unwrap()
.write_to_file(Path::new(&env::var("OUT_DIR")?).join("bindings.rs"))?;
Ok(())
}
pub fn collect_paths(cuda_toolkit: &str) -> Vec<PathBuf> {
let candidates = vec![PathBuf::from(cuda_toolkit)];
let mut valid_paths = vec![];
for base in &candidates {
let lib = PathBuf::from(base).join("lib64");
if lib.is_dir() {
valid_paths.push(lib.clone());
valid_paths.push(lib.join("stubs"));
}
let base = base.join("targets/x86_64-linux");
let header = base.join("include/cuda.h");
if header.is_file() {
valid_paths.push(base.join("lib"));
valid_paths.push(base.join("lib/stubs"));
continue;
}
}
valid_paths
}