nornir 0.5.1

Companion to cargo: dependency tracking, release gating, deploy, benchmarks, and documentation assembly. Project-agnostic.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
//! tract-onnx embedder (CPU, pure Rust) — `jina-embeddings-v2-base-code`.
//!
//! Implements [`super::store::Embedder`] by running the model's ONNX export on
//! CPU via `tract-onnx`. This is the **pure-Rust** backend (the runtime has no
//! C deps; only the shared `tokenizers` crate does — the accepted exception).
//! The ort backend ([`super::embed_ort`]) is the GPU/CPU alternative; both run
//! the same model and share [`super::embed_support`], so their vectors match.
//!
//! candle's built-in jina model is the *English* architecture and cannot load
//! the code model's QK-LayerNorm weights, so we run the official ONNX export.
//!
//! The graph takes `input_ids` + `attention_mask` (i64 `[batch, sequence]` —
//! **both dims are symbolic/dynamic** in the ONNX export) and outputs the last
//! hidden state `[batch, sequence, dim]` (dim = the **selected** model's dim,
//! 768 for the jina default); we mean-pool + L2-normalize (shared
//! [`super::embed_support::pool_and_normalize`]). The model is selectable via
//! the registry ([`super::embed_registry`]); this backend runs whichever ONNX
//! export `build.rs` fetched.
//!
//! ## Batched forward (CPU throughput)
//!
//! Because the batch dim is dynamic, tract runs one forward at shape
//! `(B, max_len)` just as it runs `(1, n)`. [`JinaEmbedder::embed`] exploits
//! this: it length-buckets the input texts (sort by token count so each batch
//! pads to a similar length — minimal padding waste), forms batches up to
//! [`MAX_BATCH_ROWS`] rows AND a [`MAX_BATCH_TOKENS`] padded-token budget, pads
//! each batch to its own max length, and runs ONE forward per batch. Padding
//! tokens carry attention-mask `0`, and pooling is **mask-aware**
//! ([`es::pool_and_normalize_masked`]) so padding never leaks into a pooled
//! vector. Results are reassembled in the original input order. This is
//! bit-identical to the one-at-a-time path (see the `batched_equals_one_at_a_time`
//! test): the attention mask makes real-token hidden states independent of the
//! padding, and the matmul contraction dim (the model width) is unchanged by
//! batching, so no float ops are reordered.
//!
//! Cargo feature: `embed-tract`.

use std::sync::Arc;

use anyhow::{Context, Result};
use tokenizers::Tokenizer;
use tract_onnx::prelude::*;

use super::embed_support as es;
use super::store::{Embedder, ModelProfile};

type OnnxModel = TypedRunnableModel;

/// Default max rows per batched forward. Tuned for the jina-v2-base-code CPU
/// forward: enough to amortize per-call overhead and share the GEMM across
/// rows, small enough that one long text in a bucket doesn't blow up the padded
/// work. Override with `$NORNIR_EMBED_BATCH_ROWS`.
pub const MAX_BATCH_ROWS: usize = 16;

/// Default padded-token budget per batch (`rows * pad_len`). Caps the total work
/// of a batch so a bucket of long texts forms smaller batches than a bucket of
/// short ones. Override with `$NORNIR_EMBED_BATCH_TOKENS`.
pub const MAX_BATCH_TOKENS: usize = 8192;

/// Read a positive `usize` from `var`, falling back to `default` when unset,
/// empty, or unparseable/zero.
fn env_usize(var: &str, default: usize) -> usize {
    std::env::var(var)
        .ok()
        .and_then(|v| v.trim().parse::<usize>().ok())
        .filter(|&x| x > 0)
        .unwrap_or(default)
}

/// A loaded jina-v2-base-code ONNX model + tokenizer (tract / CPU).
pub struct JinaEmbedder {
    model: Arc<OnnxModel>,
    tokenizer: Tokenizer,
}

impl JinaEmbedder {
    /// Load + optimize the ONNX model and tokenizer. The model dir is resolved
    /// at runtime ([`es::model_dir`]) so a service user reads a readable copy
    /// (`$NORNIR_MODEL_DIR` / `/opt/nornir/models`) rather than the builder's
    /// `~/.cache`.
    pub fn load() -> Result<Self> {
        let dir = es::model_dir();
        let tokenizer = Tokenizer::from_file(dir.join("tokenizer.json"))
            .map_err(|e| anyhow::anyhow!("load tokenizer.json: {e}"))?;
        let onnx = dir.join("model.onnx");
        let model = tract_onnx::onnx()
            .model_for_path(&onnx)
            .with_context(|| format!("load onnx {}", onnx.display()))?
            .into_optimized()
            .context("optimize onnx graph")?
            .into_runnable()
            .context("make onnx graph runnable")?;
        Ok(Self { model, tokenizer })
    }

    /// Embed a single string: tokenize → forward → mean-pool → L2-normalize.
    fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
        let enc = self
            .tokenizer
            .encode(text, true)
            .map_err(|e| anyhow::anyhow!("tokenize: {e}"))?;
        let (ids, mask) = es::prepare_tokens(enc.get_ids(), es::max_tokens());
        let n = ids.len();

        let input_ids = tract_ndarray::Array2::from_shape_vec((1, n), ids)?.into_tensor();
        let attn = tract_ndarray::Array2::from_shape_vec((1, n), mask)?.into_tensor();

        let outputs = self
            .model
            .run(tvec!(input_ids.into(), attn.into()))
            .context("onnx forward")?;
        let hidden = outputs[0]
            .to_plain_array_view::<f32>()
            .context("read hidden state")?; // [1, n, 768]
        let dim = es::dim();
        let shape = hidden.shape();
        anyhow::ensure!(
            shape.len() == 3 && shape[2] == dim,
            "unexpected output shape {shape:?} (expected last dim {dim})"
        );
        let flat = hidden.as_slice().context("hidden state not contiguous")?;
        Ok(es::pool_and_normalize(flat, n, dim))
    }

    /// Tokenize one text into truncated `i64` ids (length ∈ `[1, max_tokens]`),
    /// the same ids `embed_one` would feed — so a batched forward built from
    /// these is equivalent to running `embed_one` per text.
    fn tokenize(&self, text: &str) -> Result<Vec<i64>> {
        let enc = self
            .tokenizer
            .encode(text, true)
            .map_err(|e| anyhow::anyhow!("tokenize: {e}"))?;
        Ok(es::prepare_ids(enc.get_ids(), es::max_tokens()))
    }

    /// Run ONE forward over a batch of pre-tokenized texts (`toks[i]` for each
    /// `i` in `batch`), padding every row to the batch's longest row. Returns
    /// `(original_index, embedding)` per row. Padding tokens (id `0`, mask `0`)
    /// are ignored by mask-aware pooling, so each row's vector equals its
    /// one-at-a-time `embed_one` result.
    fn forward_batch(&self, batch: &[usize], toks: &[Vec<i64>]) -> Result<Vec<(usize, Vec<f32>)>> {
        let b = batch.len();
        debug_assert!(b > 0, "empty batch");
        let pad_len = batch.iter().map(|&i| toks[i].len()).max().unwrap_or(1).max(1);
        let dim = es::dim();

        // Build the (B, pad_len) `input_ids` + `attention_mask` tensors by writing the
        // real tokens DIRECTLY into tract's aligned buffers — no intermediate `Vec`
        // and no alloc-then-copy into the tensor (`into_tensor` copies). `Tensor::zero`
        // pre-zeros both, so padding ids stay 0 and padding mask stays 0 for free; we
        // only touch the real front cells of each row.
        let mut ids_t = Tensor::zero::<i64>(&[b, pad_len])?;
        let mut mask_t = Tensor::zero::<i64>(&[b, pad_len])?;
        {
            // SAFETY: both tensors were just created contiguous, dtype i64, length
            // exactly b*pad_len; we hold the sole &mut, so the unchecked slices are
            // sound (tract has no safe checked `as_slice_mut`, only the pointer form).
            let ids_s = unsafe { ids_t.as_slice_mut_unchecked::<i64>() };
            let mask_s = unsafe { mask_t.as_slice_mut_unchecked::<i64>() };
            for (row, &i) in batch.iter().enumerate() {
                let base = row * pad_len;
                for (j, &id) in toks[i].iter().enumerate() {
                    ids_s[base + j] = id;
                    mask_s[base + j] = 1;
                }
            }
        }

        let outputs = self
            .model
            .run(tvec!(ids_t.into(), mask_t.into()))
            .context("onnx batched forward")?;
        let hidden = outputs[0]
            .to_plain_array_view::<f32>()
            .context("read hidden state")?; // [b, pad_len, dim]
        let shape = hidden.shape();
        anyhow::ensure!(
            shape.len() == 3 && shape[0] == b && shape[1] == pad_len && shape[2] == dim,
            "unexpected batched output shape {shape:?} (expected [{b}, {pad_len}, {dim}])"
        );
        let flat = hidden.as_slice().context("hidden state not contiguous")?;

        // Pool over the FRONT `n_real` rows of each row's hidden block. Real tokens are
        // front-contiguous and the attention mask zeroed padding in the forward, so
        // this is bit-identical to the mask-aware pool — and needs no mask buffer at
        // pooling time (so `mask_t` was free to move into the forward above).
        let mut out = Vec::with_capacity(b);
        for (row, &i) in batch.iter().enumerate() {
            let n_real = toks[i].len().max(1);
            let start = row * pad_len * dim;
            let row_hidden = &flat[start..start + n_real * dim];
            out.push((i, es::pool_and_normalize(row_hidden, n_real, dim)));
        }
        Ok(out)
    }
}

/// Group text indices (given their token lengths) into batches: sort ascending
/// by token count (length bucketing → similar lengths batch together, minimal
/// padding), then greedily pack up to `max_rows` rows without the padded token
/// count (`rows * running_max_len`) exceeding `max_tokens`. A single text longer
/// than the budget forms its own batch. `order`-stable within a batch.
fn plan_batches(lens: &[usize], max_rows: usize, max_tokens: usize) -> Vec<Vec<usize>> {
    let mut order: Vec<usize> = (0..lens.len()).collect();
    order.sort_by_key(|&i| lens[i]);

    let max_rows = max_rows.max(1);
    let max_tokens = max_tokens.max(1);
    let mut batches: Vec<Vec<usize>> = Vec::new();
    let mut cur: Vec<usize> = Vec::new();
    let mut cur_max = 0usize;
    for &i in &order {
        let l = lens[i].max(1);
        let new_max = cur_max.max(l);
        let would_pad = new_max * (cur.len() + 1);
        // Start a new batch when adding this row would exceed either cap — but
        // never emit an empty batch (a lone oversized row still goes in).
        if !cur.is_empty() && (cur.len() >= max_rows || would_pad > max_tokens) {
            batches.push(std::mem::take(&mut cur));
            cur_max = 0;
        }
        cur_max = cur_max.max(l);
        cur.push(i);
    }
    if !cur.is_empty() {
        batches.push(cur);
    }
    batches
}

impl Embedder for JinaEmbedder {
    fn profile(&self) -> ModelProfile {
        es::profile()
    }

    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
        // tract runs one forward per call single-threaded. Two levers give CPU
        // throughput here: (1) each forward runs a whole BATCH of texts at shape
        // (B, max_len) — the ONNX batch dim is dynamic, so this is one GEMM-heavy
        // pass instead of B tiny ones; (2) batches are spread across cores. The
        // model + tokenizer are shared read-only across scoped threads (no clone
        // of the graph). Output is reassembled in the original `texts` order, so
        // it is identical to the one-at-a-time map.
        let n = texts.len();
        if n == 0 {
            return Ok(Vec::new());
        }
        if n == 1 {
            return Ok(vec![self.embed_one(&texts[0])?]);
        }

        // Tokenize once (cheap vs the forward); lengths drive the bucketing.
        let toks: Vec<Vec<i64>> = texts
            .iter()
            .map(|t| self.tokenize(t))
            .collect::<Result<_>>()?;
        let lens: Vec<usize> = toks.iter().map(|t| t.len()).collect();

        let max_rows = env_usize("NORNIR_EMBED_BATCH_ROWS", MAX_BATCH_ROWS);
        let max_tokens = env_usize("NORNIR_EMBED_BATCH_TOKENS", MAX_BATCH_TOKENS);
        let batches = plan_batches(&lens, max_rows, max_tokens);

        let threads = std::thread::available_parallelism()
            .map(|x| x.get())
            .unwrap_or(1)
            .min(batches.len());

        // Run each batch as one forward; distribute batches across cores,
        // heaviest-first by padded-token weight so a few big batches don't
        // strand a worker.
        let batch_results: Vec<Result<Vec<(usize, Vec<f32>)>>> = if threads <= 1 {
            batches.iter().map(|b| self.forward_batch(b, &toks)).collect()
        } else {
            znippy_zoomies::gatling_forkjoin::gatling_map_balanced(
                &batches,
                threads,
                1,
                |b: &Vec<usize>| {
                    let pad = b.iter().map(|&i| lens[i]).max().unwrap_or(0);
                    (b.len() * pad.max(1)) as u64
                },
                |_, b| self.forward_batch(b, &toks),
            )
        };

        // Scatter each batch's rows back to their original positions.
        let mut out: Vec<Vec<f32>> = vec![Vec::new(); n];
        for br in batch_results {
            for (i, v) in br? {
                out[i] = v;
            }
        }
        Ok(out)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Bucketing is pure over token lengths — no model needed. It sorts ascending
    /// (so similar lengths batch together), caps rows at `max_rows`, caps padded
    /// tokens (`rows * running_max_len`) at `max_tokens`, covers every index
    /// exactly once, and lets a lone oversized text form its own batch.
    #[test]
    fn plan_batches_buckets_by_length_and_respects_caps() {
        // 7 texts, mixed lengths. rows cap 3, token budget 12.
        let lens = vec![10usize, 2, 9, 1, 3, 8, 2];
        let batches = plan_batches(&lens, 3, 12);

        // Every index appears exactly once.
        let mut seen: Vec<usize> = batches.iter().flatten().copied().collect();
        seen.sort_unstable();
        assert_eq!(seen, (0..lens.len()).collect::<Vec<_>>());

        for b in &batches {
            assert!(!b.is_empty(), "no empty batch");
            assert!(b.len() <= 3, "row cap honored: {b:?}");
            let pad = b.iter().map(|&i| lens[i]).max().unwrap();
            // Either within budget, or a single row that alone exceeds it.
            assert!(
                pad * b.len() <= 12 || b.len() == 1,
                "token budget honored (or lone oversized row): pad {pad} * {} rows",
                b.len()
            );
            // Ascending-sorted input ⇒ within a batch lengths are non-decreasing.
            let ls: Vec<usize> = b.iter().map(|&i| lens[i]).collect();
            let mut sorted = ls.clone();
            sorted.sort_unstable();
            assert_eq!(ls, sorted, "batch lengths are contiguous in sorted order");
        }

        // The lone long text (len 10) can't share a 12-token budget with anyone
        // (2 rows * 10 = 20 > 12), so it sits in a 1-row batch.
        assert!(
            batches.iter().any(|b| b.len() == 1 && lens[b[0]] == 10),
            "the len-10 text forms its own batch: {batches:?}"
        );
    }

    /// EXACTNESS GUARD (real model). Embedding a set of varied-length texts
    /// (including duplicates and lengths that force multi-row buckets + padding)
    /// via the batched [`JinaEmbedder::embed`] MUST equal running `embed_one`
    /// one-at-a-time — bit-identical (`to_bits()`). The attention mask makes each
    /// real token's hidden state independent of padding, and batching leaves the
    /// matmul contraction dim unchanged, so no float ops are reordered. A tiny
    /// batch budget forces several multi-row, padded batches. Ignored by default
    /// (needs the staged model); run with:
    ///   `cargo test --features embed-tract -- --ignored batched_equals_one_at_a_time`
    #[test]
    #[ignore = "loads the real ONNX model (needs build-time weight cache); run with --features embed-tract -- --ignored"]
    fn batched_equals_one_at_a_time() {
        let e = JinaEmbedder::load().expect("load model");

        // Varied lengths incl. exact duplicates and near-duplicates; short and
        // long so buckets both pad and split.
        let texts: Vec<String> = vec![
            "fn a() {}".into(),
            "fn add(a: i32, b: i32) -> i32 { a + b }".into(),
            "x".into(),
            "the quick brown fox jumps over the lazy dog".into(),
            "fn add(a: i32, b: i32) -> i32 { a + b }".into(), // dup of [1]
            "pub struct Foo { bar: Vec<u8>, baz: Option<String>, qux: [f64; 16] }".into(),
            "y".into(),
            "// a fairly long comment line that tokenizes to a good handful of tokens indeed".into(),
            "let z = compute_the_thing(alpha, beta, gamma, delta, epsilon, zeta);".into(),
            "fn a() {}".into(), // dup of [0]
        ];

        // Reference: strictly one-at-a-time forwards at shape (1, n).
        let reference: Vec<Vec<f32>> =
            texts.iter().map(|t| e.embed_one(t).unwrap()).collect();

        // Force multi-row, padded batches: tiny caps.
        // TODO: Audit that the environment access only happens in single-threaded code.
        unsafe {
            std::env::set_var("NORNIR_EMBED_BATCH_ROWS", "4");
            std::env::set_var("NORNIR_EMBED_BATCH_TOKENS", "64");
        }
        let batched = e.embed(&texts).unwrap();

        assert_eq!(batched.len(), reference.len());
        let mut max_abs = 0.0f32;
        let mut exact = true;
        for (idx, (r, b)) in reference.iter().zip(&batched).enumerate() {
            assert_eq!(r.len(), b.len(), "dim mismatch at {idx}");
            for (k, (&rv, &bv)) in r.iter().zip(b).enumerate() {
                if rv.to_bits() != bv.to_bits() {
                    exact = false;
                }
                max_abs = max_abs.max((rv - bv).abs());
                // Hard cap even if not bit-identical: never accept larger drift.
                assert!(
                    (rv - bv).abs() <= 1e-5,
                    "batched vs one-at-a-time drift too large at text {idx} dim {k}: \
                     {rv} vs {bv}"
                );
            }
        }
        eprintln!(
            "batched_equals_one_at_a_time: bit_identical={exact} max_abs_diff={max_abs:e}"
        );
        // Duplicate texts must produce identical vectors to each other.
        assert_eq!(batched[0], batched[9], "dup text [0]==[9]");
        assert_eq!(batched[1], batched[4], "dup text [1]==[4]");
    }

    /// Loads the real ONNX model (needs the build-time weight cache). Ignored
    /// by default so the normal test run stays fast and offline; run with
    /// `cargo test --features embed-tract -- --ignored embed`.
    #[test]
    #[ignore = "loads the real ONNX model (needs build-time weight cache); run with --features embed-tract -- --ignored"]
    fn loads_and_embeds() {
        let e = JinaEmbedder::load().expect("load model");
        let p = e.profile();
        assert_eq!(p.dim, es::dim());
        assert_eq!(p.model_name, es::model_name());

        let v = e.embed(&["fn main() {}".to_string()]).unwrap();
        assert_eq!(v.len(), 1);
        assert_eq!(v[0].len(), es::dim());
        let norm: f32 = v[0].iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!((norm - 1.0).abs() < 1e-3, "norm {norm}");

        // Code semantics: two equivalent Rust fns are closer to each other
        // than either is to unrelated prose.
        let a = e.embed(&["fn add(a: i32, b: i32) -> i32 { a + b }".into()]).unwrap();
        let b = e.embed(&["pub fn sum(x: i32, y: i32) -> i32 { x + y }".into()]).unwrap();
        let c = e.embed(&["the quick brown fox jumps over the lazy dog".into()]).unwrap();
        let dot = |x: &[f32], y: &[f32]| x.iter().zip(y).map(|(p, q)| p * q).sum::<f32>();
        assert!(
            dot(&a[0], &b[0]) > dot(&a[0], &c[0]),
            "two fns ({}) closer than fn-vs-prose ({})",
            dot(&a[0], &b[0]),
            dot(&a[0], &c[0])
        );
    }

    /// Full pipeline with the real model: index 3 code files into the
    /// warehouse, then a natural-language query retrieves the right one.
    #[test]
    #[ignore = "real model + full warehouse index/search pipeline; run with --features embed-tract -- --ignored"]
    fn end_to_end_semantic_search() {
        use crate::vector::chunk::ChunkOptions;
        use crate::vector::store::{index_repo, search, RepoRef};
        use crate::warehouse::iceberg::IcebergWarehouse;

        let dir = tempfile::tempdir().unwrap();
        let wh = IcebergWarehouse::open(dir.path()).unwrap();
        let embedder = JinaEmbedder::load().unwrap();

        let files = vec![
            (
                "math.rs".to_string(),
                "pub fn add(a: i32, b: i32) -> i32 { a + b }".to_string(),
            ),
            (
                "io.rs".to_string(),
                "fn read_file(path: &str) -> std::io::Result<String> { std::fs::read_to_string(path) }".to_string(),
            ),
            (
                "net.rs".to_string(),
                "async fn fetch(url: &str) -> Result<String> { http_get(url).await }".to_string(),
            ),
        ];
        let snap = index_repo(
            &wh,
            &RepoRef {
                workspace: "ws",
                repo: "demo",
                git_sha: "sha1",
                branch: "main",
                complete: true,
            },
            &files,
            &ChunkOptions::default(),
            &embedder,
        )
        .unwrap();
        assert_eq!(snap.new_vectors, 3);

        let mp = embedder.profile().id();
        let q = embedder
            .embed(&["a function that adds two integers together".to_string()])
            .unwrap();
        let hits = search(&wh, "demo", Some("sha1"), &mp, &q[0], 3).unwrap();
        assert_eq!(
            hits[0].1.file, "math.rs",
            "NL query about adding integers should retrieve the add fn"
        );
    }
}