use std::sync::Arc;
use anyhow::{Context, Result};
use tokenizers::Tokenizer;
use tract_onnx::prelude::*;
use super::embed_support as es;
use super::store::{Embedder, ModelProfile};
type OnnxModel = TypedRunnableModel;
pub struct JinaEmbedder {
model: Arc<OnnxModel>,
tokenizer: Tokenizer,
}
impl JinaEmbedder {
pub fn load() -> Result<Self> {
let dir = es::model_dir();
let tokenizer = Tokenizer::from_file(dir.join("tokenizer.json"))
.map_err(|e| anyhow::anyhow!("load tokenizer.json: {e}"))?;
let onnx = dir.join("model.onnx");
let model = tract_onnx::onnx()
.model_for_path(&onnx)
.with_context(|| format!("load onnx {}", onnx.display()))?
.into_optimized()
.context("optimize onnx graph")?
.into_runnable()
.context("make onnx graph runnable")?;
Ok(Self { model, tokenizer })
}
fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
let enc = self
.tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!("tokenize: {e}"))?;
let (ids, mask) = es::prepare_tokens(enc.get_ids(), es::max_tokens());
let n = ids.len();
let input_ids = tract_ndarray::Array2::from_shape_vec((1, n), ids)?.into_tensor();
let attn = tract_ndarray::Array2::from_shape_vec((1, n), mask)?.into_tensor();
let outputs = self
.model
.run(tvec!(input_ids.into(), attn.into()))
.context("onnx forward")?;
let hidden = outputs[0]
.to_plain_array_view::<f32>()
.context("read hidden state")?; let dim = es::dim();
let shape = hidden.shape();
anyhow::ensure!(
shape.len() == 3 && shape[2] == dim,
"unexpected output shape {shape:?} (expected last dim {dim})"
);
let flat = hidden.as_slice().context("hidden state not contiguous")?;
Ok(es::pool_and_normalize(flat, n, dim))
}
}
impl Embedder for JinaEmbedder {
fn profile(&self) -> ModelProfile {
es::profile()
}
fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let n = texts.len();
if n <= 1 {
return texts.iter().map(|t| self.embed_one(t)).collect();
}
let threads = std::thread::available_parallelism()
.map(|x| x.get())
.unwrap_or(1)
.min(n);
if threads <= 1 {
return texts.iter().map(|t| self.embed_one(t)).collect();
}
let mut out: Vec<Vec<f32>> = vec![Vec::new(); n];
let chunk = n.div_ceil(threads);
let err = std::sync::Mutex::new(None::<anyhow::Error>);
std::thread::scope(|s| {
for (texts_part, out_part) in texts.chunks(chunk).zip(out.chunks_mut(chunk)) {
let err = &err;
s.spawn(move || {
for (t, slot) in texts_part.iter().zip(out_part.iter_mut()) {
match self.embed_one(t) {
Ok(v) => *slot = v,
Err(e) => {
*err.lock().unwrap() = Some(e);
return;
}
}
}
});
}
});
if let Some(e) = err.into_inner().unwrap() {
return Err(e);
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore]
fn loads_and_embeds() {
let e = JinaEmbedder::load().expect("load model");
let p = e.profile();
assert_eq!(p.dim, es::dim());
assert_eq!(p.model_name, es::model_name());
let v = e.embed(&["fn main() {}".to_string()]).unwrap();
assert_eq!(v.len(), 1);
assert_eq!(v[0].len(), es::dim());
let norm: f32 = v[0].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-3, "norm {norm}");
let a = e.embed(&["fn add(a: i32, b: i32) -> i32 { a + b }".into()]).unwrap();
let b = e.embed(&["pub fn sum(x: i32, y: i32) -> i32 { x + y }".into()]).unwrap();
let c = e.embed(&["the quick brown fox jumps over the lazy dog".into()]).unwrap();
let dot = |x: &[f32], y: &[f32]| x.iter().zip(y).map(|(p, q)| p * q).sum::<f32>();
assert!(
dot(&a[0], &b[0]) > dot(&a[0], &c[0]),
"two fns ({}) closer than fn-vs-prose ({})",
dot(&a[0], &b[0]),
dot(&a[0], &c[0])
);
}
#[test]
#[ignore]
fn end_to_end_semantic_search() {
use crate::vector::chunk::ChunkOptions;
use crate::vector::store::{index_repo, search, RepoRef};
use crate::warehouse::iceberg::IcebergWarehouse;
let dir = tempfile::tempdir().unwrap();
let wh = IcebergWarehouse::open(dir.path()).unwrap();
let embedder = JinaEmbedder::load().unwrap();
let files = vec![
(
"math.rs".to_string(),
"pub fn add(a: i32, b: i32) -> i32 { a + b }".to_string(),
),
(
"io.rs".to_string(),
"fn read_file(path: &str) -> std::io::Result<String> { std::fs::read_to_string(path) }".to_string(),
),
(
"net.rs".to_string(),
"async fn fetch(url: &str) -> Result<String> { http_get(url).await }".to_string(),
),
];
let snap = index_repo(
&wh,
&RepoRef {
workspace: "ws",
repo: "demo",
git_sha: "sha1",
branch: "main",
complete: true,
},
&files,
&ChunkOptions::default(),
&embedder,
)
.unwrap();
assert_eq!(snap.new_vectors, 3);
let mp = embedder.profile().id();
let q = embedder
.embed(&["a function that adds two integers together".to_string()])
.unwrap();
let hits = search(&wh, "demo", Some("sha1"), &mp, &q[0], 3).unwrap();
assert_eq!(
hits[0].1.file, "math.rs",
"NL query about adding integers should retrieve the add fn"
);
}
}