nornir 0.4.18

Companion to cargo: dependency tracking, release gating, deploy, benchmarks, and documentation assembly. Project-agnostic.
Documentation
//! ort (ONNX Runtime) embedder (GPU/CPU) — `jina-embeddings-v2-base-code`.
//!
//! Implements [`super::store::Embedder`] by running the model's ONNX export via
//! the `ort` crate (Microsoft ONNX Runtime). Requests the **CUDA execution
//! provider** and falls back to CPU when CUDA isn't available, so the same
//! binary works on a GPU box or a plain server.
//!
//! This is the GPU-capable alternative to the pure-Rust tract backend
//! ([`super::embed`]). Both run the same model and share
//! [`super::embed_support`], so a vector produced here is identical to one from
//! tract — embeddings made by either backend share the same `model_profile`
//! dedup key in the warehouse. The model itself is selectable via the registry
//! ([`super::embed_registry`]); this backend runs whichever ONNX export
//! `build.rs` fetched (jina-v2-base-code, 768-dim, by default).
//!
//! ⚠ Heaviest no-C exception: `ort` links the onnxruntime C++ library (and,
//! with CUDA present, the CUDA runtime). A nornir build without `embed-ort`
//! pulls none of this. Accepted by Rickard, 2026-06-04.
//!
//! onnxruntime parallelizes a single `run` internally (intra-op threads), so
//! we call it once per text; GPU work serializes on the device anyway.
//!
//! Cargo feature: `embed-ort`.

use std::sync::Mutex;

use anyhow::{Context, Result};
use ort::session::Session;
use ort::value::Tensor;
use tokenizers::Tokenizer;

use super::embed_support as es;
use super::store::{Embedder, ModelProfile};

/// A loaded jina-v2-base-code ONNX model + tokenizer (ort / GPU-or-CPU).
///
/// `Session::run` takes `&mut self`, so the session is behind a `Mutex`; calls
/// serialize, but each `run` already uses many cores (CPU EP) or the GPU.
pub struct OrtEmbedder {
    session: Mutex<Session>,
    tokenizer: Tokenizer,
    /// Whether the CUDA EP was *requested* (it falls back to CPU silently if
    /// the CUDA runtime is absent).
    gpu_requested: bool,
}

impl OrtEmbedder {
    /// Load the model, requesting the CUDA execution provider (CPU fallback).
    pub fn load() -> Result<Self> {
        Self::load_with(true)
    }

    /// Load the model on CPU only (do not request CUDA).
    pub fn load_cpu() -> Result<Self> {
        Self::load_with(false)
    }

    /// Load the model. When `prefer_gpu`, register the CUDA EP first; ort's
    /// default registration is best-effort, so a machine without CUDA quietly
    /// uses the CPU EP instead.
    pub fn load_with(prefer_gpu: bool) -> Result<Self> {
        let dir = es::model_dir();
        let tokenizer = Tokenizer::from_file(dir.join("tokenizer.json"))
            .map_err(|e| anyhow::anyhow!("load tokenizer.json: {e}"))?;

        let mut builder = Session::builder().context("ort session builder")?;
        if prefer_gpu {
            // Make CUDA libs resolvable before the provider dlopen's them,
            // so the GPU EP works without a manual LD_LIBRARY_PATH.
            let _ = super::cuda::ensure();
            builder = builder
                .with_execution_providers([ort::ep::CUDA::default().build()])
                .map_err(|e| anyhow::anyhow!("register CUDA execution provider: {e}"))?;
        }
        let session = builder
            .commit_from_file(dir.join("model.onnx"))
            .context("commit onnx session")?;

        Ok(Self {
            session: Mutex::new(session),
            tokenizer,
            gpu_requested: prefer_gpu,
        })
    }

    /// True if the CUDA EP was requested at load time. (ort doesn't expose
    /// which EP actually serviced a run; CUDA falls back to CPU silently.)
    pub fn gpu_requested(&self) -> bool {
        self.gpu_requested
    }

    /// Whether the runtime CUDA-lib preload found a usable set (cudart +
    /// cuDNN). When false, the CUDA EP almost certainly fell back to CPU.
    pub fn cuda_ready(&self) -> bool {
        let p = super::cuda::ensure();
        p.cudnn && p.loaded.iter().any(|s| s.starts_with("libcudart"))
    }

    fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
        let enc = self
            .tokenizer
            .encode(text, true)
            .map_err(|e| anyhow::anyhow!("tokenize: {e}"))?;
        let (ids, mask) = es::prepare_tokens(enc.get_ids(), es::max_tokens());
        let n = ids.len();
        let shape = [1_i64, n as i64];

        let id_tensor = Tensor::from_array((shape, ids)).context("build input_ids tensor")?;
        let mask_tensor = Tensor::from_array((shape, mask)).context("build attention_mask tensor")?;

        let mut session = self.session.lock().unwrap();
        let outputs = session
            .run(ort::inputs![
                "input_ids" => id_tensor,
                "attention_mask" => mask_tensor,
            ])
            .context("onnx forward")?;
        let (out_shape, data) = outputs[0]
            .try_extract_tensor::<f32>()
            .context("extract hidden state")?;
        let dim = es::dim();
        let dims = out_shape.as_ref();
        anyhow::ensure!(
            dims.len() == 3 && dims[2] as usize == dim,
            "unexpected output shape {dims:?} (expected last dim {dim})"
        );
        Ok(es::pool_and_normalize(data, n, dim))
    }
}

impl Embedder for OrtEmbedder {
    fn profile(&self) -> ModelProfile {
        es::profile()
    }

    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
        // onnxruntime parallelizes each run internally; calls serialize on the
        // Mutex (and on the GPU). One forward per text.
        texts.iter().map(|t| self.embed_one(t)).collect()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Loads the real ONNX model via ort (CPU EP — CUDA optional). Ignored by
    /// default; run with `cargo test --features embed-ort -- --ignored embed`.
    #[test]
    #[ignore]
    fn loads_and_embeds_cpu() {
        let e = OrtEmbedder::load_cpu().expect("load model");
        let p = e.profile();
        assert_eq!(p.dim, es::dim());
        assert_eq!(p.model_name, es::model_name());

        let v = e.embed(&["fn main() {}".to_string()]).unwrap();
        assert_eq!(v.len(), 1);
        assert_eq!(v[0].len(), es::dim());
        let norm: f32 = v[0].iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!((norm - 1.0).abs() < 1e-3, "norm {norm}");

        let a = e.embed(&["fn add(a: i32, b: i32) -> i32 { a + b }".into()]).unwrap();
        let b = e.embed(&["pub fn sum(x: i32, y: i32) -> i32 { x + y }".into()]).unwrap();
        let c = e.embed(&["the quick brown fox jumps over the lazy dog".into()]).unwrap();
        let dot = |x: &[f32], y: &[f32]| x.iter().zip(y).map(|(p, q)| p * q).sum::<f32>();
        assert!(
            dot(&a[0], &b[0]) > dot(&a[0], &c[0]),
            "two fns ({}) closer than fn-vs-prose ({})",
            dot(&a[0], &b[0]),
            dot(&a[0], &c[0])
        );
    }
}