Skip to main content

baracuda_core/
platform.rs

1//! Platform detection and default library search paths.
2
3use std::path::PathBuf;
4
5/// Broad host-OS classification used by the loader.
6#[derive(Copy, Clone, Debug, Eq, PartialEq)]
7#[non_exhaustive]
8pub enum OsFamily {
9    Linux,
10    Windows,
11    /// Anything else (e.g. macOS). baracuda refuses to load on these.
12    Unsupported,
13}
14
15/// Detect the current OS family at runtime (cheap; branches on `cfg!`).
16pub const fn os_family() -> OsFamily {
17    if cfg!(target_os = "linux") {
18        OsFamily::Linux
19    } else if cfg!(target_os = "windows") {
20        OsFamily::Windows
21    } else {
22        OsFamily::Unsupported
23    }
24}
25
26/// `true` when running under the Windows Subsystem for Linux (WSL2).
27///
28/// Cheap probe: reads `/proc/version`. Safe to call on non-Linux hosts; it
29/// simply returns `false`.
30pub fn is_wsl2() -> bool {
31    if !matches!(os_family(), OsFamily::Linux) {
32        return false;
33    }
34    std::fs::read_to_string("/proc/version")
35        .map(|s| {
36            let s = s.to_ascii_lowercase();
37            s.contains("microsoft") || s.contains("wsl")
38        })
39        .unwrap_or(false)
40}
41
42/// Directories the loader will search for NVIDIA shared libraries.
43///
44/// Order (first hit wins):
45///
46/// 1. `$CUDA_PATH`, `$CUDA_HOME`, `$CUDA_ROOT`, `$CUDA_TOOLKIT_ROOT_DIR`
47///    (each joined with the OS-appropriate `lib`/`bin` subdirectory).
48/// 2. OS defaults: `/usr/local/cuda/*`, `/usr/local/cuda/compat`,
49///    `/usr/lib/x86_64-linux-gnu`, `/usr/lib/wsl/lib` on Linux;
50///    `%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v*\bin` on Windows.
51///
52/// The returned paths are candidates only; the loader silently skips any
53/// that don't exist.
54pub fn library_search_paths() -> Vec<PathBuf> {
55    let mut paths = Vec::new();
56
57    for var in [
58        "CUDA_PATH",
59        "CUDA_HOME",
60        "CUDA_ROOT",
61        "CUDA_TOOLKIT_ROOT_DIR",
62    ] {
63        if let Ok(raw) = std::env::var(var) {
64            let base = PathBuf::from(raw);
65            push_os_subdirs(&base, &mut paths);
66        }
67    }
68
69    match os_family() {
70        OsFamily::Linux => {
71            for base in ["/usr/local/cuda", "/opt/cuda"] {
72                push_os_subdirs(&PathBuf::from(base), &mut paths);
73            }
74            paths.push(PathBuf::from("/usr/local/cuda/compat"));
75            paths.push(PathBuf::from("/usr/lib/x86_64-linux-gnu"));
76            paths.push(PathBuf::from("/usr/lib/aarch64-linux-gnu"));
77            paths.push(PathBuf::from("/usr/lib/wsl/lib"));
78            paths.push(PathBuf::from("/lib/wsl/lib"));
79        }
80        OsFamily::Windows => {
81            if let Ok(pf) = std::env::var("ProgramFiles") {
82                let toolkit = PathBuf::from(pf)
83                    .join("NVIDIA GPU Computing Toolkit")
84                    .join("CUDA");
85                // We can't glob at const time; caller pairs this with specific
86                // CUDA_PATH_V12_6 env vars. Keep the prefix as a hint.
87                paths.push(toolkit);
88            }
89            for var in [
90                "CUDA_PATH_V13_0",
91                "CUDA_PATH_V12_8",
92                "CUDA_PATH_V12_6",
93                "CUDA_PATH_V12_3",
94                "CUDA_PATH_V12_0",
95                "CUDA_PATH_V11_8",
96                "CUDA_PATH_V11_4",
97            ] {
98                if let Ok(raw) = std::env::var(var) {
99                    let base = PathBuf::from(raw);
100                    push_os_subdirs(&base, &mut paths);
101                }
102            }
103        }
104        OsFamily::Unsupported => {}
105    }
106
107    paths
108}
109
110fn push_os_subdirs(base: &std::path::Path, out: &mut Vec<PathBuf>) {
111    match os_family() {
112        OsFamily::Linux => {
113            out.push(base.join("lib64"));
114            out.push(base.join("lib"));
115            out.push(base.join("targets/x86_64-linux/lib"));
116            out.push(base.join("lib/stubs"));
117            out.push(base.join("lib64/stubs"));
118        }
119        OsFamily::Windows => {
120            out.push(base.join("bin"));
121            out.push(base.join("lib").join("x64"));
122        }
123        OsFamily::Unsupported => {}
124    }
125}
126
127/// The most common `libcuda` filenames to probe, in preference order.
128pub const fn driver_library_candidates() -> &'static [&'static str] {
129    #[cfg(target_os = "linux")]
130    {
131        &["libcuda.so.1", "libcuda.so"]
132    }
133    #[cfg(target_os = "windows")]
134    {
135        &["nvcuda.dll"]
136    }
137    #[cfg(not(any(target_os = "linux", target_os = "windows")))]
138    {
139        &[]
140    }
141}
142
143/// The most common `libcudart` filenames to probe, in preference order.
144pub const fn runtime_library_candidates() -> &'static [&'static str] {
145    #[cfg(target_os = "linux")]
146    {
147        &[
148            "libcudart.so.13",
149            "libcudart.so.12",
150            "libcudart.so.11.0",
151            "libcudart.so",
152        ]
153    }
154    #[cfg(target_os = "windows")]
155    {
156        &[
157            "cudart64_13.dll",
158            "cudart64_12.dll",
159            "cudart64_110.dll",
160            "cudart64_101.dll",
161        ]
162    }
163    #[cfg(not(any(target_os = "linux", target_os = "windows")))]
164    {
165        &[]
166    }
167}
168
169/// Build the list of probe filenames for a generic CUDA library, across the
170/// major versions baracuda targets. For example,
171/// `versioned_library_candidates("cublas", "12", "11.0")` yields
172/// `libcublas.so.12, libcublas.so.11.0, libcublas.so` on Linux and
173/// `cublas64_12.dll, cublas64_11.dll, cublas64_110.dll` on Windows.
174///
175/// Individual `-sys` crates provide their own curated candidate lists;
176/// this helper is for opportunistic probing during development.
177pub fn versioned_library_candidates(name: &str, preferred_majors: &[&str]) -> Vec<String> {
178    let mut out = Vec::with_capacity(preferred_majors.len() + 2);
179    match os_family() {
180        OsFamily::Linux => {
181            for major in preferred_majors {
182                out.push(format!("lib{name}.so.{major}"));
183            }
184            out.push(format!("lib{name}.so"));
185        }
186        OsFamily::Windows => {
187            for major in preferred_majors {
188                // Windows CUDA DLL convention: cublas64_12.dll, cublas64_11.dll, ...
189                // Keep only leading digit(s) before a dot.
190                let numeric = major.split('.').next().unwrap_or(major);
191                out.push(format!("{name}64_{numeric}.dll"));
192            }
193            out.push(format!("{name}64.dll"));
194        }
195        OsFamily::Unsupported => {}
196    }
197    out
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn os_family_matches_cfg() {
206        let f = os_family();
207        if cfg!(target_os = "linux") {
208            assert_eq!(f, OsFamily::Linux);
209        } else if cfg!(target_os = "windows") {
210            assert_eq!(f, OsFamily::Windows);
211        } else {
212            assert_eq!(f, OsFamily::Unsupported);
213        }
214    }
215
216    #[test]
217    fn driver_candidates_nonempty_on_supported() {
218        if matches!(os_family(), OsFamily::Linux | OsFamily::Windows) {
219            assert!(!driver_library_candidates().is_empty());
220            assert!(!runtime_library_candidates().is_empty());
221        }
222    }
223
224    #[test]
225    fn versioned_candidates_linux_shape() {
226        if matches!(os_family(), OsFamily::Linux) {
227            let v = versioned_library_candidates("cublas", &["12", "11.0"]);
228            assert!(v.iter().any(|s| s == "libcublas.so.12"));
229            assert!(v.iter().any(|s| s == "libcublas.so"));
230        }
231    }
232
233    #[test]
234    fn versioned_candidates_windows_shape() {
235        if matches!(os_family(), OsFamily::Windows) {
236            let v = versioned_library_candidates("cublas", &["12", "11"]);
237            assert!(v.iter().any(|s| s == "cublas64_12.dll"));
238        }
239    }
240
241    #[test]
242    fn search_paths_include_env() {
243        std::env::set_var("CUDA_PATH", "/tmp/test-cuda-path");
244        let paths = library_search_paths();
245        let has = paths.iter().any(|p| p.starts_with("/tmp/test-cuda-path"));
246        std::env::remove_var("CUDA_PATH");
247        assert!(
248            has,
249            "CUDA_PATH environment should show up in the search list"
250        );
251    }
252}