use std::f64::consts::PI;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RotatedBox {
pub cx: f64,
pub cy: f64,
pub w: f64,
pub h: f64,
pub theta: f64,
}
impl RotatedBox {
#[inline]
pub fn new(cx: f64, cy: f64, w: f64, h: f64, theta: f64) -> Self {
Self {
cx,
cy,
w,
h,
theta,
}
}
#[inline]
pub fn area(&self) -> f64 {
self.w * self.h
}
pub fn corners(&self) -> [[f64; 2]; 4] {
let cos_a = self.theta.cos();
let sin_a = self.theta.sin();
let hw = self.w * 0.5;
let hh = self.h * 0.5;
let local: [[f64; 2]; 4] = [[hw, hh], [-hw, hh], [-hw, -hh], [hw, -hh]];
let mut out = [[0.0_f64; 2]; 4];
for (i, lc) in local.iter().enumerate() {
out[i] = [
self.cx + cos_a * lc[0] - sin_a * lc[1],
self.cy + sin_a * lc[0] + cos_a * lc[1],
];
}
out
}
pub fn contains(&self, px: f64, py: f64) -> bool {
let dx = px - self.cx;
let dy = py - self.cy;
let cos_a = self.theta.cos();
let sin_a = self.theta.sin();
let local_x = cos_a * dx + sin_a * dy;
let local_y = -sin_a * dx + cos_a * dy;
local_x.abs() <= self.w * 0.5 && local_y.abs() <= self.h * 0.5
}
}
fn rotated_box_aabb(b: &RotatedBox) -> (f64, f64, f64, f64) {
let corners = b.corners();
let min_x = corners.iter().map(|c| c[0]).fold(f64::INFINITY, f64::min);
let max_x = corners
.iter()
.map(|c| c[0])
.fold(f64::NEG_INFINITY, f64::max);
let min_y = corners.iter().map(|c| c[1]).fold(f64::INFINITY, f64::min);
let max_y = corners
.iter()
.map(|c| c[1])
.fold(f64::NEG_INFINITY, f64::max);
(min_x, min_y, max_x, max_y)
}
fn enclosing_aabb(a: &RotatedBox, b: &RotatedBox) -> (f64, f64, f64, f64) {
let (ax1, ay1, ax2, ay2) = rotated_box_aabb(a);
let (bx1, by1, bx2, by2) = rotated_box_aabb(b);
(ax1.min(bx1), ay1.min(by1), ax2.max(bx2), ay2.max(by2))
}
fn polygon_area(vertices: &[[f64; 2]]) -> f64 {
let n = vertices.len();
if n < 3 {
return 0.0;
}
let mut sum = 0.0_f64;
for i in 0..n {
let j = (i + 1) % n;
sum += vertices[i][0] * vertices[j][1];
sum -= vertices[j][0] * vertices[i][1];
}
sum.abs() * 0.5
}
#[inline]
fn is_inside_halfplane(p: [f64; 2], p1: [f64; 2], p2: [f64; 2]) -> bool {
let cross = (p2[0] - p1[0]) * (p[1] - p1[1]) - (p2[1] - p1[1]) * (p[0] - p1[0]);
cross >= 0.0
}
fn line_intersection(a: [f64; 2], b: [f64; 2], p1: [f64; 2], p2: [f64; 2]) -> Option<[f64; 2]> {
let d_ab = [b[0] - a[0], b[1] - a[1]];
let d_p = [p2[0] - p1[0], p2[1] - p1[1]];
let denom = d_ab[0] * d_p[1] - d_ab[1] * d_p[0];
if denom.abs() < 1e-12 {
return None; }
let t = ((p1[0] - a[0]) * d_p[1] - (p1[1] - a[1]) * d_p[0]) / denom;
Some([a[0] + t * d_ab[0], a[1] + t * d_ab[1]])
}
fn clip_polygon_by_halfplane(polygon: &[[f64; 2]], p1: [f64; 2], p2: [f64; 2]) -> Vec<[f64; 2]> {
if polygon.is_empty() {
return Vec::new();
}
let mut output: Vec<[f64; 2]> = Vec::with_capacity(polygon.len() + 1);
let n = polygon.len();
for i in 0..n {
let current = polygon[i];
let prev = polygon[(i + n - 1) % n];
let current_inside = is_inside_halfplane(current, p1, p2);
let prev_inside = is_inside_halfplane(prev, p1, p2);
match (current_inside, prev_inside) {
(true, false) => {
if let Some(pt) = line_intersection(prev, current, p1, p2) {
output.push(pt);
}
output.push(current);
}
(true, true) => {
output.push(current);
}
(false, true) => {
if let Some(pt) = line_intersection(prev, current, p1, p2) {
output.push(pt);
}
}
(false, false) => {
}
}
}
output
}
fn polygon_intersection_area(poly_a: &[[f64; 2]], poly_b: &[[f64; 2]]) -> f64 {
if poly_a.is_empty() || poly_b.is_empty() {
return 0.0;
}
let mut clipped: Vec<[f64; 2]> = poly_a.to_vec();
let nb = poly_b.len();
for i in 0..nb {
if clipped.is_empty() {
return 0.0;
}
let p1 = poly_b[i];
let p2 = poly_b[(i + 1) % nb];
clipped = clip_polygon_by_halfplane(&clipped, p1, p2);
}
polygon_area(&clipped)
}
pub fn rotated_box_intersection(box1: &RotatedBox, box2: &RotatedBox) -> f64 {
let c1: Vec<[f64; 2]> = box1.corners().to_vec();
let c2: Vec<[f64; 2]> = box2.corners().to_vec();
polygon_intersection_area(&c1, &c2)
}
pub fn rotated_iou(box1: &RotatedBox, box2: &RotatedBox) -> f64 {
let inter = rotated_box_intersection(box1, box2);
if inter <= 0.0 {
return 0.0;
}
let union = box1.area() + box2.area() - inter;
if union <= 0.0 {
return 0.0;
}
(inter / union).clamp(0.0, 1.0)
}
pub fn rotated_giou(box1: &RotatedBox, box2: &RotatedBox) -> f64 {
let inter = rotated_box_intersection(box1, box2);
let union = box1.area() + box2.area() - inter;
let (enc_x1, enc_y1, enc_x2, enc_y2) = enclosing_aabb(box1, box2);
let enc_area = (enc_x2 - enc_x1) * (enc_y2 - enc_y1);
if enc_area <= 0.0 {
return -1.0;
}
let iou_val = if union <= 0.0 {
0.0
} else {
(inter / union).clamp(0.0, 1.0)
};
let penalty = if enc_area > 0.0 {
(enc_area - union) / enc_area
} else {
0.0
};
iou_val - penalty
}
pub fn rotated_diou(box1: &RotatedBox, box2: &RotatedBox) -> f64 {
let inter = rotated_box_intersection(box1, box2);
let union = box1.area() + box2.area() - inter;
let iou_val = if union <= 0.0 {
0.0
} else {
(inter / union).clamp(0.0, 1.0)
};
let rho2 = (box1.cx - box2.cx).powi(2) + (box1.cy - box2.cy).powi(2);
let (enc_x1, enc_y1, enc_x2, enc_y2) = enclosing_aabb(box1, box2);
let c2 = (enc_x2 - enc_x1).powi(2) + (enc_y2 - enc_y1).powi(2);
if c2 < 1e-15 {
return iou_val;
}
iou_val - rho2 / c2
}
pub fn rotated_iou_matrix(boxes: &[RotatedBox]) -> Vec<Vec<f64>> {
let n = boxes.len();
let mut mat = vec![vec![0.0_f64; n]; n];
for i in 0..n {
mat[i][i] = 1.0; for j in (i + 1)..n {
let v = rotated_iou(&boxes[i], &boxes[j]);
mat[i][j] = v;
mat[j][i] = v;
}
}
mat
}
pub fn rotated_nms(boxes: &[RotatedBox], scores: &[f64], iou_threshold: f64) -> Vec<usize> {
if boxes.len() != scores.len() || boxes.is_empty() {
return Vec::new();
}
let mut order: Vec<usize> = (0..boxes.len()).collect();
order.sort_by(|&a, &b| {
scores[b]
.partial_cmp(&scores[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut suppressed = vec![false; boxes.len()];
let mut kept = Vec::new();
for &idx in &order {
if suppressed[idx] {
continue;
}
kept.push(idx);
for &other in &order {
if other == idx || suppressed[other] {
continue;
}
let iou_val = rotated_iou(&boxes[idx], &boxes[other]);
if iou_val > iou_threshold {
suppressed[other] = true;
}
}
}
kept
}
#[cfg(test)]
mod tests {
use super::*;
fn axis_aligned_iou(
x1a: f64,
y1a: f64,
x2a: f64,
y2a: f64,
x1b: f64,
y1b: f64,
x2b: f64,
y2b: f64,
) -> f64 {
let ix1 = x1a.max(x1b);
let iy1 = y1a.max(y1b);
let ix2 = x2a.min(x2b);
let iy2 = y2a.min(y2b);
let inter = ((ix2 - ix1).max(0.0)) * ((iy2 - iy1).max(0.0));
let a_area = (x2a - x1a) * (y2a - y1a);
let b_area = (x2b - x1b) * (y2b - y1b);
let union = a_area + b_area - inter;
if union <= 0.0 {
0.0
} else {
inter / union
}
}
#[test]
fn test_axis_aligned_matches_standard_iou() {
let a = RotatedBox::new(1.0, 1.0, 2.0, 2.0, 0.0);
let b = RotatedBox::new(2.0, 2.0, 2.0, 2.0, 0.0);
let riou = rotated_iou(&a, &b);
let expected = axis_aligned_iou(0.0, 0.0, 2.0, 2.0, 1.0, 1.0, 3.0, 3.0);
assert!(
(riou - expected).abs() < 1e-6,
"rotated IoU (θ=0) should equal standard IoU {expected:.6}, got {riou:.6}"
);
}
#[test]
fn test_identical_boxes_iou_one() {
let a = RotatedBox::new(3.0, 5.0, 4.0, 6.0, 0.7);
let b = a;
let iou = rotated_iou(&a, &b);
assert!(
(iou - 1.0).abs() < 1e-9,
"identical boxes should have IoU=1, got {iou}"
);
}
#[test]
fn test_non_overlapping_iou_zero() {
let a = RotatedBox::new(0.0, 0.0, 1.0, 1.0, 0.0);
let b = RotatedBox::new(100.0, 100.0, 1.0, 1.0, 0.0);
let iou = rotated_iou(&a, &b);
assert!(
iou.abs() < 1e-12,
"non-overlapping boxes should have IoU=0, got {iou}"
);
}
#[test]
fn test_square_90_degree_rotation_iou_one() {
let a = RotatedBox::new(0.0, 0.0, 2.0, 2.0, 0.0);
let b = RotatedBox::new(0.0, 0.0, 2.0, 2.0, PI / 2.0);
let iou = rotated_iou(&a, &b);
assert!(
(iou - 1.0).abs() < 1e-9,
"90-degree rotation of a square should have IoU=1, got {iou}"
);
}
#[test]
fn test_area_preserved_under_rotation() {
let b = RotatedBox::new(0.0, 0.0, 3.0, 4.0, PI / 4.0);
let expected_area = 3.0 * 4.0;
assert!(
(b.area() - expected_area).abs() < 1e-12,
"area should be w*h regardless of angle, got {}",
b.area()
);
}
#[test]
fn test_giou_le_iou_and_negative_for_nonoverlap() {
let a = RotatedBox::new(0.0, 0.0, 1.0, 1.0, 0.0);
let b = RotatedBox::new(5.0, 5.0, 1.0, 1.0, 0.0);
let iou_val = rotated_iou(&a, &b);
let giou_val = rotated_giou(&a, &b);
assert!(
giou_val <= iou_val + 1e-9,
"GIoU should be ≤ IoU; got GIoU={giou_val}, IoU={iou_val}"
);
assert!(
giou_val < 0.0,
"GIoU should be negative for non-overlapping boxes, got {giou_val}"
);
}
#[test]
fn test_nms_removes_overlapping() {
let boxes = vec![
RotatedBox::new(0.0, 0.0, 2.0, 2.0, 0.0), RotatedBox::new(0.1, 0.0, 2.0, 2.0, 0.0), RotatedBox::new(20.0, 20.0, 2.0, 2.0, 0.0), ];
let scores = vec![0.9, 0.8, 0.7];
let kept = rotated_nms(&boxes, &scores, 0.5);
assert!(kept.contains(&0), "highest-scoring box should be kept");
assert!(kept.contains(&2), "distant box should be kept");
assert!(
!kept.contains(&1),
"overlapping lower-score box should be suppressed"
);
}
#[test]
fn test_iou_matrix_symmetric() {
let boxes = vec![
RotatedBox::new(0.0, 0.0, 2.0, 2.0, 0.0),
RotatedBox::new(1.0, 0.0, 2.0, 2.0, 0.3),
RotatedBox::new(5.0, 5.0, 1.0, 3.0, PI / 6.0),
];
let mat = rotated_iou_matrix(&boxes);
let n = boxes.len();
for i in 0..n {
for j in 0..n {
let diff = (mat[i][j] - mat[j][i]).abs();
assert!(
diff < 1e-12,
"IoU matrix should be symmetric: mat[{i}][{j}]={} but mat[{j}][{i}]={}",
mat[i][j],
mat[j][i]
);
}
}
}
#[test]
fn test_corners_distinct_for_nondegenerate_box() {
let b = RotatedBox::new(1.0, 2.0, 3.0, 4.0, PI / 5.0);
let corners = b.corners();
for i in 0..4 {
for j in (i + 1)..4 {
let dx = corners[i][0] - corners[j][0];
let dy = corners[i][1] - corners[j][1];
let dist = (dx * dx + dy * dy).sqrt();
assert!(
dist > 1e-6,
"corners {i} and {j} should be distinct, distance={dist}"
);
}
}
}
#[test]
fn test_polygon_area_unit_square() {
let square: [[f64; 2]; 4] = [[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]];
let area = polygon_area(&square);
assert!(
(area - 1.0).abs() < 1e-12,
"unit square area should be 1.0, got {area}"
);
}
#[test]
fn test_polygon_area_triangle() {
let tri: [[f64; 2]; 3] = [[0.0, 0.0], [4.0, 0.0], [0.0, 3.0]];
let area = polygon_area(&tri);
assert!(
(area - 6.0).abs() < 1e-12,
"right triangle area should be 6, got {area}"
);
}
#[test]
fn test_contains_centre() {
let b = RotatedBox::new(5.0, 5.0, 4.0, 6.0, 0.0);
assert!(b.contains(5.0, 5.0), "centre should be inside");
}
#[test]
fn test_contains_outside() {
let b = RotatedBox::new(0.0, 0.0, 2.0, 2.0, 0.0);
assert!(
!b.contains(2.0, 2.0),
"corner just outside should not be inside"
);
assert!(!b.contains(5.0, 0.0), "far point should not be inside");
}
#[test]
fn test_contains_rotated() {
let b = RotatedBox::new(0.0, 0.0, 4.0, 2.0, PI / 2.0);
assert!(
b.contains(0.0, 1.5),
"point within rotated extents should be inside"
);
assert!(
!b.contains(1.5, 0.0),
"point outside rotated extents should be outside"
);
}
#[test]
fn test_nms_length_mismatch_returns_empty() {
let boxes = vec![RotatedBox::new(0.0, 0.0, 1.0, 1.0, 0.0)];
let scores = vec![0.9, 0.8]; let kept = rotated_nms(&boxes, &scores, 0.5);
assert!(
kept.is_empty(),
"mismatched lengths should return empty vec"
);
}
#[test]
fn test_diou_le_iou() {
let a = RotatedBox::new(0.0, 0.0, 4.0, 4.0, 0.0);
let b = RotatedBox::new(2.0, 0.0, 4.0, 4.0, 0.3);
let iou_val = rotated_iou(&a, &b);
let diou_val = rotated_diou(&a, &b);
assert!(
diou_val <= iou_val + 1e-9,
"DIoU should be ≤ IoU; got DIoU={diou_val}, IoU={iou_val}"
);
}
#[test]
fn test_diou_identical_boxes() {
let a = RotatedBox::new(1.0, 2.0, 3.0, 4.0, 0.5);
let b = a;
let diou_val = rotated_diou(&a, &b);
assert!(
(diou_val - 1.0).abs() < 1e-9,
"identical boxes should have DIoU=1, got {diou_val}"
);
}
}