Skip to main content

yscv_detect/
yolo.rs

1//! YOLOv8 postprocessing pipeline.
2//!
3//! Decodes raw YOLOv8 model output tensors into [`Detection`]s. The actual
4//! model inference is left to the caller (e.g. via `yscv-onnx`); this module
5//! handles the coordinate decoding, confidence filtering, and NMS step.
6
7use yscv_tensor::Tensor;
8
9use crate::{BoundingBox, Detection, non_max_suppression};
10
11/// YOLO model configuration.
12#[derive(Debug, Clone)]
13pub struct YoloConfig {
14    /// Input image size (square).
15    pub input_size: usize,
16    /// Number of classes.
17    pub num_classes: usize,
18    /// Confidence threshold.
19    pub conf_threshold: f32,
20    /// IoU threshold for NMS.
21    pub iou_threshold: f32,
22    /// Class labels.
23    pub class_labels: Vec<String>,
24}
25
26/// Returns the 80 COCO class labels.
27#[rustfmt::skip]
28pub fn coco_labels() -> Vec<String> {
29    [
30        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
31        "truck", "boat", "traffic light", "fire hydrant", "stop sign",
32        "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep",
33        "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
34        "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
35        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
36        "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork",
37        "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
38        "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
39        "couch", "potted plant", "bed", "dining table", "toilet", "tv",
40        "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
41        "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
42        "scissors", "teddy bear", "hair drier", "toothbrush",
43    ]
44    .iter()
45    .map(|s| (*s).to_string())
46    .collect()
47}
48
49/// Returns the default YOLOv8 config for the COCO 80-class dataset.
50pub fn yolov8_coco_config() -> YoloConfig {
51    YoloConfig {
52        input_size: 640,
53        num_classes: 80,
54        conf_threshold: 0.25,
55        iou_threshold: 0.45,
56        class_labels: coco_labels(),
57    }
58}
59
60/// Decode YOLOv8 raw output tensor into detections.
61///
62/// YOLOv8 output format: `[1, 4 + num_classes, num_preds]` where the first
63/// four rows are `(cx, cy, w, h)` and the remaining rows are per-class
64/// confidence scores. A typical COCO model emits `[1, 84, 8400]`.
65///
66/// Coordinates in the output are relative to the **letterboxed** input image
67/// (i.e. `input_size x input_size`). This function maps them back to the
68/// original `(orig_width, orig_height)` frame.
69///
70/// Returns filtered detections after confidence thresholding and NMS.
71pub fn decode_yolov8_output(
72    output: &Tensor,
73    config: &YoloConfig,
74    orig_width: usize,
75    orig_height: usize,
76) -> Vec<Detection> {
77    let shape = output.shape();
78    // Expect [1, 4+num_classes, num_preds]
79    if shape.len() != 3 || shape[0] != 1 {
80        return Vec::new();
81    }
82    let rows = shape[1]; // 4 + num_classes
83    let num_preds = shape[2];
84    if rows < 5 {
85        return Vec::new();
86    }
87    let num_classes = rows - 4;
88
89    let data = output.data();
90
91    // Compute letterbox scale and padding so we can map coords back.
92    let scale = (config.input_size as f32 / orig_width as f32)
93        .min(config.input_size as f32 / orig_height as f32);
94    let new_w = orig_width as f32 * scale;
95    let new_h = orig_height as f32 * scale;
96    let pad_x = (config.input_size as f32 - new_w) / 2.0;
97    let pad_y = (config.input_size as f32 - new_h) / 2.0;
98
99    let mut candidates = Vec::new();
100
101    for i in 0..num_preds {
102        // Output is laid out row-major: data[row * num_preds + col]
103        let cx = data[i];
104        let cy = data[num_preds + i];
105        let w = data[2 * num_preds + i];
106        let h = data[3 * num_preds + i];
107
108        // Find best class
109        let mut best_score = f32::NEG_INFINITY;
110        let mut best_class = 0usize;
111        for c in 0..num_classes {
112            let s = data[(4 + c) * num_preds + i];
113            if s > best_score {
114                best_score = s;
115                best_class = c;
116            }
117        }
118
119        if best_score < config.conf_threshold {
120            continue;
121        }
122
123        // Convert from letterbox coordinates to original image coordinates.
124        let x1 = ((cx - w / 2.0) - pad_x) / scale;
125        let y1 = ((cy - h / 2.0) - pad_y) / scale;
126        let x2 = ((cx + w / 2.0) - pad_x) / scale;
127        let y2 = ((cy + h / 2.0) - pad_y) / scale;
128
129        // Clamp to image bounds.
130        let x1 = x1.max(0.0).min(orig_width as f32);
131        let y1 = y1.max(0.0).min(orig_height as f32);
132        let x2 = x2.max(0.0).min(orig_width as f32);
133        let y2 = y2.max(0.0).min(orig_height as f32);
134
135        candidates.push(Detection {
136            bbox: BoundingBox { x1, y1, x2, y2 },
137            score: best_score,
138            class_id: best_class,
139        });
140    }
141
142    non_max_suppression(&candidates, config.iou_threshold, candidates.len().max(1))
143}
144
145/// Apply letterbox preprocessing: resize an image to a square with padding.
146///
147/// The input `image` is an `[H, W, 3]` f32 tensor (RGB, normalised to 0..1).
148/// Returns `(padded, scale, pad_x, pad_y)` where `padded` has shape
149/// `[target_size, target_size, 3]`.
150pub fn letterbox_preprocess(image: &Tensor, target_size: usize) -> (Tensor, f32, f32, f32) {
151    let shape = image.shape();
152    assert!(
153        shape.len() == 3 && shape[2] == 3,
154        "expected [H, W, 3] tensor"
155    );
156    let src_h = shape[0];
157    let src_w = shape[1];
158    let data = image.data();
159
160    let scale = (target_size as f32 / src_w as f32).min(target_size as f32 / src_h as f32);
161    let new_w = (src_w as f32 * scale).round() as usize;
162    let new_h = (src_h as f32 * scale).round() as usize;
163    let pad_x = (target_size - new_w) as f32 / 2.0;
164    let pad_y = (target_size - new_h) as f32 / 2.0;
165    let pad_left = pad_x.floor() as usize;
166    let pad_top = pad_y.floor() as usize;
167
168    // Fill with 0.5 grey (common YOLO letterbox fill).
169    let total = target_size * target_size * 3;
170    let mut out = vec![0.5f32; total];
171
172    // Nearest-neighbour resize into the padded region.
173    let scale_x = src_w as f32 / new_w as f32;
174    let scale_y = src_h as f32 / new_h as f32;
175
176    for y in 0..new_h {
177        let src_y = ((y as f32 * scale_y) as usize).min(src_h - 1);
178        for x in 0..new_w {
179            let src_x = ((x as f32 * scale_x) as usize).min(src_w - 1);
180            let dst_idx = ((pad_top + y) * target_size + (pad_left + x)) * 3;
181            let src_idx = (src_y * src_w + src_x) * 3;
182            out[dst_idx] = data[src_idx];
183            out[dst_idx + 1] = data[src_idx + 1];
184            out[dst_idx + 2] = data[src_idx + 2];
185        }
186    }
187
188    let tensor = Tensor::from_vec(vec![target_size, target_size, 3], out)
189        .expect("letterbox output tensor creation");
190    (tensor, scale, pad_x, pad_y)
191}
192
193/// Convert an `[H, W, 3]` HWC tensor to `[1, 3, H, W]` NCHW f32 data.
194///
195/// This is a pure layout transformation — no normalisation is applied
196/// (the input is assumed to already be in `[0, 1]`).
197#[allow(dead_code)]
198fn hwc_to_nchw(hwc: &Tensor) -> Vec<f32> {
199    let shape = hwc.shape();
200    let h = shape[0];
201    let w = shape[1];
202    let data = hwc.data();
203    let mut nchw = vec![0.0f32; 3 * h * w];
204    for y in 0..h {
205        for x in 0..w {
206            let src = (y * w + x) * 3;
207            for c in 0..3 {
208                nchw[c * h * w + y * w + x] = data[src + c];
209            }
210        }
211    }
212    nchw
213}
214
215/// Run YOLOv8 inference using an ONNX model.
216///
217/// Takes an ONNX model, input image data (RGB, normalised to `[0,1]`) in
218/// `[1, 3, H, W]` NCHW format, original image dimensions, and a
219/// [`YoloConfig`].  Returns detected objects after NMS.
220#[cfg(feature = "onnx")]
221pub fn detect_yolov8_onnx(
222    model: &yscv_onnx::OnnxModel,
223    image_data: &[f32],
224    img_height: usize,
225    img_width: usize,
226    config: &YoloConfig,
227) -> Result<Vec<Detection>, crate::DetectError> {
228    use std::collections::HashMap;
229
230    let input_name = model
231        .inputs
232        .first()
233        .cloned()
234        .unwrap_or_else(|| "images".to_string());
235
236    let tensor = Tensor::from_vec(
237        vec![1, 3, config.input_size, config.input_size],
238        image_data.to_vec(),
239    )?;
240
241    let mut inputs = HashMap::new();
242    inputs.insert(input_name, tensor);
243
244    let outputs = yscv_onnx::run_onnx_model(model, inputs)?;
245
246    let output_name = model
247        .outputs
248        .first()
249        .cloned()
250        .unwrap_or_else(|| "output0".to_string());
251
252    let output_tensor =
253        outputs
254            .get(&output_name)
255            .ok_or_else(|| yscv_onnx::OnnxError::MissingInput {
256                node: "model_output".to_string(),
257                input: output_name,
258            })?;
259
260    Ok(decode_yolov8_output(
261        output_tensor,
262        config,
263        img_width,
264        img_height,
265    ))
266}
267
268/// Run the full YOLOv8 detection pipeline on an HWC image.
269///
270/// Accepts raw `[H, W, 3]` RGB f32 pixel data (normalised to `[0,1]`),
271/// applies letterbox preprocessing, runs ONNX inference, decodes the output,
272/// and returns the final detections.
273#[cfg(feature = "onnx")]
274pub fn detect_yolov8_from_rgb(
275    model: &yscv_onnx::OnnxModel,
276    rgb_data: &[f32],
277    height: usize,
278    width: usize,
279    config: &YoloConfig,
280) -> Result<Vec<Detection>, crate::DetectError> {
281    let image = Tensor::from_vec(vec![height, width, 3], rgb_data.to_vec())?;
282    let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&image, config.input_size);
283
284    let nchw = hwc_to_nchw(&letterboxed);
285
286    detect_yolov8_onnx(model, &nchw, height, width, config)
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_coco_labels_count() {
295        assert_eq!(coco_labels().len(), 80);
296    }
297
298    #[test]
299    fn test_yolov8_coco_config_defaults() {
300        let cfg = yolov8_coco_config();
301        assert_eq!(cfg.input_size, 640);
302        assert_eq!(cfg.num_classes, 80);
303        assert!((cfg.conf_threshold - 0.25).abs() < 1e-6);
304        assert!((cfg.iou_threshold - 0.45).abs() < 1e-6);
305        assert_eq!(cfg.class_labels.len(), 80);
306    }
307
308    /// Build a synthetic [1, 84, 8400] tensor with exactly one strong
309    /// prediction at index 0, class 5 with score 0.9.
310    fn make_one_detection_tensor() -> Tensor {
311        let num_classes = 80;
312        let rows = 4 + num_classes;
313        let num_preds = 8400;
314        let mut data = vec![0.0f32; rows * num_preds];
315
316        // Prediction at index 0: centre (320, 320), size 100x100 in 640x640.
317        data[0] = 320.0; // cx
318        data[num_preds] = 320.0; // cy
319        data[2 * num_preds] = 100.0; // w
320        data[3 * num_preds] = 100.0; // h
321
322        // Class 5 has score 0.9; others stay at 0.
323        data[(4 + 5) * num_preds] = 0.9;
324
325        Tensor::from_vec(vec![1, rows, num_preds], data).unwrap()
326    }
327
328    #[test]
329    fn test_decode_yolov8_output_basic() {
330        let tensor = make_one_detection_tensor();
331        let config = YoloConfig {
332            input_size: 640,
333            num_classes: 80,
334            conf_threshold: 0.25,
335            iou_threshold: 0.45,
336            class_labels: coco_labels(),
337        };
338
339        // Original image is also 640x640 so no rescaling.
340        let dets = decode_yolov8_output(&tensor, &config, 640, 640);
341        assert_eq!(dets.len(), 1);
342        assert_eq!(dets[0].class_id, 5);
343        assert!((dets[0].score - 0.9).abs() < 1e-6);
344
345        // Box should be (270, 270, 370, 370) in original coords.
346        let b = &dets[0].bbox;
347        assert!((b.x1 - 270.0).abs() < 1.0);
348        assert!((b.y1 - 270.0).abs() < 1.0);
349        assert!((b.x2 - 370.0).abs() < 1.0);
350        assert!((b.y2 - 370.0).abs() < 1.0);
351    }
352
353    #[test]
354    fn test_decode_yolov8_output_confidence_filter() {
355        let tensor = make_one_detection_tensor();
356        let config = YoloConfig {
357            input_size: 640,
358            num_classes: 80,
359            conf_threshold: 0.95, // higher than our 0.9 score
360            iou_threshold: 0.45,
361            class_labels: coco_labels(),
362        };
363        let dets = decode_yolov8_output(&tensor, &config, 640, 640);
364        assert!(dets.is_empty());
365    }
366
367    #[test]
368    fn test_decode_yolov8_output_nms() {
369        let num_classes = 80;
370        let rows = 4 + num_classes;
371        let num_preds = 8400;
372        let mut data = vec![0.0f32; rows * num_preds];
373
374        // Two highly overlapping boxes, same class (class 0).
375        // Box 0: centre (320, 320), 100x100, score 0.9
376        data[0] = 320.0;
377        data[num_preds] = 320.0;
378        data[2 * num_preds] = 100.0;
379        data[3 * num_preds] = 100.0;
380        data[4 * num_preds] = 0.9;
381
382        // Box 1: centre (325, 325), 100x100, score 0.8 (heavily overlapping)
383        data[1] = 325.0;
384        data[num_preds + 1] = 325.0;
385        data[2 * num_preds + 1] = 100.0;
386        data[3 * num_preds + 1] = 100.0;
387        data[4 * num_preds + 1] = 0.8;
388
389        let tensor = Tensor::from_vec(vec![1, rows, num_preds], data).unwrap();
390        let config = YoloConfig {
391            input_size: 640,
392            num_classes: 80,
393            conf_threshold: 0.25,
394            iou_threshold: 0.45,
395            class_labels: coco_labels(),
396        };
397
398        let dets = decode_yolov8_output(&tensor, &config, 640, 640);
399        // NMS should suppress the lower-scoring duplicate.
400        assert_eq!(dets.len(), 1);
401        assert!((dets[0].score - 0.9).abs() < 1e-6);
402    }
403
404    #[test]
405    fn test_letterbox_preprocess_square() {
406        // 100x100 image → 640x640 should have no padding.
407        let img = Tensor::from_vec(vec![100, 100, 3], vec![0.5; 100 * 100 * 3]).unwrap();
408        let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
409        assert_eq!(out.shape(), &[640, 640, 3]);
410        assert!((scale - 6.4).abs() < 0.01);
411        assert!(pad_x.abs() < 1.0);
412        assert!(pad_y.abs() < 1.0);
413    }
414
415    #[test]
416    fn test_hwc_to_nchw_basic() {
417        // 2x2 RGB image
418        let data = vec![
419            0.1, 0.2, 0.3, // (0,0) R G B
420            0.4, 0.5, 0.6, // (0,1)
421            0.7, 0.8, 0.9, // (1,0)
422            1.0, 0.0, 0.5, // (1,1)
423        ];
424        let img = Tensor::from_vec(vec![2, 2, 3], data).unwrap();
425        let nchw = hwc_to_nchw(&img);
426        // Expected layout: [R-plane, G-plane, B-plane], each 2x2
427        assert_eq!(nchw.len(), 12);
428        // R plane
429        assert!((nchw[0] - 0.1).abs() < 1e-6); // (0,0)
430        assert!((nchw[1] - 0.4).abs() < 1e-6); // (0,1)
431        assert!((nchw[2] - 0.7).abs() < 1e-6); // (1,0)
432        assert!((nchw[3] - 1.0).abs() < 1e-6); // (1,1)
433        // G plane
434        assert!((nchw[4] - 0.2).abs() < 1e-6);
435        assert!((nchw[5] - 0.5).abs() < 1e-6);
436        assert!((nchw[6] - 0.8).abs() < 1e-6);
437        assert!((nchw[7] - 0.0).abs() < 1e-6);
438        // B plane
439        assert!((nchw[8] - 0.3).abs() < 1e-6);
440        assert!((nchw[9] - 0.6).abs() < 1e-6);
441        assert!((nchw[10] - 0.9).abs() < 1e-6);
442        assert!((nchw[11] - 0.5).abs() < 1e-6);
443    }
444
445    #[test]
446    fn test_letterbox_then_nchw_pipeline() {
447        // Rectangular 100x200 image through full preprocess pipeline.
448        let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
449        let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&img, 640);
450        assert_eq!(letterboxed.shape(), &[640, 640, 3]);
451        let nchw = hwc_to_nchw(&letterboxed);
452        assert_eq!(nchw.len(), 3 * 640 * 640);
453    }
454
455    #[test]
456    fn test_letterbox_preprocess_landscape() {
457        // 200x100 image → scale limited by width: 640/200 = 3.2
458        // new_w = 640, new_h = 320 → pad_y = (640-320)/2 = 160
459        let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
460        let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
461        assert_eq!(out.shape(), &[640, 640, 3]);
462        assert!((scale - 3.2).abs() < 0.01);
463        assert!(pad_x.abs() < 1.0);
464        assert!((pad_y - 160.0).abs() < 1.0);
465
466        // Check that the padded (grey 0.5) region exists at top.
467        let top_pixel = &out.data()[0..3];
468        for &v in top_pixel {
469            assert!((v - 0.5).abs() < 1e-6, "top padding should be 0.5 grey");
470        }
471    }
472}