use crate::error::{VisionError, VisionResult};
#[derive(Debug, Clone, PartialEq)]
pub struct AnchorConfig {
pub base_sizes: Vec<f32>,
pub aspect_ratios: Vec<f32>,
pub feature_sizes: Vec<(usize, usize)>,
pub strides: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct AnchorGenerator {
cfg: AnchorConfig,
}
impl AnchorGenerator {
pub fn new(cfg: AnchorConfig) -> VisionResult<Self> {
if cfg.base_sizes.is_empty() {
return Err(VisionError::EmptyInput("anchor base_sizes"));
}
if cfg.aspect_ratios.is_empty() {
return Err(VisionError::EmptyInput("anchor aspect_ratios"));
}
if cfg.feature_sizes.len() != cfg.strides.len() {
return Err(VisionError::DimensionMismatch {
expected: cfg.strides.len(),
got: cfg.feature_sizes.len(),
});
}
if cfg.feature_sizes.is_empty() {
return Err(VisionError::EmptyInput("anchor feature_sizes"));
}
for &b in &cfg.base_sizes {
if b <= 0.0 || !b.is_finite() {
return Err(VisionError::DimensionMismatch {
expected: 1,
got: 0,
});
}
}
for &r in &cfg.aspect_ratios {
if r <= 0.0 || !r.is_finite() {
return Err(VisionError::DimensionMismatch {
expected: 1,
got: 0,
});
}
}
for &s in &cfg.strides {
if s == 0 {
return Err(VisionError::DimensionMismatch {
expected: 1,
got: 0,
});
}
}
for &(gh, gw) in &cfg.feature_sizes {
if gh == 0 || gw == 0 {
return Err(VisionError::InvalidImageSize {
height: gh,
width: gw,
channels: 1,
});
}
}
Ok(Self { cfg })
}
#[must_use]
#[inline]
pub fn config(&self) -> &AnchorConfig {
&self.cfg
}
#[must_use]
pub fn n_anchors(&self) -> usize {
let templates_per_cell = self.cfg.base_sizes.len() * self.cfg.aspect_ratios.len();
let mut total = 0usize;
for &(gh, gw) in &self.cfg.feature_sizes {
total += gh * gw * templates_per_cell;
}
total
}
pub fn generate(&self) -> VisionResult<Vec<f32>> {
let n_total = self.n_anchors();
let mut out = Vec::with_capacity(n_total * 4);
for (level, &(grid_h, grid_w)) in self.cfg.feature_sizes.iter().enumerate() {
let stride = match self.cfg.strides.get(level) {
Some(&s) => s as f32,
None => {
return Err(VisionError::Internal(format!(
"stride index {level} out of range"
)));
}
};
let half_stride = stride * 0.5;
for gy in 0..grid_h {
for gx in 0..grid_w {
let cx = gx as f32 * stride + half_stride;
let cy = gy as f32 * stride + half_stride;
for &base in &self.cfg.base_sizes {
for &ratio in &self.cfg.aspect_ratios {
let sqrt_r = ratio.sqrt();
let w = base * sqrt_r;
let h = base / sqrt_r;
let half_w = w * 0.5;
let half_h = h * 0.5;
out.push(cx - half_w);
out.push(cy - half_h);
out.push(cx + half_w);
out.push(cy + half_h);
}
}
}
}
}
Ok(out)
}
}
#[must_use]
pub fn iou(box_a: &[f32; 4], box_b: &[f32; 4]) -> f32 {
let area_a = box_area(box_a);
let area_b = box_area(box_b);
if area_a <= 0.0 || area_b <= 0.0 {
return 0.0;
}
let ix1 = box_a[0].max(box_b[0]);
let iy1 = box_a[1].max(box_b[1]);
let ix2 = box_a[2].min(box_b[2]);
let iy2 = box_a[3].min(box_b[3]);
let iw = (ix2 - ix1).max(0.0);
let ih = (iy2 - iy1).max(0.0);
let inter = iw * ih;
let union = area_a + area_b - inter;
if union <= 0.0 {
return 0.0;
}
(inter / union).clamp(0.0, 1.0)
}
#[inline]
fn box_area(b: &[f32; 4]) -> f32 {
let w = b[2] - b[0];
let h = b[3] - b[1];
if w <= 0.0 || h <= 0.0 { 0.0 } else { w * h }
}
#[inline]
fn read_box(boxes: &[f32], i: usize) -> [f32; 4] {
let base = i * 4;
[
boxes[base],
boxes[base + 1],
boxes[base + 2],
boxes[base + 3],
]
}
pub fn nms(
boxes: &[f32],
scores: &[f32],
n: usize,
iou_threshold: f32,
max_keep: usize,
) -> VisionResult<Vec<usize>> {
if n == 0 {
return Err(VisionError::EmptyInput("nms boxes"));
}
if boxes.len() != n * 4 {
return Err(VisionError::DimensionMismatch {
expected: n * 4,
got: boxes.len(),
});
}
if scores.len() != n {
return Err(VisionError::DimensionMismatch {
expected: n,
got: scores.len(),
});
}
if !(0.0..=1.0).contains(&iou_threshold) || !iou_threshold.is_finite() {
return Err(VisionError::NonFinite("nms iou_threshold"));
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| {
scores[b]
.partial_cmp(&scores[a])
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.cmp(&b))
});
let cap = if max_keep == 0 { n } else { max_keep };
let mut kept: Vec<usize> = Vec::with_capacity(cap.min(n));
let mut kept_boxes: Vec<[f32; 4]> = Vec::with_capacity(cap.min(n));
for &idx in &order {
if kept.len() >= cap {
break;
}
let candidate = read_box(boxes, idx);
let mut suppress = false;
for kb in &kept_boxes {
if iou(&candidate, kb) > iou_threshold {
suppress = true;
break;
}
}
if !suppress {
kept.push(idx);
kept_boxes.push(candidate);
}
}
Ok(kept)
}
pub fn soft_nms(
boxes: &[f32],
scores: &[f32],
n: usize,
sigma: f32,
score_threshold: f32,
) -> VisionResult<Vec<(usize, f32)>> {
if n == 0 {
return Err(VisionError::EmptyInput("soft_nms boxes"));
}
if boxes.len() != n * 4 {
return Err(VisionError::DimensionMismatch {
expected: n * 4,
got: boxes.len(),
});
}
if scores.len() != n {
return Err(VisionError::DimensionMismatch {
expected: n,
got: scores.len(),
});
}
if sigma <= 0.0 || !sigma.is_finite() {
return Err(VisionError::NonFinite("soft_nms sigma"));
}
if !score_threshold.is_finite() {
return Err(VisionError::NonFinite("soft_nms score_threshold"));
}
let mut pool: Vec<(usize, f32, [f32; 4])> =
(0..n).map(|i| (i, scores[i], read_box(boxes, i))).collect();
let inv_sigma = 1.0_f32 / sigma;
let mut out: Vec<(usize, f32)> = Vec::new();
while !pool.is_empty() {
let (max_pos, max_score) = pool.iter().enumerate().fold(
(0usize, f32::NEG_INFINITY),
|(best_i, best_s), (i, e)| {
if e.1 > best_s {
(i, e.1)
} else {
(best_i, best_s)
}
},
);
let pivot = pool.swap_remove(max_pos);
if max_score <= score_threshold {
break;
}
out.push((pivot.0, pivot.1));
for entry in pool.iter_mut() {
let ov = iou(&pivot.2, &entry.2);
let decay = (-(ov * ov) * inv_sigma).exp();
entry.1 *= decay;
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg_single_level() -> AnchorConfig {
AnchorConfig {
base_sizes: vec![32.0],
aspect_ratios: vec![1.0],
feature_sizes: vec![(2, 2)],
strides: vec![16],
}
}
fn cfg_two_levels() -> AnchorConfig {
AnchorConfig {
base_sizes: vec![16.0, 32.0],
aspect_ratios: vec![0.5, 1.0, 2.0],
feature_sizes: vec![(2, 3), (1, 2)],
strides: vec![8, 16],
}
}
#[test]
fn iou_identical_boxes_is_one() {
let a = [0.0_f32, 0.0, 10.0, 10.0];
let v = iou(&a, &a);
assert!((v - 1.0).abs() < 1e-6, "expected 1.0, got {v}");
}
#[test]
fn iou_disjoint_boxes_is_zero() {
let a = [0.0_f32, 0.0, 1.0, 1.0];
let b = [2.0_f32, 2.0, 3.0, 3.0];
assert!(iou(&a, &b).abs() < 1e-7);
}
#[test]
fn iou_half_overlap_known_value() {
let a = [0.0_f32, 0.0, 10.0, 10.0];
let b = [5.0_f32, 0.0, 15.0, 10.0];
let v = iou(&a, &b);
assert!((v - (1.0 / 3.0)).abs() < 1e-5, "expected 1/3, got {v}");
}
#[test]
fn iou_degenerate_box_is_zero() {
let a = [0.0_f32, 0.0, 10.0, 10.0];
let b = [5.0_f32, 5.0, 5.0, 5.0]; assert!(iou(&a, &b).abs() < 1e-7);
}
#[test]
fn anchor_generator_construction_ok() {
let g = AnchorGenerator::new(cfg_single_level()).expect("ok");
assert_eq!(g.config().base_sizes, vec![32.0]);
}
#[test]
fn anchor_count_matches_sigma_formula() {
let cfg = cfg_two_levels();
let level0 = 2 * 3 * 2 * 3; let level1 = 2 * 2 * 3; let expected = level0 + level1;
let g = AnchorGenerator::new(cfg).expect("ok");
assert_eq!(g.n_anchors(), expected);
}
#[test]
fn anchor_generate_output_length_matches_n_anchors() {
let g = AnchorGenerator::new(cfg_two_levels()).expect("ok");
let out = g.generate().expect("generate ok");
assert_eq!(out.len(), g.n_anchors() * 4);
}
#[test]
fn anchor_boxes_have_positive_extent() {
let g = AnchorGenerator::new(cfg_two_levels()).expect("ok");
let out = g.generate().expect("ok");
let n = g.n_anchors();
for i in 0..n {
let b = read_box(&out, i);
assert!(b[2] > b[0], "anchor {i} x2 <= x1: {b:?}");
assert!(b[3] > b[1], "anchor {i} y2 <= y1: {b:?}");
}
}
#[test]
fn anchor_center_at_cell_centre() {
let g = AnchorGenerator::new(cfg_single_level()).expect("ok");
let out = g.generate().expect("ok");
let b0 = read_box(&out, 0);
let cx0 = 0.5 * (b0[0] + b0[2]);
let cy0 = 0.5 * (b0[1] + b0[3]);
assert!((cx0 - 8.0).abs() < 1e-5);
assert!((cy0 - 8.0).abs() < 1e-5);
let b3 = read_box(&out, 3);
let cx3 = 0.5 * (b3[0] + b3[2]);
let cy3 = 0.5 * (b3[1] + b3[3]);
assert!((cx3 - 24.0).abs() < 1e-5, "got cx={cx3}");
assert!((cy3 - 24.0).abs() < 1e-5, "got cy={cy3}");
}
#[test]
fn anchor_size_follows_sqrt_ratio() {
let cfg = AnchorConfig {
base_sizes: vec![32.0],
aspect_ratios: vec![4.0],
feature_sizes: vec![(1, 1)],
strides: vec![16],
};
let g = AnchorGenerator::new(cfg).expect("ok");
let out = g.generate().expect("ok");
let b = read_box(&out, 0);
let w = b[2] - b[0];
let h = b[3] - b[1];
assert!((w - 64.0).abs() < 1e-4, "expected w=64, got {w}");
assert!((h - 16.0).abs() < 1e-4, "expected h=16, got {h}");
}
#[test]
fn err_empty_base_sizes() {
let cfg = AnchorConfig {
base_sizes: vec![],
aspect_ratios: vec![1.0],
feature_sizes: vec![(2, 2)],
strides: vec![16],
};
let r = AnchorGenerator::new(cfg);
assert!(matches!(r, Err(VisionError::EmptyInput(_))));
}
#[test]
fn err_empty_aspect_ratios() {
let cfg = AnchorConfig {
base_sizes: vec![1.0],
aspect_ratios: vec![],
feature_sizes: vec![(2, 2)],
strides: vec![16],
};
let r = AnchorGenerator::new(cfg);
assert!(matches!(r, Err(VisionError::EmptyInput(_))));
}
#[test]
fn err_feature_size_strides_length_mismatch() {
let cfg = AnchorConfig {
base_sizes: vec![1.0],
aspect_ratios: vec![1.0],
feature_sizes: vec![(2, 2), (1, 1)],
strides: vec![16],
};
let r = AnchorGenerator::new(cfg);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn err_zero_grid_size() {
let cfg = AnchorConfig {
base_sizes: vec![1.0],
aspect_ratios: vec![1.0],
feature_sizes: vec![(0, 2)],
strides: vec![16],
};
let r = AnchorGenerator::new(cfg);
assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
}
#[test]
fn err_zero_base_size() {
let cfg = AnchorConfig {
base_sizes: vec![0.0],
aspect_ratios: vec![1.0],
feature_sizes: vec![(2, 2)],
strides: vec![16],
};
let r = AnchorGenerator::new(cfg);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn nms_keeps_highest_score_and_suppresses_overlap() {
let boxes = vec![0.0_f32, 0.0, 10.0, 10.0, 1.0, 1.0, 10.0, 10.0];
let scores = vec![0.4_f32, 0.9];
let kept = nms(&boxes, &scores, 2, 0.3, 0).expect("ok");
assert_eq!(kept, vec![1]);
}
#[test]
fn nms_keeps_two_disjoint_boxes() {
let boxes = vec![0.0_f32, 0.0, 1.0, 1.0, 5.0, 5.0, 6.0, 6.0];
let scores = vec![0.9_f32, 0.8];
let kept = nms(&boxes, &scores, 2, 0.5, 0).expect("ok");
assert_eq!(kept.len(), 2);
assert_eq!(kept, vec![0, 1]);
}
#[test]
fn nms_max_keep_caps_output() {
let boxes = vec![
0.0_f32, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0,
];
let scores = vec![0.9_f32, 0.7, 0.5];
let kept = nms(&boxes, &scores, 3, 0.5, 2).expect("ok");
assert_eq!(kept.len(), 2);
assert_eq!(kept, vec![0, 1]);
}
#[test]
fn nms_threshold_one_keeps_all() {
let boxes = vec![0.0_f32, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0];
let scores = vec![0.9_f32, 0.8];
let kept = nms(&boxes, &scores, 2, 1.0, 0).expect("ok");
assert_eq!(kept.len(), 2);
}
#[test]
fn nms_single_box_is_kept() {
let boxes = vec![0.0_f32, 0.0, 1.0, 1.0];
let scores = vec![0.5_f32];
let kept = nms(&boxes, &scores, 1, 0.5, 0).expect("ok");
assert_eq!(kept, vec![0]);
}
#[test]
fn nms_returns_indices_in_descending_score_order() {
let boxes = vec![
0.0_f32, 0.0, 1.0, 1.0, 5.0_f32, 0.0, 6.0, 1.0, 10.0_f32, 0.0, 11.0, 1.0, ];
let scores = vec![0.3_f32, 0.9, 0.5];
let kept = nms(&boxes, &scores, 3, 0.5, 0).expect("ok");
assert_eq!(kept, vec![1, 2, 0]);
}
#[test]
fn nms_err_empty_boxes() {
let r = nms(&[], &[], 0, 0.5, 0);
assert!(matches!(r, Err(VisionError::EmptyInput(_))));
}
#[test]
fn nms_err_length_mismatch() {
let boxes = vec![0.0_f32, 0.0, 1.0, 1.0];
let scores = vec![0.9_f32, 0.8];
let r = nms(&boxes, &scores, 2, 0.5, 0);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn nms_err_iou_out_of_range() {
let boxes = vec![0.0_f32, 0.0, 1.0, 1.0];
let scores = vec![0.5_f32];
let r = nms(&boxes, &scores, 1, 1.5, 0);
assert!(matches!(r, Err(VisionError::NonFinite(_))));
let r2 = nms(&boxes, &scores, 1, -0.1, 0);
assert!(matches!(r2, Err(VisionError::NonFinite(_))));
}
#[test]
fn nms_deterministic() {
let boxes = vec![
0.0_f32, 0.0, 1.0, 1.0, 0.5, 0.5, 1.5, 1.5, 5.0, 5.0, 6.0, 6.0, ];
let scores = vec![0.4_f32, 0.7, 0.6];
let k1 = nms(&boxes, &scores, 3, 0.3, 0).expect("ok");
let k2 = nms(&boxes, &scores, 3, 0.3, 0).expect("ok");
assert_eq!(k1, k2);
}
#[test]
fn soft_nms_decays_overlapping_score() {
let boxes = vec![0.0_f32, 0.0, 10.0, 10.0, 1.0, 1.0, 10.0, 10.0];
let scores = vec![0.9_f32, 0.8];
let out = soft_nms(&boxes, &scores, 2, 0.5, 0.0).expect("ok");
assert_eq!(out.len(), 2);
assert_eq!(out[0].0, 0);
assert_eq!(out[1].0, 1);
assert!(out[1].1 < 0.8, "expected decay, got {}", out[1].1);
}
#[test]
fn soft_nms_drops_below_score_threshold() {
let boxes = vec![0.0_f32, 0.0, 10.0, 10.0, 0.0, 0.0, 10.0, 10.0];
let scores = vec![1.0_f32, 0.5];
let out = soft_nms(&boxes, &scores, 2, 0.5, 0.4).expect("ok");
assert_eq!(out.len(), 1);
assert_eq!(out[0].0, 0);
}
#[test]
fn soft_nms_disjoint_no_decay() {
let boxes = vec![0.0_f32, 0.0, 1.0, 1.0, 5.0, 5.0, 6.0, 6.0];
let scores = vec![0.7_f32, 0.6];
let out = soft_nms(&boxes, &scores, 2, 0.5, 0.0).expect("ok");
assert_eq!(out.len(), 2);
assert!((out[0].1 - 0.7).abs() < 1e-5);
assert!((out[1].1 - 0.6).abs() < 1e-5);
}
#[test]
fn soft_nms_deterministic() {
let boxes = vec![
0.0_f32, 0.0, 1.0, 1.0, 0.5, 0.5, 1.5, 1.5, 5.0, 5.0, 6.0, 6.0,
];
let scores = vec![0.4_f32, 0.7, 0.6];
let a = soft_nms(&boxes, &scores, 3, 0.5, 0.1).expect("ok");
let b = soft_nms(&boxes, &scores, 3, 0.5, 0.1).expect("ok");
assert_eq!(a, b);
}
#[test]
fn soft_nms_err_invalid_sigma() {
let boxes = vec![0.0_f32, 0.0, 1.0, 1.0];
let scores = vec![0.5_f32];
let r = soft_nms(&boxes, &scores, 1, 0.0, 0.0);
assert!(matches!(r, Err(VisionError::NonFinite(_))));
let r2 = soft_nms(&boxes, &scores, 1, -1.0, 0.0);
assert!(matches!(r2, Err(VisionError::NonFinite(_))));
}
#[test]
fn soft_nms_err_empty() {
let r = soft_nms(&[], &[], 0, 0.5, 0.0);
assert!(matches!(r, Err(VisionError::EmptyInput(_))));
}
}