use crate::error::{MlError, MlResult};
use crate::pipelines::types::Detection;
use crate::postprocess::{nms, sigmoid, BoundingBox};
#[derive(Clone, Copy, Debug)]
pub struct DecodeOptions {
pub num_classes: usize,
pub conf_threshold: f32,
pub iou_threshold: f32,
}
impl Default for DecodeOptions {
fn default() -> Self {
Self {
num_classes: 80,
conf_threshold: 0.25,
iou_threshold: 0.45,
}
}
}
pub fn decode_yolov8_output(
data: &[f32],
shape: &[usize],
opts: &DecodeOptions,
) -> MlResult<Vec<Detection>> {
let channels = 4 + opts.num_classes;
let num_anchors = validate_yolov8_shape(shape, channels, data.len())?;
if num_anchors == 0 {
return Ok(Vec::new());
}
let mut boxes: Vec<BoundingBox> = Vec::new();
let mut classes: Vec<u32> = Vec::new();
let mut scores: Vec<f32> = Vec::new();
for anchor in 0..num_anchors {
let cx = data[anchor];
let cy = data[num_anchors + anchor];
let w = data[2 * num_anchors + anchor];
let h = data[3 * num_anchors + anchor];
let mut best_class = 0_u32;
let mut best_score = f32::NEG_INFINITY;
for cls in 0..opts.num_classes {
let logit = data[(4 + cls) * num_anchors + anchor];
if logit > best_score {
best_score = logit;
best_class = cls as u32;
}
}
let conf = sigmoid(best_score);
if conf < opts.conf_threshold {
continue;
}
let bbox = BoundingBox::from_xywh_center(cx, cy, w, h);
if bbox.area() <= 0.0 {
continue;
}
boxes.push(bbox);
classes.push(best_class);
scores.push(conf);
}
if boxes.is_empty() {
return Ok(Vec::new());
}
let mut keep_mask = vec![false; boxes.len()];
let mut unique_classes: Vec<u32> = classes.clone();
unique_classes.sort_unstable();
unique_classes.dedup();
for cls in unique_classes {
let subset: Vec<usize> = classes
.iter()
.enumerate()
.filter_map(|(i, &c)| if c == cls { Some(i) } else { None })
.collect();
let sub_boxes: Vec<BoundingBox> = subset.iter().map(|&i| boxes[i]).collect();
let sub_scores: Vec<f32> = subset.iter().map(|&i| scores[i]).collect();
let kept = nms(&sub_boxes, &sub_scores, opts.iou_threshold);
for local_idx in kept {
keep_mask[subset[local_idx]] = true;
}
}
let mut out: Vec<Detection> = (0..boxes.len())
.filter(|&i| keep_mask[i])
.map(|i| Detection::new(boxes[i], classes[i], scores[i]))
.collect();
out.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(out)
}
fn validate_yolov8_shape(
shape: &[usize],
expected_channels: usize,
total_len: usize,
) -> MlResult<usize> {
let (channels, anchors) = match shape.len() {
3 if shape[0] == 1 => (shape[1], shape[2]),
2 => (shape[0], shape[1]),
_ => {
return Err(MlError::postprocess(format!(
"yolov8: expected rank-3 [1, C, A] or rank-2 [C, A] output, got shape {shape:?}"
)));
}
};
if channels != expected_channels {
return Err(MlError::postprocess(format!(
"yolov8: expected {expected_channels} channels, got {channels} (shape {shape:?})"
)));
}
if channels.saturating_mul(anchors) != total_len {
return Err(MlError::postprocess(format!(
"yolov8: output length {total_len} does not match shape {shape:?}"
)));
}
Ok(anchors)
}
#[cfg(test)]
mod tests {
use super::*;
fn build_fake_output(
anchors: &[(f32, f32, f32, f32, [f32; 2])],
num_classes: usize,
) -> (Vec<f32>, Vec<usize>) {
let n = anchors.len();
let channels = 4 + num_classes;
let mut data = vec![0.0_f32; channels * n];
for (i, (cx, cy, w, h, cls)) in anchors.iter().enumerate() {
data[i] = *cx;
data[n + i] = *cy;
data[2 * n + i] = *w;
data[3 * n + i] = *h;
for (c, &logit) in cls.iter().enumerate() {
data[(4 + c) * n + i] = logit;
}
}
(data, vec![1, channels, n])
}
#[test]
fn decode_empty_output_returns_empty() {
let data: Vec<f32> = Vec::new();
let shape = vec![1, 84, 0];
let opts = DecodeOptions::default();
let dets = decode_yolov8_output(&data, &shape, &opts).expect("ok");
assert!(dets.is_empty());
}
#[test]
fn decode_below_threshold_is_filtered() {
let (data, shape) = build_fake_output(&[(10.0, 10.0, 4.0, 4.0, [-3.0, -5.0])], 2);
let opts = DecodeOptions {
num_classes: 2,
conf_threshold: 0.25,
iou_threshold: 0.45,
};
let dets = decode_yolov8_output(&data, &shape, &opts).expect("ok");
assert!(dets.is_empty());
}
#[test]
fn decode_picks_highest_class() {
let (data, shape) = build_fake_output(&[(10.0, 10.0, 4.0, 4.0, [-2.0, 3.0])], 2);
let opts = DecodeOptions {
num_classes: 2,
conf_threshold: 0.25,
iou_threshold: 0.45,
};
let dets = decode_yolov8_output(&data, &shape, &opts).expect("ok");
assert_eq!(dets.len(), 1);
assert_eq!(dets[0].class_id, 1);
assert!(dets[0].score > 0.9);
assert!((dets[0].bbox.x0 - 8.0).abs() < 1e-5);
assert!((dets[0].bbox.x1 - 12.0).abs() < 1e-5);
}
#[test]
fn decode_nms_suppresses_duplicates_of_same_class() {
let (data, shape) = build_fake_output(
&[
(10.0, 10.0, 4.0, 4.0, [5.0, -5.0]),
(10.2, 10.0, 4.0, 4.0, [4.0, -5.0]),
],
2,
);
let opts = DecodeOptions {
num_classes: 2,
conf_threshold: 0.25,
iou_threshold: 0.45,
};
let dets = decode_yolov8_output(&data, &shape, &opts).expect("ok");
assert_eq!(dets.len(), 1);
assert_eq!(dets[0].class_id, 0);
}
#[test]
fn decode_keeps_overlapping_boxes_of_different_classes() {
let (data, shape) = build_fake_output(
&[
(10.0, 10.0, 4.0, 4.0, [5.0, -5.0]),
(10.2, 10.0, 4.0, 4.0, [-5.0, 4.0]),
],
2,
);
let opts = DecodeOptions {
num_classes: 2,
conf_threshold: 0.25,
iou_threshold: 0.45,
};
let dets = decode_yolov8_output(&data, &shape, &opts).expect("ok");
assert_eq!(dets.len(), 2);
assert!(dets[0].score >= dets[1].score);
}
#[test]
fn decode_rejects_wrong_channel_count() {
let data = vec![0.0_f32; 84 * 10];
let shape = vec![1, 50, 10];
let opts = DecodeOptions::default();
let err = decode_yolov8_output(&data, &shape, &opts).expect_err("must fail");
assert!(matches!(err, MlError::Postprocess(_)));
}
#[test]
fn decode_rejects_mismatched_length() {
let data = vec![0.0_f32; 10];
let shape = vec![1, 84, 10];
let opts = DecodeOptions::default();
let err = decode_yolov8_output(&data, &shape, &opts).expect_err("must fail");
assert!(matches!(err, MlError::Postprocess(_)));
}
#[test]
fn decode_accepts_rank_two_shape() {
let (data, shape_3d) = build_fake_output(&[(10.0, 10.0, 4.0, 4.0, [5.0, -5.0])], 2);
let shape_2d: Vec<usize> = shape_3d[1..].to_vec();
let opts = DecodeOptions {
num_classes: 2,
conf_threshold: 0.25,
iou_threshold: 0.45,
};
let dets = decode_yolov8_output(&data, &shape_2d, &opts).expect("ok");
assert_eq!(dets.len(), 1);
}
}