1use libloading::Library;
12#[cfg(target_os = "linux")]
13use libloading::os::unix::{Library as UnixLibrary, RTLD_GLOBAL, RTLD_NOW};
14use ndarray::{Array2, ArrayBase, Data, Ix2};
15use std::borrow::Cow;
16use std::path::Path;
17#[cfg(target_os = "linux")]
18use std::path::PathBuf;
19use std::sync::OnceLock;
20
21use super::gpu_error::GpuError;
22
23pub type CuResult = i32;
24#[inline]
35pub fn check_cuda(result: CuResult, name: &str) -> Result<(), GpuError> {
36 if result == 0 {
37 Ok(())
38 } else {
39 Err(GpuError::DriverCallFailed {
40 reason: format!("{name} failed with CUDA driver error {result}"),
41 })
42 }
43}
44
45#[must_use]
54pub fn cuda_driver_library_present() -> bool {
55 load_library_names(&cuda_library_candidate_names()).is_ok()
56}
57
58fn load_library_names(candidates: &[String]) -> Result<Library, GpuError> {
59 for candidate in candidates {
60 if let Ok(library) = unsafe { Library::new(candidate) } {
64 return Ok(library);
65 }
66 }
67 Err(GpuError::DriverLibraryUnavailable {
68 reason: format!("could not load any of: {}", candidates.join(", ")),
69 })
70}
71
72fn load_static_cuda_driver_library() -> Result<&'static Library, GpuError> {
73 static LIBRARY: OnceLock<Result<Library, GpuError>> = OnceLock::new();
74 LIBRARY
75 .get_or_init(|| load_library_names(&cuda_library_candidate_names()))
76 .as_ref()
77 .map_err(Clone::clone)
78}
79
80pub fn preload_cuda_driver() -> Result<(), String> {
81 static PRELOAD: OnceLock<Result<(), String>> = OnceLock::new();
82 PRELOAD
83 .get_or_init(|| {
84 load_static_cuda_driver_library()
85 .map(|_| ())
86 .map_err(|err| err.to_string())
87 })
88 .clone()
89}
90
91#[cfg(target_os = "linux")]
92fn preload_cuda_userspace_libraries() -> Result<(), String> {
93 static PRELOAD: OnceLock<Result<Vec<UnixLibrary>, String>> = OnceLock::new();
94 PRELOAD
95 .get_or_init(|| {
96 let paths = cuda_userspace_preload_paths();
97 if paths.is_empty() {
98 return Ok(Vec::new());
99 }
100 let mut loaded = Vec::new();
101 for path in paths {
102 match unsafe { UnixLibrary::open(Some(&path), RTLD_NOW | RTLD_GLOBAL) } {
108 Ok(library) => loaded.push(library),
109 Err(err) => {
110 return Err(format!(
111 "could not preload CUDA userspace library {}: {err}",
112 path.display()
113 ));
114 }
115 }
116 }
117 Ok(loaded)
118 })
119 .as_ref()
120 .map(|_| ())
121 .map_err(Clone::clone)
122}
123
124#[must_use]
139pub fn cuda_compute_library_present(stem: &str) -> bool {
140 #[cfg(target_os = "linux")]
141 {
142 if preload_cuda_userspace_libraries().is_err() {
143 return false;
144 }
145 }
146 static PROBED: OnceLock<std::sync::Mutex<std::collections::HashMap<String, bool>>> =
154 OnceLock::new();
155 static KEEP_ALIVE: OnceLock<std::sync::Mutex<Vec<Library>>> = OnceLock::new();
156 let probed = PROBED.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()));
157 if let Ok(cache) = probed.lock() {
158 if let Some(&present) = cache.get(stem) {
159 return present;
160 }
161 }
162 let present = match load_library_names(&cuda_compute_library_candidate_names(stem)) {
163 Ok(library) => {
164 if let Ok(mut keep) = KEEP_ALIVE
165 .get_or_init(|| std::sync::Mutex::new(Vec::new()))
166 .lock()
167 {
168 keep.push(library);
169 }
170 true
171 }
172 Err(_) => false,
173 };
174 if let Ok(mut cache) = probed.lock() {
175 cache.insert(stem.to_string(), present);
176 }
177 present
178}
179
180#[cfg(target_os = "linux")]
181fn cuda_userspace_preload_paths() -> Vec<PathBuf> {
182 let system_dirs = cuda_system_library_dirs();
183 for dir in &system_dirs {
184 if let Some(stack) = complete_system_cuda_stack(dir) {
185 return dedup_paths(stack);
186 }
187 if let Some(stack) = system_cuda_stack_with_packaged_nvjitlink(dir) {
188 return dedup_paths(stack);
189 }
190 }
191 for root in nvidia_package_roots() {
192 if let Some(stack) = complete_nvidia_cuda_stack(&root) {
193 return dedup_paths(stack);
194 }
195 }
196 Vec::new()
197}
198
199fn cuda_compute_library_candidate_names(stem: &str) -> Vec<String> {
200 let base = format!("lib{stem}");
201 let mut out: Vec<String> = Vec::new();
202 out.push(format!("{base}.so"));
205 out.push(format!("{base}.so.1"));
206 for major in (9..=13).rev() {
209 out.push(format!("{base}.so.{major}"));
210 }
211 #[cfg(target_os = "linux")]
212 {
213 for dir in cuda_system_library_dirs() {
214 out.push(format!("{dir}/{base}.so"));
215 for major in (9..=13).rev() {
216 out.push(format!("{dir}/{base}.so.{major}"));
217 }
218 append_versioned_linux_so_candidates(&mut out, Path::new(dir), &base);
219 }
220 for root in nvidia_package_roots() {
221 let lib_dir = root.join(nvidia_component_for_stem(stem)).join("lib");
222 out.push(format!("{}/{}.so", lib_dir.display(), base));
223 for major in (9..=13).rev() {
224 out.push(format!("{}/{}.so.{major}", lib_dir.display(), base));
225 }
226 append_versioned_linux_so_candidates(&mut out, &lib_dir, &base);
227 }
228 }
229 out
230}
231
232#[cfg(target_os = "linux")]
233fn cuda_system_library_dirs() -> Vec<&'static str> {
234 vec![
235 "/usr/local/cuda/lib64",
236 "/usr/local/cuda/lib",
237 "/usr/local/cuda/targets/x86_64-linux/lib",
238 "/usr/lib/x86_64-linux-gnu",
239 "/usr/lib64",
240 "/usr/lib/wsl/lib",
241 "/opt/cuda/lib64",
242 ]
243}
244
245#[cfg(target_os = "linux")]
246fn complete_system_cuda_stack(dir: &str) -> Option<Vec<PathBuf>> {
247 let dir = Path::new(dir);
248 let stack = vec![
249 first_existing(dir, &["libcudart.so.13", "libcudart.so.12", "libcudart.so"])?,
250 first_existing(
251 dir,
252 &[
253 "libnvJitLink.so.13",
254 "libnvJitLink.so.12",
255 "libnvJitLink.so",
256 ],
257 )?,
258 first_existing(
259 dir,
260 &["libcublasLt.so.13", "libcublasLt.so.12", "libcublasLt.so"],
261 )?,
262 first_existing(dir, &["libcublas.so.13", "libcublas.so.12", "libcublas.so"])?,
263 first_existing(
264 dir,
265 &["libcusparse.so.13", "libcusparse.so.12", "libcusparse.so"],
266 )?,
267 first_existing(
268 dir,
269 &[
270 "libcusolver.so.13",
271 "libcusolver.so.12",
272 "libcusolver.so.11",
273 "libcusolver.so",
274 ],
275 )?,
276 ];
277 Some(stack)
278}
279
280#[cfg(target_os = "linux")]
281fn system_cuda_stack_with_packaged_nvjitlink(dir: &str) -> Option<Vec<PathBuf>> {
282 let dir = Path::new(dir);
283 let nvjitlink = packaged_nvjitlink_library()?;
284 let stack = vec![
285 first_existing(dir, &["libcudart.so.13", "libcudart.so.12", "libcudart.so"])?,
286 nvjitlink,
287 first_existing(
288 dir,
289 &["libcublasLt.so.13", "libcublasLt.so.12", "libcublasLt.so"],
290 )?,
291 first_existing(dir, &["libcublas.so.13", "libcublas.so.12", "libcublas.so"])?,
292 first_existing(
293 dir,
294 &["libcusparse.so.13", "libcusparse.so.12", "libcusparse.so"],
295 )?,
296 first_existing(
297 dir,
298 &[
299 "libcusolver.so.13",
300 "libcusolver.so.12",
301 "libcusolver.so.11",
302 "libcusolver.so",
303 ],
304 )?,
305 ];
306 Some(stack)
307}
308
309#[cfg(target_os = "linux")]
310fn complete_nvidia_cuda_stack(root: &Path) -> Option<Vec<PathBuf>> {
311 let stack = vec![
312 first_existing(
313 &root.join("cuda_runtime").join("lib"),
314 &["libcudart.so.13", "libcudart.so.12", "libcudart.so"],
315 )?,
316 first_existing(
317 &root.join("nvjitlink").join("lib"),
318 &[
319 "libnvJitLink.so.13",
320 "libnvJitLink.so.12",
321 "libnvJitLink.so",
322 ],
323 )?,
324 first_existing(
325 &root.join("cublas").join("lib"),
326 &["libcublasLt.so.13", "libcublasLt.so.12", "libcublasLt.so"],
327 )?,
328 first_existing(
329 &root.join("cublas").join("lib"),
330 &["libcublas.so.13", "libcublas.so.12", "libcublas.so"],
331 )?,
332 first_existing(
333 &root.join("cusparse").join("lib"),
334 &["libcusparse.so.13", "libcusparse.so.12", "libcusparse.so"],
335 )?,
336 first_existing(
337 &root.join("cusolver").join("lib"),
338 &[
339 "libcusolver.so.13",
340 "libcusolver.so.12",
341 "libcusolver.so.11",
342 "libcusolver.so",
343 ],
344 )?,
345 ];
346 Some(stack)
347}
348
349#[cfg(target_os = "linux")]
350fn packaged_nvjitlink_library() -> Option<PathBuf> {
351 for root in nvidia_package_roots() {
352 let lib_dir = root.join("nvjitlink").join("lib");
353 if let Some(path) = first_existing(
354 &lib_dir,
355 &[
356 "libnvJitLink.so.13",
357 "libnvJitLink.so.12",
358 "libnvJitLink.so",
359 ],
360 ) {
361 return Some(path);
362 }
363 }
364 None
365}
366
367#[cfg(target_os = "linux")]
368fn nvidia_component_for_stem(stem: &str) -> String {
369 match stem {
370 "cublas" => "cublas".to_string(),
371 "cusolver" => "cusolver".to_string(),
372 "cusparse" => "cusparse".to_string(),
373 "nvJitLink" | "nvjitlink" => "nvjitlink".to_string(),
374 "cudart" | "cuda_runtime" => "cuda_runtime".to_string(),
375 _ => stem.to_string(),
376 }
377}
378
379#[cfg(target_os = "linux")]
380fn nvidia_package_roots() -> Vec<PathBuf> {
381 let mut roots = Vec::new();
382 if let Some(home) = current_user_home_dir() {
383 collect_python_nvidia_roots(home.join(".local/lib"), &mut roots);
384 }
385 collect_python_nvidia_roots(Path::new("/usr/local/lib").to_path_buf(), &mut roots);
386 collect_python_nvidia_roots(Path::new("/usr/lib").to_path_buf(), &mut roots);
387 dedup_paths(roots)
388}
389
390#[cfg(target_os = "linux")]
391fn current_user_home_dir() -> Option<PathBuf> {
392 let status = std::fs::read_to_string("/proc/self/status").ok()?;
393 let uid = status
394 .lines()
395 .find_map(|line| line.strip_prefix("Uid:"))?
396 .split_whitespace()
397 .next()?;
398 let passwd = std::fs::read_to_string("/etc/passwd").ok()?;
399 for line in passwd.lines() {
400 let mut fields = line.split(':');
401 fields.next()?;
402 fields.next()?;
403 if fields.next()? != uid {
404 continue;
405 }
406 fields.next()?;
407 fields.next()?;
408 return Some(PathBuf::from(fields.next()?));
409 }
410 None
411}
412
413#[cfg(target_os = "linux")]
414fn collect_python_nvidia_roots(base: PathBuf, out: &mut Vec<PathBuf>) {
415 let Ok(entries) = std::fs::read_dir(base) else {
416 return;
417 };
418 for entry in entries.flatten() {
419 let path = entry.path();
420 let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
421 continue;
422 };
423 if !name.starts_with("python") {
424 continue;
425 }
426 for site_dir in ["site-packages", "dist-packages"] {
427 let root = path.join(site_dir).join("nvidia");
428 if root.exists() {
429 out.push(root);
430 }
431 }
432 }
433}
434
435#[cfg(target_os = "linux")]
436fn first_existing(dir: &Path, names: &[&str]) -> Option<PathBuf> {
437 for name in names {
438 let path = dir.join(name);
439 if path.exists() {
440 return Some(path);
441 }
442 }
443 None
444}
445
446#[cfg(target_os = "linux")]
447fn dedup_paths(paths: Vec<PathBuf>) -> Vec<PathBuf> {
448 let mut out = Vec::new();
449 for path in paths {
450 let canonical = path.canonicalize().unwrap_or(path);
451 if !out.iter().any(|existing| existing == &canonical) {
452 out.push(canonical);
453 }
454 }
455 out
456}
457
458#[cfg(target_os = "linux")]
459fn append_versioned_linux_so_candidates(out: &mut Vec<String>, dir: &Path, base: &str) {
460 let Ok(entries) = std::fs::read_dir(dir) else {
461 return;
462 };
463 let prefix = format!("{base}.so.");
464 let mut versioned = Vec::new();
465 for entry in entries.flatten() {
466 let path = entry.path();
467 let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
468 continue;
469 };
470 if name.starts_with(&prefix) {
471 versioned.push(path);
472 }
473 }
474 versioned.sort();
475 for path in versioned {
476 let candidate = path.to_string_lossy().into_owned();
477 if !out.iter().any(|existing| existing == &candidate) {
478 out.push(candidate);
479 }
480 }
481}
482
483fn cuda_library_candidate_names() -> Vec<String> {
484 let mut out: Vec<String> = cuda_library_candidates()
485 .iter()
486 .map(|candidate| (*candidate).to_string())
487 .collect();
488 if cfg!(target_os = "linux") {
489 for dir in [
490 "/usr/local/nvidia/lib64",
491 "/usr/local/nvidia/lib",
492 "/usr/local/cuda/compat",
493 "/usr/lib/x86_64-linux-gnu",
494 "/usr/lib64",
495 "/usr/lib/wsl/lib",
496 ] {
497 append_versioned_linux_libcuda_candidates(&mut out, Path::new(dir));
498 }
499 }
500 out
501}
502
503fn append_versioned_linux_libcuda_candidates(out: &mut Vec<String>, dir: &Path) {
504 let Ok(entries) = std::fs::read_dir(dir) else {
505 return;
506 };
507 let mut versioned = Vec::new();
508 for entry in entries.flatten() {
509 let path = entry.path();
510 let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
511 continue;
512 };
513 if name.starts_with("libcuda.so.") && name != "libcuda.so.1" {
514 versioned.push(path);
515 }
516 }
517 versioned.sort();
518 for path in versioned {
519 let candidate = path.to_string_lossy().into_owned();
520 if !out.iter().any(|existing| existing == &candidate) {
521 out.push(candidate);
522 }
523 }
524}
525
526pub fn cuda_library_candidates() -> &'static [&'static str] {
527 if cfg!(target_os = "windows") {
528 &["nvcuda.dll"]
529 } else if cfg!(target_os = "macos") {
530 &["/usr/local/cuda/lib/libcuda.dylib", "libcuda.dylib"]
531 } else {
532 &[
533 "/usr/local/nvidia/lib64/libcuda.so.1",
534 "/usr/local/nvidia/lib64/libcuda.so",
535 "/usr/local/nvidia/lib/libcuda.so.1",
536 "/usr/local/nvidia/lib/libcuda.so",
537 "/usr/local/cuda/compat/libcuda.so.1",
538 "/usr/local/cuda/compat/libcuda.so",
539 "/usr/lib/x86_64-linux-gnu/libcuda.so.1",
540 "/usr/lib/x86_64-linux-gnu/libcuda.so",
541 "/usr/lib64/libcuda.so.1",
542 "/usr/lib64/libcuda.so",
543 "/usr/lib/wsl/lib/libcuda.so.1",
544 "/usr/lib/wsl/lib/libcuda.so",
545 "libcuda.so.1",
546 "libcuda.so",
547 ]
548 }
549}
550
551#[inline]
552pub fn to_i32(value: usize) -> Option<i32> {
553 i32::try_from(value).ok()
554}
555
556pub fn to_col_major<'a, S: Data<Elem = f64>>(a: &'a ArrayBase<S, Ix2>) -> Cow<'a, [f64]> {
568 let (rows, cols) = a.dim();
569 let strides = a.strides();
570 if rows > 0
573 && cols > 0
574 && strides[0] == 1
575 && strides[1] == rows as isize
576 && let Some(slice) = a.as_slice_memory_order()
577 {
578 return Cow::Borrowed(slice);
579 }
580 let mut out: Vec<f64> = Vec::with_capacity(rows.saturating_mul(cols));
581 for col in 0..cols {
582 out.extend(a.column(col).iter().copied());
583 }
584 Cow::Owned(out)
585}
586
587pub fn to_row_major<'a, S: Data<Elem = f64>>(a: &'a ArrayBase<S, Ix2>) -> Cow<'a, [f64]> {
601 let (rows, cols) = a.dim();
602 let strides = a.strides();
603 if rows > 0
605 && cols > 0
606 && strides[1] == 1
607 && strides[0] == cols as isize
608 && let Some(slice) = a.as_slice_memory_order()
609 {
610 return Cow::Borrowed(slice);
611 }
612 let mut out: Vec<f64> = Vec::with_capacity(rows.saturating_mul(cols));
613 for row in 0..rows {
614 out.extend(a.row(row).iter().copied());
615 }
616 Cow::Owned(out)
617}
618
619pub fn array_from_row_major(values: Vec<f64>, rows: usize, cols: usize) -> Option<Array2<f64>> {
623 if values.len() != rows.checked_mul(cols)? {
624 return None;
625 }
626 Array2::from_shape_vec((rows, cols), values).ok()
627}
628
629pub fn from_col_major_inplace(values: &[f64], out: &mut Array2<f64>) -> Option<()> {
631 let (rows, cols) = out.dim();
632 if values.len() != rows.checked_mul(cols)? {
633 return None;
634 }
635 for col in 0..cols {
636 let src = ndarray::ArrayView1::from(&values[col * rows..(col + 1) * rows]);
637 out.column_mut(col).assign(&src);
638 }
639 Some(())
640}
641
642pub fn from_col_major(values: &[f64], rows: usize, cols: usize) -> Option<Array2<f64>> {
643 let mut out = Array2::<f64>::zeros((rows, cols));
644 from_col_major_inplace(values, &mut out)?;
645 Some(out)
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651 use ndarray::array;
652
653 #[test]
654 fn to_i32_fits_small_value() {
655 assert_eq!(to_i32(0), Some(0));
656 assert_eq!(to_i32(42), Some(42));
657 assert_eq!(to_i32(i32::MAX as usize), Some(i32::MAX));
658 }
659
660 #[test]
661 fn to_i32_overflows_returns_none() {
662 assert_eq!(to_i32(i32::MAX as usize + 1), None);
663 }
664
665 #[test]
666 fn to_col_major_2x3_row_major() {
667 let a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
669 let col = to_col_major(&a);
670 assert_eq!(&*col, &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
671 }
672
673 #[test]
674 fn to_col_major_identity_roundtrip() {
675 let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
676 let col = to_col_major(&a);
677 assert_eq!(&*col, &[1.0, 0.0, 0.0, 1.0]);
678 }
679
680 #[test]
681 fn from_col_major_2x3_roundtrip() {
682 let original = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
683 let col = to_col_major(&original);
684 let recovered = from_col_major(&col, 2, 3).expect("should succeed");
685 assert_eq!(recovered, original);
686 }
687
688 #[test]
689 fn from_col_major_wrong_length_returns_none() {
690 assert!(from_col_major(&[1.0, 2.0, 3.0, 4.0, 5.0], 2, 3).is_none());
692 }
693
694 #[test]
695 fn from_col_major_inplace_mismatched_buffer_returns_none() {
696 let mut out = Array2::<f64>::zeros((3, 3));
697 let short = vec![1.0_f64; 8]; assert!(from_col_major_inplace(&short, &mut out).is_none());
699 }
700
701 #[test]
702 fn from_col_major_single_element() {
703 let result = from_col_major(&[7.0], 1, 1).expect("should succeed");
704 assert_eq!(result[[0, 0]], 7.0);
705 }
706}