1use libloading::Library;
12#[cfg(target_os = "linux")]
13use libloading::os::unix::{Library as UnixLibrary, RTLD_GLOBAL, RTLD_NOW};
14use ndarray::{Array2, ArrayBase, Data, Ix2};
15use std::borrow::Cow;
16use std::path::Path;
17#[cfg(target_os = "linux")]
18use std::path::PathBuf;
19use std::sync::OnceLock;
20
21use super::gpu_error::GpuError;
22
23pub type CuResult = i32;
24#[inline]
35pub fn check_cuda(result: CuResult, name: &str) -> Result<(), GpuError> {
36 if result == 0 {
37 Ok(())
38 } else {
39 Err(GpuError::DriverCallFailed {
40 reason: format!("{name} failed with CUDA driver error {result}"),
41 })
42 }
43}
44
45#[must_use]
54pub fn cuda_driver_library_present() -> bool {
55 load_library_names(&cuda_library_candidate_names()).is_ok()
56}
57
58fn load_library_names(candidates: &[String]) -> Result<Library, GpuError> {
59 for candidate in candidates {
60 if let Ok(library) = unsafe { Library::new(candidate) } {
64 return Ok(library);
65 }
66 }
67 Err(GpuError::DriverLibraryUnavailable {
68 reason: format!("could not load any of: {}", candidates.join(", ")),
69 })
70}
71
72fn load_static_cuda_driver_library() -> Result<&'static Library, GpuError> {
73 static LIBRARY: OnceLock<Result<Library, GpuError>> = OnceLock::new();
74 LIBRARY
75 .get_or_init(|| load_library_names(&cuda_library_candidate_names()))
76 .as_ref()
77 .map_err(Clone::clone)
78}
79
80pub fn preload_cuda_driver() -> Result<(), String> {
81 static PRELOAD: OnceLock<Result<(), String>> = OnceLock::new();
82 PRELOAD
83 .get_or_init(|| {
84 load_static_cuda_driver_library()
85 .map(|_| ())
86 .map_err(|err| err.to_string())
87 })
88 .clone()
89}
90
91#[cfg(target_os = "linux")]
92fn preload_cuda_userspace_libraries() -> Result<(), String> {
93 static PRELOAD: OnceLock<Result<Vec<UnixLibrary>, String>> = OnceLock::new();
94 PRELOAD
95 .get_or_init(|| {
96 let paths = cuda_userspace_preload_paths();
97 if paths.is_empty() {
98 return Ok(Vec::new());
99 }
100 let mut loaded = Vec::new();
101 for path in paths {
102 match unsafe { UnixLibrary::open(Some(&path), RTLD_NOW | RTLD_GLOBAL) } {
108 Ok(library) => loaded.push(library),
109 Err(err) => {
110 return Err(format!(
111 "could not preload CUDA userspace library {}: {err}",
112 path.display()
113 ));
114 }
115 }
116 }
117 Ok(loaded)
118 })
119 .as_ref()
120 .map(|_| ())
121 .map_err(Clone::clone)
122}
123
124#[must_use]
139pub fn cuda_compute_library_present(stem: &str) -> bool {
140 #[cfg(target_os = "linux")]
141 {
142 if preload_cuda_userspace_libraries().is_err() {
143 return false;
144 }
145 }
146 static PROBED: OnceLock<std::sync::Mutex<std::collections::HashMap<String, bool>>> =
154 OnceLock::new();
155 static KEEP_ALIVE: OnceLock<std::sync::Mutex<Vec<Library>>> = OnceLock::new();
156 let probed = PROBED.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()));
157 if let Ok(cache) = probed.lock() {
158 if let Some(&present) = cache.get(stem) {
159 return present;
160 }
161 }
162 let present = match load_library_names(&cuda_compute_library_candidate_names(stem)) {
163 Ok(library) => {
164 if let Ok(mut keep) = KEEP_ALIVE
165 .get_or_init(|| std::sync::Mutex::new(Vec::new()))
166 .lock()
167 {
168 keep.push(library);
169 }
170 true
171 }
172 Err(_) => false,
173 };
174 if let Ok(mut cache) = probed.lock() {
175 cache.insert(stem.to_string(), present);
176 }
177 present
178}
179
180#[cfg(target_os = "linux")]
181fn cuda_userspace_preload_paths() -> Vec<PathBuf> {
182 let system_dirs = cuda_system_library_dirs();
183 for dir in &system_dirs {
184 if let Some(stack) = complete_system_cuda_stack(dir) {
185 return dedup_paths(stack);
186 }
187 if let Some(stack) = system_cuda_stack_with_packaged_nvjitlink(dir) {
188 return dedup_paths(stack);
189 }
190 }
191 for root in nvidia_package_roots() {
192 if let Some(stack) = complete_nvidia_cuda_stack(&root) {
193 return dedup_paths(stack);
194 }
195 }
196 Vec::new()
197}
198
199fn cuda_compute_library_candidate_names(stem: &str) -> Vec<String> {
200 let base = format!("lib{stem}");
201 let mut out: Vec<String> = Vec::new();
202 out.push(format!("{base}.so"));
205 out.push(format!("{base}.so.1"));
206 for major in (9..=13).rev() {
209 out.push(format!("{base}.so.{major}"));
210 }
211 #[cfg(target_os = "linux")]
212 {
213 for dir in cuda_system_library_dirs() {
214 out.push(format!("{dir}/{base}.so"));
215 for major in (9..=13).rev() {
216 out.push(format!("{dir}/{base}.so.{major}"));
217 }
218 append_versioned_linux_so_candidates(&mut out, Path::new(dir), &base);
219 }
220 for root in nvidia_package_roots() {
221 let lib_dir = root.join(nvidia_component_for_stem(stem)).join("lib");
222 out.push(format!("{}/{}.so", lib_dir.display(), base));
223 for major in (9..=13).rev() {
224 out.push(format!("{}/{}.so.{major}", lib_dir.display(), base));
225 }
226 append_versioned_linux_so_candidates(&mut out, &lib_dir, &base);
227 }
228 }
229 out
230}
231
232#[cfg(target_os = "linux")]
233fn cuda_system_library_dirs() -> Vec<&'static str> {
234 vec![
235 "/usr/local/cuda/lib64",
236 "/usr/local/cuda/lib",
237 "/usr/local/cuda/targets/x86_64-linux/lib",
238 "/usr/lib/x86_64-linux-gnu",
239 "/usr/lib64",
240 "/usr/lib/wsl/lib",
241 "/opt/cuda/lib64",
242 ]
243}
244
245#[cfg(target_os = "linux")]
246fn complete_system_cuda_stack(dir: &str) -> Option<Vec<PathBuf>> {
247 let dir = Path::new(dir);
248 let stack = vec![
249 first_existing(dir, &["libcudart.so.12", "libcudart.so"])?,
250 first_existing(dir, &["libnvJitLink.so.12", "libnvJitLink.so"])?,
251 first_existing(dir, &["libcublasLt.so.12", "libcublasLt.so"])?,
252 first_existing(dir, &["libcublas.so.12", "libcublas.so"])?,
253 first_existing(dir, &["libcusparse.so.12", "libcusparse.so"])?,
254 first_existing(
255 dir,
256 &["libcusolver.so.12", "libcusolver.so.11", "libcusolver.so"],
257 )?,
258 ];
259 Some(stack)
260}
261
262#[cfg(target_os = "linux")]
263fn system_cuda_stack_with_packaged_nvjitlink(dir: &str) -> Option<Vec<PathBuf>> {
264 let dir = Path::new(dir);
265 let nvjitlink = packaged_nvjitlink_library()?;
266 let stack = vec![
267 first_existing(dir, &["libcudart.so.12", "libcudart.so"])?,
268 nvjitlink,
269 first_existing(dir, &["libcublasLt.so.12", "libcublasLt.so"])?,
270 first_existing(dir, &["libcublas.so.12", "libcublas.so"])?,
271 first_existing(dir, &["libcusparse.so.12", "libcusparse.so"])?,
272 first_existing(
273 dir,
274 &["libcusolver.so.12", "libcusolver.so.11", "libcusolver.so"],
275 )?,
276 ];
277 Some(stack)
278}
279
280#[cfg(target_os = "linux")]
281fn complete_nvidia_cuda_stack(root: &Path) -> Option<Vec<PathBuf>> {
282 let stack = vec![
283 first_existing(
284 &root.join("cuda_runtime").join("lib"),
285 &["libcudart.so.12", "libcudart.so"],
286 )?,
287 first_existing(
288 &root.join("nvjitlink").join("lib"),
289 &["libnvJitLink.so.12", "libnvJitLink.so"],
290 )?,
291 first_existing(
292 &root.join("cublas").join("lib"),
293 &["libcublasLt.so.12", "libcublasLt.so"],
294 )?,
295 first_existing(
296 &root.join("cublas").join("lib"),
297 &["libcublas.so.12", "libcublas.so"],
298 )?,
299 first_existing(
300 &root.join("cusparse").join("lib"),
301 &["libcusparse.so.12", "libcusparse.so"],
302 )?,
303 first_existing(
304 &root.join("cusolver").join("lib"),
305 &["libcusolver.so.12", "libcusolver.so.11", "libcusolver.so"],
306 )?,
307 ];
308 Some(stack)
309}
310
311#[cfg(target_os = "linux")]
312fn packaged_nvjitlink_library() -> Option<PathBuf> {
313 for root in nvidia_package_roots() {
314 let lib_dir = root.join("nvjitlink").join("lib");
315 if let Some(path) = first_existing(&lib_dir, &["libnvJitLink.so.12", "libnvJitLink.so"]) {
316 return Some(path);
317 }
318 }
319 None
320}
321
322#[cfg(target_os = "linux")]
323fn nvidia_component_for_stem(stem: &str) -> String {
324 match stem {
325 "cublas" => "cublas".to_string(),
326 "cusolver" => "cusolver".to_string(),
327 "cusparse" => "cusparse".to_string(),
328 "nvJitLink" | "nvjitlink" => "nvjitlink".to_string(),
329 "cudart" | "cuda_runtime" => "cuda_runtime".to_string(),
330 _ => stem.to_string(),
331 }
332}
333
334#[cfg(target_os = "linux")]
335fn nvidia_package_roots() -> Vec<PathBuf> {
336 let mut roots = Vec::new();
337 if let Some(home) = current_user_home_dir() {
338 collect_python_nvidia_roots(home.join(".local/lib"), &mut roots);
339 }
340 collect_python_nvidia_roots(Path::new("/usr/local/lib").to_path_buf(), &mut roots);
341 collect_python_nvidia_roots(Path::new("/usr/lib").to_path_buf(), &mut roots);
342 dedup_paths(roots)
343}
344
345#[cfg(target_os = "linux")]
346fn current_user_home_dir() -> Option<PathBuf> {
347 let status = std::fs::read_to_string("/proc/self/status").ok()?;
348 let uid = status
349 .lines()
350 .find_map(|line| line.strip_prefix("Uid:"))?
351 .split_whitespace()
352 .next()?;
353 let passwd = std::fs::read_to_string("/etc/passwd").ok()?;
354 for line in passwd.lines() {
355 let mut fields = line.split(':');
356 fields.next()?;
357 fields.next()?;
358 if fields.next()? != uid {
359 continue;
360 }
361 fields.next()?;
362 fields.next()?;
363 return Some(PathBuf::from(fields.next()?));
364 }
365 None
366}
367
368#[cfg(target_os = "linux")]
369fn collect_python_nvidia_roots(base: PathBuf, out: &mut Vec<PathBuf>) {
370 let Ok(entries) = std::fs::read_dir(base) else {
371 return;
372 };
373 for entry in entries.flatten() {
374 let path = entry.path();
375 let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
376 continue;
377 };
378 if !name.starts_with("python") {
379 continue;
380 }
381 for site_dir in ["site-packages", "dist-packages"] {
382 let root = path.join(site_dir).join("nvidia");
383 if root.exists() {
384 out.push(root);
385 }
386 }
387 }
388}
389
390#[cfg(target_os = "linux")]
391fn first_existing(dir: &Path, names: &[&str]) -> Option<PathBuf> {
392 for name in names {
393 let path = dir.join(name);
394 if path.exists() {
395 return Some(path);
396 }
397 }
398 None
399}
400
401#[cfg(target_os = "linux")]
402fn dedup_paths(paths: Vec<PathBuf>) -> Vec<PathBuf> {
403 let mut out = Vec::new();
404 for path in paths {
405 let canonical = path.canonicalize().unwrap_or(path);
406 if !out.iter().any(|existing| existing == &canonical) {
407 out.push(canonical);
408 }
409 }
410 out
411}
412
413#[cfg(target_os = "linux")]
414fn append_versioned_linux_so_candidates(out: &mut Vec<String>, dir: &Path, base: &str) {
415 let Ok(entries) = std::fs::read_dir(dir) else {
416 return;
417 };
418 let prefix = format!("{base}.so.");
419 let mut versioned = Vec::new();
420 for entry in entries.flatten() {
421 let path = entry.path();
422 let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
423 continue;
424 };
425 if name.starts_with(&prefix) {
426 versioned.push(path);
427 }
428 }
429 versioned.sort();
430 for path in versioned {
431 let candidate = path.to_string_lossy().into_owned();
432 if !out.iter().any(|existing| existing == &candidate) {
433 out.push(candidate);
434 }
435 }
436}
437
438fn cuda_library_candidate_names() -> Vec<String> {
439 let mut out: Vec<String> = cuda_library_candidates()
440 .iter()
441 .map(|candidate| (*candidate).to_string())
442 .collect();
443 if cfg!(target_os = "linux") {
444 for dir in [
445 "/usr/local/nvidia/lib64",
446 "/usr/local/nvidia/lib",
447 "/usr/local/cuda/compat",
448 "/usr/lib/x86_64-linux-gnu",
449 "/usr/lib64",
450 "/usr/lib/wsl/lib",
451 ] {
452 append_versioned_linux_libcuda_candidates(&mut out, Path::new(dir));
453 }
454 }
455 out
456}
457
458fn append_versioned_linux_libcuda_candidates(out: &mut Vec<String>, dir: &Path) {
459 let Ok(entries) = std::fs::read_dir(dir) else {
460 return;
461 };
462 let mut versioned = Vec::new();
463 for entry in entries.flatten() {
464 let path = entry.path();
465 let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
466 continue;
467 };
468 if name.starts_with("libcuda.so.") && name != "libcuda.so.1" {
469 versioned.push(path);
470 }
471 }
472 versioned.sort();
473 for path in versioned {
474 let candidate = path.to_string_lossy().into_owned();
475 if !out.iter().any(|existing| existing == &candidate) {
476 out.push(candidate);
477 }
478 }
479}
480
481pub fn cuda_library_candidates() -> &'static [&'static str] {
482 if cfg!(target_os = "windows") {
483 &["nvcuda.dll"]
484 } else if cfg!(target_os = "macos") {
485 &["/usr/local/cuda/lib/libcuda.dylib", "libcuda.dylib"]
486 } else {
487 &[
488 "/usr/local/nvidia/lib64/libcuda.so.1",
489 "/usr/local/nvidia/lib64/libcuda.so",
490 "/usr/local/nvidia/lib/libcuda.so.1",
491 "/usr/local/nvidia/lib/libcuda.so",
492 "/usr/local/cuda/compat/libcuda.so.1",
493 "/usr/local/cuda/compat/libcuda.so",
494 "/usr/lib/x86_64-linux-gnu/libcuda.so.1",
495 "/usr/lib/x86_64-linux-gnu/libcuda.so",
496 "/usr/lib64/libcuda.so.1",
497 "/usr/lib64/libcuda.so",
498 "/usr/lib/wsl/lib/libcuda.so.1",
499 "/usr/lib/wsl/lib/libcuda.so",
500 "libcuda.so.1",
501 "libcuda.so",
502 ]
503 }
504}
505
506#[inline]
507pub fn to_i32(value: usize) -> Option<i32> {
508 i32::try_from(value).ok()
509}
510
511pub fn to_col_major<'a, S: Data<Elem = f64>>(a: &'a ArrayBase<S, Ix2>) -> Cow<'a, [f64]> {
523 let (rows, cols) = a.dim();
524 let strides = a.strides();
525 if rows > 0
528 && cols > 0
529 && strides[0] == 1
530 && strides[1] == rows as isize
531 && let Some(slice) = a.as_slice_memory_order()
532 {
533 return Cow::Borrowed(slice);
534 }
535 let mut out: Vec<f64> = Vec::with_capacity(rows.saturating_mul(cols));
536 for col in 0..cols {
537 out.extend(a.column(col).iter().copied());
538 }
539 Cow::Owned(out)
540}
541
542pub fn from_col_major_inplace(values: &[f64], out: &mut Array2<f64>) -> Option<()> {
544 let (rows, cols) = out.dim();
545 if values.len() != rows.checked_mul(cols)? {
546 return None;
547 }
548 for col in 0..cols {
549 let src = ndarray::ArrayView1::from(&values[col * rows..(col + 1) * rows]);
550 out.column_mut(col).assign(&src);
551 }
552 Some(())
553}
554
555pub fn from_col_major(values: &[f64], rows: usize, cols: usize) -> Option<Array2<f64>> {
556 let mut out = Array2::<f64>::zeros((rows, cols));
557 from_col_major_inplace(values, &mut out)?;
558 Some(out)
559}