use std::path::{Path, PathBuf};
use std::sync::OnceLock;
const ORDERED_SONAMES: &[&str] = &[
"libcudart.so.12",
"libnvJitLink.so.12",
"libcublasLt.so.12",
"libcublas.so.12",
"libcufft.so.11",
"libcurand.so.10",
"libcusparse.so.12",
"libcusolver.so.11",
];
#[derive(Debug, Default, Clone)]
pub struct CudaPreload {
pub loaded: Vec<String>,
pub dirs: Vec<PathBuf>,
pub cudnn: bool,
}
static PRELOAD: OnceLock<CudaPreload> = OnceLock::new();
pub fn ensure() -> &'static CudaPreload {
PRELOAD.get_or_init(run)
}
fn run() -> CudaPreload {
let dirs = candidate_dirs();
let mut out = CudaPreload {
dirs: dirs.clone(),
..Default::default()
};
for soname in ORDERED_SONAMES {
if let Some(path) = find_lib(&dirs, soname) {
if dlopen_global(&path) {
out.loaded.push(soname.to_string());
}
}
}
if let Some(dir) = dirs.iter().find(|d| d.join("libcudnn.so.9").exists()) {
let mut subs: Vec<PathBuf> = std::fs::read_dir(dir)
.into_iter()
.flatten()
.flatten()
.map(|e| e.path())
.filter(|p| {
p.file_name()
.and_then(|n| n.to_str())
.is_some_and(|n| n.starts_with("libcudnn") && n.contains(".so.") && !n.ends_with("libcudnn.so.9"))
})
.collect();
subs.sort();
for p in subs {
dlopen_global(&p);
}
if dlopen_global(&dir.join("libcudnn.so.9")) {
out.loaded.push("libcudnn.so.9".to_string());
out.cudnn = true;
}
}
out
}
pub fn driver_present() -> bool {
use libloading::os::unix::{Library, RTLD_NOW};
unsafe { Library::open(Some(std::ffi::OsStr::new("libcuda.so.1")), RTLD_NOW) }.is_ok()
}
fn yn(b: bool) -> &'static str {
if b {
"found"
} else {
"MISSING"
}
}
pub fn preflight() -> (bool, String) {
let p = ensure();
let driver = driver_present();
let cudart = p.loaded.iter().any(|s| s.starts_with("libcudart"));
let gpu_ready = driver && cudart && p.cudnn;
let mut s = String::from("nornir CUDA preflight (embed-ort GPU path)\n");
s.push_str(&format!(" NVIDIA driver libcuda.so.1 : {}\n", yn(driver)));
s.push_str(&format!(" libcudart (CUDA runtime) : {}\n", yn(cudart)));
s.push_str(&format!(" cuDNN 9 : {}\n", yn(p.cudnn)));
s.push_str(&format!(
" CUDA libs loaded : {}\n",
if p.loaded.is_empty() { "(none)".into() } else { p.loaded.join(", ") }
));
s.push_str(&format!(
" dirs searched : {}\n",
p.dirs.iter().map(|d| d.display().to_string()).collect::<Vec<_>>().join(", ")
));
let missing: Vec<&str> = ORDERED_SONAMES
.iter()
.copied()
.filter(|n| !p.loaded.iter().any(|l| l == n))
.collect();
if !missing.is_empty() {
s.push_str(&format!(" runtime libs not found : {}\n", missing.join(", ")));
}
s.push_str(&format!(
"\n verdict: GPU embedding {}\n",
if gpu_ready { "READY ✓" } else { "unavailable → CPU fallback" }
));
if !driver {
s.push_str(
" → install the NVIDIA GPU driver for your distro (e.g. `sudo apt install \
nvidia-driver-XXX`, `sudo dnf install akmod-nvidia`, or NVIDIA's .run). nornir \
can't install the kernel driver — it needs root + a matching kernel module + reboot.\n",
);
}
if driver && (!p.cudnn || !cudart) {
s.push_str(
" → CUDA runtime libs incomplete. Put a matched CUDA-12 / cuDNN-9 .so set \
(libcudart.so.12, libcublas*, libcudnn.so.9 + its sub-libs) in ONE dir and set \
NORNIR_CUDA_LIBS to it (e.g. /opt/nornir/cuda). See `.nornir/vector.md`.\n",
);
}
if gpu_ready {
s.push_str(" → all set; embed-ort runs on the GPU.\n");
}
(gpu_ready, s)
}
pub fn setup(target: &Path) -> anyhow::Result<(Vec<String>, Vec<String>)> {
use anyhow::Context;
let dirs = candidate_dirs();
std::fs::create_dir_all(target)
.with_context(|| format!("create {} (need root? try sudo)", target.display()))?;
let mut copied = Vec::new();
let mut missing = Vec::new();
for soname in ORDERED_SONAMES {
match find_lib(&dirs, soname) {
Some(src) => {
let dst = target.join(src.file_name().unwrap_or_default());
std::fs::copy(&src, &dst)
.with_context(|| format!("copy {} -> {}", src.display(), dst.display()))?;
let alias = target.join(soname);
if !alias.exists() {
std::fs::copy(&src, &alias).ok();
}
copied.push(soname.to_string());
}
None => missing.push(soname.to_string()),
}
}
if let Some(dir) = dirs.iter().find(|d| d.join("libcudnn.so.9").exists()) {
let mut n = 0usize;
for e in std::fs::read_dir(dir).into_iter().flatten().flatten() {
let name = e.file_name();
let name = name.to_string_lossy();
if name.starts_with("libcudnn") && name.contains(".so") {
std::fs::copy(e.path(), target.join(e.file_name()))
.with_context(|| format!("copy {}", e.path().display()))?;
n += 1;
}
}
copied.push(format!("libcudnn.so.9 (+{n} files)"));
} else {
missing.push("libcudnn.so.9".to_string());
}
Ok((copied, missing))
}
fn candidate_dirs() -> Vec<PathBuf> {
let mut dirs: Vec<PathBuf> = Vec::new();
let push = |p: PathBuf, dirs: &mut Vec<PathBuf>| {
if p.is_dir() && !dirs.contains(&p) {
dirs.push(p);
}
};
if let Some(v) = std::env::var_os("NORNIR_CUDA_LIBS") {
for p in std::env::split_paths(&v) {
push(p, &mut dirs);
}
}
for key in ["VIRTUAL_ENV", "CONDA_PREFIX"] {
if let Some(root) = std::env::var_os(key) {
for d in nvidia_pkg_dirs(Path::new(&root)) {
push(d, &mut dirs);
}
}
}
if let Some(v) = std::env::var_os("NORNIR_CUDA_SCAN_ROOTS") {
for root in std::env::split_paths(&v) {
for d in scan_root_for_nvidia(&root) {
push(d, &mut dirs);
}
}
}
for sys in [
"/opt/nornir/cuda",
"/usr/local/lib/ollama/cuda_v12",
"/usr/local/lib/ollama/cuda_v13",
"/usr/local/cuda/lib64",
"/usr/local/cuda-12/lib64",
"/opt/cuda/lib64",
"/usr/lib/x86_64-linux-gnu",
] {
push(PathBuf::from(sys), &mut dirs);
}
dirs
}
fn nvidia_pkg_dirs(root: &Path) -> Vec<PathBuf> {
let mut out = Vec::new();
for libname in ["lib", "lib64"] {
let pyroot = root.join(libname);
let Ok(entries) = std::fs::read_dir(&pyroot) else {
continue;
};
for e in entries.flatten() {
let nvidia = e.path().join("site-packages/nvidia");
if let Ok(pkgs) = std::fs::read_dir(&nvidia) {
for p in pkgs.flatten() {
let lib = p.path().join("lib");
if lib.is_dir() {
out.push(lib);
}
}
}
}
}
out
}
fn scan_root_for_nvidia(root: &Path) -> Vec<PathBuf> {
let mut out = nvidia_pkg_dirs(root);
if let Ok(entries) = std::fs::read_dir(root) {
for e in entries.flatten() {
if e.path().is_dir() {
out.extend(nvidia_pkg_dirs(&e.path()));
}
}
}
out
}
fn find_lib(dirs: &[PathBuf], soname: &str) -> Option<PathBuf> {
for d in dirs {
let exact = d.join(soname);
if exact.exists() {
return Some(exact);
}
if let Ok(entries) = std::fs::read_dir(d) {
for e in entries.flatten() {
if let Some(name) = e.file_name().to_str() {
if name.starts_with(soname) {
return Some(e.path());
}
}
}
}
}
None
}
fn dlopen_global(path: &Path) -> bool {
use libloading::os::unix::{Library, RTLD_GLOBAL, RTLD_NOW};
match unsafe { Library::open(Some(path), RTLD_NOW | RTLD_GLOBAL) } {
Ok(lib) => {
std::mem::forget(lib);
true
}
Err(_) => false,
}
}