use std::time::{SystemTime, UNIX_EPOCH};
use ndarray::{Array, IxDyn};
use ort::session::Session;
use tracing::{debug, info};
use crate::client::WireBandClient;
use crate::error::{Result, WireBandError};
use crate::frame;
use crate::symbols::METRICS_GAUGE_SET;
fn unix_ts() -> f64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
pub struct OnnxInference {
session: Session,
input_name: String,
output_name: String,
topic_prefix: String,
symbol: u16,
}
impl OnnxInference {
pub fn from_file(
path: impl AsRef<std::path::Path>,
input_name: impl Into<String>,
output_name: impl Into<String>,
) -> Result<Self> {
let session = Session::builder()
.map_err(|e| WireBandError::Connection(format!("ONNX session builder: {e}")))?
.commit_from_file(path)
.map_err(|e| WireBandError::Connection(format!("ONNX load model: {e}")))?;
info!("ONNX model loaded");
Ok(Self {
session,
input_name: input_name.into(),
output_name: output_name.into(),
topic_prefix: "inference/onnx".to_string(),
symbol: METRICS_GAUGE_SET,
})
}
pub fn topic_prefix(mut self, prefix: impl Into<String>) -> Self {
self.topic_prefix = prefix.into();
self
}
pub fn symbol(mut self, symbol: u16) -> Self {
self.symbol = symbol;
self
}
pub async fn run(
&self,
input: &[f32],
shape: &[usize],
topic_suffix: &str,
client: &WireBandClient,
) -> Result<Vec<f32>> {
let array = Array::from_shape_vec(IxDyn(shape), input.to_vec())
.map_err(|e| WireBandError::Connection(format!("ONNX input shape error: {e}")))?;
let outputs = self.session
.run(
ort::inputs![&self.input_name => array.view()]
.map_err(|e| WireBandError::Connection(format!("ONNX inputs: {e}")))?,
)
.map_err(|e| WireBandError::Connection(format!("ONNX inference: {e}")))?;
let scores: Vec<f32> = outputs[self.output_name.as_str()]
.try_extract_tensor::<f32>()
.map_err(|e| WireBandError::Connection(format!("ONNX extract output: {e}")))?
.view()
.as_slice()
.unwrap_or(&[])
.to_vec();
let topic = format!("{}/{}", self.topic_prefix, topic_suffix);
let payload = serde_json::json!({ "scores": scores, "shape": shape });
let encoded = frame::encode(self.symbol, &topic, &payload);
client.buffer_event(topic, self.symbol, encoded, unix_ts()).await;
debug!(output_len = scores.len(), "ONNX inference complete");
Ok(scores)
}
}