use std::path::Path;
use std::sync::Mutex;
use std::borrow::Cow;
use ort::session::{Session, SessionInputValue};
use ort::value::Tensor;
use tokenizers::Tokenizer;
use super::semantic_matcher::SemanticMatcher;
use crate::error::SyaraError;
const DEFAULT_MAX_LENGTH: usize = 256;
const INPUT_IDS: &str = "input_ids";
const ATTENTION_MASK: &str = "attention_mask";
const TOKEN_TYPE_IDS: &str = "token_type_ids";
const LAST_HIDDEN_STATE: &str = "last_hidden_state";
pub struct OnnxEmbeddingMatcher {
session: Mutex<Session>,
tokenizer: Tokenizer,
max_length: usize,
needs_token_type_ids: bool,
}
impl OnnxEmbeddingMatcher {
pub fn from_dir(model_dir: impl AsRef<Path>) -> Result<Self, SyaraError> {
let dir = model_dir.as_ref();
Self::from_paths(dir.join("model.onnx"), dir.join("tokenizer.json"))
}
pub fn from_paths(
model: impl AsRef<Path>,
tokenizer: impl AsRef<Path>,
) -> Result<Self, SyaraError> {
let model_path = model.as_ref();
let tokenizer_path = tokenizer.as_ref();
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
SyaraError::SemanticError(format!(
"failed to load tokenizer from {}: {e}",
tokenizer_path.display()
))
})?;
let session = Session::builder()
.map_err(|e| {
SyaraError::SemanticError(format!(
"failed to build ONNX session: {e}"
))
})?
.commit_from_file(model_path)
.map_err(|e| {
SyaraError::SemanticError(format!(
"failed to load ONNX model from {}: {e}",
model_path.display()
))
})?;
let needs_token_type_ids = session
.inputs()
.iter()
.any(|outlet| outlet.name() == TOKEN_TYPE_IDS);
Ok(Self {
session: Mutex::new(session),
tokenizer,
max_length: DEFAULT_MAX_LENGTH,
needs_token_type_ids,
})
}
pub fn with_max_length(mut self, max_length: usize) -> Self {
self.max_length = max_length;
self
}
}
impl SemanticMatcher for OnnxEmbeddingMatcher {
fn embed(&self, text: &str) -> Result<Vec<f32>, SyaraError> {
if text.is_empty() {
return Ok(vec![]);
}
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| SyaraError::SemanticError(format!("tokenize: {e}")))?;
let ids_src = encoding.get_ids();
let mask_src = encoding.get_attention_mask();
let types_src = encoding.get_type_ids();
let seq_len = ids_src.len().min(self.max_length);
if seq_len == 0 {
return Ok(vec![]);
}
let ids: Vec<i64> = ids_src[..seq_len].iter().map(|&v| v as i64).collect();
let mask: Vec<i64> =
mask_src[..seq_len].iter().map(|&v| v as i64).collect();
let shape = [1_usize, seq_len];
let ids_tensor = Tensor::from_array((shape, ids)).map_err(|e| {
SyaraError::SemanticError(format!("input_ids tensor: {e}"))
})?;
let mask_tensor =
Tensor::from_array((shape, mask.clone())).map_err(|e| {
SyaraError::SemanticError(format!("attention_mask tensor: {e}"))
})?;
let mut inputs: Vec<(Cow<'_, str>, SessionInputValue<'_>)> = vec![
(Cow::Borrowed(INPUT_IDS), SessionInputValue::from(ids_tensor)),
(
Cow::Borrowed(ATTENTION_MASK),
SessionInputValue::from(mask_tensor),
),
];
if self.needs_token_type_ids {
let types: Vec<i64> =
types_src[..seq_len].iter().map(|&v| v as i64).collect();
let types_tensor =
Tensor::from_array((shape, types)).map_err(|e| {
SyaraError::SemanticError(format!(
"token_type_ids tensor: {e}"
))
})?;
inputs.push((
Cow::Borrowed(TOKEN_TYPE_IDS),
SessionInputValue::from(types_tensor),
));
}
let mut session = self.session.lock().map_err(|_| {
SyaraError::SemanticError("ONNX session mutex poisoned".into())
})?;
let outputs = session.run(inputs).map_err(|e| {
SyaraError::SemanticError(format!("ONNX run failed: {e}"))
})?;
let output_value = outputs
.get(LAST_HIDDEN_STATE)
.or_else(|| outputs.get(outputs.iter().next()?.0))
.ok_or_else(|| {
SyaraError::SemanticError("ONNX model produced no outputs".into())
})?;
let hidden = output_value.try_extract_array::<f32>().map_err(|e| {
SyaraError::SemanticError(format!(
"failed to extract last_hidden_state: {e}"
))
})?;
let shape = hidden.shape();
if shape.len() != 3 || shape[0] != 1 || shape[1] != seq_len {
return Err(SyaraError::SemanticError(format!(
"unexpected output shape {:?}, expected [1, {seq_len}, H]",
shape
)));
}
let hidden_dim = shape[2];
let mut pooled = vec![0.0_f32; hidden_dim];
let mut mask_sum = 0.0_f32;
for t in 0..seq_len {
let m = mask[t] as f32;
if m == 0.0 {
continue;
}
mask_sum += m;
for (h, out) in pooled.iter_mut().enumerate().take(hidden_dim) {
*out += hidden[[0, t, h]] * m;
}
}
if mask_sum == 0.0 {
return Ok(vec![0.0; hidden_dim]);
}
for v in pooled.iter_mut() {
*v /= mask_sum;
}
let norm: f32 = pooled.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 0.0 {
for v in pooled.iter_mut() {
*v /= norm;
}
}
Ok(pooled)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tokenizer_load_missing_file() {
let result = OnnxEmbeddingMatcher::from_paths(
"/nonexistent/model.onnx",
"/nonexistent/tokenizer.json",
);
match result {
Ok(_) => panic!("loading missing tokenizer must fail"),
Err(SyaraError::SemanticError(msg)) => {
assert!(
msg.contains("tokenizer"),
"expected tokenizer error, got: {msg}"
);
}
Err(other) => panic!("expected SemanticError, got {other:?}"),
}
}
}