use grafeo_common::utils::error::{Error, Result};
use ort::session::Session;
use parking_lot::Mutex;
use std::path::Path;
use tokenizers::Tokenizer;
use super::EmbeddingModel;
pub struct OnnxEmbeddingModel {
session: Mutex<Session>,
tokenizer: Tokenizer,
dimensions: usize,
name: String,
batch_size: usize,
}
#[allow(unsafe_code)]
unsafe impl Send for OnnxEmbeddingModel {}
#[allow(unsafe_code)]
unsafe impl Sync for OnnxEmbeddingModel {}
impl std::fmt::Debug for OnnxEmbeddingModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OnnxEmbeddingModel")
.field("name", &self.name)
.field("dimensions", &self.dimensions)
.field("batch_size", &self.batch_size)
.finish()
}
}
impl OnnxEmbeddingModel {
#[cfg(feature = "embed")]
pub fn from_config(config: super::EmbeddingModelConfig) -> Result<Self> {
Self::from_config_with_options(config, super::EmbeddingOptions::default())
}
#[cfg(feature = "embed")]
pub fn from_config_with_options(
config: super::EmbeddingModelConfig,
options: super::EmbeddingOptions,
) -> Result<Self> {
let resolved = super::download::resolve(&config)?;
let mut session = Session::builder()
.map_err(|e| Error::Internal(format!("Failed to create ONNX session builder: {e}")))?
.with_inter_threads(options.inter_threads)
.map_err(|e| Error::Internal(format!("Failed to set inter threads: {e}")))?
.with_intra_threads(options.intra_threads)
.map_err(|e| Error::Internal(format!("Failed to set intra threads: {e}")))?
.commit_from_file(&resolved.model_path)
.map_err(|e| Error::Internal(format!("Failed to load ONNX model: {e}")))?;
let tokenizer = Tokenizer::from_file(&resolved.tokenizer_path)
.map_err(|e| Error::Internal(format!("Failed to load tokenizer: {e}")))?;
let dimensions = Self::probe_dimensions(&mut session, &tokenizer)?;
Ok(Self {
session: Mutex::new(session),
tokenizer,
dimensions,
name: resolved.name,
batch_size: options.batch_size,
})
}
pub fn from_files(
name: impl Into<String>,
model_path: impl AsRef<Path>,
tokenizer_path: impl AsRef<Path>,
) -> Result<Self> {
let mut session = Session::builder()
.map_err(|e| Error::Internal(format!("Failed to create ONNX session builder: {e}")))?
.with_inter_threads(1)
.map_err(|e| Error::Internal(format!("Failed to set inter threads: {e}")))?
.with_intra_threads(1)
.map_err(|e| Error::Internal(format!("Failed to set intra threads: {e}")))?
.commit_from_file(model_path.as_ref())
.map_err(|e| Error::Internal(format!("Failed to load ONNX model: {e}")))?;
let tokenizer = Tokenizer::from_file(tokenizer_path.as_ref())
.map_err(|e| Error::Internal(format!("Failed to load tokenizer: {e}")))?;
let dimensions = Self::probe_dimensions(&mut session, &tokenizer)?;
Ok(Self {
session: Mutex::new(session),
tokenizer,
dimensions,
name: name.into(),
batch_size: 32,
})
}
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
fn make_i64_tensor(
shape: Vec<usize>,
data: Vec<i64>,
) -> Result<ort::value::Value<ort::value::TensorValueType<i64>>> {
ort::value::Tensor::from_array((shape, data.into_boxed_slice()))
.map_err(|e| Error::Internal(format!("Failed to create tensor: {e}")))
}
fn probe_dimensions(session: &mut Session, tokenizer: &Tokenizer) -> Result<usize> {
let encoding = tokenizer
.encode("hello", false)
.map_err(|e| Error::Internal(format!("Tokenizer encode error: {e}")))?;
let ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
let mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&m| m as i64)
.collect();
let len = ids.len();
let input_ids = Self::make_i64_tensor(vec![1, len], ids)?;
let attention_mask = Self::make_i64_tensor(vec![1, len], mask)?;
let outputs = session
.run(ort::inputs![input_ids, attention_mask])
.map_err(|e| Error::Internal(format!("Model inference failed: {e}")))?;
let output = &outputs[0];
let (shape, _data) = output
.try_extract_tensor::<f32>()
.map_err(|e| Error::Internal(format!("Failed to extract output tensor: {e}")))?;
let dims = *shape
.last()
.ok_or_else(|| Error::Internal("Model output has no dimensions".to_string()))?;
Ok(dims as usize)
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let text_strings: Vec<String> = texts.iter().map(|t| (*t).to_string()).collect();
let encodings = self
.tokenizer
.encode_batch(text_strings, false)
.map_err(|e| Error::Internal(format!("Tokenizer batch encode error: {e}")))?;
let max_len = encodings
.iter()
.map(|e| e.get_ids().len())
.max()
.unwrap_or(0);
let batch = texts.len();
let mut all_ids = vec![0i64; batch * max_len];
let mut all_mask = vec![0i64; batch * max_len];
for (i, enc) in encodings.iter().enumerate() {
for (j, &id) in enc.get_ids().iter().enumerate() {
all_ids[i * max_len + j] = id as i64;
}
for (j, &m) in enc.get_attention_mask().iter().enumerate() {
all_mask[i * max_len + j] = m as i64;
}
}
let input_ids = Self::make_i64_tensor(vec![batch, max_len], all_ids)?;
let attention_mask = Self::make_i64_tensor(vec![batch, max_len], all_mask)?;
let mut session = self.session.lock();
let outputs = session
.run(ort::inputs![input_ids, attention_mask])
.map_err(|e| Error::Internal(format!("Model inference failed: {e}")))?;
let output = &outputs[0];
let (shape, data) = output
.try_extract_tensor::<f32>()
.map_err(|e| Error::Internal(format!("Failed to extract output tensor: {e}")))?;
let shape_usize: Vec<usize> = shape.iter().map(|&s| s as usize).collect();
let result = if shape_usize.len() == 3 {
let seq_len = shape_usize[1];
let dims = shape_usize[2];
(0..batch)
.map(|b| {
let offset = b * seq_len * dims; data[offset..offset + dims].to_vec()
})
.collect()
} else if shape_usize.len() == 2 {
let dims = shape_usize[1];
(0..batch)
.map(|b| {
let offset = b * dims;
data[offset..offset + dims].to_vec()
})
.collect()
} else {
return Err(Error::Internal(format!(
"Unexpected output shape: {shape_usize:?}"
)));
};
Ok(result)
}
}
impl EmbeddingModel for OnnxEmbeddingModel {
fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
for chunk in texts.chunks(self.batch_size) {
results.extend(self.embed_batch(chunk)?);
}
Ok(results)
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
&self.name
}
}