fn main() {
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-env-changed=CARGO_FEATURE_CUSPARSELT");
println!("cargo:rerun-if-env-changed=CARGO_FEATURE_CUDA");
println!("cargo:rerun-if-env-changed=CUSPARSELT_INCLUDE_DIR");
println!("cargo:rerun-if-env-changed=CUSPARSELT_LIB_DIR");
println!("cargo:rerun-if-env-changed=CUDA_PATH");
if std::env::var_os("CARGO_FEATURE_CUSPARSELT").is_some() {
#[cfg(feature = "cusparselt")]
cusparselt::generate();
}
if std::env::var_os("CARGO_FEATURE_CUDA").is_some() && cfg!(target_os = "linux") {
cuda_cusolver_compat::ensure();
}
}
mod cuda_cusolver_compat {
use std::path::{Path, PathBuf};
pub fn ensure() {
let Some(lib) = locate_cusolver_so11() else {
println!(
"cargo:warning=ferrotorch-gpu(cuda): no CUDA 12.x cuSOLVER (libcusolver.so.11*) \
found. The cudarc 12080 pin needs the legacy cusolverDn* symbols that exist only \
in libcusolver.so.11 (CUDA 12.x); a CUDA 13.x libcusolver.so.12 lacks them. \
cusolver tests (cusolver::*) may panic with 'undefined symbol: cusolverDnGeqrf'. \
Install the CUDA 12.8 toolkit or set CUDA_PATH to a CUDA 12.x prefix. Searched \
$CUDA_PATH/targets/x86_64-linux/lib and /usr/local/cuda-12*."
);
return;
};
println!("cargo:rerun-if-changed={}", lib.display());
let out_dir = PathBuf::from(std::env::var_os("OUT_DIR").expect("OUT_DIR set by cargo"));
let compat_dir = out_dir.join("cuda-compat");
if let Err(e) = std::fs::create_dir_all(&compat_dir) {
println!(
"cargo:warning=ferrotorch-gpu(cuda): failed to create compat dir {}: {e}. \
cusolver tests may fail.",
compat_dir.display()
);
return;
}
for name in ["libcusolver.so", "libcusolver.so.11"] {
let link = compat_dir.join(name);
let _ = std::fs::remove_file(&link); if let Err(e) = std::os::unix::fs::symlink(&lib, &link) {
println!(
"cargo:warning=ferrotorch-gpu(cuda): failed to symlink {} -> {}: {e}. \
cusolver tests may fail.",
link.display(),
lib.display()
);
return;
}
}
let compat_dir_str = compat_dir.to_string_lossy();
println!("cargo:rustc-link-arg=-Wl,-rpath,{compat_dir_str}");
println!("cargo:rustc-link-search=native={compat_dir_str}");
}
fn locate_cusolver_so11() -> Option<PathBuf> {
let mut dirs: Vec<PathBuf> = Vec::new();
if let Some(p) = std::env::var_os("CUDA_PATH") {
dirs.push(PathBuf::from(&p).join("targets/x86_64-linux/lib"));
dirs.push(PathBuf::from(&p).join("lib64"));
}
for root in [
"/usr/local/cuda-12.9",
"/usr/local/cuda-12.8",
"/usr/local/cuda-12",
] {
dirs.push(PathBuf::from(root).join("targets/x86_64-linux/lib"));
dirs.push(PathBuf::from(root).join("lib64"));
}
if let Ok(entries) = std::fs::read_dir("/usr/local") {
for entry in entries.flatten() {
let name = entry.file_name();
let name = name.to_string_lossy();
if name.starts_with("cuda-12.") {
dirs.push(entry.path().join("targets/x86_64-linux/lib"));
dirs.push(entry.path().join("lib64"));
}
}
}
for dir in dirs {
if let Some(found) = find_so11_in(&dir) {
return found.canonicalize().ok().or(Some(found));
}
}
None
}
fn find_so11_in(dir: &Path) -> Option<PathBuf> {
let entries = std::fs::read_dir(dir).ok()?;
let mut best: Option<PathBuf> = None;
for entry in entries.flatten() {
let name = entry.file_name();
let name = name.to_string_lossy();
if name.starts_with("libcusolver.so.11") {
match &best {
Some(prev) if prev.file_name().map(|n| n.len()).unwrap_or(0) >= name.len() => {}
_ => best = Some(entry.path()),
}
}
}
best
}
}
#[cfg(feature = "cusparselt")]
mod cusparselt {
use std::path::{Path, PathBuf};
pub fn generate() {
let header = match locate_header() {
Some(p) => p,
None => {
println!(
"cargo:warning=cusparselt feature is enabled but `cusparseLt.h` was not found on this host. Set CUSPARSELT_INCLUDE_DIR to the directory containing cusparseLt.h, or install the NVIDIA cuSPARSELt SDK (https://docs.nvidia.com/cuda/cusparselt/getting_started.html). Searched: $CUSPARSELT_INCLUDE_DIR, $CUDA_PATH/include, /usr/local/cuda/include, /usr/local/cuda-12.*/include, /usr/include, /opt/nvidia/cusparselt/include."
);
panic!(
"ferrotorch-gpu: cusparselt feature requires cusparseLt.h but none of the probed locations contained it. See cargo:warning above for resolution."
);
}
};
if let Ok(dir) = std::env::var("CUSPARSELT_LIB_DIR") {
println!("cargo:rustc-link-search=native={dir}");
}
for candidate in [
"/usr/local/cuda/lib64",
"/usr/local/cuda-12.9/lib64",
"/usr/local/cuda-12.8/lib64",
"/usr/lib64",
"/opt/nvidia/cusparselt/lib64",
] {
if Path::new(candidate).exists() {
println!("cargo:rustc-link-search=native={candidate}");
}
}
println!("cargo:rustc-link-lib=cusparseLt");
println!("cargo:rerun-if-changed={}", header.display());
let header_str = header.to_string_lossy().to_string();
let mut builder = bindgen::Builder::default()
.header(header_str.clone())
.allowlist_function("cusparseLt.*")
.allowlist_type("cusparseLt.*")
.allowlist_var("CUSPARSELT_.*")
.allowlist_var("CUSPARSE_.*")
.allowlist_type("cudaDataType.*")
.allowlist_type("cudaStream_t")
.allowlist_type("cusparseStatus_t")
.allowlist_type("cusparseOperation_t")
.allowlist_type("cusparseComputeType.*")
.allowlist_type("cusparseOrder_t")
.default_enum_style(bindgen::EnumVariation::Rust {
non_exhaustive: false,
})
.derive_default(true)
.derive_debug(true)
.layout_tests(false)
.generate_comments(false);
if let Some(parent) = header.parent() {
builder = builder.clang_arg(format!("-I{}", parent.display()));
}
for path in cuda_include_dirs() {
builder = builder.clang_arg(format!("-I{}", path.display()));
}
let bindings = builder
.generate()
.expect("bindgen failed to generate cusparseLt bindings");
let out_path = PathBuf::from(std::env::var_os("OUT_DIR").expect("OUT_DIR set by cargo"))
.join("cusparselt_sys.rs");
bindings
.write_to_file(&out_path)
.expect("failed to write cusparselt_sys.rs");
}
fn locate_header() -> Option<PathBuf> {
let candidates: Vec<PathBuf> = [
std::env::var_os("CUSPARSELT_INCLUDE_DIR").map(PathBuf::from),
std::env::var_os("CUDA_PATH").map(|p| PathBuf::from(p).join("include")),
Some(PathBuf::from("/usr/local/cuda/include")),
Some(PathBuf::from("/usr/local/cuda-12.9/include")),
Some(PathBuf::from("/usr/local/cuda-12.8/include")),
Some(PathBuf::from("/usr/include")),
Some(PathBuf::from("/opt/nvidia/cusparselt/include")),
]
.into_iter()
.flatten()
.map(|d| d.join("cusparseLt.h"))
.collect();
candidates.into_iter().find(|p| p.exists())
}
fn cuda_include_dirs() -> Vec<PathBuf> {
let mut out = Vec::new();
if let Some(p) = std::env::var_os("CUDA_PATH") {
out.push(PathBuf::from(p).join("include"));
}
for c in [
"/usr/local/cuda/include",
"/usr/local/cuda-12.9/include",
"/usr/local/cuda-12.8/include",
"/usr/include",
] {
let p = PathBuf::from(c);
if p.exists() {
out.push(p);
}
}
out
}
}