use std::path::{Path, PathBuf};
use std::sync::OnceLock;
const ORDERED_SONAMES: &[&str] = &[
"libcudart.so.12",
"libnvJitLink.so.12",
"libcublasLt.so.12",
"libcublas.so.12",
"libcufft.so.11",
"libcurand.so.10",
"libcusparse.so.12",
"libcusolver.so.11",
];
#[derive(Debug, Default, Clone)]
pub struct CudaPreload {
pub loaded: Vec<String>,
pub dirs: Vec<PathBuf>,
pub cudnn: bool,
}
static PRELOAD: OnceLock<CudaPreload> = OnceLock::new();
pub fn ensure() -> &'static CudaPreload {
PRELOAD.get_or_init(run)
}
fn run() -> CudaPreload {
let dirs = candidate_dirs();
let mut out = CudaPreload {
dirs: dirs.clone(),
..Default::default()
};
for soname in ORDERED_SONAMES {
if let Some(path) = find_lib(&dirs, soname) {
if dlopen_global(&path) {
out.loaded.push(soname.to_string());
}
}
}
if let Some(dir) = dirs.iter().find(|d| d.join("libcudnn.so.9").exists()) {
let mut subs: Vec<PathBuf> = std::fs::read_dir(dir)
.into_iter()
.flatten()
.flatten()
.map(|e| e.path())
.filter(|p| {
p.file_name()
.and_then(|n| n.to_str())
.is_some_and(|n| n.starts_with("libcudnn") && n.contains(".so.") && !n.ends_with("libcudnn.so.9"))
})
.collect();
subs.sort();
for p in subs {
dlopen_global(&p);
}
if dlopen_global(&dir.join("libcudnn.so.9")) {
out.loaded.push("libcudnn.so.9".to_string());
out.cudnn = true;
}
}
out
}
fn candidate_dirs() -> Vec<PathBuf> {
let mut dirs: Vec<PathBuf> = Vec::new();
let push = |p: PathBuf, dirs: &mut Vec<PathBuf>| {
if p.is_dir() && !dirs.contains(&p) {
dirs.push(p);
}
};
if let Some(v) = std::env::var_os("NORNIR_CUDA_LIBS") {
for p in std::env::split_paths(&v) {
push(p, &mut dirs);
}
}
for key in ["VIRTUAL_ENV", "CONDA_PREFIX"] {
if let Some(root) = std::env::var_os(key) {
for d in nvidia_pkg_dirs(Path::new(&root)) {
push(d, &mut dirs);
}
}
}
if let Some(v) = std::env::var_os("NORNIR_CUDA_SCAN_ROOTS") {
for root in std::env::split_paths(&v) {
for d in scan_root_for_nvidia(&root) {
push(d, &mut dirs);
}
}
}
for sys in [
"/usr/local/lib/ollama/cuda_v12",
"/usr/local/lib/ollama/cuda_v13",
"/usr/local/cuda/lib64",
"/usr/local/cuda-12/lib64",
"/opt/cuda/lib64",
"/usr/lib/x86_64-linux-gnu",
] {
push(PathBuf::from(sys), &mut dirs);
}
dirs
}
fn nvidia_pkg_dirs(root: &Path) -> Vec<PathBuf> {
let mut out = Vec::new();
for libname in ["lib", "lib64"] {
let pyroot = root.join(libname);
let Ok(entries) = std::fs::read_dir(&pyroot) else {
continue;
};
for e in entries.flatten() {
let nvidia = e.path().join("site-packages/nvidia");
if let Ok(pkgs) = std::fs::read_dir(&nvidia) {
for p in pkgs.flatten() {
let lib = p.path().join("lib");
if lib.is_dir() {
out.push(lib);
}
}
}
}
}
out
}
fn scan_root_for_nvidia(root: &Path) -> Vec<PathBuf> {
let mut out = nvidia_pkg_dirs(root);
if let Ok(entries) = std::fs::read_dir(root) {
for e in entries.flatten() {
if e.path().is_dir() {
out.extend(nvidia_pkg_dirs(&e.path()));
}
}
}
out
}
fn find_lib(dirs: &[PathBuf], soname: &str) -> Option<PathBuf> {
for d in dirs {
let exact = d.join(soname);
if exact.exists() {
return Some(exact);
}
if let Ok(entries) = std::fs::read_dir(d) {
for e in entries.flatten() {
if let Some(name) = e.file_name().to_str() {
if name.starts_with(soname) {
return Some(e.path());
}
}
}
}
}
None
}
fn dlopen_global(path: &Path) -> bool {
use libloading::os::unix::{Library, RTLD_GLOBAL, RTLD_NOW};
match unsafe { Library::open(Some(path), RTLD_NOW | RTLD_GLOBAL) } {
Ok(lib) => {
std::mem::forget(lib);
true
}
Err(_) => false,
}
}