mod config;
pub use config::ObjectDetectorConfig;
use anyhow::{Context as AnyhowContext, Result};
use mecha10_core::health::HealthReportingExt;
use mecha10_core::messages::{HealthStatus, Message};
use mecha10_core::prelude::*;
use mecha10_core::topics::Topic;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::time::Instant;
use tracing::info;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CameraImage {
pub camera_id: String,
pub width: u32,
pub height: u32,
pub timestamp: u64,
#[serde(with = "serde_bytes")]
pub image_bytes: Vec<u8>,
pub format: String,
}
impl Message for CameraImage {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BoundingBox {
pub x: f32,
pub y: f32,
pub width: f32,
pub height: f32,
}
impl BoundingBox {
pub fn new(x: f32, y: f32, width: f32, height: f32) -> Self {
Self { x, y, width, height }
}
pub fn iou(&self, other: &BoundingBox) -> f32 {
let x1 = self.x.max(other.x);
let y1 = self.y.max(other.y);
let x2 = (self.x + self.width).min(other.x + other.width);
let y2 = (self.y + self.height).min(other.y + other.height);
if x2 < x1 || y2 < y1 {
return 0.0;
}
let intersection = (x2 - x1) * (y2 - y1);
let union = (self.width * self.height) + (other.width * other.height) - intersection;
intersection / union
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Detection {
pub class_id: u32,
pub class_name: String,
pub confidence: f32,
pub bbox: BoundingBox,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DetectionResult {
pub frame_id: u64,
pub timestamp: u64,
pub detections: Vec<Detection>,
pub inference_time_ms: f32,
pub model_name: String,
}
impl Message for DetectionResult {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceCommand {
pub action: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
}
impl Message for InferenceCommand {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum InferenceMode {
Idle,
Active,
}
pub struct ObjectDetectorNode {
config: ObjectDetectorConfig,
session: Option<Session>,
frame_count: u64,
mode: InferenceMode,
class_names: Vec<String>,
preprocessing_semaphore: std::sync::Arc<tokio::sync::Semaphore>,
}
impl ObjectDetectorNode {
pub fn new(config: ObjectDetectorConfig) -> Self {
let max_concurrent = config.max_async_frames.max(1); let preprocessing_semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(max_concurrent));
Self {
config,
session: None,
frame_count: 0,
mode: InferenceMode::Idle,
class_names: Self::get_coco_class_names(),
preprocessing_semaphore,
}
}
fn handle_control(&mut self, cmd: InferenceCommand) {
info!("🎮 Received control command: {:?}", cmd.action);
match cmd.action.as_str() {
"enable" => {
self.mode = InferenceMode::Active;
info!("✅ Object detection enabled");
}
"disable" => {
self.mode = InferenceMode::Idle;
info!("⏸️ Object detection disabled");
}
"set_threshold" => {
if let Some(params) = &cmd.params {
if let Some(conf) = params.get("confidence").and_then(|v| v.as_f64()) {
self.config.confidence_threshold = conf as f32;
info!("🎯 Confidence threshold set to {:.2}", conf);
}
}
}
_ => {
tracing::warn!("❌ Unknown command: {}", cmd.action);
}
}
}
async fn detect_objects(&mut self, msg: &CameraImage) -> Result<DetectionResult> {
let start = Instant::now();
let permit = self.preprocessing_semaphore.clone().acquire_owned().await?;
let msg_clone = msg.clone();
let input_size = self.config.input_size;
let input_tensor = tokio::task::spawn_blocking(move || -> Result<ndarray::Array4<f32>> {
let image = Self::decode_camera_image_static(&msg_clone)?;
let result = Self::preprocess_yolo_image_static(image, input_size);
drop(permit);
result
})
.await??;
let output_vec: Vec<f32> = {
let input_value = ort::value::TensorRef::from_array_view(input_tensor.view())?;
let session = self.session.as_mut().context("ONNX session not initialized")?;
let outputs = session.run(ort::inputs![input_value])?;
let output_array: ndarray::ArrayViewD<f32> = outputs[0]
.try_extract_array()
.context("Failed to extract output tensor")?;
output_array
.as_slice()
.context("Failed to get tensor data as slice")?
.to_vec()
};
let detections = self.process_yolo_output(&output_vec, msg.width, msg.height)?;
let inference_time_ms = start.elapsed().as_secs_f32() * 1000.0;
Ok(DetectionResult {
frame_id: self.frame_count,
timestamp: msg.timestamp,
detections,
inference_time_ms,
model_name: self.config.model_name.clone(),
})
}
fn decode_camera_image_static(msg: &CameraImage) -> Result<image::DynamicImage> {
let bytes = &msg.image_bytes;
let image = match msg.format.as_str() {
"jpeg" | "jpg" => image::load_from_memory_with_format(bytes, image::ImageFormat::Jpeg)?,
"png" => image::load_from_memory_with_format(bytes, image::ImageFormat::Png)?,
"rgb" | "rgb8" => {
let img = image::RgbImage::from_raw(msg.width, msg.height, bytes.clone())
.context("Failed to create RGB image from raw bytes")?;
image::DynamicImage::ImageRgb8(img)
}
_ => anyhow::bail!("Unsupported image format: {}", msg.format),
};
Ok(image)
}
#[allow(dead_code)]
fn decode_camera_image(&self, msg: &CameraImage) -> Result<image::DynamicImage> {
Self::decode_camera_image_static(msg)
}
fn preprocess_yolo_image_static(image: image::DynamicImage, size: u32) -> Result<ndarray::Array4<f32>> {
let resized = image.resize_exact(size, size, image::imageops::FilterType::Triangle);
let rgb = resized.to_rgb8();
let mut array = ndarray::Array4::<f32>::zeros((1, 3, size as usize, size as usize));
for (x, y, pixel) in rgb.enumerate_pixels() {
array[[0, 0, y as usize, x as usize]] = pixel[0] as f32 / 255.0;
array[[0, 1, y as usize, x as usize]] = pixel[1] as f32 / 255.0;
array[[0, 2, y as usize, x as usize]] = pixel[2] as f32 / 255.0;
}
Ok(array)
}
#[allow(dead_code)]
fn preprocess_yolo_image(&self, image: image::DynamicImage) -> Result<ndarray::Array4<f32>> {
Self::preprocess_yolo_image_static(image, self.config.input_size)
}
fn process_yolo_output(&self, output_slice: &[f32], _img_width: u32, _img_height: u32) -> Result<Vec<Detection>> {
let num_detections = 8400;
let num_classes = 80;
let mut raw_detections = Vec::new();
let mut max_score_seen = 0.0f32;
for i in 0..num_detections {
let cx = output_slice[i];
let cy = output_slice[num_detections + i];
let w = output_slice[2 * num_detections + i];
let h = output_slice[3 * num_detections + i];
let (max_class_idx, max_score) = (0..num_classes)
.map(|c| {
let score = output_slice[(4 + c) * num_detections + i];
(c, score)
})
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0, 0.0));
if max_score > max_score_seen {
max_score_seen = max_score;
}
if max_score >= self.config.confidence_threshold {
let input_size = self.config.input_size as f32;
let x = (cx - w / 2.0) / input_size;
let y = (cy - h / 2.0) / input_size;
let width = w / input_size;
let height = h / input_size;
raw_detections.push(Detection {
class_id: max_class_idx as u32,
class_name: self
.class_names
.get(max_class_idx)
.cloned()
.unwrap_or_else(|| format!("class_{}", max_class_idx)),
confidence: max_score,
bbox: BoundingBox::new(x, y, width, height),
});
}
}
let filtered = self.non_max_suppression(raw_detections);
if filtered.is_empty() {
tracing::debug!(
"No detections above threshold {:.2}. Max score seen: {:.4}",
self.config.confidence_threshold,
max_score_seen
);
}
Ok(filtered)
}
fn non_max_suppression(&self, mut detections: Vec<Detection>) -> Vec<Detection> {
detections.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
let mut keep = Vec::new();
while !detections.is_empty() {
let current = detections.remove(0);
keep.push(current.clone());
detections.retain(|det| {
det.class_id != current.class_id || det.bbox.iou(¤t.bbox) < self.config.iou_threshold
});
}
keep
}
fn get_coco_class_names() -> Vec<String> {
vec![
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
.iter()
.map(|s| s.to_string())
.collect()
}
}
pub async fn run() -> Result<()> {
info!("🤖 Starting Object Detector Node");
let ctx = Context::new("object-detector").await?;
ctx.start_health_reporting(|| async { HealthStatus::healthy() }).await?;
let config: ObjectDetectorConfig = ctx.load_node_config("object-detector").await?;
info!("Configuration: {:?}", config);
let model_path = PathBuf::from(&config.model_path);
if !model_path.exists() {
anyhow::bail!(
"Model not found at {}. Please download a YOLOv8 ONNX model.",
model_path.display()
);
}
info!("📦 Loading model from: {}", model_path.display());
let final_model_path = if config.use_int8 {
let int8_path =
model_path.with_file_name(model_path.file_stem().unwrap().to_string_lossy().to_string() + "-int8.onnx");
if int8_path.exists() {
info!("🔢 Using INT8 quantized model for 2x speedup");
int8_path
} else {
info!(
"⚠️ INT8 enabled but quantized model not found at {}",
int8_path.display()
);
info!(" Falling back to FP32 model. Run conversion script to create INT8 model.");
model_path
}
} else {
model_path
};
let mut session_builder = Session::builder()?;
if config.use_int8 {
session_builder = session_builder
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?; }
let session = session_builder
.with_execution_providers([
#[cfg(target_os = "macos")]
ort::execution_providers::CoreMLExecutionProvider::default().build(),
#[cfg(not(target_os = "macos"))]
ort::execution_providers::CUDAExecutionProvider::default().build(),
ort::execution_providers::CPUExecutionProvider::default().build(),
])?
.commit_from_file(&final_model_path)?;
info!(
"✅ Model loaded successfully ({}) with hardware acceleration",
if config.use_int8 { "INT8" } else { "FP32" }
);
let mut node = ObjectDetectorNode::new(config.clone());
node.session = Some(session);
let input_topic = config.input_topic();
let output_topic = config.output_topic();
let control_topic = config.control_topic();
info!("📡 Input: {}", input_topic);
info!("📡 Control: {}", control_topic);
info!("📤 Output: {}", output_topic);
info!(
"⏸️ Starting in {} mode",
if config.default_enabled { "ACTIVE" } else { "IDLE" }
);
if config.default_enabled {
node.mode = InferenceMode::Active;
}
let input_topic_str: &'static str = Box::leak(input_topic.into_boxed_str());
let control_topic_str: &'static str = Box::leak(control_topic.into_boxed_str());
let output_topic_str: &'static str = Box::leak(output_topic.into_boxed_str());
let camera_topic = Topic::<mecha10_core::messages::RedisMessage<CameraImage>>::new(input_topic_str);
let control_topic = Topic::<InferenceCommand>::new(control_topic_str);
let output_topic = Topic::<DetectionResult>::new(output_topic_str);
let mut camera_receiver = ctx.subscribe(camera_topic).await?;
let mut control_receiver = ctx.subscribe(control_topic).await?;
info!("✅ Subscribed to camera and control topics");
info!("🔄 Entering main processing loop with frame dropping (latest-only)");
info!("🕐 System time at start: {} μs", now_micros());
loop {
tokio::select! {
Some(control_msg) = control_receiver.recv() => {
node.handle_control(control_msg);
}
Some(camera_envelope) = camera_receiver.recv() => {
let mut latest_envelope = camera_envelope;
let mut dropped_count = 0;
while let Ok(newer_envelope) = camera_receiver.try_recv() {
latest_envelope = newer_envelope;
dropped_count += 1;
}
if dropped_count > 0 {
tracing::debug!("📉 Dropped {} old frames, processing latest", dropped_count);
}
let current_time = now_micros();
let frame_timestamp = latest_envelope.timestamp;
static FRAME_DEBUG_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let debug_count = FRAME_DEBUG_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if debug_count < 3 {
tracing::info!(
"🕐 Timestamp debug: system={} μs, frame={} μs, diff={} μs",
current_time,
frame_timestamp,
current_time.saturating_sub(frame_timestamp)
);
}
if current_time > frame_timestamp {
let frame_age_us = current_time - frame_timestamp;
let frame_age_ms = frame_age_us as f64 / 1000.0;
if frame_age_ms > 1000.0 && frame_age_ms < 10_000.0 {
tracing::warn!(
"⏰ Processing old frame ({:.1}ms old) - possible system lag",
frame_age_ms
);
} else if frame_age_ms > 100.0 && frame_age_ms < 1000.0 {
tracing::debug!("⏱️ Frame age: {:.1}ms", frame_age_ms);
}
}
let camera_msg = latest_envelope.payload;
if node.mode == InferenceMode::Active {
node.frame_count += 1;
if node.frame_count % config.frame_skip as u64 == 0 {
match node.detect_objects(&camera_msg).await {
Ok(result) => {
info!(
"🎯 Detected {} objects in {:.1}ms",
result.detections.len(),
result.inference_time_ms
);
for det in &result.detections {
info!(
" - {} ({:.1}%)",
det.class_name,
det.confidence * 100.0
);
}
ctx.publish_to(output_topic, &result).await?;
}
Err(e) => {
tracing::warn!("❌ Detection failed: {}", e);
}
}
}
}
}
}
}
}