Skip to main content

cuda_config/
lib.rs

1use glob::glob;
2use std::{env, path::PathBuf};
3
4pub fn read_env() -> Vec<PathBuf> {
5    if let Ok(path) = env::var("CUDA_LIBRARY_PATH") {
6        // The location of the libcuda, libcudart, and libcublas can be hardcoded with the
7        // CUDA_LIBRARY_PATH environment variable.
8        let split_char = if cfg!(target_os = "windows") {
9            ";"
10        } else {
11            ":"
12        };
13        path.split(split_char).map(|s| PathBuf::from(s)).collect()
14    } else {
15        vec![]
16    }
17}
18
19pub fn find_cuda() -> Vec<PathBuf> {
20    let mut candidates = read_env();
21    candidates.push(PathBuf::from("/opt/cuda"));
22    candidates.push(PathBuf::from("/usr/local/cuda"));
23    for e in glob("/usr/local/cuda-*").unwrap() {
24        if let Ok(path) = e {
25            candidates.push(path)
26        }
27    }
28
29    let mut valid_paths = vec![];
30    for base in &candidates {
31        let lib = PathBuf::from(base).join("lib64");
32        if lib.is_dir() {
33            valid_paths.push(lib.clone());
34            valid_paths.push(lib.join("stubs"));
35        }
36        let base = base.join("targets/x86_64-linux");
37        let header = base.join("include/cuda.h");
38        if header.is_file() {
39            valid_paths.push(base.join("lib"));
40            valid_paths.push(base.join("lib/stubs"));
41            continue;
42        }
43    }
44    eprintln!("Found CUDA paths: {:?}", valid_paths);
45    valid_paths
46}
47
48pub fn find_cuda_windows() -> PathBuf {
49    let paths = read_env();
50    if !paths.is_empty() {
51        return paths[0].clone();
52    }
53
54    if let Ok(path) = env::var("CUDA_PATH") {
55        // If CUDA_LIBRARY_PATH is not found, then CUDA_PATH will be used when building for
56        // Windows to locate the Cuda installation. Cuda installs the full Cuda SDK for 64-bit,
57        // but only a limited set of libraries for 32-bit. Namely, it does not include cublas in
58        // 32-bit, which cuda-sys requires.
59
60        // 'path' points to the base of the CUDA Installation. The lib directory is a
61        // sub-directory.
62        let path = PathBuf::from(path);
63
64        // To do this the right way, we check to see which target we're building for.
65        let target = env::var("TARGET")
66            .expect("cargo did not set the TARGET environment variable as required.");
67
68        // Targets use '-' separators. e.g. x86_64-pc-windows-msvc
69        let target_components: Vec<_> = target.as_str().split("-").collect();
70
71        // We check that we're building for Windows. This code assumes that the layout in
72        // CUDA_PATH matches Windows.
73        if target_components[2] != "windows" {
74            panic!(
75                "The CUDA_PATH variable is only used by cuda-sys on Windows. Your target is {}.",
76                target
77            );
78        }
79
80        // Sanity check that the second component of 'target' is "pc"
81        debug_assert_eq!(
82            "pc", target_components[1],
83            "Expected a Windows target to have the second component be 'pc'. Target: {}",
84            target
85        );
86
87        // x86_64 should use the libs in the "lib/x64" directory. If we ever support i686 (which
88        // does not ship with cublas support), its libraries are in "lib/Win32".
89        let lib_path = match target_components[0] {
90            "x86_64" => "x64",
91            "i686" => {
92                // lib path would be "Win32" if we support i686. "cublas" is not present in the
93                // 32-bit install.
94                panic!("Rust cuda-sys does not currently support 32-bit Windows.");
95            }
96            _ => {
97                panic!("Rust cuda-sys only supports the x86_64 Windows architecture.");
98            }
99        };
100
101        // i.e. $CUDA_PATH/lib/x64
102        return path.join("lib").join(lib_path);
103    }
104
105    // No idea where to look for CUDA
106    panic!("CUDA cannot find");
107}