nornir 0.4.18

Companion to cargo: dependency tracking, release gating, deploy, benchmarks, and documentation assembly. Project-agnostic.
Documentation
//! tract-onnx embedder (CPU, pure Rust) — `jina-embeddings-v2-base-code`.
//!
//! Implements [`super::store::Embedder`] by running the model's ONNX export on
//! CPU via `tract-onnx`. This is the **pure-Rust** backend (the runtime has no
//! C deps; only the shared `tokenizers` crate does — the accepted exception).
//! The ort backend ([`super::embed_ort`]) is the GPU/CPU alternative; both run
//! the same model and share [`super::embed_support`], so their vectors match.
//!
//! candle's built-in jina model is the *English* architecture and cannot load
//! the code model's QK-LayerNorm weights, so we run the official ONNX export.
//!
//! The graph takes `input_ids` + `attention_mask` (i64 `[1, n]`) and outputs
//! the last hidden state `[1, n, dim]` (dim = the **selected** model's dim, 768
//! for the jina default); we mean-pool + L2-normalize (shared
//! [`super::embed_support::pool_and_normalize`]). The model is selectable via
//! the registry ([`super::embed_registry`]); this backend runs whichever ONNX
//! export `build.rs` fetched.
//!
//! Cargo feature: `embed-tract`.

use std::sync::Arc;

use anyhow::{Context, Result};
use tokenizers::Tokenizer;
use tract_onnx::prelude::*;

use super::embed_support as es;
use super::store::{Embedder, ModelProfile};

type OnnxModel = TypedRunnableModel;

/// A loaded jina-v2-base-code ONNX model + tokenizer (tract / CPU).
pub struct JinaEmbedder {
    model: Arc<OnnxModel>,
    tokenizer: Tokenizer,
}

impl JinaEmbedder {
    /// Load + optimize the ONNX model and tokenizer. The model dir is resolved
    /// at runtime ([`es::model_dir`]) so a service user reads a readable copy
    /// (`$NORNIR_MODEL_DIR` / `/opt/nornir/models`) rather than the builder's
    /// `~/.cache`.
    pub fn load() -> Result<Self> {
        let dir = es::model_dir();
        let tokenizer = Tokenizer::from_file(dir.join("tokenizer.json"))
            .map_err(|e| anyhow::anyhow!("load tokenizer.json: {e}"))?;
        let onnx = dir.join("model.onnx");
        let model = tract_onnx::onnx()
            .model_for_path(&onnx)
            .with_context(|| format!("load onnx {}", onnx.display()))?
            .into_optimized()
            .context("optimize onnx graph")?
            .into_runnable()
            .context("make onnx graph runnable")?;
        Ok(Self { model, tokenizer })
    }

    /// Embed a single string: tokenize → forward → mean-pool → L2-normalize.
    fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
        let enc = self
            .tokenizer
            .encode(text, true)
            .map_err(|e| anyhow::anyhow!("tokenize: {e}"))?;
        let (ids, mask) = es::prepare_tokens(enc.get_ids(), es::max_tokens());
        let n = ids.len();

        let input_ids = tract_ndarray::Array2::from_shape_vec((1, n), ids)?.into_tensor();
        let attn = tract_ndarray::Array2::from_shape_vec((1, n), mask)?.into_tensor();

        let outputs = self
            .model
            .run(tvec!(input_ids.into(), attn.into()))
            .context("onnx forward")?;
        let hidden = outputs[0]
            .to_plain_array_view::<f32>()
            .context("read hidden state")?; // [1, n, 768]
        let dim = es::dim();
        let shape = hidden.shape();
        anyhow::ensure!(
            shape.len() == 3 && shape[2] == dim,
            "unexpected output shape {shape:?} (expected last dim {dim})"
        );
        let flat = hidden.as_slice().context("hidden state not contiguous")?;
        Ok(es::pool_and_normalize(flat, n, dim))
    }
}

impl Embedder for JinaEmbedder {
    fn profile(&self) -> ModelProfile {
        es::profile()
    }

    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
        // tract runs one forward per call single-threaded, so parallelism
        // comes from embedding multiple chunks across cores. The model and
        // tokenizer are shared read-only across scoped threads (no clone of
        // the graph).
        let n = texts.len();
        if n <= 1 {
            return texts.iter().map(|t| self.embed_one(t)).collect();
        }
        let threads = std::thread::available_parallelism()
            .map(|x| x.get())
            .unwrap_or(1)
            .min(n);
        if threads <= 1 {
            return texts.iter().map(|t| self.embed_one(t)).collect();
        }

        let mut out: Vec<Vec<f32>> = vec![Vec::new(); n];
        let chunk = n.div_ceil(threads);
        let err = std::sync::Mutex::new(None::<anyhow::Error>);
        std::thread::scope(|s| {
            for (texts_part, out_part) in texts.chunks(chunk).zip(out.chunks_mut(chunk)) {
                let err = &err;
                s.spawn(move || {
                    for (t, slot) in texts_part.iter().zip(out_part.iter_mut()) {
                        match self.embed_one(t) {
                            Ok(v) => *slot = v,
                            Err(e) => {
                                *err.lock().unwrap() = Some(e);
                                return;
                            }
                        }
                    }
                });
            }
        });
        if let Some(e) = err.into_inner().unwrap() {
            return Err(e);
        }
        Ok(out)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Loads the real ONNX model (needs the build-time weight cache). Ignored
    /// by default so the normal test run stays fast and offline; run with
    /// `cargo test --features embed-tract -- --ignored embed`.
    #[test]
    #[ignore]
    fn loads_and_embeds() {
        let e = JinaEmbedder::load().expect("load model");
        let p = e.profile();
        assert_eq!(p.dim, es::dim());
        assert_eq!(p.model_name, es::model_name());

        let v = e.embed(&["fn main() {}".to_string()]).unwrap();
        assert_eq!(v.len(), 1);
        assert_eq!(v[0].len(), es::dim());
        let norm: f32 = v[0].iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!((norm - 1.0).abs() < 1e-3, "norm {norm}");

        // Code semantics: two equivalent Rust fns are closer to each other
        // than either is to unrelated prose.
        let a = e.embed(&["fn add(a: i32, b: i32) -> i32 { a + b }".into()]).unwrap();
        let b = e.embed(&["pub fn sum(x: i32, y: i32) -> i32 { x + y }".into()]).unwrap();
        let c = e.embed(&["the quick brown fox jumps over the lazy dog".into()]).unwrap();
        let dot = |x: &[f32], y: &[f32]| x.iter().zip(y).map(|(p, q)| p * q).sum::<f32>();
        assert!(
            dot(&a[0], &b[0]) > dot(&a[0], &c[0]),
            "two fns ({}) closer than fn-vs-prose ({})",
            dot(&a[0], &b[0]),
            dot(&a[0], &c[0])
        );
    }

    /// Full pipeline with the real model: index 3 code files into the
    /// warehouse, then a natural-language query retrieves the right one.
    #[test]
    #[ignore]
    fn end_to_end_semantic_search() {
        use crate::vector::chunk::ChunkOptions;
        use crate::vector::store::{index_repo, search, RepoRef};
        use crate::warehouse::iceberg::IcebergWarehouse;

        let dir = tempfile::tempdir().unwrap();
        let wh = IcebergWarehouse::open(dir.path()).unwrap();
        let embedder = JinaEmbedder::load().unwrap();

        let files = vec![
            (
                "math.rs".to_string(),
                "pub fn add(a: i32, b: i32) -> i32 { a + b }".to_string(),
            ),
            (
                "io.rs".to_string(),
                "fn read_file(path: &str) -> std::io::Result<String> { std::fs::read_to_string(path) }".to_string(),
            ),
            (
                "net.rs".to_string(),
                "async fn fetch(url: &str) -> Result<String> { http_get(url).await }".to_string(),
            ),
        ];
        let snap = index_repo(
            &wh,
            &RepoRef {
                workspace: "ws",
                repo: "demo",
                git_sha: "sha1",
                branch: "main",
                complete: true,
            },
            &files,
            &ChunkOptions::default(),
            &embedder,
        )
        .unwrap();
        assert_eq!(snap.new_vectors, 3);

        let mp = embedder.profile().id();
        let q = embedder
            .embed(&["a function that adds two integers together".to_string()])
            .unwrap();
        let hits = search(&wh, "demo", Some("sha1"), &mp, &q[0], 3).unwrap();
        assert_eq!(
            hits[0].1.file, "math.rs",
            "NL query about adding integers should retrieve the add fn"
        );
    }
}