nornir 0.4.40

Companion to cargo: dependency tracking, release gating, deploy, benchmarks, and documentation assembly. Project-agnostic.
//! Runtime embedder backend selection (#9) — ONE binary, picked at runtime.
//!
//! A default `cargo install nornir` ships **both** embedders: the pure-Rust
//! **tract** CPU backend ([`super::embed`]) and the GPU-capable **ort** backend
//! ([`super::embed_ort`], CUDA/ROCm). This module is the single place that
//! decides which one to load on the *running* machine, gracefully:
//!
//! ```text
//!   probe onnxruntime present?  ── no ──▶ tract  (CPU, pure Rust)
//!          │ yes
//!   probe NVIDIA CUDA ready?    ── yes ─▶ ort CUDA   (NVIDIA GPU)
//!          │ no
//!   probe AMD  ROCm  ready?     ── yes ─▶ ort ROCm   (AMD GPU)
//!          │ no
//!          └──────────────────────────▶ tract  (CPU, pure Rust)
//! ```
//!
//! The order is **GPU-or-tract**: onnxruntime (the ort C++ runtime) is only ever
//! touched when a GPU is actually usable. With no GPU we stay on the proven
//! pure-Rust CPU path even if onnxruntime happens to be installed — there is no
//! reason to pull the C++ runtime in to do CPU work tract already does well.
//!
//! ## Why a probe-first gate (build/run safety)
//!
//! `embed-ort` is compiled with ort's **`load-dynamic`** feature, so the binary
//! LINKS NO onnxruntime — it `dlopen`s `libonnxruntime.so` at runtime. The
//! consequence: the *first* call into any ort API ([`ort::session::Session`])
//! triggers ort to load the dylib, and if it is absent ort **panics**. So the
//! selector must NEVER construct an ort object unless [`super::cuda::arm_onnxruntime`]
//! has confirmed a real onnxruntime is loadable. That is exactly what
//! [`decide`] + [`load`] enforce: a box with no onnxruntime (or no GPU) returns
//! [`Backend::TractCpu`] and the ort module is never entered.
//!
//! ## Testability (LAW: inject-assert)
//!
//! [`decide`] is a pure function over a [`Probe`] of booleans, so a test can
//! INJECT a fake probe (GPU-ready / not, onnxruntime present / not) and ASSERT
//! the chosen [`Backend`] — without needing a GPU or onnxruntime in CI. The live
//! path ([`live_probe`]) fills the same struct from the real cuda/rocm probes.

use anyhow::Result;

use super::store::Embedder;

/// The embedder backend the selector chose for this machine.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Backend {
    /// ort (ONNX Runtime) with the **NVIDIA CUDA** execution provider.
    OrtCuda,
    /// ort (ONNX Runtime) with the **AMD ROCm / MIGraphX** execution provider.
    OrtRocm,
    /// Pure-Rust **tract** CPU embedder — the guaranteed-present fallback.
    TractCpu,
}

impl Backend {
    /// Human-readable name, for `vector doctor` / diagnostics.
    pub fn label(self) -> &'static str {
        match self {
            Backend::OrtCuda => "ort (ONNX Runtime, NVIDIA CUDA GPU)",
            Backend::OrtRocm => "ort (ONNX Runtime, AMD ROCm GPU)",
            Backend::TractCpu => "tract-onnx (CPU, pure Rust)",
        }
    }

    /// True if this backend runs on a GPU (ort CUDA or ROCm).
    pub fn is_gpu(self) -> bool {
        matches!(self, Backend::OrtCuda | Backend::OrtRocm)
    }
}

/// The runtime facts the selector decides on. A test injects this directly;
/// [`live_probe`] fills it from the real cuda/rocm/onnxruntime probes.
#[derive(Debug, Clone, Copy, Default)]
pub struct Probe {
    /// `libonnxruntime.so` is present and loadable (the ort dynamic-load dylib).
    pub onnxruntime: bool,
    /// NVIDIA driver + CUDA runtime + cuDNN are all present (CUDA EP would run).
    pub cuda_ready: bool,
    /// AMD driver + HIP runtime + MIOpen are all present (ROCm EP would run).
    pub rocm_ready: bool,
}

/// Pure decision function (#9): given the runtime [`Probe`], choose the backend.
///
/// onnxruntime is the gate for *any* ort path; within it NVIDIA wins over AMD
/// (CUDA is the proven path), and absent any usable GPU we fall back to tract
/// CPU even when onnxruntime is present. No I/O, no globals — trivially testable.
pub fn decide(p: &Probe) -> Backend {
    if p.onnxruntime {
        if p.cuda_ready {
            return Backend::OrtCuda;
        }
        if p.rocm_ready {
            return Backend::OrtRocm;
        }
    }
    Backend::TractCpu
}

/// Run the real probes on this machine and, as a side effect, **arm** the ort
/// dynamic-load path (set `ORT_DYLIB_PATH` when onnxruntime is found) so a
/// subsequent ort load resolves the same dylib we probed. Pure-CPU boxes leave
/// the env untouched.
pub fn live_probe() -> Probe {
    // Arming sets ORT_DYLIB_PATH iff a usable libonnxruntime.so exists; its
    // bool answer IS the onnxruntime probe.
    let onnxruntime = super::cuda::arm_onnxruntime();

    // CUDA readiness: driver + cudart + cuDNN. `preflight` returns (ready, _).
    let cuda_ready = super::cuda::preflight().0;

    // ROCm readiness only when the AMD feature is compiled; otherwise false.
    #[cfg(feature = "embed-ort-rocm")]
    let rocm_ready = super::rocm::available();
    #[cfg(not(feature = "embed-ort-rocm"))]
    let rocm_ready = false;

    Probe { onnxruntime, cuda_ready, rocm_ready }
}

/// The backend [`load`] would choose on this machine right now (probes the real
/// hardware). Cheap; used by `vector doctor` and diagnostics to report the
/// runtime choice without actually loading a model.
pub fn chosen_backend() -> Backend {
    decide(&live_probe())
}

/// Load the embedder the runtime selector chose: probe → [`decide`] → construct.
///
/// GPU backends are best-effort *inside* ort too (the EP falls back to CPU if a
/// provider fails to register), but we only ever enter ort when [`decide`] said
/// onnxruntime + a GPU are present. Any ort load error degrades to tract CPU so
/// the binary never fails to produce an embedder when tract is available.
pub fn load() -> Result<Box<dyn Embedder>> {
    let probe = live_probe();
    let backend = decide(&probe);
    load_backend(backend)
}

/// Construct the embedder for an explicitly chosen [`Backend`], degrading to
/// tract CPU if an ort GPU load fails (so selection is always graceful).
pub fn load_backend(backend: Backend) -> Result<Box<dyn Embedder>> {
    match backend {
        Backend::OrtCuda | Backend::OrtRocm => match super::embed_ort::OrtEmbedder::load() {
            Ok(e) => Ok(Box::new(e)),
            Err(e) => {
                // GPU/ort path failed at runtime — fall back to the always-present
                // pure-Rust CPU embedder rather than erroring out.
                eprintln!("nornir: ort/{} embedder load failed ({e:#}); falling back to tract CPU", backend.label());
                Ok(Box::new(super::embed::JinaEmbedder::load()?))
            }
        },
        Backend::TractCpu => Ok(Box::new(super::embed::JinaEmbedder::load()?)),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// LAW (inject-assert): when the probe says onnxruntime + a usable NVIDIA
    /// GPU are present, the selector MUST pick the ort CUDA backend.
    #[test]
    fn picks_ort_cuda_when_gpu_ready() {
        let p = Probe { onnxruntime: true, cuda_ready: true, rocm_ready: false };
        assert_eq!(decide(&p), Backend::OrtCuda);
        assert!(decide(&p).is_gpu(), "CUDA choice must report as a GPU backend");
        assert!(decide(&p).label().contains("CUDA"));
    }

    /// LAW (inject-assert): NVIDIA wins over AMD when both probe ready (CUDA is
    /// the proven path), and AMD is chosen when only ROCm is ready.
    #[test]
    fn cuda_precedes_rocm_but_rocm_chosen_alone() {
        let both = Probe { onnxruntime: true, cuda_ready: true, rocm_ready: true };
        assert_eq!(decide(&both), Backend::OrtCuda, "CUDA precedes ROCm");

        let amd = Probe { onnxruntime: true, cuda_ready: false, rocm_ready: true };
        assert_eq!(decide(&amd), Backend::OrtRocm, "ROCm chosen when only AMD is ready");
        assert!(decide(&amd).label().contains("ROCm"));
    }

    /// LAW (inject-assert): no GPU ⇒ tract CPU, even if onnxruntime is present.
    /// We don't pull the C++ runtime in to do CPU work tract already does.
    #[test]
    fn falls_back_to_tract_without_gpu() {
        let cpu_only = Probe { onnxruntime: true, cuda_ready: false, rocm_ready: false };
        assert_eq!(decide(&cpu_only), Backend::TractCpu);
        assert!(!decide(&cpu_only).is_gpu());
    }

    /// LAW (inject-assert): no onnxruntime ⇒ tract CPU regardless of GPU probe —
    /// the gate that keeps ort (which would panic without its dylib) untouched.
    #[test]
    fn no_onnxruntime_never_chooses_ort() {
        // Even with a GPU "ready", absent onnxruntime we must not pick ort.
        let gpu_no_ort = Probe { onnxruntime: false, cuda_ready: true, rocm_ready: true };
        assert_eq!(decide(&gpu_no_ort), Backend::TractCpu,
            "no libonnxruntime.so ⇒ must stay on tract (ort would panic loading its dylib)");
        let nothing = Probe::default();
        assert_eq!(decide(&nothing), Backend::TractCpu);
    }

    /// The live probe must run on any box without panicking and agree with the
    /// pure decision over its own result (the live path and decide() are
    /// consistent). On a GPU-less CI box this asserts the CPU branch concretely.
    #[test]
    fn live_probe_is_self_consistent() {
        let p = live_probe();
        let live = chosen_backend();
        assert_eq!(live, decide(&p), "chosen_backend() must equal decide(live_probe())");
        // A box with no onnxruntime must land on tract (the common CI case).
        if !p.onnxruntime {
            assert_eq!(live, Backend::TractCpu, "no onnxruntime ⇒ tract CPU");
        }
    }

    /// LAW (inject-assert): when a *loadable* `libonnxruntime.so` is present
    /// (pointed to by `$NORNIR_TEST_ORT_STUB`), `arm_onnxruntime()` must detect
    /// it, set `ORT_DYLIB_PATH`, and the live selection must move OFF tract — to
    /// `OrtCuda` on a CUDA box (this RTX-4090 dev box) or `OrtRocm` on an AMD one.
    /// Proves the onnxruntime gate + GPU selection plumbing end-to-end without a
    /// real GPU embed. Ignored by default; run with:
    ///   `NORNIR_TEST_ORT_STUB=/path/to/libonnxruntime.so \`
    ///   `  cargo test --lib -- --ignored selection_arms_ort_when_onnxruntime_present`
    #[test]
    #[ignore]
    fn selection_arms_ort_when_onnxruntime_present() {
        let stub = std::env::var("NORNIR_TEST_ORT_STUB")
            .expect("set NORNIR_TEST_ORT_STUB to a loadable libonnxruntime.so");
        std::env::set_var("ORT_DYLIB_PATH", &stub);

        // The onnxruntime probe must now report present.
        assert!(super::super::cuda::arm_onnxruntime(), "loadable onnxruntime should arm");
        assert!(super::super::cuda::onnxruntime_dylib().is_some(), "dylib must be discoverable");

        let probe = live_probe();
        assert!(probe.onnxruntime, "live probe must see the onnxruntime stub");
        let backend = decide(&probe);
        if probe.cuda_ready {
            assert_eq!(backend, Backend::OrtCuda, "CUDA box + onnxruntime ⇒ ort CUDA");
        } else if probe.rocm_ready {
            assert_eq!(backend, Backend::OrtRocm, "AMD box + onnxruntime ⇒ ort ROCm");
        } else {
            // No GPU on this box: even with onnxruntime present we stay on tract.
            assert_eq!(backend, Backend::TractCpu, "no GPU ⇒ tract even with onnxruntime");
        }
        assert!(backend.is_gpu() || !probe.cuda_ready,
            "with CUDA ready the selected backend must be a GPU backend");
    }

    /// LAW (inject-assert): the selector's CPU-fallback path actually LOADS and
    /// EMBEDS — not "didn't panic". Ignored by default (needs the model staged);
    /// run with `cargo test --lib -- --ignored selector_load_backend_embeds`.
    /// We force [`Backend::TractCpu`] (the universal fallback) so this runs on
    /// any box, and assert a real 768-dim L2-normalized vector comes out and
    /// that semantically-close code embeds nearer than unrelated prose.
    #[test]
    #[ignore]
    fn selector_load_backend_embeds() {
        let e = load_backend(Backend::TractCpu).expect("load tract CPU embedder");
        let dim = super::super::embed_support::dim();

        let v = e.embed(&["fn main() {}".to_string()]).expect("embed");
        assert_eq!(v.len(), 1, "one vector per input text");
        assert_eq!(v[0].len(), dim, "vector has the model's dimensionality");
        let norm: f32 = v[0].iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!((norm - 1.0).abs() < 1e-3, "vector is L2-normalized (norm {norm})");

        // Real signal: two functions embed closer to each other than to prose.
        let a = e.embed(&["fn add(a: i32, b: i32) -> i32 { a + b }".into()]).unwrap();
        let b = e.embed(&["pub fn sum(x: i32, y: i32) -> i32 { x + y }".into()]).unwrap();
        let c = e.embed(&["the quick brown fox jumps over the lazy dog".into()]).unwrap();
        let dot = |x: &[f32], y: &[f32]| x.iter().zip(y).map(|(p, q)| p * q).sum::<f32>();
        assert!(
            dot(&a[0], &b[0]) > dot(&a[0], &c[0]),
            "two fns ({}) should embed closer than fn-vs-prose ({})",
            dot(&a[0], &b[0]),
            dot(&a[0], &c[0])
        );
    }
}