use async_trait::async_trait;
use kapsl_engine_api::{
BinaryTensorPacket, Engine, EngineError, EngineMetrics, EngineModelInfo, EngineStream,
InferenceRequest, TensorDtype,
};
use std::path::Path;
pub struct OnnxClassifyBackend {
inner: Box<dyn Engine>,
apply_softmax: bool,
}
impl OnnxClassifyBackend {
pub fn new(inner: Box<dyn Engine>, apply_softmax: bool) -> Self {
Self {
inner,
apply_softmax,
}
}
}
#[async_trait]
impl Engine for OnnxClassifyBackend {
async fn load(&mut self, model_path: &Path) -> Result<(), EngineError> {
self.inner.load(model_path).await
}
fn infer(&self, request: &InferenceRequest) -> Result<BinaryTensorPacket, EngineError> {
let output = self.inner.infer(request)?;
classify_from_output(&output, self.apply_softmax)
}
fn infer_stream(&self, request: &InferenceRequest) -> EngineStream {
let result = self.infer(request);
Box::pin(futures::stream::once(async move { result }))
}
fn unload(&mut self) {
self.inner.unload();
}
fn metrics(&self) -> EngineMetrics {
self.inner.metrics()
}
fn model_info(&self) -> Option<EngineModelInfo> {
self.inner.model_info()
}
fn health_check(&self) -> Result<(), EngineError> {
self.inner.health_check()
}
}
fn classify_from_output(
output: &BinaryTensorPacket,
apply_softmax: bool,
) -> Result<BinaryTensorPacket, EngineError> {
if output.dtype != TensorDtype::Float32 {
return Err(EngineError::backend(format!(
"classifier output dtype {:?} is not supported (expected float32)",
output.dtype
)));
}
let mut values = bytes_to_f32(&output.data);
let (batch, classes, shape) = match output.shape.as_slice() {
[classes] => (1usize, dim_usize(*classes), vec![1, *classes]),
[batch, classes] => (dim_usize(*batch), dim_usize(*classes), output.shape.clone()),
other => {
return Err(EngineError::backend(format!(
"classifier expects a 1-D [classes] or 2-D [batch, classes] output, got shape {:?}",
other
)));
}
};
let expected = batch * classes;
if values.len() != expected {
return Err(EngineError::backend(format!(
"classifier output has {} values but shape {:?} implies {}",
values.len(),
output.shape,
expected
)));
}
if apply_softmax {
softmax_rows(&mut values, batch, classes);
}
Ok(f32_packet(shape, values))
}
fn softmax_rows(v: &mut [f32], rows: usize, classes: usize) {
for r in 0..rows {
let row = &mut v[r * classes..r * classes + classes];
let max = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0f32;
for x in row.iter_mut() {
*x = (*x - max).exp();
sum += *x;
}
let sum = sum.max(1e-12);
for x in row.iter_mut() {
*x /= sum;
}
}
}
fn dim_usize(d: i64) -> usize {
d.max(0) as usize
}
fn bytes_to_f32(data: &[u8]) -> Vec<f32> {
data.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect()
}
fn f32_packet(shape: Vec<i64>, values: Vec<f32>) -> BinaryTensorPacket {
let mut data = Vec::with_capacity(values.len() * 4);
for v in &values {
data.extend_from_slice(&v.to_le_bytes());
}
BinaryTensorPacket {
shape,
dtype: TensorDtype::Float32,
data,
}
}
#[cfg(test)]
#[path = "onnx_classify_tests.rs"]
mod onnx_classify_tests;