baracuda_core/
platform.rs1use std::path::PathBuf;
4
5#[derive(Copy, Clone, Debug, Eq, PartialEq)]
7#[non_exhaustive]
8pub enum OsFamily {
9 Linux,
10 Windows,
11 Unsupported,
13}
14
15pub const fn os_family() -> OsFamily {
17 if cfg!(target_os = "linux") {
18 OsFamily::Linux
19 } else if cfg!(target_os = "windows") {
20 OsFamily::Windows
21 } else {
22 OsFamily::Unsupported
23 }
24}
25
26pub fn is_wsl2() -> bool {
31 if !matches!(os_family(), OsFamily::Linux) {
32 return false;
33 }
34 std::fs::read_to_string("/proc/version")
35 .map(|s| {
36 let s = s.to_ascii_lowercase();
37 s.contains("microsoft") || s.contains("wsl")
38 })
39 .unwrap_or(false)
40}
41
42pub fn library_search_paths() -> Vec<PathBuf> {
55 let mut paths = Vec::new();
56
57 for var in [
58 "CUDA_PATH",
59 "CUDA_HOME",
60 "CUDA_ROOT",
61 "CUDA_TOOLKIT_ROOT_DIR",
62 ] {
63 if let Ok(raw) = std::env::var(var) {
64 let base = PathBuf::from(raw);
65 push_os_subdirs(&base, &mut paths);
66 }
67 }
68
69 match os_family() {
70 OsFamily::Linux => {
71 for base in ["/usr/local/cuda", "/opt/cuda"] {
72 push_os_subdirs(&PathBuf::from(base), &mut paths);
73 }
74 paths.push(PathBuf::from("/usr/local/cuda/compat"));
75 paths.push(PathBuf::from("/usr/lib/x86_64-linux-gnu"));
76 paths.push(PathBuf::from("/usr/lib/aarch64-linux-gnu"));
77 paths.push(PathBuf::from("/usr/lib/wsl/lib"));
78 paths.push(PathBuf::from("/lib/wsl/lib"));
79 }
80 OsFamily::Windows => {
81 if let Ok(pf) = std::env::var("ProgramFiles") {
82 let toolkit = PathBuf::from(pf)
83 .join("NVIDIA GPU Computing Toolkit")
84 .join("CUDA");
85 paths.push(toolkit);
88 }
89 for var in [
90 "CUDA_PATH_V13_0",
91 "CUDA_PATH_V12_8",
92 "CUDA_PATH_V12_6",
93 "CUDA_PATH_V12_3",
94 "CUDA_PATH_V12_0",
95 "CUDA_PATH_V11_8",
96 "CUDA_PATH_V11_4",
97 ] {
98 if let Ok(raw) = std::env::var(var) {
99 let base = PathBuf::from(raw);
100 push_os_subdirs(&base, &mut paths);
101 }
102 }
103 }
104 OsFamily::Unsupported => {}
105 }
106
107 paths
108}
109
110fn push_os_subdirs(base: &std::path::Path, out: &mut Vec<PathBuf>) {
111 match os_family() {
112 OsFamily::Linux => {
113 out.push(base.join("lib64"));
114 out.push(base.join("lib"));
115 out.push(base.join("targets/x86_64-linux/lib"));
116 out.push(base.join("lib/stubs"));
117 out.push(base.join("lib64/stubs"));
118 }
119 OsFamily::Windows => {
120 out.push(base.join("bin"));
121 out.push(base.join("lib").join("x64"));
122 }
123 OsFamily::Unsupported => {}
124 }
125}
126
127pub const fn driver_library_candidates() -> &'static [&'static str] {
129 #[cfg(target_os = "linux")]
130 {
131 &["libcuda.so.1", "libcuda.so"]
132 }
133 #[cfg(target_os = "windows")]
134 {
135 &["nvcuda.dll"]
136 }
137 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
138 {
139 &[]
140 }
141}
142
143pub const fn runtime_library_candidates() -> &'static [&'static str] {
145 #[cfg(target_os = "linux")]
146 {
147 &[
148 "libcudart.so.13",
149 "libcudart.so.12",
150 "libcudart.so.11.0",
151 "libcudart.so",
152 ]
153 }
154 #[cfg(target_os = "windows")]
155 {
156 &[
157 "cudart64_13.dll",
158 "cudart64_12.dll",
159 "cudart64_110.dll",
160 "cudart64_101.dll",
161 ]
162 }
163 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
164 {
165 &[]
166 }
167}
168
169pub fn versioned_library_candidates(name: &str, preferred_majors: &[&str]) -> Vec<String> {
178 let mut out = Vec::with_capacity(preferred_majors.len() + 2);
179 match os_family() {
180 OsFamily::Linux => {
181 for major in preferred_majors {
182 out.push(format!("lib{name}.so.{major}"));
183 }
184 out.push(format!("lib{name}.so"));
185 }
186 OsFamily::Windows => {
187 for major in preferred_majors {
188 let numeric = major.split('.').next().unwrap_or(major);
191 out.push(format!("{name}64_{numeric}.dll"));
192 }
193 out.push(format!("{name}64.dll"));
194 }
195 OsFamily::Unsupported => {}
196 }
197 out
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn os_family_matches_cfg() {
206 let f = os_family();
207 if cfg!(target_os = "linux") {
208 assert_eq!(f, OsFamily::Linux);
209 } else if cfg!(target_os = "windows") {
210 assert_eq!(f, OsFamily::Windows);
211 } else {
212 assert_eq!(f, OsFamily::Unsupported);
213 }
214 }
215
216 #[test]
217 fn driver_candidates_nonempty_on_supported() {
218 if matches!(os_family(), OsFamily::Linux | OsFamily::Windows) {
219 assert!(!driver_library_candidates().is_empty());
220 assert!(!runtime_library_candidates().is_empty());
221 }
222 }
223
224 #[test]
225 fn versioned_candidates_linux_shape() {
226 if matches!(os_family(), OsFamily::Linux) {
227 let v = versioned_library_candidates("cublas", &["12", "11.0"]);
228 assert!(v.iter().any(|s| s == "libcublas.so.12"));
229 assert!(v.iter().any(|s| s == "libcublas.so"));
230 }
231 }
232
233 #[test]
234 fn versioned_candidates_windows_shape() {
235 if matches!(os_family(), OsFamily::Windows) {
236 let v = versioned_library_candidates("cublas", &["12", "11"]);
237 assert!(v.iter().any(|s| s == "cublas64_12.dll"));
238 }
239 }
240
241 #[test]
242 fn search_paths_include_env() {
243 std::env::set_var("CUDA_PATH", "/tmp/test-cuda-path");
244 let paths = library_search_paths();
245 let has = paths.iter().any(|p| p.starts_with("/tmp/test-cuda-path"));
246 std::env::remove_var("CUDA_PATH");
247 assert!(
248 has,
249 "CUDA_PATH environment should show up in the search list"
250 );
251 }
252}