use std::path::{Path, PathBuf};
use std::sync::OnceLock;
const ORDERED_SONAMES: &[&str] = &[
"libamdhip64.so", "librocblas.so", "libhipblas.so", "librocfft.so", "libMIOpen.so", "libmigraphx_c.so", ];
#[derive(Debug, Default, Clone)]
pub struct RocmPreload {
pub loaded: Vec<String>,
pub dirs: Vec<PathBuf>,
pub hip: bool,
pub miopen: bool,
}
static PRELOAD: OnceLock<RocmPreload> = OnceLock::new();
pub fn ensure() -> &'static RocmPreload {
PRELOAD.get_or_init(run)
}
pub fn available() -> bool {
let p = ensure();
driver_present() && p.hip && p.miopen
}
fn run() -> RocmPreload {
let dirs = candidate_dirs();
let mut out = RocmPreload {
dirs: dirs.clone(),
..Default::default()
};
for soname in ORDERED_SONAMES {
if let Some(path) = find_lib(&dirs, soname) {
if dlopen_global(&path) {
out.loaded.push(soname.to_string());
if soname.starts_with("libamdhip64") {
out.hip = true;
} else if soname.starts_with("libMIOpen") {
out.miopen = true;
}
}
}
}
out
}
pub fn driver_present() -> bool {
use libloading::os::unix::{Library, RTLD_NOW};
unsafe { Library::open(Some(std::ffi::OsStr::new("libamdhip64.so")), RTLD_NOW) }.is_ok()
}
fn yn(b: bool) -> &'static str {
if b {
"found"
} else {
"MISSING"
}
}
pub fn preflight() -> (bool, String) {
let p = ensure();
let driver = driver_present();
let gpu_ready = available();
let mut s = String::from("nornir ROCm preflight (embed-ort-rocm AMD GPU path)\n");
s.push_str(&format!(" AMD HIP runtime libamdhip64 : {}\n", yn(driver && p.hip)));
s.push_str(&format!(" MIOpen (ROCm DNN) : {}\n", yn(p.miopen)));
s.push_str(&format!(
" ROCm libs loaded : {}\n",
if p.loaded.is_empty() { "(none)".into() } else { p.loaded.join(", ") }
));
s.push_str(&format!(
" dirs searched : {}\n",
p.dirs.iter().map(|d| d.display().to_string()).collect::<Vec<_>>().join(", ")
));
let missing: Vec<&str> = ORDERED_SONAMES
.iter()
.copied()
.filter(|n| !p.loaded.iter().any(|l| l == n))
.collect();
if !missing.is_empty() {
s.push_str(&format!(" runtime libs not found : {}\n", missing.join(", ")));
}
s.push_str(&format!(
"\n verdict: AMD GPU embedding {}\n",
if gpu_ready { "READY ✓" } else { "unavailable → CPU fallback" }
));
if !driver {
s.push_str(
" → install ROCm for your distro (e.g. `sudo apt install rocm-hip-runtime miopen-hip`, \
or AMD's amdgpu-install). nornir can't install the kernel driver — it needs root + the \
amdgpu module + (usually) a reboot.\n",
);
}
if driver && !p.miopen {
s.push_str(
" → ROCm runtime incomplete (no MIOpen). Install `miopen-hip` / `migraphx`, or put a \
matched ROCm .so set in ONE dir and set NORNIR_ROCM_LIBS to it.\n",
);
}
if gpu_ready {
s.push_str(" → all set; embed-ort runs on the AMD GPU via MIGraphX/ROCm.\n");
}
(gpu_ready, s)
}
fn candidate_dirs() -> Vec<PathBuf> {
let mut dirs: Vec<PathBuf> = Vec::new();
let push = |p: PathBuf, dirs: &mut Vec<PathBuf>| {
if p.is_dir() && !dirs.contains(&p) {
dirs.push(p);
}
};
if let Some(v) = std::env::var_os("NORNIR_ROCM_LIBS") {
for p in std::env::split_paths(&v) {
push(p, &mut dirs);
}
}
for key in ["ROCM_PATH", "HIP_PATH"] {
if let Some(root) = std::env::var_os(key) {
push(Path::new(&root).join("lib"), &mut dirs);
push(Path::new(&root).join("lib64"), &mut dirs);
}
}
for sys in [
"/opt/nornir/rocm",
"/opt/rocm/lib",
"/opt/rocm-6.0.0/lib",
"/opt/rocm-6.1.0/lib",
"/opt/rocm/lib64",
"/usr/lib/x86_64-linux-gnu",
] {
push(PathBuf::from(sys), &mut dirs);
}
dirs
}
fn find_lib(dirs: &[PathBuf], soname: &str) -> Option<PathBuf> {
for d in dirs {
let exact = d.join(soname);
if exact.exists() {
return Some(exact);
}
if let Ok(entries) = std::fs::read_dir(d) {
for e in entries.flatten() {
if let Some(name) = e.file_name().to_str() {
if name.starts_with(soname) {
return Some(e.path());
}
}
}
}
}
None
}
fn dlopen_global(path: &Path) -> bool {
use libloading::os::unix::{Library, RTLD_GLOBAL, RTLD_NOW};
match unsafe { Library::open(Some(path), RTLD_NOW | RTLD_GLOBAL) } {
Ok(lib) => {
std::mem::forget(lib);
true
}
Err(_) => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn probe_degrades_to_cpu_without_rocm() {
let p = ensure();
assert!(!p.dirs.is_empty(), "candidate dirs must be populated");
let (ready, report) = preflight();
assert_eq!(ready, available(), "preflight verdict must match available()");
assert!(report.contains("ROCm preflight"), "report header missing: {report}");
if available() {
assert!(driver_present() && p.hip && p.miopen, "available() implies hip+miopen+driver");
assert!(report.contains("READY"), "ready box should say READY: {report}");
} else {
assert!(!ready, "no ROCm => not ready");
assert!(report.contains("CPU fallback"), "verdict should name CPU fallback: {report}");
}
}
}