Skip to main content

yscv_detect/
yolo.rs

1//! YOLOv8 postprocessing pipeline.
2//!
3//! Decodes raw YOLOv8 model output tensors into [`Detection`]s. The actual
4//! model inference is left to the caller (e.g. via `yscv-onnx`); this module
5//! handles the coordinate decoding, confidence filtering, and NMS step.
6
7use yscv_tensor::Tensor;
8
9use crate::{BoundingBox, Detection, non_max_suppression};
10
11/// YOLO model configuration.
12#[derive(Debug, Clone)]
13pub struct YoloConfig {
14    /// Input image size (square).
15    pub input_size: usize,
16    /// Number of classes.
17    pub num_classes: usize,
18    /// Confidence threshold.
19    pub conf_threshold: f32,
20    /// IoU threshold for NMS.
21    pub iou_threshold: f32,
22    /// Class labels.
23    pub class_labels: Vec<String>,
24}
25
26/// Returns the 80 COCO class labels.
27#[rustfmt::skip]
28pub fn coco_labels() -> Vec<String> {
29    [
30        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
31        "truck", "boat", "traffic light", "fire hydrant", "stop sign",
32        "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep",
33        "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
34        "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
35        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
36        "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork",
37        "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
38        "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
39        "couch", "potted plant", "bed", "dining table", "toilet", "tv",
40        "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
41        "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
42        "scissors", "teddy bear", "hair drier", "toothbrush",
43    ]
44    .iter()
45    .map(|s| (*s).to_string())
46    .collect()
47}
48
49/// Returns the default YOLOv8 config for the COCO 80-class dataset.
50pub fn yolov8_coco_config() -> YoloConfig {
51    YoloConfig {
52        input_size: 640,
53        num_classes: 80,
54        conf_threshold: 0.25,
55        iou_threshold: 0.45,
56        class_labels: coco_labels(),
57    }
58}
59
60/// Decode YOLOv8 raw output tensor into detections.
61///
62/// YOLOv8 output format: `[1, 4 + num_classes, num_preds]` where the first
63/// four rows are `(cx, cy, w, h)` and the remaining rows are per-class
64/// confidence scores. A typical COCO model emits `[1, 84, 8400]`.
65///
66/// Coordinates in the output are relative to the **letterboxed** input image
67/// (i.e. `input_size x input_size`). This function maps them back to the
68/// original `(orig_width, orig_height)` frame.
69///
70/// Returns filtered detections after confidence thresholding and NMS.
71pub fn decode_yolov8_output(
72    output: &Tensor,
73    config: &YoloConfig,
74    orig_width: usize,
75    orig_height: usize,
76) -> Vec<Detection> {
77    let shape = output.shape();
78    // Expect [1, 4+num_classes, num_preds]
79    if shape.len() != 3 || shape[0] != 1 {
80        return Vec::new();
81    }
82    let rows = shape[1]; // 4 + num_classes
83    let num_preds = shape[2];
84    if rows < 5 {
85        return Vec::new();
86    }
87    let num_classes = rows - 4;
88
89    let data = output.data();
90
91    // Compute letterbox scale and padding so we can map coords back.
92    let scale = (config.input_size as f32 / orig_width as f32)
93        .min(config.input_size as f32 / orig_height as f32);
94    let new_w = orig_width as f32 * scale;
95    let new_h = orig_height as f32 * scale;
96    let pad_x = (config.input_size as f32 - new_w) / 2.0;
97    let pad_y = (config.input_size as f32 - new_h) / 2.0;
98
99    let mut candidates = Vec::new();
100
101    // Bounds guard: ensure tensor data is large enough for all accesses
102    let required_len = (4 + num_classes) * num_preds;
103    if data.len() < required_len {
104        return Vec::new();
105    }
106
107    for i in 0..num_preds {
108        // Output is laid out row-major: data[row * num_preds + col]
109        let cx = data[i];
110        let cy = data[num_preds + i];
111        let w = data[2 * num_preds + i];
112        let h = data[3 * num_preds + i];
113
114        // Find best class
115        let mut best_score = f32::NEG_INFINITY;
116        let mut best_class = 0usize;
117        for c in 0..num_classes {
118            let s = data[(4 + c) * num_preds + i];
119            if s > best_score {
120                best_score = s;
121                best_class = c;
122            }
123        }
124
125        if best_score < config.conf_threshold {
126            continue;
127        }
128
129        // Convert from letterbox coordinates to original image coordinates.
130        let x1 = ((cx - w / 2.0) - pad_x) / scale;
131        let y1 = ((cy - h / 2.0) - pad_y) / scale;
132        let x2 = ((cx + w / 2.0) - pad_x) / scale;
133        let y2 = ((cy + h / 2.0) - pad_y) / scale;
134
135        // Clamp to image bounds.
136        let x1 = x1.max(0.0).min(orig_width as f32);
137        let y1 = y1.max(0.0).min(orig_height as f32);
138        let x2 = x2.max(0.0).min(orig_width as f32);
139        let y2 = y2.max(0.0).min(orig_height as f32);
140
141        candidates.push(Detection {
142            bbox: BoundingBox { x1, y1, x2, y2 },
143            score: best_score,
144            class_id: best_class,
145        });
146    }
147
148    non_max_suppression(&candidates, config.iou_threshold, candidates.len().max(1))
149}
150
151/// Default COCO config for YOLOv11 models.
152///
153/// Same as YOLOv8 COCO — 80 classes, 640 input, 0.25 conf, 0.45 IoU.
154pub fn yolov11_coco_config() -> YoloConfig {
155    yolov8_coco_config()
156}
157
158/// Decode YOLOv11 output tensor into detections.
159///
160/// YOLOv11 output shape: `[1, num_preds, 4 + num_classes]` (transposed vs YOLOv8).
161/// Each row is `[cx, cy, w, h, class_0_score, class_1_score, ...]`.
162/// Coordinates are in letterboxed image space; this function maps them back
163/// to the original image dimensions.
164pub fn decode_yolov11_output(
165    output: &Tensor,
166    config: &YoloConfig,
167    orig_width: usize,
168    orig_height: usize,
169) -> Vec<Detection> {
170    let shape = output.shape();
171    // YOLOv11: [1, N, 4+C] or [N, 4+C]
172    let (num_preds, cols) = if shape.len() == 3 {
173        (shape[1], shape[2])
174    } else if shape.len() == 2 {
175        (shape[0], shape[1])
176    } else {
177        return Vec::new();
178    };
179
180    if cols < 5 {
181        return Vec::new();
182    }
183    let num_classes = cols - 4;
184
185    let data = output.data();
186
187    let scale = (config.input_size as f32 / orig_width as f32)
188        .min(config.input_size as f32 / orig_height as f32);
189    let new_w = orig_width as f32 * scale;
190    let new_h = orig_height as f32 * scale;
191    let pad_x = (config.input_size as f32 - new_w) / 2.0;
192    let pad_y = (config.input_size as f32 - new_h) / 2.0;
193
194    let mut candidates = Vec::new();
195
196    // Bounds guard
197    let required_len = num_preds * cols;
198    if data.len() < required_len {
199        return Vec::new();
200    }
201
202    for i in 0..num_preds {
203        let row = i * cols;
204        let cx = data[row];
205        let cy = data[row + 1];
206        let w = data[row + 2];
207        let h = data[row + 3];
208
209        let mut best_score = f32::NEG_INFINITY;
210        let mut best_class = 0usize;
211        for c in 0..num_classes {
212            let s = data[row + 4 + c];
213            if s > best_score {
214                best_score = s;
215                best_class = c;
216            }
217        }
218
219        if best_score < config.conf_threshold {
220            continue;
221        }
222
223        let x1 = ((cx - w / 2.0) - pad_x) / scale;
224        let y1 = ((cy - h / 2.0) - pad_y) / scale;
225        let x2 = ((cx + w / 2.0) - pad_x) / scale;
226        let y2 = ((cy + h / 2.0) - pad_y) / scale;
227
228        let x1 = x1.max(0.0).min(orig_width as f32);
229        let y1 = y1.max(0.0).min(orig_height as f32);
230        let x2 = x2.max(0.0).min(orig_width as f32);
231        let y2 = y2.max(0.0).min(orig_height as f32);
232
233        candidates.push(Detection {
234            bbox: BoundingBox { x1, y1, x2, y2 },
235            score: best_score,
236            class_id: best_class,
237        });
238    }
239
240    non_max_suppression(&candidates, config.iou_threshold, candidates.len().max(1))
241}
242
243/// Apply letterbox preprocessing: resize an image to a square with padding.
244///
245/// The input `image` is an `[H, W, 3]` f32 tensor (RGB, normalised to 0..1).
246/// Returns `(padded, scale, pad_x, pad_y)` where `padded` has shape
247/// `[target_size, target_size, 3]`.
248pub fn letterbox_preprocess(image: &Tensor, target_size: usize) -> (Tensor, f32, f32, f32) {
249    let shape = image.shape();
250    assert!(
251        shape.len() == 3 && shape[2] == 3,
252        "expected [H, W, 3] tensor"
253    );
254    let src_h = shape[0];
255    let src_w = shape[1];
256    let data = image.data();
257
258    let scale = (target_size as f32 / src_w as f32).min(target_size as f32 / src_h as f32);
259    let new_w = ((src_w as f32 * scale).round() as usize).min(target_size);
260    let new_h = ((src_h as f32 * scale).round() as usize).min(target_size);
261    let pad_x = (target_size - new_w) as f32 / 2.0;
262    let pad_y = (target_size - new_h) as f32 / 2.0;
263    let pad_left = pad_x.floor() as usize;
264    let pad_top = pad_y.floor() as usize;
265
266    // Standard YOLO letterbox fill: 114 in uint8 = 114/255 in float.
267    let total = target_size * target_size * 3;
268    let mut out = vec![114.0f32 / 255.0; total];
269
270    // Resize with anti-aliased bilinear (Pillow-compatible).
271    // For downscaling (scale < 1), use a filter window proportional to 1/scale
272    // so each output pixel averages the correct number of input pixels.
273    // For upscaling (scale >= 1), fall back to standard 2x2 bilinear.
274    let inv_scale_x = src_w as f32 / new_w as f32;
275    let inv_scale_y = src_h as f32 / new_h as f32;
276    // Filter support: bilinear filter has radius 1.0, scaled by inv_scale for downsampling.
277    let support_x = if inv_scale_x > 1.0 { inv_scale_x } else { 1.0 };
278    let support_y = if inv_scale_y > 1.0 { inv_scale_y } else { 1.0 };
279
280    for y in 0..new_h {
281        // Center of this output pixel in source coordinates
282        let center_y = (y as f32 + 0.5) * inv_scale_y - 0.5;
283        let y_min = ((center_y - support_y).ceil() as isize).max(0) as usize;
284        let y_max = ((center_y + support_y).floor() as isize).min(src_h as isize - 1) as usize;
285
286        for x in 0..new_w {
287            let center_x = (x as f32 + 0.5) * inv_scale_x - 0.5;
288            let x_min = ((center_x - support_x).ceil() as isize).max(0) as usize;
289            let x_max = ((center_x + support_x).floor() as isize).min(src_w as isize - 1) as usize;
290
291            let dst_idx = ((pad_top + y) * target_size + (pad_left + x)) * 3;
292            let mut sum = [0.0f32; 3];
293            let mut weight_sum = 0.0f32;
294
295            for sy in y_min..=y_max {
296                let wy = 1.0 - (sy as f32 - center_y).abs() / support_y;
297                if wy <= 0.0 {
298                    continue;
299                }
300                for sx in x_min..=x_max {
301                    let wx = 1.0 - (sx as f32 - center_x).abs() / support_x;
302                    if wx <= 0.0 {
303                        continue;
304                    }
305                    let w = wx * wy;
306                    let src_idx = (sy * src_w + sx) * 3;
307                    sum[0] += data[src_idx] * w;
308                    sum[1] += data[src_idx + 1] * w;
309                    sum[2] += data[src_idx + 2] * w;
310                    weight_sum += w;
311                }
312            }
313
314            if weight_sum > 0.0 {
315                let inv_w = 1.0 / weight_sum;
316                out[dst_idx] = sum[0] * inv_w;
317                out[dst_idx + 1] = sum[1] * inv_w;
318                out[dst_idx + 2] = sum[2] * inv_w;
319            }
320        }
321    }
322
323    let tensor = Tensor::from_vec(vec![target_size, target_size, 3], out)
324        .unwrap_or_else(|_| unreachable!("letterbox: shape matches pre-allocated output"));
325    (tensor, scale, pad_x, pad_y)
326}
327
328/// Convert an `[H, W, 3]` HWC tensor to `[1, 3, H, W]` NCHW f32 data.
329///
330/// This is a pure layout transformation — no normalisation is applied
331/// (the input is assumed to already be in `[0, 1]`).
332#[cfg(any(feature = "onnx", test))]
333fn hwc_to_nchw(hwc: &Tensor) -> Vec<f32> {
334    let shape = hwc.shape();
335    let h = shape[0];
336    let w = shape[1];
337    let data = hwc.data();
338    let mut nchw = vec![0.0f32; 3 * h * w];
339    for y in 0..h {
340        for x in 0..w {
341            let src = (y * w + x) * 3;
342            for c in 0..3 {
343                nchw[c * h * w + y * w + x] = data[src + c];
344            }
345        }
346    }
347    nchw
348}
349
350/// Run YOLOv8 inference using an ONNX model.
351///
352/// Takes an ONNX model, input image data (RGB, normalised to `[0,1]`) in
353/// `[1, 3, H, W]` NCHW format, original image dimensions, and a
354/// [`YoloConfig`].  Returns detected objects after NMS.
355#[cfg(feature = "onnx")]
356pub fn detect_yolov8_onnx(
357    model: &yscv_onnx::OnnxModel,
358    image_data: &[f32],
359    img_height: usize,
360    img_width: usize,
361    config: &YoloConfig,
362) -> Result<Vec<Detection>, crate::DetectError> {
363    use std::collections::HashMap;
364
365    let input_name = model
366        .inputs
367        .first()
368        .cloned()
369        .unwrap_or_else(|| "images".to_string());
370
371    let tensor = Tensor::from_vec(
372        vec![1, 3, config.input_size, config.input_size],
373        image_data.to_vec(),
374    )?;
375
376    let mut inputs = HashMap::new();
377    inputs.insert(input_name, tensor);
378
379    let outputs = yscv_onnx::run_onnx_model(model, inputs)?;
380
381    let output_name = model
382        .outputs
383        .first()
384        .cloned()
385        .unwrap_or_else(|| "output0".to_string());
386
387    let output_tensor =
388        outputs
389            .get(&output_name)
390            .ok_or_else(|| yscv_onnx::OnnxError::MissingInput {
391                node: "model_output".to_string(),
392                input: output_name,
393            })?;
394
395    Ok(decode_yolov8_output(
396        output_tensor,
397        config,
398        img_width,
399        img_height,
400    ))
401}
402
403/// Run the full YOLOv8 detection pipeline on an HWC image.
404///
405/// Accepts raw `[H, W, 3]` RGB f32 pixel data (normalised to `[0,1]`),
406/// applies letterbox preprocessing, runs ONNX inference, decodes the output,
407/// and returns the final detections.
408#[cfg(feature = "onnx")]
409pub fn detect_yolov8_from_rgb(
410    model: &yscv_onnx::OnnxModel,
411    rgb_data: &[f32],
412    height: usize,
413    width: usize,
414    config: &YoloConfig,
415) -> Result<Vec<Detection>, crate::DetectError> {
416    let image = Tensor::from_vec(vec![height, width, 3], rgb_data.to_vec())?;
417    let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&image, config.input_size);
418
419    let nchw = hwc_to_nchw(&letterboxed);
420
421    detect_yolov8_onnx(model, &nchw, height, width, config)
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_coco_labels_count() {
430        assert_eq!(coco_labels().len(), 80);
431    }
432
433    #[test]
434    fn test_yolov8_coco_config_defaults() {
435        let cfg = yolov8_coco_config();
436        assert_eq!(cfg.input_size, 640);
437        assert_eq!(cfg.num_classes, 80);
438        assert!((cfg.conf_threshold - 0.25).abs() < 1e-6);
439        assert!((cfg.iou_threshold - 0.45).abs() < 1e-6);
440        assert_eq!(cfg.class_labels.len(), 80);
441    }
442
443    /// Build a synthetic [1, 84, 8400] tensor with exactly one strong
444    /// prediction at index 0, class 5 with score 0.9.
445    fn make_one_detection_tensor() -> Tensor {
446        let num_classes = 80;
447        let rows = 4 + num_classes;
448        let num_preds = 8400;
449        let mut data = vec![0.0f32; rows * num_preds];
450
451        // Prediction at index 0: centre (320, 320), size 100x100 in 640x640.
452        data[0] = 320.0; // cx
453        data[num_preds] = 320.0; // cy
454        data[2 * num_preds] = 100.0; // w
455        data[3 * num_preds] = 100.0; // h
456
457        // Class 5 has score 0.9; others stay at 0.
458        data[(4 + 5) * num_preds] = 0.9;
459
460        Tensor::from_vec(vec![1, rows, num_preds], data).unwrap()
461    }
462
463    #[test]
464    fn test_decode_yolov8_output_basic() {
465        let tensor = make_one_detection_tensor();
466        let config = YoloConfig {
467            input_size: 640,
468            num_classes: 80,
469            conf_threshold: 0.25,
470            iou_threshold: 0.45,
471            class_labels: coco_labels(),
472        };
473
474        // Original image is also 640x640 so no rescaling.
475        let dets = decode_yolov8_output(&tensor, &config, 640, 640);
476        assert_eq!(dets.len(), 1);
477        assert_eq!(dets[0].class_id, 5);
478        assert!((dets[0].score - 0.9).abs() < 1e-6);
479
480        // Box should be (270, 270, 370, 370) in original coords.
481        let b = &dets[0].bbox;
482        assert!((b.x1 - 270.0).abs() < 1.0);
483        assert!((b.y1 - 270.0).abs() < 1.0);
484        assert!((b.x2 - 370.0).abs() < 1.0);
485        assert!((b.y2 - 370.0).abs() < 1.0);
486    }
487
488    #[test]
489    fn test_decode_yolov8_output_confidence_filter() {
490        let tensor = make_one_detection_tensor();
491        let config = YoloConfig {
492            input_size: 640,
493            num_classes: 80,
494            conf_threshold: 0.95, // higher than our 0.9 score
495            iou_threshold: 0.45,
496            class_labels: coco_labels(),
497        };
498        let dets = decode_yolov8_output(&tensor, &config, 640, 640);
499        assert!(dets.is_empty());
500    }
501
502    #[test]
503    fn test_decode_yolov8_output_nms() {
504        let num_classes = 80;
505        let rows = 4 + num_classes;
506        let num_preds = 8400;
507        let mut data = vec![0.0f32; rows * num_preds];
508
509        // Two highly overlapping boxes, same class (class 0).
510        // Box 0: centre (320, 320), 100x100, score 0.9
511        data[0] = 320.0;
512        data[num_preds] = 320.0;
513        data[2 * num_preds] = 100.0;
514        data[3 * num_preds] = 100.0;
515        data[4 * num_preds] = 0.9;
516
517        // Box 1: centre (325, 325), 100x100, score 0.8 (heavily overlapping)
518        data[1] = 325.0;
519        data[num_preds + 1] = 325.0;
520        data[2 * num_preds + 1] = 100.0;
521        data[3 * num_preds + 1] = 100.0;
522        data[4 * num_preds + 1] = 0.8;
523
524        let tensor = Tensor::from_vec(vec![1, rows, num_preds], data).unwrap();
525        let config = YoloConfig {
526            input_size: 640,
527            num_classes: 80,
528            conf_threshold: 0.25,
529            iou_threshold: 0.45,
530            class_labels: coco_labels(),
531        };
532
533        let dets = decode_yolov8_output(&tensor, &config, 640, 640);
534        // NMS should suppress the lower-scoring duplicate.
535        assert_eq!(dets.len(), 1);
536        assert!((dets[0].score - 0.9).abs() < 1e-6);
537    }
538
539    #[test]
540    fn test_letterbox_preprocess_square() {
541        // 100x100 image → 640x640 should have no padding.
542        let img = Tensor::from_vec(vec![100, 100, 3], vec![0.5; 100 * 100 * 3]).unwrap();
543        let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
544        assert_eq!(out.shape(), &[640, 640, 3]);
545        assert!((scale - 6.4).abs() < 0.01);
546        assert!(pad_x.abs() < 1.0);
547        assert!(pad_y.abs() < 1.0);
548    }
549
550    #[test]
551    fn test_hwc_to_nchw_basic() {
552        // 2x2 RGB image
553        let data = vec![
554            0.1, 0.2, 0.3, // (0,0) R G B
555            0.4, 0.5, 0.6, // (0,1)
556            0.7, 0.8, 0.9, // (1,0)
557            1.0, 0.0, 0.5, // (1,1)
558        ];
559        let img = Tensor::from_vec(vec![2, 2, 3], data).unwrap();
560        let nchw = hwc_to_nchw(&img);
561        // Expected layout: [R-plane, G-plane, B-plane], each 2x2
562        assert_eq!(nchw.len(), 12);
563        // R plane
564        assert!((nchw[0] - 0.1).abs() < 1e-6); // (0,0)
565        assert!((nchw[1] - 0.4).abs() < 1e-6); // (0,1)
566        assert!((nchw[2] - 0.7).abs() < 1e-6); // (1,0)
567        assert!((nchw[3] - 1.0).abs() < 1e-6); // (1,1)
568        // G plane
569        assert!((nchw[4] - 0.2).abs() < 1e-6);
570        assert!((nchw[5] - 0.5).abs() < 1e-6);
571        assert!((nchw[6] - 0.8).abs() < 1e-6);
572        assert!((nchw[7] - 0.0).abs() < 1e-6);
573        // B plane
574        assert!((nchw[8] - 0.3).abs() < 1e-6);
575        assert!((nchw[9] - 0.6).abs() < 1e-6);
576        assert!((nchw[10] - 0.9).abs() < 1e-6);
577        assert!((nchw[11] - 0.5).abs() < 1e-6);
578    }
579
580    #[test]
581    fn test_letterbox_then_nchw_pipeline() {
582        // Rectangular 100x200 image through full preprocess pipeline.
583        let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
584        let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&img, 640);
585        assert_eq!(letterboxed.shape(), &[640, 640, 3]);
586        let nchw = hwc_to_nchw(&letterboxed);
587        assert_eq!(nchw.len(), 3 * 640 * 640);
588    }
589
590    #[test]
591    fn test_letterbox_preprocess_landscape() {
592        // 200x100 image → scale limited by width: 640/200 = 3.2
593        // new_w = 640, new_h = 320 → pad_y = (640-320)/2 = 160
594        let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
595        let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
596        assert_eq!(out.shape(), &[640, 640, 3]);
597        assert!((scale - 3.2).abs() < 0.01);
598        assert!(pad_x.abs() < 1.0);
599        assert!((pad_y - 160.0).abs() < 1.0);
600
601        // Check that the padded (grey 0.5) region exists at top.
602        let top_pixel = &out.data()[0..3];
603        for &v in top_pixel {
604            assert!(
605                (v - 114.0 / 255.0).abs() < 1e-6,
606                "top padding should be 114/255 grey"
607            );
608        }
609    }
610}