use async_trait::async_trait;
use kapsl_engine_api::{
BinaryTensorPacket, Engine, EngineError, EngineMetrics, EngineModelInfo, EngineStream,
InferenceRequest, TensorDtype,
};
use std::path::Path;
pub struct OnnxEmbedBackend {
inner: Box<dyn Engine>,
normalize: bool,
}
impl OnnxEmbedBackend {
pub fn new(inner: Box<dyn Engine>, normalize: bool) -> Self {
Self { inner, normalize }
}
}
#[async_trait]
impl Engine for OnnxEmbedBackend {
async fn load(&mut self, model_path: &Path) -> Result<(), EngineError> {
self.inner.load(model_path).await
}
fn infer(&self, request: &InferenceRequest) -> Result<BinaryTensorPacket, EngineError> {
let output = self.inner.infer(request)?;
embed_from_output(&output, request, self.normalize)
}
fn infer_stream(&self, request: &InferenceRequest) -> EngineStream {
let result = self.infer(request);
Box::pin(futures::stream::once(async move { result }))
}
fn unload(&mut self) {
self.inner.unload();
}
fn metrics(&self) -> EngineMetrics {
self.inner.metrics()
}
fn model_info(&self) -> Option<EngineModelInfo> {
self.inner.model_info()
}
fn health_check(&self) -> Result<(), EngineError> {
self.inner.health_check()
}
}
fn embed_from_output(
output: &BinaryTensorPacket,
request: &InferenceRequest,
normalize: bool,
) -> Result<BinaryTensorPacket, EngineError> {
if output.dtype != TensorDtype::Float32 {
return Err(EngineError::backend(format!(
"embedding output dtype {:?} is not supported (expected float32)",
output.dtype
)));
}
let values = bytes_to_f32(&output.data);
match output.shape.as_slice() {
[batch, dim] => {
let (batch, dim) = (dim_usize(*batch), dim_usize(*dim));
let mut pooled = values;
if normalize {
l2_normalize_rows(&mut pooled, batch, dim);
}
Ok(f32_packet(vec![batch as i64, dim as i64], pooled))
}
[batch, seq, dim] => {
let (batch, seq, dim) = (dim_usize(*batch), dim_usize(*seq), dim_usize(*dim));
let expected = batch * seq * dim;
if values.len() != expected {
return Err(EngineError::backend(format!(
"embedding output has {} values but shape {:?} implies {}",
values.len(),
output.shape,
expected
)));
}
let mask = extract_attention_mask(request, batch, seq);
let mut pooled = masked_mean_pool(&values, batch, seq, dim, &mask);
if normalize {
l2_normalize_rows(&mut pooled, batch, dim);
}
Ok(f32_packet(vec![batch as i64, dim as i64], pooled))
}
other => Err(EngineError::backend(format!(
"embedding expects a 2-D [batch, dim] or 3-D [batch, seq, dim] output, got shape {:?}",
other
))),
}
}
fn masked_mean_pool(hidden: &[f32], batch: usize, seq: usize, dim: usize, mask: &[f32]) -> Vec<f32> {
let mut out = vec![0f32; batch * dim];
for b in 0..batch {
let mut denom = 0f32;
for s in 0..seq {
let w = mask.get(b * seq + s).copied().unwrap_or(1.0);
if w == 0.0 {
continue;
}
denom += w;
let hbase = (b * seq + s) * dim;
let obase = b * dim;
for d in 0..dim {
out[obase + d] += w * hidden[hbase + d];
}
}
let denom = denom.max(1e-9);
let obase = b * dim;
for d in 0..dim {
out[obase + d] /= denom;
}
}
out
}
fn l2_normalize_rows(v: &mut [f32], rows: usize, dim: usize) {
for r in 0..rows {
let base = r * dim;
let row = &mut v[base..base + dim];
let norm = row.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
for x in row.iter_mut() {
*x /= norm;
}
}
}
fn extract_attention_mask(request: &InferenceRequest, batch: usize, seq: usize) -> Vec<f32> {
let expected = batch * seq;
for named in &request.additional_inputs {
if named.name.contains("attention_mask") {
let mask = packet_to_f32(&named.tensor);
if mask.len() == expected {
return mask;
}
}
}
vec![1.0; expected]
}
fn dim_usize(d: i64) -> usize {
d.max(0) as usize
}
fn bytes_to_f32(data: &[u8]) -> Vec<f32> {
data.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect()
}
fn packet_to_f32(packet: &BinaryTensorPacket) -> Vec<f32> {
match packet.dtype {
TensorDtype::Float32 => bytes_to_f32(&packet.data),
TensorDtype::Int64 => packet
.data
.chunks_exact(8)
.map(|b| i64::from_le_bytes(b.try_into().unwrap()) as f32)
.collect(),
TensorDtype::Int32 => packet
.data
.chunks_exact(4)
.map(|b| i32::from_le_bytes(b.try_into().unwrap()) as f32)
.collect(),
_ => Vec::new(),
}
}
fn f32_packet(shape: Vec<i64>, values: Vec<f32>) -> BinaryTensorPacket {
let mut data = Vec::with_capacity(values.len() * 4);
for v in &values {
data.extend_from_slice(&v.to_le_bytes());
}
BinaryTensorPacket {
shape,
dtype: TensorDtype::Float32,
data,
}
}
#[cfg(test)]
#[path = "onnx_embed_tests.rs"]
mod onnx_embed_tests;