use std::path::Path;
use std::sync::Mutex;
use anyhow::{Context, Result};
use ort::session::Session;
use ort::value::Tensor;
use tokenizers::Tokenizer;
use super::embed_support as es;
use super::store::{Embedder, ModelProfile};
pub struct OrtEmbedder {
session: Mutex<Session>,
tokenizer: Tokenizer,
gpu_requested: bool,
}
impl OrtEmbedder {
pub fn load() -> Result<Self> {
Self::load_with(true)
}
pub fn load_cpu() -> Result<Self> {
Self::load_with(false)
}
pub fn load_with(prefer_gpu: bool) -> Result<Self> {
let dir = Path::new(es::MODEL_DIR);
let tokenizer = Tokenizer::from_file(dir.join("tokenizer.json"))
.map_err(|e| anyhow::anyhow!("load tokenizer.json: {e}"))?;
let mut builder = Session::builder().context("ort session builder")?;
if prefer_gpu {
let _ = super::cuda::ensure();
builder = builder
.with_execution_providers([ort::ep::CUDA::default().build()])
.map_err(|e| anyhow::anyhow!("register CUDA execution provider: {e}"))?;
}
let session = builder
.commit_from_file(dir.join("model.onnx"))
.context("commit onnx session")?;
Ok(Self {
session: Mutex::new(session),
tokenizer,
gpu_requested: prefer_gpu,
})
}
pub fn gpu_requested(&self) -> bool {
self.gpu_requested
}
pub fn cuda_ready(&self) -> bool {
let p = super::cuda::ensure();
p.cudnn && p.loaded.iter().any(|s| s.starts_with("libcudart"))
}
fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
let enc = self
.tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!("tokenize: {e}"))?;
let (ids, mask) = es::prepare_tokens(enc.get_ids());
let n = ids.len();
let shape = [1_i64, n as i64];
let id_tensor = Tensor::from_array((shape, ids)).context("build input_ids tensor")?;
let mask_tensor = Tensor::from_array((shape, mask)).context("build attention_mask tensor")?;
let mut session = self.session.lock().unwrap();
let outputs = session
.run(ort::inputs![
"input_ids" => id_tensor,
"attention_mask" => mask_tensor,
])
.context("onnx forward")?;
let (out_shape, data) = outputs[0]
.try_extract_tensor::<f32>()
.context("extract hidden state")?;
let dims = out_shape.as_ref();
anyhow::ensure!(
dims.len() == 3 && dims[2] as usize == es::DIM,
"unexpected output shape {dims:?}"
);
Ok(es::pool_and_normalize(data, n))
}
}
impl Embedder for OrtEmbedder {
fn profile(&self) -> ModelProfile {
es::profile()
}
fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed_one(t)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore]
fn loads_and_embeds_cpu() {
let e = OrtEmbedder::load_cpu().expect("load model");
let p = e.profile();
assert_eq!(p.dim, es::DIM);
assert_eq!(p.model_name, es::MODEL_NAME);
let v = e.embed(&["fn main() {}".to_string()]).unwrap();
assert_eq!(v.len(), 1);
assert_eq!(v[0].len(), es::DIM);
let norm: f32 = v[0].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-3, "norm {norm}");
let a = e.embed(&["fn add(a: i32, b: i32) -> i32 { a + b }".into()]).unwrap();
let b = e.embed(&["pub fn sum(x: i32, y: i32) -> i32 { x + y }".into()]).unwrap();
let c = e.embed(&["the quick brown fox jumps over the lazy dog".into()]).unwrap();
let dot = |x: &[f32], y: &[f32]| x.iter().zip(y).map(|(p, q)| p * q).sum::<f32>();
assert!(
dot(&a[0], &b[0]) > dot(&a[0], &c[0]),
"two fns ({}) closer than fn-vs-prose ({})",
dot(&a[0], &b[0]),
dot(&a[0], &c[0])
);
}
}