nornir 0.4.32

Companion to cargo: dependency tracking, release gating, deploy, benchmarks, and documentation assembly. Project-agnostic.
//! ort (ONNX Runtime) embedder (GPU/CPU) — `jina-embeddings-v2-base-code`.
//!
//! Implements [`super::store::Embedder`] by running the model's ONNX export via
//! the `ort` crate (Microsoft ONNX Runtime). Requests a **GPU execution
//! provider** and falls back to CPU when none is available, so the same binary
//! works on a GPU box or a plain server. Two GPU vendors are supported:
//!
//! - **NVIDIA / CUDA** (always wired when `embed-ort` is on): the `CUDA` EP,
//!   with [`super::cuda`] preloading the runtime libs.
//! - **AMD / ROCm** (extra feature `embed-ort-rocm`): the `MIGraphX` then
//!   `ROCm` EPs, with [`super::rocm`] preloading the runtime libs and a real
//!   runtime probe ([`super::rocm::available`]) so the AMD EP is only requested
//!   when ROCm is actually installed. A non-AMD build never pulls this in; a
//!   build with it but no AMD GPU degrades to the CUDA-then-CPU path unchanged.
//!
//! EP registration is **best-effort** (`fail_silently`): a missing GPU runtime
//! quietly drops to the next provider, ending at CPU — it never errors out of
//! `load`.
//!
//! This is the GPU-capable alternative to the pure-Rust tract backend
//! ([`super::embed`]). Both run the same model and share
//! [`super::embed_support`], so a vector produced here is identical to one from
//! tract — embeddings made by either backend share the same `model_profile`
//! dedup key in the warehouse. The model itself is selectable via the registry
//! ([`super::embed_registry`]); this backend runs whichever ONNX export
//! `build.rs` fetched (jina-v2-base-code, 768-dim, by default).
//!
//! ⚠ Heaviest no-C exception: `ort` links the onnxruntime C++ library (and,
//! with CUDA present, the CUDA runtime). A nornir build without `embed-ort`
//! pulls none of this. Accepted by Rickard, 2026-06-04.
//!
//! onnxruntime parallelizes a single `run` internally (intra-op threads), so
//! we call it once per text; GPU work serializes on the device anyway.
//!
//! Cargo feature: `embed-ort`.

use std::sync::Mutex;

use anyhow::{Context, Result};
use ort::session::Session;
use ort::value::Tensor;
use tokenizers::Tokenizer;

use super::embed_support as es;
use super::store::{Embedder, ModelProfile};

/// A loaded jina-v2-base-code ONNX model + tokenizer (ort / GPU-or-CPU).
///
/// `Session::run` takes `&mut self`, so the session is behind a `Mutex`; calls
/// serialize, but each `run` already uses many cores (CPU EP) or the GPU.
pub struct OrtEmbedder {
    session: Mutex<Session>,
    tokenizer: Tokenizer,
    /// Whether a GPU EP was *requested* (it falls back to CPU silently if the
    /// GPU runtime is absent).
    gpu_requested: bool,
    /// Which GPU vendor's EP was requested, for diagnostics: `"cuda"`,
    /// `"rocm"`, or `"cpu"` (no GPU requested / available).
    gpu_vendor: &'static str,
}

impl OrtEmbedder {
    /// Load the model, requesting the best available GPU execution provider
    /// (AMD ROCm when `embed-ort-rocm` is on and ROCm is present, else NVIDIA
    /// CUDA), with a clean CPU fallback.
    pub fn load() -> Result<Self> {
        Self::load_with(true)
    }

    /// Load the model on CPU only (do not request any GPU EP).
    pub fn load_cpu() -> Result<Self> {
        Self::load_with(false)
    }

    /// Load the model. When `prefer_gpu`, register a GPU EP first; ort's
    /// registration is best-effort (`fail_silently`), so a machine without that
    /// GPU runtime quietly uses the CPU EP instead.
    ///
    /// EP precedence: AMD ROCm (only when `embed-ort-rocm` is compiled **and**
    /// [`super::rocm::available`] runtime-probes true) → NVIDIA CUDA → CPU.
    /// MIGraphX is registered ahead of plain ROCm because it graph-compiles the
    /// whole model; ort tries them in order and falls through on registration
    /// failure.
    pub fn load_with(prefer_gpu: bool) -> Result<Self> {
        let dir = es::model_dir();
        let tokenizer = Tokenizer::from_file(dir.join("tokenizer.json"))
            .map_err(|e| anyhow::anyhow!("load tokenizer.json: {e}"))?;

        let mut builder = Session::builder().context("ort session builder")?;
        let mut gpu_vendor = "cpu";
        if prefer_gpu {
            // AMD ROCm path — only when the feature is built and the runtime
            // probe says ROCm is actually installed. On any other box this
            // branch is skipped entirely, leaving the CUDA path below untouched.
            #[cfg(feature = "embed-ort-rocm")]
            {
                if super::rocm::available() {
                    // Make the ROCm libs resolvable before the provider
                    // dlopen's them, mirroring the CUDA preload.
                    let _ = super::rocm::ensure();
                    builder = builder
                        .with_execution_providers([
                            // MIGraphX compiles the graph; ROCm is the generic
                            // fallback. Both fail_silently → CPU on any miss.
                            ort::ep::MIGraphX::default().build().fail_silently(),
                            ort::ep::ROCm::default().build().fail_silently(),
                        ])
                        .map_err(|e| anyhow::anyhow!("register ROCm/MIGraphX EP: {e}"))?;
                    gpu_vendor = "rocm";
                }
            }
            // NVIDIA CUDA path — the original behaviour, used whenever the ROCm
            // branch didn't claim the GPU (no AMD feature, or no ROCm present).
            if gpu_vendor != "rocm" {
                // Make CUDA libs resolvable before the provider dlopen's them,
                // so the GPU EP works without a manual LD_LIBRARY_PATH.
                let _ = super::cuda::ensure();
                builder = builder
                    .with_execution_providers([ort::ep::CUDA::default().build().fail_silently()])
                    .map_err(|e| anyhow::anyhow!("register CUDA execution provider: {e}"))?;
                gpu_vendor = "cuda";
            }
        }
        let session = builder
            .commit_from_file(dir.join("model.onnx"))
            .context("commit onnx session")?;

        Ok(Self {
            session: Mutex::new(session),
            tokenizer,
            gpu_requested: prefer_gpu,
            gpu_vendor,
        })
    }

    /// True if a GPU EP was requested at load time. (ort doesn't expose which EP
    /// actually serviced a run; GPU EPs fall back to CPU silently.)
    pub fn gpu_requested(&self) -> bool {
        self.gpu_requested
    }

    /// Which GPU vendor's EP was requested: `"cuda"`, `"rocm"`, or `"cpu"`.
    pub fn gpu_vendor(&self) -> &'static str {
        self.gpu_vendor
    }

    /// Whether the AMD ROCm runtime preload found a usable set (HIP + MIOpen).
    /// Always `false` unless the `embed-ort-rocm` feature is compiled.
    pub fn rocm_ready(&self) -> bool {
        #[cfg(feature = "embed-ort-rocm")]
        {
            return super::rocm::available();
        }
        #[cfg(not(feature = "embed-ort-rocm"))]
        {
            false
        }
    }

    /// Whether the runtime CUDA-lib preload found a usable set (cudart +
    /// cuDNN). When false, the CUDA EP almost certainly fell back to CPU.
    pub fn cuda_ready(&self) -> bool {
        let p = super::cuda::ensure();
        p.cudnn && p.loaded.iter().any(|s| s.starts_with("libcudart"))
    }

    fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
        let enc = self
            .tokenizer
            .encode(text, true)
            .map_err(|e| anyhow::anyhow!("tokenize: {e}"))?;
        let (ids, mask) = es::prepare_tokens(enc.get_ids(), es::max_tokens());
        let n = ids.len();
        let shape = [1_i64, n as i64];

        let id_tensor = Tensor::from_array((shape, ids)).context("build input_ids tensor")?;
        let mask_tensor = Tensor::from_array((shape, mask)).context("build attention_mask tensor")?;

        let mut session = self.session.lock().unwrap();
        let outputs = session
            .run(ort::inputs![
                "input_ids" => id_tensor,
                "attention_mask" => mask_tensor,
            ])
            .context("onnx forward")?;
        let (out_shape, data) = outputs[0]
            .try_extract_tensor::<f32>()
            .context("extract hidden state")?;
        let dim = es::dim();
        let dims = out_shape.as_ref();
        anyhow::ensure!(
            dims.len() == 3 && dims[2] as usize == dim,
            "unexpected output shape {dims:?} (expected last dim {dim})"
        );
        Ok(es::pool_and_normalize(data, n, dim))
    }
}

impl Embedder for OrtEmbedder {
    fn profile(&self) -> ModelProfile {
        es::profile()
    }

    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
        // onnxruntime parallelizes each run internally; calls serialize on the
        // Mutex (and on the GPU). One forward per text.
        texts.iter().map(|t| self.embed_one(t)).collect()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Loads the real ONNX model via ort (CPU EP — CUDA optional). Ignored by
    /// default; run with `cargo test --features embed-ort -- --ignored embed`.
    #[test]
    #[ignore]
    fn loads_and_embeds_cpu() {
        let e = OrtEmbedder::load_cpu().expect("load model");
        let p = e.profile();
        assert_eq!(p.dim, es::dim());
        assert_eq!(p.model_name, es::model_name());

        let v = e.embed(&["fn main() {}".to_string()]).unwrap();
        assert_eq!(v.len(), 1);
        assert_eq!(v[0].len(), es::dim());
        let norm: f32 = v[0].iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!((norm - 1.0).abs() < 1e-3, "norm {norm}");

        let a = e.embed(&["fn add(a: i32, b: i32) -> i32 { a + b }".into()]).unwrap();
        let b = e.embed(&["pub fn sum(x: i32, y: i32) -> i32 { x + y }".into()]).unwrap();
        let c = e.embed(&["the quick brown fox jumps over the lazy dog".into()]).unwrap();
        let dot = |x: &[f32], y: &[f32]| x.iter().zip(y).map(|(p, q)| p * q).sum::<f32>();
        assert!(
            dot(&a[0], &b[0]) > dot(&a[0], &c[0]),
            "two fns ({}) closer than fn-vs-prose ({})",
            dot(&a[0], &b[0]),
            dot(&a[0], &c[0])
        );
    }
}