use anyhow::Result;
use super::store::Embedder;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Backend {
OrtCuda,
OrtRocm,
TractCpu,
}
impl Backend {
pub fn label(self) -> &'static str {
match self {
Backend::OrtCuda => "ort (ONNX Runtime, NVIDIA CUDA GPU)",
Backend::OrtRocm => "ort (ONNX Runtime, AMD ROCm GPU)",
Backend::TractCpu => "tract-onnx (CPU, pure Rust)",
}
}
pub fn is_gpu(self) -> bool {
matches!(self, Backend::OrtCuda | Backend::OrtRocm)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Probe {
pub onnxruntime: bool,
pub cuda_ready: bool,
pub rocm_ready: bool,
}
pub fn decide(p: &Probe) -> Backend {
if p.onnxruntime {
if p.cuda_ready {
return Backend::OrtCuda;
}
if p.rocm_ready {
return Backend::OrtRocm;
}
}
Backend::TractCpu
}
pub fn live_probe() -> Probe {
let onnxruntime = super::cuda::arm_onnxruntime();
let cuda_ready = super::cuda::preflight().0;
#[cfg(feature = "embed-ort-rocm")]
let rocm_ready = super::rocm::available();
#[cfg(not(feature = "embed-ort-rocm"))]
let rocm_ready = false;
Probe { onnxruntime, cuda_ready, rocm_ready }
}
pub fn chosen_backend() -> Backend {
decide(&live_probe())
}
pub fn load() -> Result<Box<dyn Embedder>> {
let probe = live_probe();
let backend = decide(&probe);
load_backend(backend)
}
pub fn load_backend(backend: Backend) -> Result<Box<dyn Embedder>> {
match backend {
Backend::OrtCuda | Backend::OrtRocm => match super::embed_ort::OrtEmbedder::load() {
Ok(e) => Ok(Box::new(e)),
Err(e) => {
eprintln!("nornir: ort/{} embedder load failed ({e:#}); falling back to tract CPU", backend.label());
Ok(Box::new(super::embed::JinaEmbedder::load()?))
}
},
Backend::TractCpu => Ok(Box::new(super::embed::JinaEmbedder::load()?)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn picks_ort_cuda_when_gpu_ready() {
let p = Probe { onnxruntime: true, cuda_ready: true, rocm_ready: false };
assert_eq!(decide(&p), Backend::OrtCuda);
assert!(decide(&p).is_gpu(), "CUDA choice must report as a GPU backend");
assert!(decide(&p).label().contains("CUDA"));
}
#[test]
fn cuda_precedes_rocm_but_rocm_chosen_alone() {
let both = Probe { onnxruntime: true, cuda_ready: true, rocm_ready: true };
assert_eq!(decide(&both), Backend::OrtCuda, "CUDA precedes ROCm");
let amd = Probe { onnxruntime: true, cuda_ready: false, rocm_ready: true };
assert_eq!(decide(&amd), Backend::OrtRocm, "ROCm chosen when only AMD is ready");
assert!(decide(&amd).label().contains("ROCm"));
}
#[test]
fn falls_back_to_tract_without_gpu() {
let cpu_only = Probe { onnxruntime: true, cuda_ready: false, rocm_ready: false };
assert_eq!(decide(&cpu_only), Backend::TractCpu);
assert!(!decide(&cpu_only).is_gpu());
}
#[test]
fn no_onnxruntime_never_chooses_ort() {
let gpu_no_ort = Probe { onnxruntime: false, cuda_ready: true, rocm_ready: true };
assert_eq!(decide(&gpu_no_ort), Backend::TractCpu,
"no libonnxruntime.so ⇒ must stay on tract (ort would panic loading its dylib)");
let nothing = Probe::default();
assert_eq!(decide(¬hing), Backend::TractCpu);
}
#[test]
fn live_probe_is_self_consistent() {
let p = live_probe();
let live = chosen_backend();
assert_eq!(live, decide(&p), "chosen_backend() must equal decide(live_probe())");
if !p.onnxruntime {
assert_eq!(live, Backend::TractCpu, "no onnxruntime ⇒ tract CPU");
}
}
#[test]
#[ignore]
fn selection_arms_ort_when_onnxruntime_present() {
let stub = std::env::var("NORNIR_TEST_ORT_STUB")
.expect("set NORNIR_TEST_ORT_STUB to a loadable libonnxruntime.so");
std::env::set_var("ORT_DYLIB_PATH", &stub);
assert!(super::super::cuda::arm_onnxruntime(), "loadable onnxruntime should arm");
assert!(super::super::cuda::onnxruntime_dylib().is_some(), "dylib must be discoverable");
let probe = live_probe();
assert!(probe.onnxruntime, "live probe must see the onnxruntime stub");
let backend = decide(&probe);
if probe.cuda_ready {
assert_eq!(backend, Backend::OrtCuda, "CUDA box + onnxruntime ⇒ ort CUDA");
} else if probe.rocm_ready {
assert_eq!(backend, Backend::OrtRocm, "AMD box + onnxruntime ⇒ ort ROCm");
} else {
assert_eq!(backend, Backend::TractCpu, "no GPU ⇒ tract even with onnxruntime");
}
assert!(backend.is_gpu() || !probe.cuda_ready,
"with CUDA ready the selected backend must be a GPU backend");
}
#[test]
#[ignore]
fn selector_load_backend_embeds() {
let e = load_backend(Backend::TractCpu).expect("load tract CPU embedder");
let dim = super::super::embed_support::dim();
let v = e.embed(&["fn main() {}".to_string()]).expect("embed");
assert_eq!(v.len(), 1, "one vector per input text");
assert_eq!(v[0].len(), dim, "vector has the model's dimensionality");
let norm: f32 = v[0].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-3, "vector is L2-normalized (norm {norm})");
let a = e.embed(&["fn add(a: i32, b: i32) -> i32 { a + b }".into()]).unwrap();
let b = e.embed(&["pub fn sum(x: i32, y: i32) -> i32 { x + y }".into()]).unwrap();
let c = e.embed(&["the quick brown fox jumps over the lazy dog".into()]).unwrap();
let dot = |x: &[f32], y: &[f32]| x.iter().zip(y).map(|(p, q)| p * q).sum::<f32>();
assert!(
dot(&a[0], &b[0]) > dot(&a[0], &c[0]),
"two fns ({}) should embed closer than fn-vs-prose ({})",
dot(&a[0], &b[0]),
dot(&a[0], &c[0])
);
}
}