1mod config;
16
17pub use config::ObjectDetectorConfig;
18
19use anyhow::{Context as AnyhowContext, Result};
20use mecha10_core::health::HealthReportingExt;
21use mecha10_core::messages::{HealthStatus, Message};
22use mecha10_core::prelude::*;
23use mecha10_core::topics::Topic;
24use ort::session::builder::GraphOptimizationLevel;
25use ort::session::Session;
26use serde::{Deserialize, Serialize};
27use std::path::PathBuf;
28use std::time::Instant;
29use tracing::info;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CameraImage {
36 pub camera_id: String,
37 pub width: u32,
38 pub height: u32,
39 pub timestamp: u64,
40 #[serde(with = "serde_bytes")]
41 pub image_bytes: Vec<u8>,
42 pub format: String,
43}
44
45impl Message for CameraImage {}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct BoundingBox {
50 pub x: f32,
51 pub y: f32,
52 pub width: f32,
53 pub height: f32,
54}
55
56impl BoundingBox {
57 pub fn new(x: f32, y: f32, width: f32, height: f32) -> Self {
58 Self { x, y, width, height }
59 }
60
61 pub fn iou(&self, other: &BoundingBox) -> f32 {
63 let x1 = self.x.max(other.x);
64 let y1 = self.y.max(other.y);
65 let x2 = (self.x + self.width).min(other.x + other.width);
66 let y2 = (self.y + self.height).min(other.y + other.height);
67
68 if x2 < x1 || y2 < y1 {
69 return 0.0;
70 }
71
72 let intersection = (x2 - x1) * (y2 - y1);
73 let union = (self.width * self.height) + (other.width * other.height) - intersection;
74
75 intersection / union
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct Detection {
82 pub class_id: u32,
83 pub class_name: String,
84 pub confidence: f32,
85 pub bbox: BoundingBox,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct DetectionResult {
91 pub frame_id: u64,
92 pub timestamp: u64,
93 pub detections: Vec<Detection>,
94 pub inference_time_ms: f32,
95 pub model_name: String,
96}
97
98impl Message for DetectionResult {}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct InferenceCommand {
103 pub action: String,
104 #[serde(skip_serializing_if = "Option::is_none")]
105 pub params: Option<serde_json::Value>,
106}
107
108impl Message for InferenceCommand {}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112enum InferenceMode {
113 Idle,
114 Active,
115}
116
117pub struct ObjectDetectorNode {
120 config: ObjectDetectorConfig,
121 session: Option<Session>,
122 frame_count: u64,
123 mode: InferenceMode,
124 class_names: Vec<String>,
125 preprocessing_semaphore: std::sync::Arc<tokio::sync::Semaphore>,
127}
128
129impl ObjectDetectorNode {
130 pub fn new(config: ObjectDetectorConfig) -> Self {
131 let max_concurrent = config.max_async_frames.max(1); let preprocessing_semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(max_concurrent));
135
136 Self {
137 config,
138 session: None,
139 frame_count: 0,
140 mode: InferenceMode::Idle,
141 class_names: Self::get_coco_class_names(),
142 preprocessing_semaphore,
143 }
144 }
145
146 fn handle_control(&mut self, cmd: InferenceCommand) {
148 info!("🎮 Received control command: {:?}", cmd.action);
149 match cmd.action.as_str() {
150 "enable" => {
151 self.mode = InferenceMode::Active;
152 info!("✅ Object detection enabled");
153 }
154 "disable" => {
155 self.mode = InferenceMode::Idle;
156 info!("⏸️ Object detection disabled");
157 }
158 "set_threshold" => {
159 if let Some(params) = &cmd.params {
160 if let Some(conf) = params.get("confidence").and_then(|v| v.as_f64()) {
161 self.config.confidence_threshold = conf as f32;
162 info!("🎯 Confidence threshold set to {:.2}", conf);
163 }
164 }
165 }
166 _ => {
167 tracing::warn!("❌ Unknown command: {}", cmd.action);
168 }
169 }
170 }
171
172 async fn detect_objects(&mut self, msg: &CameraImage) -> Result<DetectionResult> {
174 let start = Instant::now();
175
176 let permit = self.preprocessing_semaphore.clone().acquire_owned().await?;
180
181 let msg_clone = msg.clone();
182 let input_size = self.config.input_size;
183
184 let input_tensor = tokio::task::spawn_blocking(move || -> Result<ndarray::Array4<f32>> {
185 let image = Self::decode_camera_image_static(&msg_clone)?;
187
188 let result = Self::preprocess_yolo_image_static(image, input_size);
190
191 drop(permit);
193
194 result
195 })
196 .await??;
197
198 let output_vec: Vec<f32> = {
200 let input_value = ort::value::TensorRef::from_array_view(input_tensor.view())?;
201
202 let session = self.session.as_mut().context("ONNX session not initialized")?;
203 let outputs = session.run(ort::inputs![input_value])?;
204
205 let output_array: ndarray::ArrayViewD<f32> = outputs[0]
207 .try_extract_array()
208 .context("Failed to extract output tensor")?;
209 output_array
210 .as_slice()
211 .context("Failed to get tensor data as slice")?
212 .to_vec()
213 };
214 let detections = self.process_yolo_output(&output_vec, msg.width, msg.height)?;
218
219 let inference_time_ms = start.elapsed().as_secs_f32() * 1000.0;
220
221 Ok(DetectionResult {
222 frame_id: self.frame_count,
223 timestamp: msg.timestamp,
224 detections,
225 inference_time_ms,
226 model_name: self.config.model_name.clone(),
227 })
228 }
229
230 fn decode_camera_image_static(msg: &CameraImage) -> Result<image::DynamicImage> {
232 let bytes = &msg.image_bytes;
233
234 let image = match msg.format.as_str() {
235 "jpeg" | "jpg" => image::load_from_memory_with_format(bytes, image::ImageFormat::Jpeg)?,
236 "png" => image::load_from_memory_with_format(bytes, image::ImageFormat::Png)?,
237 "rgb" | "rgb8" => {
238 let img = image::RgbImage::from_raw(msg.width, msg.height, bytes.clone())
239 .context("Failed to create RGB image from raw bytes")?;
240 image::DynamicImage::ImageRgb8(img)
241 }
242 _ => anyhow::bail!("Unsupported image format: {}", msg.format),
243 };
244
245 Ok(image)
246 }
247
248 #[allow(dead_code)]
250 fn decode_camera_image(&self, msg: &CameraImage) -> Result<image::DynamicImage> {
251 Self::decode_camera_image_static(msg)
252 }
253
254 fn preprocess_yolo_image_static(image: image::DynamicImage, size: u32) -> Result<ndarray::Array4<f32>> {
256 let resized = image.resize_exact(size, size, image::imageops::FilterType::Triangle);
258 let rgb = resized.to_rgb8();
259
260 let mut array = ndarray::Array4::<f32>::zeros((1, 3, size as usize, size as usize));
262
263 for (x, y, pixel) in rgb.enumerate_pixels() {
264 array[[0, 0, y as usize, x as usize]] = pixel[0] as f32 / 255.0;
265 array[[0, 1, y as usize, x as usize]] = pixel[1] as f32 / 255.0;
266 array[[0, 2, y as usize, x as usize]] = pixel[2] as f32 / 255.0;
267 }
268
269 Ok(array)
270 }
271
272 #[allow(dead_code)]
274 fn preprocess_yolo_image(&self, image: image::DynamicImage) -> Result<ndarray::Array4<f32>> {
275 Self::preprocess_yolo_image_static(image, self.config.input_size)
276 }
277
278 fn process_yolo_output(&self, output_slice: &[f32], _img_width: u32, _img_height: u32) -> Result<Vec<Detection>> {
280 let num_detections = 8400;
283 let num_classes = 80;
284
285 let mut raw_detections = Vec::new();
286 let mut max_score_seen = 0.0f32;
287
288 for i in 0..num_detections {
289 let cx = output_slice[i];
291 let cy = output_slice[num_detections + i];
292 let w = output_slice[2 * num_detections + i];
293 let h = output_slice[3 * num_detections + i];
294
295 let (max_class_idx, max_score) = (0..num_classes)
297 .map(|c| {
298 let score = output_slice[(4 + c) * num_detections + i];
299 (c, score)
300 })
301 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
302 .unwrap_or((0, 0.0));
303
304 if max_score > max_score_seen {
306 max_score_seen = max_score;
307 }
308
309 if max_score >= self.config.confidence_threshold {
310 let input_size = self.config.input_size as f32;
313 let x = (cx - w / 2.0) / input_size;
314 let y = (cy - h / 2.0) / input_size;
315 let width = w / input_size;
316 let height = h / input_size;
317
318 raw_detections.push(Detection {
319 class_id: max_class_idx as u32,
320 class_name: self
321 .class_names
322 .get(max_class_idx)
323 .cloned()
324 .unwrap_or_else(|| format!("class_{}", max_class_idx)),
325 confidence: max_score,
326 bbox: BoundingBox::new(x, y, width, height),
327 });
328 }
329 }
330
331 let filtered = self.non_max_suppression(raw_detections);
333
334 if filtered.is_empty() {
336 tracing::debug!(
337 "No detections above threshold {:.2}. Max score seen: {:.4}",
338 self.config.confidence_threshold,
339 max_score_seen
340 );
341 }
342
343 Ok(filtered)
344 }
345
346 fn non_max_suppression(&self, mut detections: Vec<Detection>) -> Vec<Detection> {
348 detections.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
350
351 let mut keep = Vec::new();
352 while !detections.is_empty() {
353 let current = detections.remove(0);
354 keep.push(current.clone());
355
356 detections.retain(|det| {
357 det.class_id != current.class_id || det.bbox.iou(¤t.bbox) < self.config.iou_threshold
359 });
360 }
361
362 keep
363 }
364
365 fn get_coco_class_names() -> Vec<String> {
367 vec![
368 "person",
369 "bicycle",
370 "car",
371 "motorcycle",
372 "airplane",
373 "bus",
374 "train",
375 "truck",
376 "boat",
377 "traffic light",
378 "fire hydrant",
379 "stop sign",
380 "parking meter",
381 "bench",
382 "bird",
383 "cat",
384 "dog",
385 "horse",
386 "sheep",
387 "cow",
388 "elephant",
389 "bear",
390 "zebra",
391 "giraffe",
392 "backpack",
393 "umbrella",
394 "handbag",
395 "tie",
396 "suitcase",
397 "frisbee",
398 "skis",
399 "snowboard",
400 "sports ball",
401 "kite",
402 "baseball bat",
403 "baseball glove",
404 "skateboard",
405 "surfboard",
406 "tennis racket",
407 "bottle",
408 "wine glass",
409 "cup",
410 "fork",
411 "knife",
412 "spoon",
413 "bowl",
414 "banana",
415 "apple",
416 "sandwich",
417 "orange",
418 "broccoli",
419 "carrot",
420 "hot dog",
421 "pizza",
422 "donut",
423 "cake",
424 "chair",
425 "couch",
426 "potted plant",
427 "bed",
428 "dining table",
429 "toilet",
430 "tv",
431 "laptop",
432 "mouse",
433 "remote",
434 "keyboard",
435 "cell phone",
436 "microwave",
437 "oven",
438 "toaster",
439 "sink",
440 "refrigerator",
441 "book",
442 "clock",
443 "vase",
444 "scissors",
445 "teddy bear",
446 "hair drier",
447 "toothbrush",
448 ]
449 .iter()
450 .map(|s| s.to_string())
451 .collect()
452 }
453}
454
455pub async fn run() -> Result<()> {
458 info!("🤖 Starting Object Detector Node");
459
460 let ctx = Context::new("object-detector").await?;
462
463 ctx.start_health_reporting(|| async { HealthStatus::healthy() }).await?;
465
466 let config: ObjectDetectorConfig = ctx.load_node_config("object-detector").await?;
468 info!("Configuration: {:?}", config);
469
470 let model_path = PathBuf::from(&config.model_path);
472 if !model_path.exists() {
473 anyhow::bail!(
474 "Model not found at {}. Please download a YOLOv8 ONNX model.",
475 model_path.display()
476 );
477 }
478
479 info!("📦 Loading model from: {}", model_path.display());
480
481 let final_model_path = if config.use_int8 {
483 let int8_path =
484 model_path.with_file_name(model_path.file_stem().unwrap().to_string_lossy().to_string() + "-int8.onnx");
485
486 if int8_path.exists() {
487 info!("🔢 Using INT8 quantized model for 2x speedup");
488 int8_path
489 } else {
490 info!(
491 "⚠️ INT8 enabled but quantized model not found at {}",
492 int8_path.display()
493 );
494 info!(" Falling back to FP32 model. Run conversion script to create INT8 model.");
495 model_path
496 }
497 } else {
498 model_path
499 };
500
501 let mut session_builder = Session::builder()?;
503
504 if config.use_int8 {
506 session_builder = session_builder
507 .with_optimization_level(GraphOptimizationLevel::Level3)?
508 .with_intra_threads(4)?; }
510
511 let session = session_builder
512 .with_execution_providers([
513 #[cfg(target_os = "macos")]
515 ort::execution_providers::CoreMLExecutionProvider::default().build(),
516 #[cfg(not(target_os = "macos"))]
518 ort::execution_providers::CUDAExecutionProvider::default().build(),
519 ort::execution_providers::CPUExecutionProvider::default().build(),
521 ])?
522 .commit_from_file(&final_model_path)?;
523
524 info!(
525 "✅ Model loaded successfully ({}) with hardware acceleration",
526 if config.use_int8 { "INT8" } else { "FP32" }
527 );
528
529 let mut node = ObjectDetectorNode::new(config.clone());
531 node.session = Some(session);
532
533 let input_topic = config.input_topic();
535 let output_topic = config.output_topic();
536 let control_topic = config.control_topic();
537
538 info!("📡 Input: {}", input_topic);
539 info!("📡 Control: {}", control_topic);
540 info!("📤 Output: {}", output_topic);
541 info!(
542 "⏸️ Starting in {} mode",
543 if config.default_enabled { "ACTIVE" } else { "IDLE" }
544 );
545
546 if config.default_enabled {
547 node.mode = InferenceMode::Active;
548 }
549
550 let input_topic_str: &'static str = Box::leak(input_topic.into_boxed_str());
552 let control_topic_str: &'static str = Box::leak(control_topic.into_boxed_str());
553 let output_topic_str: &'static str = Box::leak(output_topic.into_boxed_str());
554
555 let camera_topic = Topic::<mecha10_core::messages::RedisMessage<CameraImage>>::new(input_topic_str);
556 let control_topic = Topic::<InferenceCommand>::new(control_topic_str);
557 let output_topic = Topic::<DetectionResult>::new(output_topic_str);
558
559 let mut camera_receiver = ctx.subscribe(camera_topic).await?;
561 let mut control_receiver = ctx.subscribe(control_topic).await?;
562
563 info!("✅ Subscribed to camera and control topics");
564 info!("🔄 Entering main processing loop with frame dropping (latest-only)");
565 info!("🕐 System time at start: {} μs", now_micros());
566
567 loop {
569 tokio::select! {
570 Some(control_msg) = control_receiver.recv() => {
572 node.handle_control(control_msg);
573 }
574
575 Some(camera_envelope) = camera_receiver.recv() => {
577 let mut latest_envelope = camera_envelope;
580 let mut dropped_count = 0;
581
582 while let Ok(newer_envelope) = camera_receiver.try_recv() {
584 latest_envelope = newer_envelope;
585 dropped_count += 1;
586 }
587
588 if dropped_count > 0 {
589 tracing::debug!("📉 Dropped {} old frames, processing latest", dropped_count);
590 }
591
592 let current_time = now_micros();
594 let frame_timestamp = latest_envelope.timestamp;
595
596 static FRAME_DEBUG_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
598 let debug_count = FRAME_DEBUG_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
599 if debug_count < 3 {
600 tracing::info!(
601 "🕐 Timestamp debug: system={} μs, frame={} μs, diff={} μs",
602 current_time,
603 frame_timestamp,
604 current_time.saturating_sub(frame_timestamp)
605 );
606 }
607
608 if current_time > frame_timestamp {
611 let frame_age_us = current_time - frame_timestamp;
612 let frame_age_ms = frame_age_us as f64 / 1000.0;
613
614 if frame_age_ms > 1000.0 && frame_age_ms < 10_000.0 {
615 tracing::warn!(
617 "⏰ Processing old frame ({:.1}ms old) - possible system lag",
618 frame_age_ms
619 );
620 } else if frame_age_ms > 100.0 && frame_age_ms < 1000.0 {
621 tracing::debug!("⏱️ Frame age: {:.1}ms", frame_age_ms);
623 }
624 }
626
627 let camera_msg = latest_envelope.payload;
628
629 if node.mode == InferenceMode::Active {
630 node.frame_count += 1;
631
632 if node.frame_count % config.frame_skip as u64 == 0 {
634 match node.detect_objects(&camera_msg).await {
635 Ok(result) => {
636 info!(
637 "🎯 Detected {} objects in {:.1}ms",
638 result.detections.len(),
639 result.inference_time_ms
640 );
641
642 for det in &result.detections {
644 info!(
645 " - {} ({:.1}%)",
646 det.class_name,
647 det.confidence * 100.0
648 );
649 }
650
651 ctx.publish_to(output_topic, &result).await?;
653 }
654 Err(e) => {
655 tracing::warn!("❌ Detection failed: {}", e);
656 }
657 }
658 }
659 }
660 }
661 }
662 }
663}