use crate::error::{MetricsError, Result};
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub struct RotatedBBox {
pub center_x: f64,
pub center_y: f64,
pub width: f64,
pub height: f64,
pub angle_rad: f64,
}
impl RotatedBBox {
pub fn new(center_x: f64, center_y: f64, width: f64, height: f64, angle_rad: f64) -> Self {
Self {
center_x,
center_y,
width,
height,
angle_rad,
}
}
pub fn area(&self) -> f64 {
self.width * self.height
}
}
impl Default for RotatedBBox {
fn default() -> Self {
Self {
center_x: 0.0,
center_y: 0.0,
width: 1.0,
height: 1.0,
angle_rad: 0.0,
}
}
}
pub fn rotated_bbox_corners(bbox: &RotatedBBox) -> [(f64, f64); 4] {
let cos_a = bbox.angle_rad.cos();
let sin_a = bbox.angle_rad.sin();
let hw = bbox.width * 0.5;
let hh = bbox.height * 0.5;
let local = [(hw, hh), (-hw, hh), (-hw, -hh), (hw, -hh)];
let mut corners = [(0.0, 0.0); 4];
for (i, &(lx, ly)) in local.iter().enumerate() {
corners[i] = (
bbox.center_x + cos_a * lx - sin_a * ly,
bbox.center_y + sin_a * lx + cos_a * ly,
);
}
corners
}
pub fn polygon_area(vertices: &[(f64, f64)]) -> f64 {
let n = vertices.len();
if n < 3 {
return 0.0;
}
let mut area = 0.0_f64;
for i in 0..n {
let j = (i + 1) % n;
area += vertices[i].0 * vertices[j].1;
area -= vertices[j].0 * vertices[i].1;
}
area.abs() * 0.5
}
#[inline]
fn is_inside(p: (f64, f64), p1: (f64, f64), p2: (f64, f64)) -> bool {
let cross = (p2.0 - p1.0) * (p.1 - p1.1) - (p2.1 - p1.1) * (p.0 - p1.0);
cross >= 0.0
}
fn line_intersection(
a: (f64, f64),
b: (f64, f64),
p1: (f64, f64),
p2: (f64, f64),
) -> Option<(f64, f64)> {
let d_ab = (b.0 - a.0, b.1 - a.1);
let d_p = (p2.0 - p1.0, p2.1 - p1.1);
let denom = d_ab.0 * d_p.1 - d_ab.1 * d_p.0;
if denom.abs() < 1e-12 {
return None; }
let t = ((p1.0 - a.0) * d_p.1 - (p1.1 - a.1) * d_p.0) / denom;
Some((a.0 + t * d_ab.0, a.1 + t * d_ab.1))
}
fn clip_polygon_by_halfplane(
poly: &[(f64, f64)],
p1: (f64, f64),
p2: (f64, f64),
) -> Vec<(f64, f64)> {
if poly.is_empty() {
return vec![];
}
let mut output = Vec::with_capacity(poly.len() + 1);
let n = poly.len();
for i in 0..n {
let current = poly[i];
let prev = poly[(i + n - 1) % n];
let current_inside = is_inside(current, p1, p2);
let prev_inside = is_inside(prev, p1, p2);
if current_inside {
if !prev_inside {
if let Some(pt) = line_intersection(prev, current, p1, p2) {
output.push(pt);
}
}
output.push(current);
} else if prev_inside {
if let Some(pt) = line_intersection(prev, current, p1, p2) {
output.push(pt);
}
}
}
output
}
fn polygon_intersection_area(corners_a: &[(f64, f64)], corners_b: &[(f64, f64)]) -> f64 {
if corners_a.is_empty() || corners_b.is_empty() {
return 0.0;
}
let mut clipped: Vec<(f64, f64)> = corners_a.to_vec();
let n_b = corners_b.len();
for i in 0..n_b {
if clipped.is_empty() {
break;
}
let p1 = corners_b[i];
let p2 = corners_b[(i + 1) % n_b];
clipped = clip_polygon_by_halfplane(&clipped, p1, p2);
}
polygon_area(&clipped)
}
pub fn rotated_iou(a: &RotatedBBox, b: &RotatedBBox) -> f64 {
let corners_a = rotated_bbox_corners(a);
let corners_b = rotated_bbox_corners(b);
let poly_a: Vec<(f64, f64)> = corners_a.to_vec();
let poly_b: Vec<(f64, f64)> = corners_b.to_vec();
let inter_area = polygon_intersection_area(&poly_a, &poly_b);
if inter_area <= 0.0 {
return 0.0;
}
let area_a = a.area();
let area_b = b.area();
let union_area = area_a + area_b - inter_area;
if union_area <= 0.0 {
0.0
} else {
(inter_area / union_area).clamp(0.0, 1.0)
}
}
pub fn rotated_nms(
boxes: &[RotatedBBox],
scores: &[f64],
iou_threshold: f64,
) -> Result<Vec<usize>> {
if boxes.len() != scores.len() {
return Err(MetricsError::DimensionMismatch(format!(
"boxes len {} != scores len {}",
boxes.len(),
scores.len()
)));
}
if boxes.is_empty() {
return Ok(vec![]);
}
let mut order: Vec<usize> = (0..boxes.len()).collect();
order.sort_by(|&a, &b| {
scores[b]
.partial_cmp(&scores[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut suppressed = vec![false; boxes.len()];
let mut kept = Vec::new();
for &idx in &order {
if suppressed[idx] {
continue;
}
kept.push(idx);
for &other in &order {
if suppressed[other] || other == idx {
continue;
}
let iou = rotated_iou(&boxes[idx], &boxes[other]);
if iou > iou_threshold {
suppressed[other] = true;
}
}
}
Ok(kept)
}
pub fn rotated_iou_matrix(boxes_a: &[RotatedBBox], boxes_b: &[RotatedBBox]) -> Vec<Vec<f64>> {
boxes_a
.iter()
.map(|a| boxes_b.iter().map(|b| rotated_iou(a, b)).collect())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rotated_iou_identical_boxes() {
let a = RotatedBBox::new(0.0, 0.0, 2.0, 3.0, 0.0);
let b = RotatedBBox::new(0.0, 0.0, 2.0, 3.0, 0.0);
let iou = rotated_iou(&a, &b);
assert!(
(iou - 1.0).abs() < 1e-9,
"identical boxes should have IoU=1.0, got {iou}"
);
}
#[test]
fn test_rotated_iou_identical_rotated() {
let angle = std::f64::consts::FRAC_PI_4; let a = RotatedBBox::new(0.0, 0.0, 2.0, 2.0, angle);
let b = RotatedBBox::new(0.0, 0.0, 2.0, 2.0, angle);
let iou = rotated_iou(&a, &b);
assert!(
(iou - 1.0).abs() < 1e-9,
"identical rotated boxes should have IoU=1.0, got {iou}"
);
}
#[test]
fn test_rotated_iou_non_overlapping() {
let a = RotatedBBox::new(0.0, 0.0, 1.0, 1.0, 0.0);
let b = RotatedBBox::new(100.0, 100.0, 1.0, 1.0, 0.0);
let iou = rotated_iou(&a, &b);
assert!(
iou.abs() < 1e-12,
"non-overlapping boxes should have IoU=0, got {iou}"
);
}
#[test]
fn test_rotated_iou_axis_aligned_matches_standard() {
let a = RotatedBBox::new(0.0, 0.0, 2.0, 2.0, 0.0);
let b = RotatedBBox::new(1.0, 0.0, 2.0, 2.0, 0.0);
let iou = rotated_iou(&a, &b);
let expected = 2.0 / 6.0;
assert!(
(iou - expected).abs() < 1e-6,
"axis-aligned IoU should be {expected}, got {iou}"
);
}
#[test]
fn test_rotated_iou_45_degree() {
let a = RotatedBBox::new(0.0, 0.0, 2.0, 2.0, 0.0);
let b = RotatedBBox::new(0.0, 0.0, 2.0, 2.0, std::f64::consts::FRAC_PI_4);
let iou = rotated_iou(&a, &b);
assert!(iou > 0.0, "45-degree overlap should be > 0, got {iou}");
assert!(iou < 1.0, "45-degree overlap should be < 1, got {iou}");
assert!(
iou > 0.3 && iou < 1.0,
"45-degree IoU should be reasonable, got {iou}"
);
}
#[test]
fn test_polygon_area_triangle() {
let triangle = vec![(0.0, 0.0), (4.0, 0.0), (0.0, 3.0)];
let area = polygon_area(&triangle);
assert!(
(area - 6.0).abs() < 1e-12,
"triangle area should be 6, got {area}"
);
}
#[test]
fn test_polygon_area_rectangle() {
let rect = vec![(0.0, 0.0), (5.0, 0.0), (5.0, 3.0), (0.0, 3.0)];
let area = polygon_area(&rect);
assert!(
(area - 15.0).abs() < 1e-12,
"rectangle area should be 15, got {area}"
);
}
#[test]
fn test_polygon_area_empty() {
let area = polygon_area(&[]);
assert!(area.abs() < 1e-12);
}
#[test]
fn test_polygon_area_line() {
let line = vec![(0.0, 0.0), (1.0, 1.0)];
let area = polygon_area(&line);
assert!(area.abs() < 1e-12, "line should have 0 area");
}
#[test]
fn test_rotated_bbox_corners_no_rotation() {
let bbox = RotatedBBox::new(1.0, 2.0, 4.0, 6.0, 0.0);
let corners = rotated_bbox_corners(&bbox);
assert!((corners[0].0 - 3.0).abs() < 1e-12);
assert!((corners[0].1 - 5.0).abs() < 1e-12);
assert!((corners[2].0 - (-1.0)).abs() < 1e-12);
assert!((corners[2].1 - (-1.0)).abs() < 1e-12);
}
#[test]
fn test_rotated_bbox_corners_90_degrees() {
let bbox = RotatedBBox::new(0.0, 0.0, 2.0, 4.0, std::f64::consts::FRAC_PI_2);
let corners = rotated_bbox_corners(&bbox);
assert!((corners[0].0 - (-2.0)).abs() < 1e-9);
assert!((corners[0].1 - 1.0).abs() < 1e-9);
}
#[test]
fn test_rotated_nms_removes_overlapping() {
let boxes = vec![
RotatedBBox::new(0.0, 0.0, 2.0, 2.0, 0.0),
RotatedBBox::new(0.1, 0.0, 2.0, 2.0, 0.0), RotatedBBox::new(10.0, 10.0, 2.0, 2.0, 0.0), ];
let scores = vec![0.9, 0.8, 0.7];
let kept = rotated_nms(&boxes, &scores, 0.5).expect("should succeed");
assert!(kept.contains(&0), "highest scoring box should be kept");
assert!(kept.contains(&2), "distant box should be kept");
assert!(
!kept.contains(&1),
"overlapping lower-score box should be suppressed"
);
}
#[test]
fn test_rotated_nms_empty() {
let kept = rotated_nms(&[], &[], 0.5).expect("should succeed");
assert!(kept.is_empty());
}
#[test]
fn test_rotated_nms_no_suppression() {
let boxes = vec![
RotatedBBox::new(0.0, 0.0, 1.0, 1.0, 0.0),
RotatedBBox::new(50.0, 50.0, 1.0, 1.0, 0.0),
];
let scores = vec![0.9, 0.8];
let kept = rotated_nms(&boxes, &scores, 0.5).expect("should succeed");
assert_eq!(kept.len(), 2, "non-overlapping boxes should all be kept");
}
#[test]
fn test_rotated_iou_matrix_shape() {
let a = vec![
RotatedBBox::new(0.0, 0.0, 1.0, 1.0, 0.0),
RotatedBBox::new(5.0, 5.0, 1.0, 1.0, 0.0),
];
let b = vec![
RotatedBBox::new(0.0, 0.0, 1.0, 1.0, 0.0),
RotatedBBox::new(2.0, 2.0, 1.0, 1.0, 0.0),
RotatedBBox::new(5.0, 5.0, 1.0, 1.0, 0.0),
];
let mat = rotated_iou_matrix(&a, &b);
assert_eq!(mat.len(), 2);
assert_eq!(mat[0].len(), 3);
assert!((mat[0][0] - 1.0).abs() < 1e-9);
assert!((mat[1][2] - 1.0).abs() < 1e-9);
assert!(mat[0][2].abs() < 1e-9);
}
#[test]
fn test_rotated_bbox_default() {
let bbox = RotatedBBox::default();
assert!((bbox.center_x).abs() < 1e-12);
assert!((bbox.width - 1.0).abs() < 1e-12);
assert!((bbox.area() - 1.0).abs() < 1e-12);
}
#[test]
fn test_rotated_nms_dimension_mismatch() {
let boxes = vec![RotatedBBox::default()];
let scores = vec![0.5, 0.3]; let result = rotated_nms(&boxes, &scores, 0.5);
assert!(result.is_err());
}
}