use std::time::{SystemTime, UNIX_EPOCH};
use tflite::ops::builtin::BuiltinOpResolver;
use tflite::{FlatBufferModel, InterpreterBuilder};
use tracing::{debug, info};
use crate::client::WireBandClient;
use crate::error::{Result, WireBandError};
use crate::frame;
use crate::symbols::METRICS_GAUGE_SET;
fn unix_ts() -> f64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
pub struct TfliteInference {
model_path: String,
topic_prefix: String,
symbol: u16,
}
impl TfliteInference {
pub fn from_file(path: impl Into<String>) -> Result<Self> {
let model_path = path.into();
if !std::path::Path::new(&model_path).exists() {
return Err(WireBandError::Connection(
format!("TFLite model not found: {model_path}")
));
}
info!(path = %model_path, "TFLite model path set");
Ok(Self {
model_path,
topic_prefix: "inference/tflite".to_string(),
symbol: METRICS_GAUGE_SET,
})
}
pub fn topic_prefix(mut self, prefix: impl Into<String>) -> Self {
self.topic_prefix = prefix.into();
self
}
pub fn symbol(mut self, symbol: u16) -> Self {
self.symbol = symbol;
self
}
pub async fn run(
&self,
input: &[f32],
topic_suffix: &str,
client: &WireBandClient,
) -> Result<Vec<f32>> {
let model = FlatBufferModel::build_from_file(&self.model_path)
.ok_or_else(|| WireBandError::Connection(
format!("TFLite failed to load model: {}", self.model_path)
))?;
let resolver = BuiltinOpResolver::default();
let builder = InterpreterBuilder::new(model, resolver)
.map_err(|e| WireBandError::Connection(format!("TFLite builder: {e}")))?;
let mut interp = builder.build()
.map_err(|e| WireBandError::Connection(format!("TFLite build interpreter: {e}")))?;
interp.allocate_tensors()
.map_err(|e| WireBandError::Connection(format!("TFLite allocate tensors: {e}")))?;
let input_idx = interp.inputs()[0];
let input_data: &mut [f32] = interp.tensor_data_mut(input_idx)
.map_err(|e| WireBandError::Connection(format!("TFLite input tensor: {e}")))?;
let n = input_data.len().min(input.len());
input_data[..n].copy_from_slice(&input[..n]);
if !interp.invoke() {
return Err(WireBandError::Connection("TFLite invoke failed".into()));
}
let output_idx = interp.outputs()[0];
let scores: Vec<f32> = interp
.tensor_data::<f32>(output_idx)
.map_err(|e| WireBandError::Connection(format!("TFLite output tensor: {e}")))?
.to_vec();
let topic = format!("{}/{}", self.topic_prefix, topic_suffix);
let payload = serde_json::json!({ "scores": scores });
let encoded = frame::encode(self.symbol, &topic, &payload);
client.buffer_event(topic, self.symbol, encoded, unix_ts()).await;
debug!(output_len = scores.len(), "TFLite inference complete");
Ok(scores)
}
}