mecha10_nodes_object_detector/
lib.rs

1//! Object Detector Node
2//!
3//! Real-time object detection node that:
4//! - Subscribes to camera frames from any source
5//! - Runs YOLO inference using ONNX Runtime
6//! - Publishes detection results with bounding boxes
7//! - Supports enable/disable control commands
8//!
9//! # Topic Interface
10//!
11//! **Input:** `/camera/rgb` (CameraImage)
12//! **Output:** `/vision/object/detections` (DetectionResult)
13//! **Control:** `/vision/object/control` (InferenceCommand)
14
15mod config;
16
17pub use config::ObjectDetectorConfig;
18
19use anyhow::{Context as AnyhowContext, Result};
20use mecha10_core::health::HealthReportingExt;
21use mecha10_core::messages::{HealthStatus, Message};
22use mecha10_core::prelude::*;
23use mecha10_core::topics::Topic;
24use ort::session::builder::GraphOptimizationLevel;
25use ort::session::Session;
26use serde::{Deserialize, Serialize};
27use std::path::PathBuf;
28use std::time::Instant;
29use tracing::info;
30
31// === Message Types ===
32
33/// Camera image message (reused from simulation-bridge)
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CameraImage {
36    pub camera_id: String,
37    pub width: u32,
38    pub height: u32,
39    pub timestamp: u64,
40    #[serde(with = "serde_bytes")]
41    pub image_bytes: Vec<u8>,
42    pub format: String,
43}
44
45impl Message for CameraImage {}
46
47/// Bounding box in normalized coordinates (0-1)
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct BoundingBox {
50    pub x: f32,
51    pub y: f32,
52    pub width: f32,
53    pub height: f32,
54}
55
56impl BoundingBox {
57    pub fn new(x: f32, y: f32, width: f32, height: f32) -> Self {
58        Self { x, y, width, height }
59    }
60
61    /// Calculate Intersection over Union (IoU) with another box
62    pub fn iou(&self, other: &BoundingBox) -> f32 {
63        let x1 = self.x.max(other.x);
64        let y1 = self.y.max(other.y);
65        let x2 = (self.x + self.width).min(other.x + other.width);
66        let y2 = (self.y + self.height).min(other.y + other.height);
67
68        if x2 < x1 || y2 < y1 {
69            return 0.0;
70        }
71
72        let intersection = (x2 - x1) * (y2 - y1);
73        let union = (self.width * self.height) + (other.width * other.height) - intersection;
74
75        intersection / union
76    }
77}
78
79/// Single object detection
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct Detection {
82    pub class_id: u32,
83    pub class_name: String,
84    pub confidence: f32,
85    pub bbox: BoundingBox,
86}
87
88/// Detection result message (matches dashboard interface)
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct DetectionResult {
91    pub frame_id: u64,
92    pub timestamp: u64,
93    pub detections: Vec<Detection>,
94    pub inference_time_ms: f32,
95    pub model_name: String,
96}
97
98impl Message for DetectionResult {}
99
100/// Inference control commands
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct InferenceCommand {
103    pub action: String,
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub params: Option<serde_json::Value>,
106}
107
108impl Message for InferenceCommand {}
109
110/// Inference state
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112enum InferenceMode {
113    Idle,
114    Active,
115}
116
117// === Object Detector Node ===
118
119pub struct ObjectDetectorNode {
120    config: ObjectDetectorConfig,
121    session: Option<Session>,
122    frame_count: u64,
123    mode: InferenceMode,
124    class_names: Vec<String>,
125    /// Semaphore to limit concurrent async preprocessing tasks
126    preprocessing_semaphore: std::sync::Arc<tokio::sync::Semaphore>,
127}
128
129impl ObjectDetectorNode {
130    pub fn new(config: ObjectDetectorConfig) -> Self {
131        // Limit concurrent preprocessing based on config
132        // This prevents memory buildup if frames arrive faster than we can infer
133        let max_concurrent = config.max_async_frames.max(1); // At least 1
134        let preprocessing_semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(max_concurrent));
135
136        Self {
137            config,
138            session: None,
139            frame_count: 0,
140            mode: InferenceMode::Idle,
141            class_names: Self::get_coco_class_names(),
142            preprocessing_semaphore,
143        }
144    }
145
146    /// Handle inference control command
147    fn handle_control(&mut self, cmd: InferenceCommand) {
148        info!("🎮 Received control command: {:?}", cmd.action);
149        match cmd.action.as_str() {
150            "enable" => {
151                self.mode = InferenceMode::Active;
152                info!("✅ Object detection enabled");
153            }
154            "disable" => {
155                self.mode = InferenceMode::Idle;
156                info!("⏸️  Object detection disabled");
157            }
158            "set_threshold" => {
159                if let Some(params) = &cmd.params {
160                    if let Some(conf) = params.get("confidence").and_then(|v| v.as_f64()) {
161                        self.config.confidence_threshold = conf as f32;
162                        info!("🎯 Confidence threshold set to {:.2}", conf);
163                    }
164                }
165            }
166            _ => {
167                tracing::warn!("❌ Unknown command: {}", cmd.action);
168            }
169        }
170    }
171
172    /// Process a camera image and return detections
173    async fn detect_objects(&mut self, msg: &CameraImage) -> Result<DetectionResult> {
174        let start = Instant::now();
175
176        // OPTIMIZATION: Run CPU-bound preprocessing in separate thread to avoid blocking
177        // This allows overlap with GPU inference from previous frame
178        // Semaphore limits concurrent preprocessing to prevent memory buildup
179        let permit = self.preprocessing_semaphore.clone().acquire_owned().await?;
180
181        let msg_clone = msg.clone();
182        let input_size = self.config.input_size;
183
184        let input_tensor = tokio::task::spawn_blocking(move || -> Result<ndarray::Array4<f32>> {
185            // 1. Decode image (CPU-bound)
186            let image = Self::decode_camera_image_static(&msg_clone)?;
187
188            // 2. Preprocess for YOLO (CPU-bound: resize + normalize)
189            let result = Self::preprocess_yolo_image_static(image, input_size);
190
191            // Permit is dropped here, allowing next frame to preprocess
192            drop(permit);
193
194            result
195        })
196        .await??;
197
198        // 3. Run ONNX inference and extract output in a separate scope
199        let output_vec: Vec<f32> = {
200            let input_value = ort::value::TensorRef::from_array_view(input_tensor.view())?;
201
202            let session = self.session.as_mut().context("ONNX session not initialized")?;
203            let outputs = session.run(ort::inputs![input_value])?;
204
205            // Extract output tensor and copy data (YOLOv8 format: [1, 84, 8400])
206            let output_array: ndarray::ArrayViewD<f32> = outputs[0]
207                .try_extract_array()
208                .context("Failed to extract output tensor")?;
209            output_array
210                .as_slice()
211                .context("Failed to get tensor data as slice")?
212                .to_vec()
213        };
214        // outputs and session borrow are dropped here
215
216        // 4. Process YOLO output to detections
217        let detections = self.process_yolo_output(&output_vec, msg.width, msg.height)?;
218
219        let inference_time_ms = start.elapsed().as_secs_f32() * 1000.0;
220
221        Ok(DetectionResult {
222            frame_id: self.frame_count,
223            timestamp: msg.timestamp,
224            detections,
225            inference_time_ms,
226            model_name: self.config.model_name.clone(),
227        })
228    }
229
230    /// Decode camera image from various formats (static version for async)
231    fn decode_camera_image_static(msg: &CameraImage) -> Result<image::DynamicImage> {
232        let bytes = &msg.image_bytes;
233
234        let image = match msg.format.as_str() {
235            "jpeg" | "jpg" => image::load_from_memory_with_format(bytes, image::ImageFormat::Jpeg)?,
236            "png" => image::load_from_memory_with_format(bytes, image::ImageFormat::Png)?,
237            "rgb" | "rgb8" => {
238                let img = image::RgbImage::from_raw(msg.width, msg.height, bytes.clone())
239                    .context("Failed to create RGB image from raw bytes")?;
240                image::DynamicImage::ImageRgb8(img)
241            }
242            _ => anyhow::bail!("Unsupported image format: {}", msg.format),
243        };
244
245        Ok(image)
246    }
247
248    /// Decode camera image from various formats (instance method for compatibility)
249    #[allow(dead_code)]
250    fn decode_camera_image(&self, msg: &CameraImage) -> Result<image::DynamicImage> {
251        Self::decode_camera_image_static(msg)
252    }
253
254    /// Preprocess image for YOLO (static version for async)
255    fn preprocess_yolo_image_static(image: image::DynamicImage, size: u32) -> Result<ndarray::Array4<f32>> {
256        // Resize to square using Triangle filter (much faster than Lanczos3, minimal quality loss for object detection)
257        let resized = image.resize_exact(size, size, image::imageops::FilterType::Triangle);
258        let rgb = resized.to_rgb8();
259
260        // Create tensor [1, 3, H, W] with normalization [0, 1]
261        let mut array = ndarray::Array4::<f32>::zeros((1, 3, size as usize, size as usize));
262
263        for (x, y, pixel) in rgb.enumerate_pixels() {
264            array[[0, 0, y as usize, x as usize]] = pixel[0] as f32 / 255.0;
265            array[[0, 1, y as usize, x as usize]] = pixel[1] as f32 / 255.0;
266            array[[0, 2, y as usize, x as usize]] = pixel[2] as f32 / 255.0;
267        }
268
269        Ok(array)
270    }
271
272    /// Preprocess image for YOLO (instance method for compatibility)
273    #[allow(dead_code)]
274    fn preprocess_yolo_image(&self, image: image::DynamicImage) -> Result<ndarray::Array4<f32>> {
275        Self::preprocess_yolo_image_static(image, self.config.input_size)
276    }
277
278    /// Process YOLO output tensor to detection list
279    fn process_yolo_output(&self, output_slice: &[f32], _img_width: u32, _img_height: u32) -> Result<Vec<Detection>> {
280        // YOLOv8 output format: [1, 84, 8400]
281        // 84 = 4 bbox coords + 80 class scores
282        let num_detections = 8400;
283        let num_classes = 80;
284
285        let mut raw_detections = Vec::new();
286        let mut max_score_seen = 0.0f32;
287
288        for i in 0..num_detections {
289            // Get bbox coordinates (center_x, center_y, width, height) - normalized to input size
290            let cx = output_slice[i];
291            let cy = output_slice[num_detections + i];
292            let w = output_slice[2 * num_detections + i];
293            let h = output_slice[3 * num_detections + i];
294
295            // Find best class and confidence (optimized with iterator)
296            let (max_class_idx, max_score) = (0..num_classes)
297                .map(|c| {
298                    let score = output_slice[(4 + c) * num_detections + i];
299                    (c, score)
300                })
301                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
302                .unwrap_or((0, 0.0));
303
304            // Track maximum score seen across all anchors for debugging
305            if max_score > max_score_seen {
306                max_score_seen = max_score;
307            }
308
309            if max_score >= self.config.confidence_threshold {
310                // Convert from center format to corner format
311                // Normalize to 0-1 based on input size
312                let input_size = self.config.input_size as f32;
313                let x = (cx - w / 2.0) / input_size;
314                let y = (cy - h / 2.0) / input_size;
315                let width = w / input_size;
316                let height = h / input_size;
317
318                raw_detections.push(Detection {
319                    class_id: max_class_idx as u32,
320                    class_name: self
321                        .class_names
322                        .get(max_class_idx)
323                        .cloned()
324                        .unwrap_or_else(|| format!("class_{}", max_class_idx)),
325                    confidence: max_score,
326                    bbox: BoundingBox::new(x, y, width, height),
327                });
328            }
329        }
330
331        // Apply Non-Maximum Suppression
332        let filtered = self.non_max_suppression(raw_detections);
333
334        // Debug logging
335        if filtered.is_empty() {
336            tracing::debug!(
337                "No detections above threshold {:.2}. Max score seen: {:.4}",
338                self.config.confidence_threshold,
339                max_score_seen
340            );
341        }
342
343        Ok(filtered)
344    }
345
346    /// Apply Non-Maximum Suppression to filter overlapping detections
347    fn non_max_suppression(&self, mut detections: Vec<Detection>) -> Vec<Detection> {
348        // Sort by confidence (descending)
349        detections.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
350
351        let mut keep = Vec::new();
352        while !detections.is_empty() {
353            let current = detections.remove(0);
354            keep.push(current.clone());
355
356            detections.retain(|det| {
357                // Keep if different class or IoU below threshold
358                det.class_id != current.class_id || det.bbox.iou(&current.bbox) < self.config.iou_threshold
359            });
360        }
361
362        keep
363    }
364
365    /// Get COCO class names (80 classes for COCO dataset)
366    fn get_coco_class_names() -> Vec<String> {
367        vec![
368            "person",
369            "bicycle",
370            "car",
371            "motorcycle",
372            "airplane",
373            "bus",
374            "train",
375            "truck",
376            "boat",
377            "traffic light",
378            "fire hydrant",
379            "stop sign",
380            "parking meter",
381            "bench",
382            "bird",
383            "cat",
384            "dog",
385            "horse",
386            "sheep",
387            "cow",
388            "elephant",
389            "bear",
390            "zebra",
391            "giraffe",
392            "backpack",
393            "umbrella",
394            "handbag",
395            "tie",
396            "suitcase",
397            "frisbee",
398            "skis",
399            "snowboard",
400            "sports ball",
401            "kite",
402            "baseball bat",
403            "baseball glove",
404            "skateboard",
405            "surfboard",
406            "tennis racket",
407            "bottle",
408            "wine glass",
409            "cup",
410            "fork",
411            "knife",
412            "spoon",
413            "bowl",
414            "banana",
415            "apple",
416            "sandwich",
417            "orange",
418            "broccoli",
419            "carrot",
420            "hot dog",
421            "pizza",
422            "donut",
423            "cake",
424            "chair",
425            "couch",
426            "potted plant",
427            "bed",
428            "dining table",
429            "toilet",
430            "tv",
431            "laptop",
432            "mouse",
433            "remote",
434            "keyboard",
435            "cell phone",
436            "microwave",
437            "oven",
438            "toaster",
439            "sink",
440            "refrigerator",
441            "book",
442            "clock",
443            "vase",
444            "scissors",
445            "teddy bear",
446            "hair drier",
447            "toothbrush",
448        ]
449        .iter()
450        .map(|s| s.to_string())
451        .collect()
452    }
453}
454
455// === Node Entry Point ===
456
457pub async fn run() -> Result<()> {
458    info!("🤖 Starting Object Detector Node");
459
460    // Create context
461    let ctx = Context::new("object-detector").await?;
462
463    // Start automatic health reporting (reports healthy every 5 seconds)
464    ctx.start_health_reporting(|| async { HealthStatus::healthy() }).await?;
465
466    // Load config
467    let config: ObjectDetectorConfig = ctx.load_node_config("object-detector").await?;
468    info!("Configuration: {:?}", config);
469
470    // Load ONNX model
471    let model_path = PathBuf::from(&config.model_path);
472    if !model_path.exists() {
473        anyhow::bail!(
474            "Model not found at {}. Please download a YOLOv8 ONNX model.",
475            model_path.display()
476        );
477    }
478
479    info!("📦 Loading model from: {}", model_path.display());
480
481    // INT8 quantization support - use quantized model if enabled
482    let final_model_path = if config.use_int8 {
483        let int8_path =
484            model_path.with_file_name(model_path.file_stem().unwrap().to_string_lossy().to_string() + "-int8.onnx");
485
486        if int8_path.exists() {
487            info!("🔢 Using INT8 quantized model for 2x speedup");
488            int8_path
489        } else {
490            info!(
491                "⚠️  INT8 enabled but quantized model not found at {}",
492                int8_path.display()
493            );
494            info!("   Falling back to FP32 model. Run conversion script to create INT8 model.");
495            model_path
496        }
497    } else {
498        model_path
499    };
500
501    // Try to use hardware acceleration (CoreML on macOS, CUDA on Linux with GPU)
502    let mut session_builder = Session::builder()?;
503
504    // Configure for INT8 if enabled
505    if config.use_int8 {
506        session_builder = session_builder
507            .with_optimization_level(GraphOptimizationLevel::Level3)?
508            .with_intra_threads(4)?; // Parallel INT8 ops
509    }
510
511    let session = session_builder
512        .with_execution_providers([
513            // CoreML for macOS (Apple Silicon or Intel with Neural Engine)
514            #[cfg(target_os = "macos")]
515            ort::execution_providers::CoreMLExecutionProvider::default().build(),
516            // CUDA for NVIDIA GPUs (Linux/Windows)
517            #[cfg(not(target_os = "macos"))]
518            ort::execution_providers::CUDAExecutionProvider::default().build(),
519            // Fallback to CPU if hardware acceleration unavailable
520            ort::execution_providers::CPUExecutionProvider::default().build(),
521        ])?
522        .commit_from_file(&final_model_path)?;
523
524    info!(
525        "✅ Model loaded successfully ({}) with hardware acceleration",
526        if config.use_int8 { "INT8" } else { "FP32" }
527    );
528
529    // Create node
530    let mut node = ObjectDetectorNode::new(config.clone());
531    node.session = Some(session);
532
533    // Setup topics
534    let input_topic = config.input_topic();
535    let output_topic = config.output_topic();
536    let control_topic = config.control_topic();
537
538    info!("📡 Input: {}", input_topic);
539    info!("📡 Control: {}", control_topic);
540    info!("📤 Output: {}", output_topic);
541    info!(
542        "⏸️  Starting in {} mode",
543        if config.default_enabled { "ACTIVE" } else { "IDLE" }
544    );
545
546    if config.default_enabled {
547        node.mode = InferenceMode::Active;
548    }
549
550    // Create topic references (leak strings for 'static lifetime)
551    let input_topic_str: &'static str = Box::leak(input_topic.into_boxed_str());
552    let control_topic_str: &'static str = Box::leak(control_topic.into_boxed_str());
553    let output_topic_str: &'static str = Box::leak(output_topic.into_boxed_str());
554
555    let camera_topic = Topic::<mecha10_core::messages::RedisMessage<CameraImage>>::new(input_topic_str);
556    let control_topic = Topic::<InferenceCommand>::new(control_topic_str);
557    let output_topic = Topic::<DetectionResult>::new(output_topic_str);
558
559    // Subscribe
560    let mut camera_receiver = ctx.subscribe(camera_topic).await?;
561    let mut control_receiver = ctx.subscribe(control_topic).await?;
562
563    info!("✅ Subscribed to camera and control topics");
564    info!("🔄 Entering main processing loop with frame dropping (latest-only)");
565    info!("🕐 System time at start: {} μs", now_micros());
566
567    // Main loop
568    loop {
569        tokio::select! {
570            // Handle control commands
571            Some(control_msg) = control_receiver.recv() => {
572                node.handle_control(control_msg);
573            }
574
575            // Handle camera frames - with aggressive frame dropping for low latency
576            Some(camera_envelope) = camera_receiver.recv() => {
577                // CRITICAL: Drop all queued frames to ensure we process only the latest
578                // This prevents processing stale frames during long inference times
579                let mut latest_envelope = camera_envelope;
580                let mut dropped_count = 0;
581
582                // Drain channel to get the absolute latest frame
583                while let Ok(newer_envelope) = camera_receiver.try_recv() {
584                    latest_envelope = newer_envelope;
585                    dropped_count += 1;
586                }
587
588                if dropped_count > 0 {
589                    tracing::debug!("📉 Dropped {} old frames, processing latest", dropped_count);
590                }
591
592                // Check frame age for diagnostics (but don't reject - timestamps may use different epochs)
593                let current_time = now_micros();
594                let frame_timestamp = latest_envelope.timestamp;
595
596                // Debug: Log timestamps on first few frames to diagnose epoch issues
597                static FRAME_DEBUG_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
598                let debug_count = FRAME_DEBUG_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
599                if debug_count < 3 {
600                    tracing::info!(
601                        "🕐 Timestamp debug: system={} μs, frame={} μs, diff={} μs",
602                        current_time,
603                        frame_timestamp,
604                        current_time.saturating_sub(frame_timestamp)
605                    );
606                }
607
608                // Only log frame age if it seems reasonable (< 10 seconds old)
609                // This handles cases where Godot and system use different time bases
610                if current_time > frame_timestamp {
611                    let frame_age_us = current_time - frame_timestamp;
612                    let frame_age_ms = frame_age_us as f64 / 1000.0;
613
614                    if frame_age_ms > 1000.0 && frame_age_ms < 10_000.0 {
615                        // Frame is 1-10 seconds old - log warning but still process
616                        tracing::warn!(
617                            "⏰ Processing old frame ({:.1}ms old) - possible system lag",
618                            frame_age_ms
619                        );
620                    } else if frame_age_ms > 100.0 && frame_age_ms < 1000.0 {
621                        // Frame is 100ms-1s old - debug log only
622                        tracing::debug!("⏱️  Frame age: {:.1}ms", frame_age_ms);
623                    }
624                    // If > 10s, likely timestamp epoch mismatch - ignore
625                }
626
627                let camera_msg = latest_envelope.payload;
628
629                if node.mode == InferenceMode::Active {
630                    node.frame_count += 1;
631
632                    // Process every Nth frame based on config
633                    if node.frame_count % config.frame_skip as u64 == 0 {
634                        match node.detect_objects(&camera_msg).await {
635                            Ok(result) => {
636                                info!(
637                                    "🎯 Detected {} objects in {:.1}ms",
638                                    result.detections.len(),
639                                    result.inference_time_ms
640                                );
641
642                                // Log detected classes
643                                for det in &result.detections {
644                                    info!(
645                                        "  - {} ({:.1}%)",
646                                        det.class_name,
647                                        det.confidence * 100.0
648                                    );
649                                }
650
651                                // Publish detections
652                                ctx.publish_to(output_topic, &result).await?;
653                            }
654                            Err(e) => {
655                                tracing::warn!("❌ Detection failed: {}", e);
656                            }
657                        }
658                    }
659                }
660            }
661        }
662    }
663}