use ndarray::{Array2, ArrayView, Dim, Dimension, IxDynImpl, s};
use ort::session::Session;
use ort::value::Value;
use tokenizers::Tokenizer;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Pooling {
Cls,
Mean,
}
pub struct EmbeddingEngine {
tokenizer: Tokenizer,
session: Session,
pooling: Pooling,
need_token_type_ids: bool,
}
impl EmbeddingEngine {
pub fn new(tokenizer: Tokenizer, session: Session, pooling: Pooling) -> Self {
let need_token_type_ids = session.inputs().iter().any(|input| input.name() == "token_type_ids");
Self {
tokenizer,
session,
pooling,
need_token_type_ids,
}
}
pub fn embed<S: AsRef<str>>(&self, texts: &[S], batch_size: usize) -> Result<Vec<Vec<f32>>, EmbedError> {
if texts.is_empty() {
return Ok(Vec::new());
}
let mut all_embeddings = Vec::with_capacity(texts.len());
for batch in texts.chunks(batch_size) {
let batch_embeddings = self.embed_batch(batch)?;
all_embeddings.extend(batch_embeddings);
}
Ok(all_embeddings)
}
fn embed_batch<S: AsRef<str>>(&self, batch: &[S]) -> Result<Vec<Vec<f32>>, EmbedError> {
let inputs: Vec<&str> = batch.iter().map(|t| t.as_ref()).collect();
let encodings = self
.tokenizer
.encode_batch(inputs, true)
.map_err(|e| EmbedError::Tokenizer(e.to_string()))?;
let encoding_length = encodings
.first()
.ok_or_else(|| EmbedError::Tokenizer("Empty encodings".to_string()))?
.len();
let batch_size = batch.len();
let max_size = encoding_length * batch_size;
let mut ids_array = Vec::with_capacity(max_size);
let mut mask_array = Vec::with_capacity(max_size);
let mut type_ids_array = Vec::with_capacity(max_size);
for encoding in &encodings {
ids_array.extend(encoding.get_ids().iter().map(|&x| x as i64));
mask_array.extend(encoding.get_attention_mask().iter().map(|&x| x as i64));
type_ids_array.extend(encoding.get_type_ids().iter().map(|&x| x as i64));
}
let ids_tensor = ndarray::Array::from_shape_vec((batch_size, encoding_length), ids_array)
.map_err(|e| EmbedError::Shape(e.to_string()))?;
let type_ids_tensor = ndarray::Array::from_shape_vec((batch_size, encoding_length), type_ids_array)
.map_err(|e| EmbedError::Shape(e.to_string()))?;
let mask_nd = ndarray::Array::from_shape_vec((batch_size, encoding_length), mask_array)
.map_err(|e| EmbedError::Shape(e.to_string()))?;
let attention_mask_for_pooling = if self.pooling == Pooling::Mean {
Some(mask_nd.clone())
} else {
None
};
let mask_tensor = Value::from_array(mask_nd)?;
let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(ids_tensor)?,
"attention_mask" => mask_tensor,
];
if self.need_token_type_ids {
session_inputs.push(("token_type_ids".into(), Value::from_array(type_ids_tensor)?.into()));
}
#[allow(unsafe_code)]
let outputs = unsafe {
let session_ptr = &self.session as *const Session as *mut Session;
(*session_ptr).run(session_inputs)
}
.map_err(EmbedError::Ort)?;
let (_, output_value) = outputs.iter().next().ok_or(EmbedError::NoOutput)?;
let tensor: ArrayView<f32, Dim<IxDynImpl>> = output_value.try_extract_array().map_err(EmbedError::Ort)?;
let pooled = match attention_mask_for_pooling {
Some(mask) => mean_pool(&tensor, mask)?,
None => cls_pool(&tensor)?,
};
let embeddings: Vec<Vec<f32>> = pooled
.rows()
.into_iter()
.map(|row| row.as_slice().unwrap_or(&[]).to_vec())
.collect();
Ok(embeddings)
}
}
#[allow(unsafe_code)]
unsafe impl Send for EmbeddingEngine {}
#[allow(unsafe_code)]
unsafe impl Sync for EmbeddingEngine {}
fn cls_pool(tensor: &ArrayView<f32, Dim<IxDynImpl>>) -> Result<Array2<f32>, EmbedError> {
match tensor.dim().ndim() {
2 => Ok(tensor.slice(s![.., ..]).to_owned()),
3 => Ok(tensor.slice(s![.., 0, ..]).to_owned()),
_ => Err(EmbedError::Shape(format!(
"Expected 2D or 3D tensor, got {:?}",
tensor.dim()
))),
}
}
fn mean_pool(tensor: &ArrayView<f32, Dim<IxDynImpl>>, attention_mask: Array2<i64>) -> Result<Array2<f32>, EmbedError> {
if tensor.dim().ndim() == 2 {
return Ok(tensor.slice(s![.., ..]).to_owned());
}
if tensor.dim().ndim() != 3 {
return Err(EmbedError::Shape(format!(
"Expected 2D or 3D tensor, got {:?}",
tensor.dim()
)));
}
let token_embeddings = tensor.slice(s![.., .., ..]);
let mask_dim = attention_mask.dim();
let mask_expanded = attention_mask
.insert_axis(ndarray::Axis(2))
.broadcast(token_embeddings.dim())
.ok_or_else(|| {
EmbedError::Shape(format!(
"Cannot broadcast attention mask {:?} to {:?}",
mask_dim,
token_embeddings.dim()
))
})?
.mapv(|x| x as f32);
let masked = &mask_expanded * &token_embeddings;
let sum = masked.sum_axis(ndarray::Axis(1));
let mask_sum = mask_expanded.sum_axis(ndarray::Axis(1));
let mask_sum = mask_sum.mapv(|x| if x == 0.0 { 1.0 } else { x });
Ok(&sum / &mask_sum)
}
pub fn normalize(v: &[f32]) -> Vec<f32> {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > f32::EPSILON {
let inv = 1.0 / norm;
v.iter().map(|&x| x * inv).collect()
} else {
v.to_vec()
}
}
#[derive(Debug)]
pub enum EmbedError {
Tokenizer(String),
Ort(ort::Error),
Shape(String),
NoOutput,
}
impl From<ort::Error> for EmbedError {
fn from(e: ort::Error) -> Self {
Self::Ort(e)
}
}
impl std::fmt::Display for EmbedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tokenizer(e) => write!(f, "Tokenizer error: {e}"),
Self::Ort(e) => write!(f, "ONNX Runtime error: {e}"),
Self::Shape(e) => write!(f, "Tensor shape error: {e}"),
Self::NoOutput => write!(f, "Model produced no output tensors"),
}
}
}
impl std::error::Error for EmbedError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_unit_vector() {
let v = vec![3.0, 4.0]; let normalized = normalize(&v);
assert_eq!(normalized.len(), 2);
assert!(
(normalized[0] - 0.6).abs() < 1e-6,
"Expected ~0.6, got {}",
normalized[0]
);
assert!(
(normalized[1] - 0.8).abs() < 1e-6,
"Expected ~0.8, got {}",
normalized[1]
);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6, "L2 norm should be ~1.0, got {}", norm);
}
#[test]
fn test_normalize_zero_vector() {
let v = vec![0.0, 0.0, 0.0];
let normalized = normalize(&v);
assert_eq!(normalized.len(), 3);
assert!(!normalized.iter().any(|x| x.is_nan()), "No NaN values expected");
assert!(
!normalized.iter().any(|x| x.is_infinite()),
"No infinite values expected"
);
for &val in &normalized {
assert_eq!(val, 0.0, "Zero vector should remain zero");
}
}
#[test]
fn test_normalize_single_element() {
let v = vec![5.0];
let normalized = normalize(&v);
assert_eq!(normalized.len(), 1);
assert!(
(normalized[0] - 1.0).abs() < 1e-6,
"Expected 1.0, got {}",
normalized[0]
);
}
#[test]
fn test_embed_error_display() {
let tokenizer_err = EmbedError::Tokenizer("test error".to_string());
let display = format!("{}", tokenizer_err);
assert!(display.contains("Tokenizer error"), "Tokenizer display: {}", display);
let shape_err = EmbedError::Shape("invalid shape".to_string());
let display = format!("{}", shape_err);
assert!(display.contains("Tensor shape error"), "Shape display: {}", display);
let no_output_err = EmbedError::NoOutput;
let display = format!("{}", no_output_err);
assert!(display.contains("no output"), "NoOutput display: {}", display);
}
#[test]
fn test_pooling_variants() {
let cls = Pooling::Cls;
let mean = Pooling::Mean;
assert_ne!(cls, mean, "Pooling::Cls and Pooling::Mean should be different");
assert_eq!(cls, Pooling::Cls);
assert_eq!(mean, Pooling::Mean);
let cls_clone = cls.clone();
assert_eq!(cls, cls_clone);
let debug_output = format!("{:?}", cls);
assert!(debug_output.contains("Cls"), "Debug output: {}", debug_output);
}
#[test]
fn test_normalize_preserves_length() {
let test_cases = vec![vec![1.0, 2.0, 3.0], vec![-1.0, -2.0], vec![0.1, 0.2, 0.3, 0.4, 0.5]];
for v in test_cases {
let original_len = v.len();
let normalized = normalize(&v);
assert_eq!(
normalized.len(),
original_len,
"Normalization should preserve vector length"
);
}
}
#[test]
fn test_normalize_negative_values() {
let v = vec![-3.0, -4.0]; let normalized = normalize(&v);
assert!((normalized[0] - (-0.6)).abs() < 1e-6, "Expected ~-0.6");
assert!((normalized[1] - (-0.8)).abs() < 1e-6, "Expected ~-0.8");
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_normalize_very_small_values() {
let v = vec![f32::EPSILON / 2.0, f32::EPSILON / 2.0];
let normalized = normalize(&v);
assert_eq!(normalized, v, "Very small vectors (< epsilon) returned unchanged");
}
#[test]
fn test_embed_error_is_error_type() {
let err = EmbedError::Shape("test".to_string());
let _: &dyn std::error::Error = &err;
}
#[test]
fn test_pooling_traits() {
let cls = Pooling::Cls;
let mean = Pooling::Mean;
let cls_clone = cls.clone();
let mean_clone = mean.clone();
assert_eq!(cls, cls_clone);
assert_eq!(mean, mean_clone);
let cls_debug = format!("{:?}", cls);
let mean_debug = format!("{:?}", mean);
assert!(!cls_debug.is_empty());
assert!(!mean_debug.is_empty());
assert_eq!(cls, cls);
assert_eq!(mean, mean);
assert_ne!(cls, mean);
}
#[test]
fn test_normalize_large_values() {
let v = vec![1e6, 1e6];
let normalized = normalize(&v);
assert!(!normalized.iter().any(|x| x.is_infinite()), "No overflow");
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "L2 norm should be ~1.0");
}
#[test]
fn test_normalize_mixed_signs() {
let v = vec![1.0, -1.0, 0.0];
let normalized = normalize(&v);
assert_eq!(normalized.len(), 3);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6, "L2 norm should be ~1.0");
}
}