use std::path::PathBuf;
const TYPICAL_CUDA_PATH_ENV_VARS: [&str; 5] = [
"CUDA_HOME",
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
const SUPPORTED_CUDA_VERSIONS: &[((usize, usize), bool)] = &[
((13, 2), cfg!(feature = "cuda-13020")),
((13, 1), cfg!(feature = "cuda-13010")),
((13, 0), cfg!(feature = "cuda-13000")),
((12, 9), cfg!(feature = "cuda-12090")),
((12, 8), cfg!(feature = "cuda-12080")),
((12, 6), cfg!(feature = "cuda-12060")),
((12, 5), cfg!(feature = "cuda-12050")),
((12, 4), cfg!(feature = "cuda-12040")),
((12, 3), cfg!(feature = "cuda-12030")),
((12, 2), cfg!(feature = "cuda-12020")),
((12, 1), cfg!(feature = "cuda-12010")),
((12, 0), cfg!(feature = "cuda-12000")),
((11, 8), cfg!(feature = "cuda-11080")),
((11, 7), cfg!(feature = "cuda-11070")),
((11, 6), cfg!(feature = "cuda-11060")),
((11, 5), cfg!(feature = "cuda-11050")),
((11, 4), cfg!(feature = "cuda-11040")),
];
const SUPPORTED_NCCL_VERSIONS: &[((u32, u32), bool)] = &[
((2, 30), cfg!(feature = "nccl-02030")),
((2, 29), cfg!(feature = "nccl-02029")),
((2, 28), cfg!(feature = "nccl-02028")),
((2, 27), cfg!(feature = "nccl-02027")),
((2, 26), cfg!(feature = "nccl-02026")),
((2, 25), cfg!(feature = "nccl-02025")),
((2, 24), cfg!(feature = "nccl-02024")),
((2, 22), cfg!(feature = "nccl-02022")),
];
const SUPPORTED_CUDNN_VERSIONS: &[((u32, u32), bool)] = &[
((9, 21), cfg!(feature = "cudnn-09021")),
((9, 10), cfg!(feature = "cudnn-09010")),
((8, 9), cfg!(feature = "cudnn-08009")),
];
const SUPPORTED_CUTENSOR_VERSIONS: &[((u32, u32), bool)] = &[
((2, 6), cfg!(feature = "cutensor-02006")),
((2, 5), cfg!(feature = "cutensor-02005")),
((2, 4), cfg!(feature = "cutensor-02004")),
((2, 3), cfg!(feature = "cutensor-02003")),
];
fn detect_version_from_env() -> Option<(usize, usize)> {
match std::env::var("CUDARC_CUDA_VERSION") {
Ok(version) => {
let version = version.as_str();
for &((major, minor), _) in SUPPORTED_CUDA_VERSIONS.iter() {
if version == format!("{major}0{minor}0") {
return Some((major, minor));
}
}
panic!("Unsupported cuda toolkit version: `$CUDARC_CUDA_VERSION={version}`. Please raise a github issue.")
}
_ => None,
}
}
fn detect_version_from_feature() -> Option<(usize, usize)> {
for &((major, minor), is_feature_set) in SUPPORTED_CUDA_VERSIONS.iter() {
if is_feature_set {
return Some((major, minor));
}
}
None
}
fn main() {
#[cfg(all(
not(feature = "dynamic-linking"),
not(feature = "static-linking"),
not(feature = "dynamic-loading"),
not(feature = "fallback-dynamic-loading")
))]
panic!("None between `dynamic-loading`, `fallback-dynamic-loading`, `dynamic-linking` and `static-linking` features are active, this is a bug");
#[cfg(all(feature = "dynamic-linking", feature = "static-linking"))]
panic!("Both `dynamic-linking` and `static-linking` features are active, this is a bug");
#[cfg(all(feature = "dynamic-loading", feature = "static-linking"))]
panic!("Both `dynamic-loading` and `static-linking` features are active, this is a bug");
#[cfg(all(feature = "dynamic-loading", feature = "dynamic-linking"))]
panic!("Both `dynamic-loading` and `dynamic-linking` features are active, this is a bug");
#[cfg(all(
feature = "fallback-dynamic-loading",
not(any(
feature = "dynamic-loading",
feature = "dynamic-linking",
feature = "static-linking"
))
))]
println!("cargo:rustc-cfg=feature=\"dynamic-loading\"");
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-env-changed=CUDARC_CUDA_VERSION");
TYPICAL_CUDA_PATH_ENV_VARS
.iter()
.for_each(|var| println!("cargo:rerun-if-env-changed={var}"));
let (major, minor): (usize, usize) = if let Some((major, minor)) = detect_version_from_env() {
println!("cargo:rustc-cfg=feature=\"cuda-{major}0{minor}0\"");
(major, minor)
} else if let Some((major, minor)) = detect_version_from_feature() {
(major, minor)
} else {
#[cfg(not(feature = "cuda-version-from-build-system"))]
panic!("Must specify one of the following features: [cuda-version-from-build-system, cuda-13020, cuda-13010, cuda-13000, cuda-12090, cuda-12080, cuda-12060, cuda-12050, cuda-12040, cuda-12030, cuda-12020, cuda-12010, cuda-12000, cuda-11080, cuda-11070, cuda-11060, cuda-11050, cuda-11040]");
#[cfg(feature = "cuda-version-from-build-system")]
{
let (major, minor) = cuda_version_from_build_system();
println!("cargo:rustc-cfg=feature=\"cuda-{major}0{minor}0\"");
(major, minor)
}
};
println!("cargo:rustc-env=CUDA_MAJOR_VERSION={major}");
println!("cargo:rustc-env=CUDA_MINOR_VERSION={minor}");
#[cfg(feature = "nccl-version-from-build-system")]
{
let (major, minor) = nccl_version_from_build_system();
println!("cargo:rustc-cfg=feature=\"nccl-{major:02}{minor:03}\"");
}
#[cfg(feature = "cudnn-version-from-build-system")]
{
let (major, minor) = cudnn_version_from_build_system();
println!("cargo:rustc-cfg=feature=\"cudnn-{major:02}{minor:03}\"");
}
#[cfg(feature = "cutensor-version-from-build-system")]
{
let (major, minor) = cutensor_version_from_build_system();
println!("cargo:rustc-cfg=feature=\"cutensor-{major:02}{minor:03}\"");
}
#[cfg(feature = "dynamic-linking")]
dynamic_linking(major, minor);
#[cfg(feature = "static-linking")]
static_linking(major, minor);
}
#[allow(unused)]
fn cuda_version_from_build_system() -> (usize, usize) {
let output = match std::process::Command::new("nvcc").arg("--version").output() {
Ok(output) if output.status.success() => output,
output_result => {
#[cfg(feature = "fallback-latest")]
{
let latest = SUPPORTED_CUDA_VERSIONS[0].0;
println!("cargo:warning=Failed to run `nvcc --version`. Following `-F fallback-latest`; using CUDA {latest:?}.");
return latest;
}
#[cfg(not(feature = "fallback-latest"))]
panic!("`nvcc --version` failed.\n{output_result:?}");
}
};
let stdout = String::from_utf8_lossy(&output.stdout);
let version_line = stdout.lines().nth(3).unwrap();
let release_section = version_line.split(", ").nth(1).unwrap();
let version_number = release_section.split(' ').nth(1).unwrap();
for &((major, minor), _) in SUPPORTED_CUDA_VERSIONS.iter() {
if version_number == format!("{major}.{minor}") {
return (major, minor);
}
}
panic!("Unsupported cuda toolkit version: `{version_number}`. Please raise a github issue.")
}
#[allow(unused)]
fn dynamic_linking(major: usize, minor: usize) {
for path in link_searches(major, minor) {
println!("cargo:rustc-link-search=native={}", path.display());
}
#[cfg(feature = "driver")]
println!("cargo:rustc-link-lib=dylib=cuda");
#[cfg(any(
feature = "nccl",
feature = "nccl-version-from-build-system",
feature = "nccl-02022",
feature = "nccl-02024",
feature = "nccl-02025",
feature = "nccl-02026",
feature = "nccl-02027",
feature = "nccl-02028",
feature = "nccl-02029",
feature = "nccl-02030",
))]
println!("cargo:rustc-link-lib=dylib=nccl");
#[cfg(feature = "nvrtc")]
println!("cargo:rustc-link-lib=dylib=nvrtc");
#[cfg(feature = "curand")]
println!("cargo:rustc-link-lib=dylib=curand");
#[cfg(feature = "cublas")]
println!("cargo:rustc-link-lib=dylib=cublas");
#[cfg(any(feature = "cublas", feature = "cublaslt"))]
println!("cargo:rustc-link-lib=dylib=cublasLt");
#[cfg(feature = "cupti")]
println!("cargo:rustc-link-lib=dylib=cupti");
#[cfg(feature = "cusparse")]
println!("cargo:rustc-link-lib=dylib=cusparse");
#[cfg(feature = "cusolver")]
println!("cargo:rustc-link-lib=dylib=cusolver");
#[cfg(feature = "cusolvermg")]
println!("cargo:rustc-link-lib=dylib=cusolverMg");
#[cfg(any(
feature = "cudnn",
feature = "cudnn-version-from-build-system",
feature = "cudnn-08009",
feature = "cudnn-09010",
feature = "cudnn-09021",
))]
println!("cargo:rustc-link-lib=dylib=cudnn");
#[cfg(feature = "runtime")]
println!("cargo:rustc-link-lib=dylib=cudart");
#[cfg(feature = "cufile")]
{
println!("cargo:rustc-link-lib=dylib=cufile");
println!("cargo:rustc-link-lib=dylib=cufile_rdma");
}
#[cfg(feature = "nvtx")]
println!("cargo:rustc-link-lib=dylib=nvToolsExt");
#[cfg(any(
feature = "cutensor",
feature = "cutensor-version-from-build-system",
feature = "cutensor-02003",
feature = "cutensor-02004",
feature = "cutensor-02005",
feature = "cutensor-02006",
))]
println!("cargo:rustc-link-lib=dylib=cutensor");
}
#[allow(unused)]
fn static_linking(major: usize, minor: usize) {
for path in link_searches(major, minor) {
println!("cargo:rustc-link-search=native={}", path.display());
}
println!("cargo:rustc-link-lib=static:+whole-archive=stdc++");
#[cfg(any(feature = "driver", feature = "runtime"))]
{
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=static:+whole-archive=cudart_static");
}
#[cfg(any(
feature = "nccl",
feature = "nccl-version-from-build-system",
feature = "nccl-02022",
feature = "nccl-02024",
feature = "nccl-02025",
feature = "nccl-02026",
feature = "nccl-02027",
feature = "nccl-02028",
feature = "nccl-02029",
feature = "nccl-02030",
))]
println!("cargo:rustc-link-lib=static:+whole-archive=nccl_static");
#[cfg(feature = "nvrtc")]
{
println!("cargo:rustc-link-lib=static:+whole-archive=nvrtc_static");
println!("cargo:rustc-link-lib=static:+whole-archive=nvptxcompiler_static");
println!("cargo:rustc-link-lib=static:+whole-archive=nvrtc-builtins_static");
}
#[cfg(any(
feature = "curand",
feature = "cublas",
feature = "cublaslt",
feature = "cusparse",
feature = "cusolver"
))]
println!("cargo:rustc-link-lib=static:+whole-archive=culibos");
#[cfg(feature = "curand")]
println!("cargo:rustc-link-lib=static:+whole-archive=curand_static");
#[cfg(feature = "cublas")]
println!("cargo:rustc-link-lib=static:+whole-archive=cublas_static");
#[cfg(any(feature = "cublas", feature = "cublaslt"))]
println!("cargo:rustc-link-lib=static:+whole-archive=cublasLt_static");
#[cfg(feature = "cupti")]
println!("cargo:rustc-link-lib=static:+whole-archive=cupti_static");
#[cfg(feature = "cusparse")]
println!("cargo:rustc-link-lib=static:+whole-archive=cusparse_static");
#[cfg(feature = "cusolver")]
{
println!("cargo:rustc-link-lib=static:+whole-archive=cusolver_static");
println!("cargo:rustc-link-lib=static:+whole-archive=cusolver_lapack_static");
println!("cargo:rustc-link-lib=static:+whole-archive=cusolver_metis_static");
}
#[cfg(feature = "cusolvermg")]
println!("cargo:rustc-link-lib=dylib=cusolverMg");
#[cfg(any(
feature = "cudnn",
feature = "cudnn-version-from-build-system",
feature = "cudnn-08009",
feature = "cudnn-09010",
feature = "cudnn-09021",
))]
println!("cargo:rustc-link-lib=static:+whole-archive=cudnn");
#[cfg(feature = "cufile")]
{
println!("cargo:rustc-link-lib=static:+whole-archive=cufile_static");
println!("cargo:rustc-link-lib=static:+whole-archive=cufile_rdma_static");
}
#[cfg(feature = "nvtx")]
println!("cargo:rustc-link-lib=dylib=nvToolsExt");
#[cfg(any(
feature = "cutensor",
feature = "cutensor-version-from-build-system",
feature = "cutensor-02003",
feature = "cutensor-02004",
feature = "cutensor-02005",
feature = "cutensor-02006",
))]
println!("cargo:rustc-link-lib=static:+whole-archive=cutensor_static");
}
#[allow(unused)]
fn link_searches(major: usize, minor: usize) -> Vec<PathBuf> {
let env_vars = TYPICAL_CUDA_PATH_ENV_VARS
.iter()
.map(std::env::var)
.filter_map(Result::ok)
.collect::<Vec<_>>();
#[cfg(feature = "dynamic-linking")]
if env_vars.is_empty() && std::env::var("CONDA_PREFIX").is_ok() {
println!("cargo::warning=Detected $CONDA_PREFIX, but no CUDA path was set through one of: {TYPICAL_CUDA_PATH_ENV_VARS:?}. Linking to system CUDA libraries; linker errors may occur. To use CUDA installed via conda please ensure the environment contains all required dependencies (e.g. the \"cuda-driver-dev\") and retry building with CUDA_HOME=$CONDA_PREFIX.")
}
let mut typical_locations: Vec<String> = vec![
"/usr".into(),
"/usr/local/cuda".into(),
"/opt/cuda".into(),
"/usr/lib/cuda".into(),
"C:/Program Files/NVIDIA GPU Computing Toolkit".into(),
"C:/Program Files/NVIDIA".into(),
"C:/CUDA".into(),
];
for &((cudnn_major, cudnn_minor), active) in SUPPORTED_CUDNN_VERSIONS {
if active {
typical_locations.push(format!(
"C:/Program Files/NVIDIA/CUDNN/v{cudnn_major}.{cudnn_minor}"
));
}
}
let possible_locations = if env_vars.is_empty() {
typical_locations
} else {
env_vars
};
let mut candidates = Vec::new();
for root in possible_locations.into_iter().map(Into::<PathBuf>::into) {
candidates.extend(
[
"lib".into(),
"lib/stubs".into(),
"lib/x64".into(),
"lib/Win32".into(),
"lib/x86_64".into(),
"lib/x86_64-linux-gnu".into(),
"lib64".into(),
"lib64/stubs".into(),
"targets/x86_64-linux".into(),
"targets/x86_64-linux/lib".into(),
"targets/x86_64-linux/lib/stubs".into(),
std::format!("lib/{major}.{minor}/x64"),
std::format!("lib/{major}.{minor}/x86_64"),
]
.iter()
.map(|p| root.join(p))
.filter(|p| p.is_dir()),
)
}
candidates
}
#[allow(unused)]
fn nccl_version_from_build_system() -> (u32, u32) {
let Some(header) = find_header("nccl.h") else {
#[cfg(feature = "fallback-latest")]
{
let latest = SUPPORTED_NCCL_VERSIONS[0].0;
println!(
"cargo:warning=nccl.h not found. Following `-F fallback-latest`; using NCCL {}.{}.",
latest.0, latest.1
);
return latest;
}
#[cfg(not(feature = "fallback-latest"))]
panic!("nccl.h not found — install NCCL development headers or set NCCL_HOME");
};
let major = parse_define(&header, "NCCL_MAJOR").expect("NCCL_MAJOR not found in nccl.h");
let minor = parse_define(&header, "NCCL_MINOR").expect("NCCL_MINOR not found in nccl.h");
if !SUPPORTED_NCCL_VERSIONS
.iter()
.any(|&(v, _)| v == (major, minor))
{
#[cfg(feature = "fallback-latest")]
return SUPPORTED_NCCL_VERSIONS[0].0;
#[cfg(not(feature = "fallback-latest"))]
panic!("Unsupported NCCL version {major}.{minor}. Please raise a github issue.");
}
(major, minor)
}
#[allow(unused)]
fn cudnn_version_from_build_system() -> (u32, u32) {
let Some(header) = find_header("cudnn_version.h").or_else(|| find_header("cudnn.h")) else {
#[cfg(feature = "fallback-latest")]
{
let latest = SUPPORTED_CUDNN_VERSIONS[0].0;
println!("cargo:warning=cudnn_version.h not found. Following `-F fallback-latest`; using cuDNN {}.{}.", latest.0, latest.1);
return latest;
}
#[cfg(not(feature = "fallback-latest"))]
panic!("cudnn_version.h not found — install cuDNN development headers");
};
let major = parse_define(&header, "CUDNN_MAJOR").expect("CUDNN_MAJOR not found");
let minor = parse_define(&header, "CUDNN_MINOR").expect("CUDNN_MINOR not found");
if !SUPPORTED_CUDNN_VERSIONS
.iter()
.any(|&(v, _)| v == (major, minor))
{
#[cfg(feature = "fallback-latest")]
return SUPPORTED_CUDNN_VERSIONS[0].0;
#[cfg(not(feature = "fallback-latest"))]
panic!("Unsupported cuDNN version {major}.{minor}). Please raise a github issue.");
}
(major, minor)
}
#[allow(unused)]
fn cutensor_version_from_build_system() -> (u32, u32) {
let Some(header) = find_header("cutensor.h") else {
#[cfg(feature = "fallback-latest")]
{
let latest = SUPPORTED_CUTENSOR_VERSIONS[0].0;
println!("cargo:warning=cutensor.h not found. Following `-F fallback-latest`; using cuTENSOR {}.{}.", latest.0, latest.1);
return latest;
}
#[cfg(not(feature = "fallback-latest"))]
panic!("cutensor.h not found — install cuTENSOR development headers");
};
let major = parse_define(&header, "CUTENSOR_MAJOR").expect("CUTENSOR_MAJOR not found");
let minor = parse_define(&header, "CUTENSOR_MINOR").expect("CUTENSOR_MINOR not found");
if !SUPPORTED_CUTENSOR_VERSIONS
.iter()
.any(|&(v, _)| v == (major, minor))
{
#[cfg(feature = "fallback-latest")]
return SUPPORTED_CUTENSOR_VERSIONS[0].0;
#[cfg(not(feature = "fallback-latest"))]
panic!("Unsupported cuTENSOR version {major}.{minor}). Please raise a github issue.");
}
(major, minor)
}
fn find_header(filename: &str) -> Option<std::path::PathBuf> {
let mut search_dirs: Vec<std::path::PathBuf> = Vec::new();
for var in &["CUDA_PATH", "CUDA_HOME", "CUDNN_PATH", "CUTENSOR_PATH"] {
if let Ok(val) = std::env::var(var) {
search_dirs.push(std::path::PathBuf::from(val).join("include"));
}
}
for &((major, minor), active) in SUPPORTED_CUDNN_VERSIONS {
if active {
search_dirs.push(std::path::PathBuf::from(format!(
"C:/Program Files/NVIDIA/CUDNN/v{major}.{minor}/include"
)));
}
}
for dir in &[
"/usr/local/cuda/include",
"/usr/local/include",
"/opt/cuda/include",
"/usr/include",
] {
search_dirs.push(std::path::Path::new(dir).to_path_buf());
}
for dir in search_dirs {
let path = dir.join(filename);
if path.exists() {
return Some(path);
}
}
None
}
fn parse_define(header: &std::path::Path, name: &str) -> Option<u32> {
let content = std::fs::read_to_string(header).ok()?;
for line in content.lines().map(str::trim) {
if let Some(name_value) = line.strip_prefix("#define").map(str::trim) {
if let Some(value) = name_value.strip_prefix(name).map(str::trim) {
if let Ok(v) = value.parse() {
return Some(v);
}
}
}
}
None
}