use crate::utils::kalman::MeasurementVector;
use nalgebra::SVector;
pub fn tlwh_to_xyah(tlwh: &[f32; 4]) -> MeasurementVector {
let x = tlwh[0] + tlwh[2] / 2.0;
let y = tlwh[1] + tlwh[3] / 2.0;
let a = tlwh[2] / tlwh[3].max(1e-6);
let h = tlwh[3];
MeasurementVector::from_vec(vec![x, y, a, h])
}
pub fn xyah_to_tlwh<const N: usize>(state: &SVector<f32, N>) -> [f32; 4] {
let w = state[2] * state[3];
let h = state[3];
let x = state[0] - w / 2.0;
let y = state[1] - h / 2.0;
[x, y, w, h]
}
pub fn iou(box1: &[f32; 4], box2: &[f32; 4]) -> f32 {
let box1_tlbr = tlwh_to_tlbr(box1);
let box2_tlbr = tlwh_to_tlbr(box2);
let x1 = box1_tlbr[0].max(box2_tlbr[0]);
let y1 = box1_tlbr[1].max(box2_tlbr[1]);
let x2 = box1_tlbr[2].min(box2_tlbr[2]);
let y2 = box1_tlbr[3].min(box2_tlbr[3]);
let w = (x2 - x1).max(0.0);
let h = (y2 - y1).max(0.0);
let inter_area = w * h;
let area1 = box1[2] * box1[3];
let area2 = box2[2] * box2[3];
let union_area = area1 + area2 - inter_area;
if union_area <= 0.0 {
return 0.0;
}
inter_area / union_area
}
pub fn tlwh_to_tlbr(tlwh: &[f32; 4]) -> [f32; 4] {
[tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3]]
}
pub fn iou_batch(bboxes1: &[[f32; 4]], bboxes2: &[[f32; 4]]) -> Vec<Vec<f32>> {
let mut iou_matrix = vec![vec![0.0; bboxes2.len()]; bboxes1.len()];
for (i, box1) in bboxes1.iter().enumerate() {
for (j, box2) in bboxes2.iter().enumerate() {
iou_matrix[i][j] = iou(box1, box2);
}
}
iou_matrix
}
pub fn iou_cost_matrix(tracks: &[[f32; 4]], dets: &[[f32; 4]]) -> Vec<Vec<f32>> {
tracks
.iter()
.map(|t| dets.iter().map(|d| 1.0 - iou(t, d)).collect())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tlwh_to_tlbr() {
let tlwh = [10.0, 20.0, 30.0, 40.0];
let tlbr = tlwh_to_tlbr(&tlwh);
assert_eq!(tlbr, [10.0, 20.0, 40.0, 60.0]);
}
#[test]
fn test_iou_overlapping() {
let box1 = [0.0, 0.0, 10.0, 10.0];
let box2 = [5.0, 5.0, 10.0, 10.0];
let val = iou(&box1, &box2);
assert!((val - 0.142857).abs() < 1e-5);
}
#[test]
fn test_iou_no_overlap() {
let box1 = [0.0, 0.0, 10.0, 10.0];
let box2 = [20.0, 20.0, 10.0, 10.0];
let val = iou(&box1, &box2);
assert_eq!(val, 0.0);
}
#[test]
fn test_iou_degenerate_zero_area() {
let box1 = [0.0, 0.0, 0.0, 10.0]; let box2 = [0.0, 0.0, 10.0, 0.0]; assert_eq!(iou(&box1, &box2), 0.0);
}
#[test]
fn test_iou_contained() {
let box1 = [0.0, 0.0, 100.0, 100.0];
let box2 = [25.0, 25.0, 50.0, 50.0];
let val = iou(&box1, &box2);
assert_eq!(val, 0.25);
}
#[test]
fn test_iou_batch() {
let boxes1 = vec![[0.0, 0.0, 10.0, 10.0], [100.0, 100.0, 10.0, 10.0]];
let boxes2 = vec![
[0.0, 0.0, 10.0, 10.0], [5.0, 5.0, 10.0, 10.0], [200.0, 200.0, 10.0, 10.0], ];
let ious = iou_batch(&boxes1, &boxes2);
assert_eq!(ious.len(), 2);
assert_eq!(ious[0].len(), 3);
assert_eq!(ious[1].len(), 3);
assert_eq!(ious[0][0], 1.0);
assert!((ious[0][1] - 0.142857).abs() < 1e-4);
assert_eq!(ious[0][2], 0.0);
assert_eq!(ious[1][0], 0.0);
assert_eq!(ious[1][1], 0.0);
assert_eq!(ious[1][2], 0.0);
}
#[test]
fn test_iou_cost_matrix() {
let tracks = vec![[0.0, 0.0, 10.0, 10.0], [100.0, 100.0, 10.0, 10.0]];
let dets = vec![[0.0, 0.0, 10.0, 10.0], [200.0, 200.0, 10.0, 10.0]];
let cost = iou_cost_matrix(&tracks, &dets);
assert_eq!(cost.len(), 2);
assert_eq!(cost[0].len(), 2);
assert_eq!(cost[0][0], 0.0);
assert_eq!(cost[0][1], 1.0);
assert_eq!(cost[1][0], 1.0);
assert_eq!(cost[1][1], 1.0);
}
#[test]
fn test_iou_cost_matrix_empty() {
let tracks = vec![[0.0, 0.0, 10.0, 10.0]];
let cost = iou_cost_matrix(&tracks, &[]);
assert_eq!(cost.len(), 1);
assert!(cost[0].is_empty());
assert!(iou_cost_matrix(&[], &tracks).is_empty());
}
}