wireband-edge 0.4.0

Lightweight Wire.Band client — semantic data middleware for any domain (IoT, AI/ML, DeFi, legal, geospatial, supply chain, and more)
Documentation
//! TensorFlow Lite inference — run quantized ML models on constrained hardware.
//!
//! Loads a `.tflite` flatbuffer model, fills the input tensor with f32 data,
//! invokes the interpreter, and emits predictions as Wire.Band events.
//!
//! Well-suited for INT8 quantized anomaly detectors on Raspberry Pi or Coral
//! Edge TPU (via the TFLite C API delegate interface).
//!
//! Feature-gated: `--features infer-tflite` (uses `tflite`).
//!
//! # Prerequisites
//!
//! The `tflite` crate compiles TensorFlow Lite from source, which requires
//! CMake and a C++ toolchain:
//! ```bash
//! sudo apt-get install cmake g++
//! cargo build --features infer-tflite   # first build takes ~5 min
//! ```

use std::time::{SystemTime, UNIX_EPOCH};

use tflite::ops::builtin::BuiltinOpResolver;
use tflite::{FlatBufferModel, InterpreterBuilder};
use tracing::{debug, info};

use crate::client::WireBandClient;
use crate::error::{Result, WireBandError};
use crate::frame;
use crate::symbols::METRICS_GAUGE_SET;

fn unix_ts() -> f64 {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_secs_f64()
}

/// TensorFlow Lite model runner for edge inference.
///
/// The model file is reloaded each call to `run` — acceptable for most
/// sensor polling rates. For high-frequency inference (>100 Hz), cache the
/// `Interpreter` externally and call the TFLite C API directly.
///
/// # Example
///
/// ```ignore
/// use wireband_edge::infer_tflite::TfliteInference;
///
/// let model = TfliteInference::from_file("anomaly.tflite")?;
///
/// // In your sensor loop:
/// let scores = model
///     .run(&sensor_readings, "predictive_maintenance", &client)
///     .await?;
///
/// if scores[0] > 0.85 {
///     tracing::warn!("TFLite anomaly: score={:.2}", scores[0]);
/// }
/// ```
pub struct TfliteInference {
    model_path:   String,
    topic_prefix: String,
    symbol:       u16,
}

impl TfliteInference {
    /// Load a TFLite model from `path` (`.tflite` flatbuffer file).
    ///
    /// Validates the file exists at construction time; the model is loaded
    /// fresh on each [`run`] call.
    pub fn from_file(path: impl Into<String>) -> Result<Self> {
        let model_path = path.into();
        if !std::path::Path::new(&model_path).exists() {
            return Err(WireBandError::Connection(
                format!("TFLite model not found: {model_path}")
            ));
        }
        info!(path = %model_path, "TFLite model path set");
        Ok(Self {
            model_path,
            topic_prefix: "inference/tflite".to_string(),
            symbol:       METRICS_GAUGE_SET,
        })
    }

    /// Override the Wire.Band topic prefix (default: `"inference/tflite"`).
    ///
    /// Predictions are published to `{prefix}/{topic_suffix}` in [`run`].
    pub fn topic_prefix(mut self, prefix: impl Into<String>) -> Self {
        self.topic_prefix = prefix.into();
        self
    }

    /// Override the theta symbol for inference events (default: `METRICS_GAUGE_SET`).
    pub fn symbol(mut self, symbol: u16) -> Self {
        self.symbol = symbol;
        self
    }

    /// Run inference on `input` (flat f32 slice for tensor index 0) and buffer
    /// the output scores as a Wire.Band event.
    ///
    /// - `input`:        flat f32 values matching the model's input tensor size.
    /// - `topic_suffix`: appended to the topic prefix (e.g. `"zone-a/anomaly"`).
    ///
    /// Returns the raw output scores from tensor index 0.
    pub async fn run(
        &self,
        input:        &[f32],
        topic_suffix: &str,
        client:       &WireBandClient,
    ) -> Result<Vec<f32>> {
        let model = FlatBufferModel::build_from_file(&self.model_path)
            .ok_or_else(|| WireBandError::Connection(
                format!("TFLite failed to load model: {}", self.model_path)
            ))?;

        let resolver = BuiltinOpResolver::default();
        let builder  = InterpreterBuilder::new(model, resolver)
            .map_err(|e| WireBandError::Connection(format!("TFLite builder: {e}")))?;
        let mut interp = builder.build()
            .map_err(|e| WireBandError::Connection(format!("TFLite build interpreter: {e}")))?;

        interp.allocate_tensors()
            .map_err(|e| WireBandError::Connection(format!("TFLite allocate tensors: {e}")))?;

        // Fill input tensor (index 0)
        let input_idx  = interp.inputs()[0];
        let input_data: &mut [f32] = interp.tensor_data_mut(input_idx)
            .map_err(|e| WireBandError::Connection(format!("TFLite input tensor: {e}")))?;
        let n = input_data.len().min(input.len());
        input_data[..n].copy_from_slice(&input[..n]);

        if !interp.invoke() {
            return Err(WireBandError::Connection("TFLite invoke failed".into()));
        }

        // Read output tensor (index 0)
        let output_idx = interp.outputs()[0];
        let scores: Vec<f32> = interp
            .tensor_data::<f32>(output_idx)
            .map_err(|e| WireBandError::Connection(format!("TFLite output tensor: {e}")))?
            .to_vec();

        let topic   = format!("{}/{}", self.topic_prefix, topic_suffix);
        let payload = serde_json::json!({ "scores": scores });
        let encoded = frame::encode(self.symbol, &topic, &payload);
        client.buffer_event(topic, self.symbol, encoded, unix_ts()).await;

        debug!(output_len = scores.len(), "TFLite inference complete");
        Ok(scores)
    }
}