wireband-edge 0.4.1

Lightweight Wire.Band client — semantic data middleware for any domain (IoT, AI/ML, DeFi, legal, geospatial, supply chain, and more)
Documentation
//! ONNX Runtime inference — run ML models at the edge via ONNX Runtime.
//!
//! Loads a `.onnx` model file and runs inference on f32 input tensors,
//! emitting predictions as Wire.Band events. Common use cases: anomaly
//! detection, predictive maintenance, image classification on Jetson/RPi.
//!
//! Feature-gated: `--features infer-onnx` (uses `ort` + `ndarray`).
//!
//! # Prerequisites
//!
//! ONNX Runtime shared library must be available. On Linux:
//! ```bash
//! # Download from https://github.com/microsoft/onnxruntime/releases
//! export ORT_DYLIB_PATH=/usr/local/lib/libonnxruntime.so
//! ```

use std::time::{SystemTime, UNIX_EPOCH};

use ndarray::{Array, IxDyn};
use ort::session::Session;
use tracing::{debug, info};

use crate::client::WireBandClient;
use crate::error::{Result, WireBandError};
use crate::frame;
use crate::symbols::METRICS_GAUGE_SET;

fn unix_ts() -> f64 {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_secs_f64()
}

/// ONNX Runtime model runner for edge inference.
///
/// # Example
///
/// ```ignore
/// use wireband_edge::infer_onnx::OnnxInference;
///
/// let model = OnnxInference::from_file(
///     "anomaly_detector.onnx",
///     "input",
///     "output",
/// )?;
///
/// // In your sensor loop:
/// let scores = model
///     .run(&sensor_readings, &[1, 16], "predictive_maintenance", &client)
///     .await?;
///
/// if scores[0] > 0.85 {
///     tracing::warn!("Anomaly detected: score={:.2}", scores[0]);
/// }
/// ```
pub struct OnnxInference {
    session:      Session,
    input_name:   String,
    output_name:  String,
    topic_prefix: String,
    symbol:       u16,
}

impl OnnxInference {
    /// Load an ONNX model from `path`.
    ///
    /// - `input_name`:  name of the input node in the model graph.
    /// - `output_name`: name of the output node to read predictions from.
    pub fn from_file(
        path:        impl AsRef<std::path::Path>,
        input_name:  impl Into<String>,
        output_name: impl Into<String>,
    ) -> Result<Self> {
        let session = Session::builder()
            .map_err(|e| WireBandError::Connection(format!("ONNX session builder: {e}")))?
            .commit_from_file(path)
            .map_err(|e| WireBandError::Connection(format!("ONNX load model: {e}")))?;

        info!("ONNX model loaded");

        Ok(Self {
            session,
            input_name:   input_name.into(),
            output_name:  output_name.into(),
            topic_prefix: "inference/onnx".to_string(),
            symbol:       METRICS_GAUGE_SET,
        })
    }

    /// Override the Wire.Band topic prefix (default: `"inference/onnx"`).
    ///
    /// Predictions are published to `{prefix}/{topic_suffix}` in [`run`].
    pub fn topic_prefix(mut self, prefix: impl Into<String>) -> Self {
        self.topic_prefix = prefix.into();
        self
    }

    /// Override the theta symbol for inference events (default: `METRICS_GAUGE_SET`).
    pub fn symbol(mut self, symbol: u16) -> Self {
        self.symbol = symbol;
        self
    }

    /// Run inference and buffer the output scores as a Wire.Band event.
    ///
    /// - `input`:        flat f32 slice (row-major).
    /// - `shape`:        tensor shape (e.g. `&[1, 16]` for batch=1, features=16).
    /// - `topic_suffix`: appended to the topic prefix (e.g. `"zone-a/anomaly"`).
    ///
    /// Returns the raw output scores.
    pub async fn run(
        &self,
        input:        &[f32],
        shape:        &[usize],
        topic_suffix: &str,
        client:       &WireBandClient,
    ) -> Result<Vec<f32>> {
        let array = Array::from_shape_vec(IxDyn(shape), input.to_vec())
            .map_err(|e| WireBandError::Connection(format!("ONNX input shape error: {e}")))?;

        let outputs = self.session
            .run(
                ort::inputs![&self.input_name => array.view()]
                    .map_err(|e| WireBandError::Connection(format!("ONNX inputs: {e}")))?,
            )
            .map_err(|e| WireBandError::Connection(format!("ONNX inference: {e}")))?;

        let scores: Vec<f32> = outputs[self.output_name.as_str()]
            .try_extract_tensor::<f32>()
            .map_err(|e| WireBandError::Connection(format!("ONNX extract output: {e}")))?
            .view()
            .as_slice()
            .unwrap_or(&[])
            .to_vec();

        let topic   = format!("{}/{}", self.topic_prefix, topic_suffix);
        let payload = serde_json::json!({ "scores": scores, "shape": shape });
        let encoded = frame::encode(self.symbol, &topic, &payload);
        client.buffer_event(topic, self.symbol, encoded, unix_ts()).await;

        debug!(output_len = scores.len(), "ONNX inference complete");
        Ok(scores)
    }
}