use std::sync::Arc;
use anyhow::{Context, Result};
use tokenizers::Tokenizer;
use tract_onnx::prelude::*;
use super::embed_support as es;
use super::store::{Embedder, ModelProfile};
type OnnxModel = TypedRunnableModel;
pub const MAX_BATCH_ROWS: usize = 16;
pub const MAX_BATCH_TOKENS: usize = 8192;
fn env_usize(var: &str, default: usize) -> usize {
std::env::var(var)
.ok()
.and_then(|v| v.trim().parse::<usize>().ok())
.filter(|&x| x > 0)
.unwrap_or(default)
}
pub struct JinaEmbedder {
model: Arc<OnnxModel>,
tokenizer: Tokenizer,
}
impl JinaEmbedder {
pub fn load() -> Result<Self> {
let dir = es::model_dir();
let tokenizer = Tokenizer::from_file(dir.join("tokenizer.json"))
.map_err(|e| anyhow::anyhow!("load tokenizer.json: {e}"))?;
let onnx = dir.join("model.onnx");
let model = tract_onnx::onnx()
.model_for_path(&onnx)
.with_context(|| format!("load onnx {}", onnx.display()))?
.into_optimized()
.context("optimize onnx graph")?
.into_runnable()
.context("make onnx graph runnable")?;
Ok(Self { model, tokenizer })
}
fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
let enc = self
.tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!("tokenize: {e}"))?;
let (ids, mask) = es::prepare_tokens(enc.get_ids(), es::max_tokens());
let n = ids.len();
let input_ids = tract_ndarray::Array2::from_shape_vec((1, n), ids)?.into_tensor();
let attn = tract_ndarray::Array2::from_shape_vec((1, n), mask)?.into_tensor();
let outputs = self
.model
.run(tvec!(input_ids.into(), attn.into()))
.context("onnx forward")?;
let hidden = outputs[0]
.to_plain_array_view::<f32>()
.context("read hidden state")?; let dim = es::dim();
let shape = hidden.shape();
anyhow::ensure!(
shape.len() == 3 && shape[2] == dim,
"unexpected output shape {shape:?} (expected last dim {dim})"
);
let flat = hidden.as_slice().context("hidden state not contiguous")?;
Ok(es::pool_and_normalize(flat, n, dim))
}
fn tokenize(&self, text: &str) -> Result<Vec<i64>> {
let enc = self
.tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!("tokenize: {e}"))?;
Ok(es::prepare_ids(enc.get_ids(), es::max_tokens()))
}
fn forward_batch(&self, batch: &[usize], toks: &[Vec<i64>]) -> Result<Vec<(usize, Vec<f32>)>> {
let b = batch.len();
debug_assert!(b > 0, "empty batch");
let pad_len = batch.iter().map(|&i| toks[i].len()).max().unwrap_or(1).max(1);
let dim = es::dim();
let mut ids_t = Tensor::zero::<i64>(&[b, pad_len])?;
let mut mask_t = Tensor::zero::<i64>(&[b, pad_len])?;
{
let ids_s = unsafe { ids_t.as_slice_mut_unchecked::<i64>() };
let mask_s = unsafe { mask_t.as_slice_mut_unchecked::<i64>() };
for (row, &i) in batch.iter().enumerate() {
let base = row * pad_len;
for (j, &id) in toks[i].iter().enumerate() {
ids_s[base + j] = id;
mask_s[base + j] = 1;
}
}
}
let outputs = self
.model
.run(tvec!(ids_t.into(), mask_t.into()))
.context("onnx batched forward")?;
let hidden = outputs[0]
.to_plain_array_view::<f32>()
.context("read hidden state")?; let shape = hidden.shape();
anyhow::ensure!(
shape.len() == 3 && shape[0] == b && shape[1] == pad_len && shape[2] == dim,
"unexpected batched output shape {shape:?} (expected [{b}, {pad_len}, {dim}])"
);
let flat = hidden.as_slice().context("hidden state not contiguous")?;
let mut out = Vec::with_capacity(b);
for (row, &i) in batch.iter().enumerate() {
let n_real = toks[i].len().max(1);
let start = row * pad_len * dim;
let row_hidden = &flat[start..start + n_real * dim];
out.push((i, es::pool_and_normalize(row_hidden, n_real, dim)));
}
Ok(out)
}
}
fn plan_batches(lens: &[usize], max_rows: usize, max_tokens: usize) -> Vec<Vec<usize>> {
let mut order: Vec<usize> = (0..lens.len()).collect();
order.sort_by_key(|&i| lens[i]);
let max_rows = max_rows.max(1);
let max_tokens = max_tokens.max(1);
let mut batches: Vec<Vec<usize>> = Vec::new();
let mut cur: Vec<usize> = Vec::new();
let mut cur_max = 0usize;
for &i in &order {
let l = lens[i].max(1);
let new_max = cur_max.max(l);
let would_pad = new_max * (cur.len() + 1);
if !cur.is_empty() && (cur.len() >= max_rows || would_pad > max_tokens) {
batches.push(std::mem::take(&mut cur));
cur_max = 0;
}
cur_max = cur_max.max(l);
cur.push(i);
}
if !cur.is_empty() {
batches.push(cur);
}
batches
}
impl Embedder for JinaEmbedder {
fn profile(&self) -> ModelProfile {
es::profile()
}
fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let n = texts.len();
if n == 0 {
return Ok(Vec::new());
}
if n == 1 {
return Ok(vec![self.embed_one(&texts[0])?]);
}
let toks: Vec<Vec<i64>> = texts
.iter()
.map(|t| self.tokenize(t))
.collect::<Result<_>>()?;
let lens: Vec<usize> = toks.iter().map(|t| t.len()).collect();
let max_rows = env_usize("NORNIR_EMBED_BATCH_ROWS", MAX_BATCH_ROWS);
let max_tokens = env_usize("NORNIR_EMBED_BATCH_TOKENS", MAX_BATCH_TOKENS);
let batches = plan_batches(&lens, max_rows, max_tokens);
let threads = std::thread::available_parallelism()
.map(|x| x.get())
.unwrap_or(1)
.min(batches.len());
let batch_results: Vec<Result<Vec<(usize, Vec<f32>)>>> = if threads <= 1 {
batches.iter().map(|b| self.forward_batch(b, &toks)).collect()
} else {
znippy_zoomies::gatling_forkjoin::gatling_map_balanced(
&batches,
threads,
1,
|b: &Vec<usize>| {
let pad = b.iter().map(|&i| lens[i]).max().unwrap_or(0);
(b.len() * pad.max(1)) as u64
},
|_, b| self.forward_batch(b, &toks),
)
};
let mut out: Vec<Vec<f32>> = vec![Vec::new(); n];
for br in batch_results {
for (i, v) in br? {
out[i] = v;
}
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn plan_batches_buckets_by_length_and_respects_caps() {
let lens = vec![10usize, 2, 9, 1, 3, 8, 2];
let batches = plan_batches(&lens, 3, 12);
let mut seen: Vec<usize> = batches.iter().flatten().copied().collect();
seen.sort_unstable();
assert_eq!(seen, (0..lens.len()).collect::<Vec<_>>());
for b in &batches {
assert!(!b.is_empty(), "no empty batch");
assert!(b.len() <= 3, "row cap honored: {b:?}");
let pad = b.iter().map(|&i| lens[i]).max().unwrap();
assert!(
pad * b.len() <= 12 || b.len() == 1,
"token budget honored (or lone oversized row): pad {pad} * {} rows",
b.len()
);
let ls: Vec<usize> = b.iter().map(|&i| lens[i]).collect();
let mut sorted = ls.clone();
sorted.sort_unstable();
assert_eq!(ls, sorted, "batch lengths are contiguous in sorted order");
}
assert!(
batches.iter().any(|b| b.len() == 1 && lens[b[0]] == 10),
"the len-10 text forms its own batch: {batches:?}"
);
}
#[test]
#[ignore = "loads the real ONNX model (needs build-time weight cache); run with --features embed-tract -- --ignored"]
fn batched_equals_one_at_a_time() {
let e = JinaEmbedder::load().expect("load model");
let texts: Vec<String> = vec![
"fn a() {}".into(),
"fn add(a: i32, b: i32) -> i32 { a + b }".into(),
"x".into(),
"the quick brown fox jumps over the lazy dog".into(),
"fn add(a: i32, b: i32) -> i32 { a + b }".into(), "pub struct Foo { bar: Vec<u8>, baz: Option<String>, qux: [f64; 16] }".into(),
"y".into(),
"// a fairly long comment line that tokenizes to a good handful of tokens indeed".into(),
"let z = compute_the_thing(alpha, beta, gamma, delta, epsilon, zeta);".into(),
"fn a() {}".into(), ];
let reference: Vec<Vec<f32>> =
texts.iter().map(|t| e.embed_one(t).unwrap()).collect();
unsafe {
std::env::set_var("NORNIR_EMBED_BATCH_ROWS", "4");
std::env::set_var("NORNIR_EMBED_BATCH_TOKENS", "64");
}
let batched = e.embed(&texts).unwrap();
assert_eq!(batched.len(), reference.len());
let mut max_abs = 0.0f32;
let mut exact = true;
for (idx, (r, b)) in reference.iter().zip(&batched).enumerate() {
assert_eq!(r.len(), b.len(), "dim mismatch at {idx}");
for (k, (&rv, &bv)) in r.iter().zip(b).enumerate() {
if rv.to_bits() != bv.to_bits() {
exact = false;
}
max_abs = max_abs.max((rv - bv).abs());
assert!(
(rv - bv).abs() <= 1e-5,
"batched vs one-at-a-time drift too large at text {idx} dim {k}: \
{rv} vs {bv}"
);
}
}
eprintln!(
"batched_equals_one_at_a_time: bit_identical={exact} max_abs_diff={max_abs:e}"
);
assert_eq!(batched[0], batched[9], "dup text [0]==[9]");
assert_eq!(batched[1], batched[4], "dup text [1]==[4]");
}
#[test]
#[ignore = "loads the real ONNX model (needs build-time weight cache); run with --features embed-tract -- --ignored"]
fn loads_and_embeds() {
let e = JinaEmbedder::load().expect("load model");
let p = e.profile();
assert_eq!(p.dim, es::dim());
assert_eq!(p.model_name, es::model_name());
let v = e.embed(&["fn main() {}".to_string()]).unwrap();
assert_eq!(v.len(), 1);
assert_eq!(v[0].len(), es::dim());
let norm: f32 = v[0].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-3, "norm {norm}");
let a = e.embed(&["fn add(a: i32, b: i32) -> i32 { a + b }".into()]).unwrap();
let b = e.embed(&["pub fn sum(x: i32, y: i32) -> i32 { x + y }".into()]).unwrap();
let c = e.embed(&["the quick brown fox jumps over the lazy dog".into()]).unwrap();
let dot = |x: &[f32], y: &[f32]| x.iter().zip(y).map(|(p, q)| p * q).sum::<f32>();
assert!(
dot(&a[0], &b[0]) > dot(&a[0], &c[0]),
"two fns ({}) closer than fn-vs-prose ({})",
dot(&a[0], &b[0]),
dot(&a[0], &c[0])
);
}
#[test]
#[ignore = "real model + full warehouse index/search pipeline; run with --features embed-tract -- --ignored"]
fn end_to_end_semantic_search() {
use crate::vector::chunk::ChunkOptions;
use crate::vector::store::{index_repo, search, RepoRef};
use crate::warehouse::iceberg::IcebergWarehouse;
let dir = tempfile::tempdir().unwrap();
let wh = IcebergWarehouse::open(dir.path()).unwrap();
let embedder = JinaEmbedder::load().unwrap();
let files = vec![
(
"math.rs".to_string(),
"pub fn add(a: i32, b: i32) -> i32 { a + b }".to_string(),
),
(
"io.rs".to_string(),
"fn read_file(path: &str) -> std::io::Result<String> { std::fs::read_to_string(path) }".to_string(),
),
(
"net.rs".to_string(),
"async fn fetch(url: &str) -> Result<String> { http_get(url).await }".to_string(),
),
];
let snap = index_repo(
&wh,
&RepoRef {
workspace: "ws",
repo: "demo",
git_sha: "sha1",
branch: "main",
complete: true,
},
&files,
&ChunkOptions::default(),
&embedder,
)
.unwrap();
assert_eq!(snap.new_vectors, 3);
let mp = embedder.profile().id();
let q = embedder
.embed(&["a function that adds two integers together".to_string()])
.unwrap();
let hits = search(&wh, "demo", Some("sha1"), &mp, &q[0], 3).unwrap();
assert_eq!(
hits[0].1.file, "math.rs",
"NL query about adding integers should retrieve the add fn"
);
}
}