Skip to main content

yscv_detect/
yolo.rs

1//! YOLOv8 postprocessing pipeline.
2//!
3//! Decodes raw YOLOv8 model output tensors into [`Detection`]s. The actual
4//! model inference is left to the caller (e.g. via `yscv-onnx`); this module
5//! handles the coordinate decoding, confidence filtering, and NMS step.
6
7use yscv_tensor::Tensor;
8
9use crate::{BoundingBox, Detection, non_max_suppression};
10
11/// YOLO model configuration.
12#[derive(Debug, Clone)]
13pub struct YoloConfig {
14    /// Input image size (square).
15    pub input_size: usize,
16    /// Number of classes.
17    pub num_classes: usize,
18    /// Confidence threshold.
19    pub conf_threshold: f32,
20    /// IoU threshold for NMS.
21    pub iou_threshold: f32,
22    /// Class labels.
23    pub class_labels: Vec<String>,
24}
25
26/// Returns the 80 COCO class labels.
27#[rustfmt::skip]
28pub fn coco_labels() -> Vec<String> {
29    [
30        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
31        "truck", "boat", "traffic light", "fire hydrant", "stop sign",
32        "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep",
33        "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
34        "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
35        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
36        "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork",
37        "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
38        "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
39        "couch", "potted plant", "bed", "dining table", "toilet", "tv",
40        "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
41        "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
42        "scissors", "teddy bear", "hair drier", "toothbrush",
43    ]
44    .iter()
45    .map(|s| (*s).to_string())
46    .collect()
47}
48
49/// Returns the default YOLOv8 config for the COCO 80-class dataset.
50pub fn yolov8_coco_config() -> YoloConfig {
51    YoloConfig {
52        input_size: 640,
53        num_classes: 80,
54        conf_threshold: 0.25,
55        iou_threshold: 0.45,
56        class_labels: coco_labels(),
57    }
58}
59
60/// Decode YOLOv8 raw output tensor into detections.
61///
62/// YOLOv8 output format: `[1, 4 + num_classes, num_preds]` where the first
63/// four rows are `(cx, cy, w, h)` and the remaining rows are per-class
64/// confidence scores. A typical COCO model emits `[1, 84, 8400]`.
65///
66/// Coordinates in the output are relative to the **letterboxed** input image
67/// (i.e. `input_size x input_size`). This function maps them back to the
68/// original `(orig_width, orig_height)` frame.
69///
70/// Returns filtered detections after confidence thresholding and NMS.
71pub fn decode_yolov8_output(
72    output: &Tensor,
73    config: &YoloConfig,
74    orig_width: usize,
75    orig_height: usize,
76) -> Vec<Detection> {
77    let shape = output.shape();
78    // Expect [1, 4+num_classes, num_preds]
79    if shape.len() != 3 || shape[0] != 1 {
80        return Vec::new();
81    }
82    let rows = shape[1]; // 4 + num_classes
83    let num_preds = shape[2];
84    if rows < 5 {
85        return Vec::new();
86    }
87    let num_classes = rows - 4;
88
89    let data = output.data();
90
91    // Compute letterbox scale and padding so we can map coords back.
92    let scale = (config.input_size as f32 / orig_width as f32)
93        .min(config.input_size as f32 / orig_height as f32);
94    let new_w = orig_width as f32 * scale;
95    let new_h = orig_height as f32 * scale;
96    let pad_x = (config.input_size as f32 - new_w) / 2.0;
97    let pad_y = (config.input_size as f32 - new_h) / 2.0;
98
99    let mut candidates = Vec::new();
100
101    for i in 0..num_preds {
102        // Output is laid out row-major: data[row * num_preds + col]
103        let cx = data[i];
104        let cy = data[num_preds + i];
105        let w = data[2 * num_preds + i];
106        let h = data[3 * num_preds + i];
107
108        // Find best class
109        let mut best_score = f32::NEG_INFINITY;
110        let mut best_class = 0usize;
111        for c in 0..num_classes {
112            let s = data[(4 + c) * num_preds + i];
113            if s > best_score {
114                best_score = s;
115                best_class = c;
116            }
117        }
118
119        if best_score < config.conf_threshold {
120            continue;
121        }
122
123        // Convert from letterbox coordinates to original image coordinates.
124        let x1 = ((cx - w / 2.0) - pad_x) / scale;
125        let y1 = ((cy - h / 2.0) - pad_y) / scale;
126        let x2 = ((cx + w / 2.0) - pad_x) / scale;
127        let y2 = ((cy + h / 2.0) - pad_y) / scale;
128
129        // Clamp to image bounds.
130        let x1 = x1.max(0.0).min(orig_width as f32);
131        let y1 = y1.max(0.0).min(orig_height as f32);
132        let x2 = x2.max(0.0).min(orig_width as f32);
133        let y2 = y2.max(0.0).min(orig_height as f32);
134
135        candidates.push(Detection {
136            bbox: BoundingBox { x1, y1, x2, y2 },
137            score: best_score,
138            class_id: best_class,
139        });
140    }
141
142    non_max_suppression(&candidates, config.iou_threshold, candidates.len().max(1))
143}
144
145/// Default COCO config for YOLOv11 models.
146///
147/// Same as YOLOv8 COCO — 80 classes, 640 input, 0.25 conf, 0.45 IoU.
148pub fn yolov11_coco_config() -> YoloConfig {
149    yolov8_coco_config()
150}
151
152/// Decode YOLOv11 output tensor into detections.
153///
154/// YOLOv11 output shape: `[1, num_preds, 4 + num_classes]` (transposed vs YOLOv8).
155/// Each row is `[cx, cy, w, h, class_0_score, class_1_score, ...]`.
156/// Coordinates are in letterboxed image space; this function maps them back
157/// to the original image dimensions.
158pub fn decode_yolov11_output(
159    output: &Tensor,
160    config: &YoloConfig,
161    orig_width: usize,
162    orig_height: usize,
163) -> Vec<Detection> {
164    let shape = output.shape();
165    // YOLOv11: [1, N, 4+C] or [N, 4+C]
166    let (num_preds, cols) = if shape.len() == 3 {
167        (shape[1], shape[2])
168    } else if shape.len() == 2 {
169        (shape[0], shape[1])
170    } else {
171        return Vec::new();
172    };
173
174    if cols < 5 {
175        return Vec::new();
176    }
177    let num_classes = cols - 4;
178
179    let data = output.data();
180
181    let scale = (config.input_size as f32 / orig_width as f32)
182        .min(config.input_size as f32 / orig_height as f32);
183    let new_w = orig_width as f32 * scale;
184    let new_h = orig_height as f32 * scale;
185    let pad_x = (config.input_size as f32 - new_w) / 2.0;
186    let pad_y = (config.input_size as f32 - new_h) / 2.0;
187
188    let mut candidates = Vec::new();
189
190    // Skip batch dimension offset if present
191    let base = if shape.len() == 3 { 0 } else { 0 };
192
193    for i in 0..num_preds {
194        let row = base + i * cols;
195        let cx = data[row];
196        let cy = data[row + 1];
197        let w = data[row + 2];
198        let h = data[row + 3];
199
200        let mut best_score = f32::NEG_INFINITY;
201        let mut best_class = 0usize;
202        for c in 0..num_classes {
203            let s = data[row + 4 + c];
204            if s > best_score {
205                best_score = s;
206                best_class = c;
207            }
208        }
209
210        if best_score < config.conf_threshold {
211            continue;
212        }
213
214        let x1 = ((cx - w / 2.0) - pad_x) / scale;
215        let y1 = ((cy - h / 2.0) - pad_y) / scale;
216        let x2 = ((cx + w / 2.0) - pad_x) / scale;
217        let y2 = ((cy + h / 2.0) - pad_y) / scale;
218
219        let x1 = x1.max(0.0).min(orig_width as f32);
220        let y1 = y1.max(0.0).min(orig_height as f32);
221        let x2 = x2.max(0.0).min(orig_width as f32);
222        let y2 = y2.max(0.0).min(orig_height as f32);
223
224        candidates.push(Detection {
225            bbox: BoundingBox { x1, y1, x2, y2 },
226            score: best_score,
227            class_id: best_class,
228        });
229    }
230
231    non_max_suppression(&candidates, config.iou_threshold, candidates.len().max(1))
232}
233
234/// Apply letterbox preprocessing: resize an image to a square with padding.
235///
236/// The input `image` is an `[H, W, 3]` f32 tensor (RGB, normalised to 0..1).
237/// Returns `(padded, scale, pad_x, pad_y)` where `padded` has shape
238/// `[target_size, target_size, 3]`.
239pub fn letterbox_preprocess(image: &Tensor, target_size: usize) -> (Tensor, f32, f32, f32) {
240    let shape = image.shape();
241    assert!(
242        shape.len() == 3 && shape[2] == 3,
243        "expected [H, W, 3] tensor"
244    );
245    let src_h = shape[0];
246    let src_w = shape[1];
247    let data = image.data();
248
249    let scale = (target_size as f32 / src_w as f32).min(target_size as f32 / src_h as f32);
250    let new_w = (src_w as f32 * scale).round() as usize;
251    let new_h = (src_h as f32 * scale).round() as usize;
252    let pad_x = (target_size - new_w) as f32 / 2.0;
253    let pad_y = (target_size - new_h) as f32 / 2.0;
254    let pad_left = pad_x.floor() as usize;
255    let pad_top = pad_y.floor() as usize;
256
257    // Fill with 0.5 grey (common YOLO letterbox fill).
258    let total = target_size * target_size * 3;
259    let mut out = vec![0.5f32; total];
260
261    // Nearest-neighbour resize into the padded region.
262    let scale_x = src_w as f32 / new_w as f32;
263    let scale_y = src_h as f32 / new_h as f32;
264
265    for y in 0..new_h {
266        let src_y = ((y as f32 * scale_y) as usize).min(src_h - 1);
267        for x in 0..new_w {
268            let src_x = ((x as f32 * scale_x) as usize).min(src_w - 1);
269            let dst_idx = ((pad_top + y) * target_size + (pad_left + x)) * 3;
270            let src_idx = (src_y * src_w + src_x) * 3;
271            out[dst_idx] = data[src_idx];
272            out[dst_idx + 1] = data[src_idx + 1];
273            out[dst_idx + 2] = data[src_idx + 2];
274        }
275    }
276
277    let tensor = Tensor::from_vec(vec![target_size, target_size, 3], out)
278        .expect("letterbox output tensor creation");
279    (tensor, scale, pad_x, pad_y)
280}
281
282/// Convert an `[H, W, 3]` HWC tensor to `[1, 3, H, W]` NCHW f32 data.
283///
284/// This is a pure layout transformation — no normalisation is applied
285/// (the input is assumed to already be in `[0, 1]`).
286#[allow(dead_code)]
287fn hwc_to_nchw(hwc: &Tensor) -> Vec<f32> {
288    let shape = hwc.shape();
289    let h = shape[0];
290    let w = shape[1];
291    let data = hwc.data();
292    let mut nchw = vec![0.0f32; 3 * h * w];
293    for y in 0..h {
294        for x in 0..w {
295            let src = (y * w + x) * 3;
296            for c in 0..3 {
297                nchw[c * h * w + y * w + x] = data[src + c];
298            }
299        }
300    }
301    nchw
302}
303
304/// Run YOLOv8 inference using an ONNX model.
305///
306/// Takes an ONNX model, input image data (RGB, normalised to `[0,1]`) in
307/// `[1, 3, H, W]` NCHW format, original image dimensions, and a
308/// [`YoloConfig`].  Returns detected objects after NMS.
309#[cfg(feature = "onnx")]
310pub fn detect_yolov8_onnx(
311    model: &yscv_onnx::OnnxModel,
312    image_data: &[f32],
313    img_height: usize,
314    img_width: usize,
315    config: &YoloConfig,
316) -> Result<Vec<Detection>, crate::DetectError> {
317    use std::collections::HashMap;
318
319    let input_name = model
320        .inputs
321        .first()
322        .cloned()
323        .unwrap_or_else(|| "images".to_string());
324
325    let tensor = Tensor::from_vec(
326        vec![1, 3, config.input_size, config.input_size],
327        image_data.to_vec(),
328    )?;
329
330    let mut inputs = HashMap::new();
331    inputs.insert(input_name, tensor);
332
333    let outputs = yscv_onnx::run_onnx_model(model, inputs)?;
334
335    let output_name = model
336        .outputs
337        .first()
338        .cloned()
339        .unwrap_or_else(|| "output0".to_string());
340
341    let output_tensor =
342        outputs
343            .get(&output_name)
344            .ok_or_else(|| yscv_onnx::OnnxError::MissingInput {
345                node: "model_output".to_string(),
346                input: output_name,
347            })?;
348
349    Ok(decode_yolov8_output(
350        output_tensor,
351        config,
352        img_width,
353        img_height,
354    ))
355}
356
357/// Run the full YOLOv8 detection pipeline on an HWC image.
358///
359/// Accepts raw `[H, W, 3]` RGB f32 pixel data (normalised to `[0,1]`),
360/// applies letterbox preprocessing, runs ONNX inference, decodes the output,
361/// and returns the final detections.
362#[cfg(feature = "onnx")]
363pub fn detect_yolov8_from_rgb(
364    model: &yscv_onnx::OnnxModel,
365    rgb_data: &[f32],
366    height: usize,
367    width: usize,
368    config: &YoloConfig,
369) -> Result<Vec<Detection>, crate::DetectError> {
370    let image = Tensor::from_vec(vec![height, width, 3], rgb_data.to_vec())?;
371    let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&image, config.input_size);
372
373    let nchw = hwc_to_nchw(&letterboxed);
374
375    detect_yolov8_onnx(model, &nchw, height, width, config)
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn test_coco_labels_count() {
384        assert_eq!(coco_labels().len(), 80);
385    }
386
387    #[test]
388    fn test_yolov8_coco_config_defaults() {
389        let cfg = yolov8_coco_config();
390        assert_eq!(cfg.input_size, 640);
391        assert_eq!(cfg.num_classes, 80);
392        assert!((cfg.conf_threshold - 0.25).abs() < 1e-6);
393        assert!((cfg.iou_threshold - 0.45).abs() < 1e-6);
394        assert_eq!(cfg.class_labels.len(), 80);
395    }
396
397    /// Build a synthetic [1, 84, 8400] tensor with exactly one strong
398    /// prediction at index 0, class 5 with score 0.9.
399    fn make_one_detection_tensor() -> Tensor {
400        let num_classes = 80;
401        let rows = 4 + num_classes;
402        let num_preds = 8400;
403        let mut data = vec![0.0f32; rows * num_preds];
404
405        // Prediction at index 0: centre (320, 320), size 100x100 in 640x640.
406        data[0] = 320.0; // cx
407        data[num_preds] = 320.0; // cy
408        data[2 * num_preds] = 100.0; // w
409        data[3 * num_preds] = 100.0; // h
410
411        // Class 5 has score 0.9; others stay at 0.
412        data[(4 + 5) * num_preds] = 0.9;
413
414        Tensor::from_vec(vec![1, rows, num_preds], data).unwrap()
415    }
416
417    #[test]
418    fn test_decode_yolov8_output_basic() {
419        let tensor = make_one_detection_tensor();
420        let config = YoloConfig {
421            input_size: 640,
422            num_classes: 80,
423            conf_threshold: 0.25,
424            iou_threshold: 0.45,
425            class_labels: coco_labels(),
426        };
427
428        // Original image is also 640x640 so no rescaling.
429        let dets = decode_yolov8_output(&tensor, &config, 640, 640);
430        assert_eq!(dets.len(), 1);
431        assert_eq!(dets[0].class_id, 5);
432        assert!((dets[0].score - 0.9).abs() < 1e-6);
433
434        // Box should be (270, 270, 370, 370) in original coords.
435        let b = &dets[0].bbox;
436        assert!((b.x1 - 270.0).abs() < 1.0);
437        assert!((b.y1 - 270.0).abs() < 1.0);
438        assert!((b.x2 - 370.0).abs() < 1.0);
439        assert!((b.y2 - 370.0).abs() < 1.0);
440    }
441
442    #[test]
443    fn test_decode_yolov8_output_confidence_filter() {
444        let tensor = make_one_detection_tensor();
445        let config = YoloConfig {
446            input_size: 640,
447            num_classes: 80,
448            conf_threshold: 0.95, // higher than our 0.9 score
449            iou_threshold: 0.45,
450            class_labels: coco_labels(),
451        };
452        let dets = decode_yolov8_output(&tensor, &config, 640, 640);
453        assert!(dets.is_empty());
454    }
455
456    #[test]
457    fn test_decode_yolov8_output_nms() {
458        let num_classes = 80;
459        let rows = 4 + num_classes;
460        let num_preds = 8400;
461        let mut data = vec![0.0f32; rows * num_preds];
462
463        // Two highly overlapping boxes, same class (class 0).
464        // Box 0: centre (320, 320), 100x100, score 0.9
465        data[0] = 320.0;
466        data[num_preds] = 320.0;
467        data[2 * num_preds] = 100.0;
468        data[3 * num_preds] = 100.0;
469        data[4 * num_preds] = 0.9;
470
471        // Box 1: centre (325, 325), 100x100, score 0.8 (heavily overlapping)
472        data[1] = 325.0;
473        data[num_preds + 1] = 325.0;
474        data[2 * num_preds + 1] = 100.0;
475        data[3 * num_preds + 1] = 100.0;
476        data[4 * num_preds + 1] = 0.8;
477
478        let tensor = Tensor::from_vec(vec![1, rows, num_preds], data).unwrap();
479        let config = YoloConfig {
480            input_size: 640,
481            num_classes: 80,
482            conf_threshold: 0.25,
483            iou_threshold: 0.45,
484            class_labels: coco_labels(),
485        };
486
487        let dets = decode_yolov8_output(&tensor, &config, 640, 640);
488        // NMS should suppress the lower-scoring duplicate.
489        assert_eq!(dets.len(), 1);
490        assert!((dets[0].score - 0.9).abs() < 1e-6);
491    }
492
493    #[test]
494    fn test_letterbox_preprocess_square() {
495        // 100x100 image → 640x640 should have no padding.
496        let img = Tensor::from_vec(vec![100, 100, 3], vec![0.5; 100 * 100 * 3]).unwrap();
497        let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
498        assert_eq!(out.shape(), &[640, 640, 3]);
499        assert!((scale - 6.4).abs() < 0.01);
500        assert!(pad_x.abs() < 1.0);
501        assert!(pad_y.abs() < 1.0);
502    }
503
504    #[test]
505    fn test_hwc_to_nchw_basic() {
506        // 2x2 RGB image
507        let data = vec![
508            0.1, 0.2, 0.3, // (0,0) R G B
509            0.4, 0.5, 0.6, // (0,1)
510            0.7, 0.8, 0.9, // (1,0)
511            1.0, 0.0, 0.5, // (1,1)
512        ];
513        let img = Tensor::from_vec(vec![2, 2, 3], data).unwrap();
514        let nchw = hwc_to_nchw(&img);
515        // Expected layout: [R-plane, G-plane, B-plane], each 2x2
516        assert_eq!(nchw.len(), 12);
517        // R plane
518        assert!((nchw[0] - 0.1).abs() < 1e-6); // (0,0)
519        assert!((nchw[1] - 0.4).abs() < 1e-6); // (0,1)
520        assert!((nchw[2] - 0.7).abs() < 1e-6); // (1,0)
521        assert!((nchw[3] - 1.0).abs() < 1e-6); // (1,1)
522        // G plane
523        assert!((nchw[4] - 0.2).abs() < 1e-6);
524        assert!((nchw[5] - 0.5).abs() < 1e-6);
525        assert!((nchw[6] - 0.8).abs() < 1e-6);
526        assert!((nchw[7] - 0.0).abs() < 1e-6);
527        // B plane
528        assert!((nchw[8] - 0.3).abs() < 1e-6);
529        assert!((nchw[9] - 0.6).abs() < 1e-6);
530        assert!((nchw[10] - 0.9).abs() < 1e-6);
531        assert!((nchw[11] - 0.5).abs() < 1e-6);
532    }
533
534    #[test]
535    fn test_letterbox_then_nchw_pipeline() {
536        // Rectangular 100x200 image through full preprocess pipeline.
537        let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
538        let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&img, 640);
539        assert_eq!(letterboxed.shape(), &[640, 640, 3]);
540        let nchw = hwc_to_nchw(&letterboxed);
541        assert_eq!(nchw.len(), 3 * 640 * 640);
542    }
543
544    #[test]
545    fn test_letterbox_preprocess_landscape() {
546        // 200x100 image → scale limited by width: 640/200 = 3.2
547        // new_w = 640, new_h = 320 → pad_y = (640-320)/2 = 160
548        let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
549        let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
550        assert_eq!(out.shape(), &[640, 640, 3]);
551        assert!((scale - 3.2).abs() < 0.01);
552        assert!(pad_x.abs() < 1.0);
553        assert!((pad_y - 160.0).abs() < 1.0);
554
555        // Check that the padded (grey 0.5) region exists at top.
556        let top_pixel = &out.data()[0..3];
557        for &v in top_pixel {
558            assert!((v - 0.5).abs() < 1e-6, "top padding should be 0.5 grey");
559        }
560    }
561}