use std::f64::consts::PI;
use std::ops::{Add, Sub};
use wide::f64x4;
#[derive(Clone, Copy)]
pub struct Point {
pub x: f64,
pub y: f64,
}
impl Add for Point {
type Output = Point;
fn add(self, other: Point) -> Point {
Point::new(self.x + other.x, self.y + other.y)
}
}
impl Sub for Point {
type Output = Point;
fn sub(self, other: Point) -> Point {
Point::new(self.x - other.x, self.y - other.y)
}
}
impl Point {
pub fn new(x: f64, y: f64) -> Self {
Point { x, y }
}
pub fn cross(self, other: Point) -> f64 {
self.x * other.y - self.y * other.x
}
}
pub fn cxcywha_to_points(cx: f64, cy: f64, w: f64, h: f64, a: f64) -> (Point, Point, Point, Point) {
let radians = PI * a / 180.0;
let dx = w / 2.0;
let dy = h / 2.0;
let sin_radians = radians.sin();
let cos_radians = radians.cos();
let dxcos = dx * cos_radians;
let dxsin = dx * sin_radians;
let dycos = dy * cos_radians;
let dysin = dy * sin_radians;
let (p1, p2, p3, p4) = (
Point::new(cx, cy) + Point::new(-dxcos - -dysin, -dxsin + -dycos),
Point::new(cx, cy) + Point::new(dxcos - -dysin, dxsin + -dycos),
Point::new(cx, cy) + Point::new(dxcos - dysin, dxsin + dycos),
Point::new(cx, cy) + Point::new(-dxcos - dysin, -dxsin + dycos),
);
(p1, p2, p3, p4)
}
#[derive(Clone, Copy)]
pub struct Rect {
pub p1: Point,
pub p2: Point,
pub p3: Point,
pub p4: Point,
}
impl Rect {
pub fn new(cx: f64, cy: f64, w: f64, h: f64, a: f64) -> Self {
let (p1, p2, p3, p4) = cxcywha_to_points(cx, cy, w, h, a);
Rect { p1, p2, p3, p4 }
}
pub fn points(&self) -> [Point; 4] {
[self.p1, self.p2, self.p3, self.p4]
}
}
#[derive(Debug)]
pub struct Line {
pub a: f64,
pub b: f64,
pub c: f64,
}
impl Line {
pub fn new(p1: Point, p2: Point) -> Self {
let a = p2.y - p1.y;
let b = p1.x - p2.x;
let c = p2.cross(p1);
Line { a, b, c }
}
pub fn call(&self, p: Point) -> f64 {
self.a * p.x + self.b * p.y + self.c
}
pub fn intersection(&self, other: &Line) -> Point {
let w = self.a * other.b - self.b * other.a;
Point::new(
(self.b * other.c - self.c * other.b) / w,
(self.c * other.a - self.a * other.c) / w,
)
}
}
const MAX_POLY_VERTS: usize = 8;
struct SoaPoly {
xs: [f64; MAX_POLY_VERTS],
ys: [f64; MAX_POLY_VERTS],
len: usize,
}
impl SoaPoly {
#[inline]
fn from_rect(rect: &Rect) -> Self {
let pts = rect.points();
let mut xs = [0.0; MAX_POLY_VERTS];
let mut ys = [0.0; MAX_POLY_VERTS];
for (i, p) in pts.iter().enumerate() {
xs[i] = p.x;
ys[i] = p.y;
}
SoaPoly { xs, ys, len: 4 }
}
#[inline]
fn line_values(&self, line: &Line) -> ([f64; MAX_POLY_VERTS], usize) {
let mut values = [0.0f64; MAX_POLY_VERTS];
let a_vec = f64x4::splat(line.a);
let b_vec = f64x4::splat(line.b);
let c_vec = f64x4::splat(line.c);
if self.len >= 4 {
let xs = f64x4::from([self.xs[0], self.xs[1], self.xs[2], self.xs[3]]);
let ys = f64x4::from([self.ys[0], self.ys[1], self.ys[2], self.ys[3]]);
let result = a_vec * xs + b_vec * ys + c_vec;
let arr: [f64; 4] = result.into();
values[..4].copy_from_slice(&arr);
} else {
for (i, v) in values.iter_mut().enumerate().take(self.len) {
*v = line.a * self.xs[i] + line.b * self.ys[i] + line.c;
}
return (values, self.len);
}
if self.len > 4 {
let remaining = self.len - 4;
let mut rx = [0.0f64; 4];
let mut ry = [0.0f64; 4];
rx[..remaining].copy_from_slice(&self.xs[4..4 + remaining]);
ry[..remaining].copy_from_slice(&self.ys[4..4 + remaining]);
let xs = f64x4::from(rx);
let ys = f64x4::from(ry);
let result = a_vec * xs + b_vec * ys + c_vec;
let arr: [f64; 4] = result.into();
values[4..(remaining + 4)].copy_from_slice(&arr[..remaining]);
}
(values, self.len)
}
#[inline]
fn area(&self) -> f64 {
if self.len <= 2 {
return 0.0;
}
let n = self.len;
let mut xs_next = [0.0f64; MAX_POLY_VERTS];
let mut ys_next = [0.0f64; MAX_POLY_VERTS];
xs_next[..n - 1].copy_from_slice(&self.xs[1..n]);
ys_next[..n - 1].copy_from_slice(&self.ys[1..n]);
xs_next[n - 1] = self.xs[0];
ys_next[n - 1] = self.ys[0];
let mut sum = 0.0f64;
if n >= 4 {
let x = f64x4::from([self.xs[0], self.xs[1], self.xs[2], self.xs[3]]);
let y = f64x4::from([self.ys[0], self.ys[1], self.ys[2], self.ys[3]]);
let xn = f64x4::from([xs_next[0], xs_next[1], xs_next[2], xs_next[3]]);
let yn = f64x4::from([ys_next[0], ys_next[1], ys_next[2], ys_next[3]]);
let cross = x * yn - y * xn;
let arr: [f64; 4] = cross.into();
sum += arr[0] + arr[1] + arr[2] + arr[3];
}
if n > 4 {
let remaining = n - 4;
let mut rx = [0.0f64; 4];
let mut ry = [0.0f64; 4];
let mut rxn = [0.0f64; 4];
let mut ryn = [0.0f64; 4];
rx[..remaining].copy_from_slice(&self.xs[4..4 + remaining]);
ry[..remaining].copy_from_slice(&self.ys[4..4 + remaining]);
rxn[..remaining].copy_from_slice(&xs_next[4..4 + remaining]);
ryn[..remaining].copy_from_slice(&ys_next[4..4 + remaining]);
let x = f64x4::from(rx);
let y = f64x4::from(ry);
let xn = f64x4::from(rxn);
let yn = f64x4::from(ryn);
let cross = x * yn - y * xn;
let arr: [f64; 4] = cross.into();
sum += arr[..remaining].iter().sum::<f64>();
} else if n < 4 {
for i in 0..n {
sum += self.xs[i] * ys_next[i] - self.ys[i] * xs_next[i];
}
return 0.5 * sum;
}
0.5 * sum
}
}
pub fn intersection_area(rect1: &Rect, rect2: &Rect) -> f64 {
let mut poly = SoaPoly::from_rect(rect1);
let r2_pts = rect2.points();
for edge_idx in 0..4 {
if poly.len <= 2 {
return 0.0;
}
let next_idx = (edge_idx + 1) & 3; let line = Line::new(r2_pts[edge_idx], r2_pts[next_idx]);
let (line_vals, n) = poly.line_values(&line);
let mut new_xs = [0.0f64; MAX_POLY_VERTS];
let mut new_ys = [0.0f64; MAX_POLY_VERTS];
let mut new_len: usize = 0;
for i in 0..n {
let next_i = if i + 1 < n { i + 1 } else { 0 };
let s_val = line_vals[i];
let t_val = line_vals[next_i];
if s_val <= 0.0 {
new_xs[new_len] = poly.xs[i];
new_ys[new_len] = poly.ys[i];
new_len += 1;
}
if s_val * t_val < 0.0 {
let s_pt = Point::new(poly.xs[i], poly.ys[i]);
let t_pt = Point::new(poly.xs[next_i], poly.ys[next_i]);
let intersection_pt = line.intersection(&Line::new(s_pt, t_pt));
new_xs[new_len] = intersection_pt.x;
new_ys[new_len] = intersection_pt.y;
new_len += 1;
}
}
poly.xs = new_xs;
poly.ys = new_ys;
poly.len = new_len;
}
poly.area()
}
pub fn minimal_bounding_rect(points: &[Point]) -> (f64, f64, f64, f64) {
let (mut min_x, mut min_y, mut max_x, mut max_y) = (f64::MAX, f64::MAX, f64::MIN, f64::MIN);
for point in points {
min_x = min_x.min(point.x);
min_y = min_y.min(point.y);
max_x = max_x.max(point.x);
max_y = max_y.max(point.y);
}
(min_x, min_y, max_x, max_y)
}
pub fn envelopes_intersect(rect1: &Rect, rect2: &Rect) -> bool {
let r1_pts = rect1.points();
let r2_pts = rect2.points();
let (r1_min_x, r1_min_y, r1_max_x, r1_max_y) = minimal_bounding_rect(&r1_pts);
let (r2_min_x, r2_min_y, r2_max_x, r2_max_y) = minimal_bounding_rect(&r2_pts);
!(r1_max_x < r2_min_x || r2_max_x < r1_min_x || r1_max_y < r2_min_y || r2_max_y < r1_min_y)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rotated_intersection_normal_case() {
let r1 = Rect::new(10., 15., 15., 10., 30.);
let r2 = Rect::new(15., 15., 20., 10., 0.);
let intersection = intersection_area(&r1, &r2);
assert_eq!(intersection, 110.17763185469022);
}
#[test]
fn test_rotated_intersection_zero_intersection() {
let r1 = Rect::new(10., 15., 15., 10., 30.);
let r2 = Rect::new(150., 150., 20., 10., 0.);
let intersection = intersection_area(&r1, &r2);
assert_eq!(intersection, 0.0);
}
#[test]
fn test_rotated_intersection_max_intersection() {
let r1 = Rect::new(150., 150., 20., 10., 0.);
let r2 = Rect::new(150., 150., 20., 10., 0.);
let intersection = intersection_area(&r1, &r2);
assert_eq!(intersection, 200.0);
}
#[test]
fn test_envelopes_intersect_overlapping() {
let r1 = Rect::new(10., 15., 15., 10., 30.);
let r2 = Rect::new(15., 15., 20., 10., 0.);
assert!(envelopes_intersect(&r1, &r2));
}
#[test]
fn test_envelopes_intersect_non_overlapping() {
let r1 = Rect::new(10., 15., 15., 10., 30.);
let r2 = Rect::new(150., 150., 20., 10., 0.);
assert!(!envelopes_intersect(&r1, &r2));
}
#[test]
fn test_envelopes_intersect_identical() {
let r1 = Rect::new(50., 50., 20., 10., 45.);
let r2 = Rect::new(50., 50., 20., 10., 45.);
assert!(envelopes_intersect(&r1, &r2));
}
#[test]
fn test_envelopes_intersect_touching_edge() {
let r1 = Rect::new(10., 10., 20., 10., 0.);
let r2 = Rect::new(30., 10., 20., 10., 0.);
assert!(envelopes_intersect(&r1, &r2));
}
#[test]
fn test_envelopes_intersect_rotated_far_apart() {
let r1 = Rect::new(0., 0., 10., 10., 45.);
let r2 = Rect::new(100., 100., 10., 10., 90.);
assert!(!envelopes_intersect(&r1, &r2));
}
#[test]
fn test_envelopes_intersect_rotated_overlapping() {
let r1 = Rect::new(0., 0., 20., 10., 45.);
let r2 = Rect::new(5., 5., 20., 10., -45.);
assert!(envelopes_intersect(&r1, &r2));
}
}