const STATE_DIM: usize = 8; const MEAS_DIM: usize = 4;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BBox {
pub x1: f32,
pub y1: f32,
pub x2: f32,
pub y2: f32,
}
impl BBox {
#[must_use]
pub fn new(x1: f32, y1: f32, x2: f32, y2: f32) -> Self {
Self {
x1: x1.min(x2),
y1: y1.min(y2),
x2: x1.max(x2),
y2: y1.max(y2),
}
}
#[must_use]
pub fn area(&self) -> f32 {
(self.x2 - self.x1).max(0.0) * (self.y2 - self.y1).max(0.0)
}
#[must_use]
pub fn center(&self) -> (f32, f32) {
((self.x1 + self.x2) * 0.5, (self.y1 + self.y2) * 0.5)
}
#[must_use]
pub fn width(&self) -> f32 {
(self.x2 - self.x1).max(0.0)
}
#[must_use]
pub fn height(&self) -> f32 {
(self.y2 - self.y1).max(0.0)
}
#[must_use]
pub fn to_xyah(&self) -> [f32; 4] {
let (cx, cy) = self.center();
let w = self.width();
let h = self.height().max(1e-6);
[cx, cy, w / h, h]
}
#[must_use]
pub fn iou(&self, other: &Self) -> f32 {
let ix1 = self.x1.max(other.x1);
let iy1 = self.y1.max(other.y1);
let ix2 = self.x2.min(other.x2);
let iy2 = self.y2.min(other.y2);
let inter_w = (ix2 - ix1).max(0.0);
let inter_h = (iy2 - iy1).max(0.0);
let inter = inter_w * inter_h;
if inter == 0.0 {
return 0.0;
}
let union = self.area() + other.area() - inter;
if union <= 0.0 {
return 0.0;
}
inter / union
}
#[must_use]
fn from_xyah(xyah: &[f32; 4]) -> Self {
let cx = xyah[0];
let cy = xyah[1];
let ar = xyah[2].max(1e-6);
let h = xyah[3].max(1e-6);
let w = ar * h;
Self {
x1: cx - w * 0.5,
y1: cy - h * 0.5,
x2: cx + w * 0.5,
y2: cy + h * 0.5,
}
}
}
fn mat_add_n<const N: usize>(a: &[f32; N], b: &[f32; N]) -> [f32; N] {
let mut out = [0.0f32; N];
for i in 0..N {
out[i] = a[i] + b[i];
}
out
}
fn mat_mul_sq<const S: usize, const SS: usize>(a: &[f32; SS], b: &[f32; SS]) -> [f32; SS] {
let mut out = [0.0f32; SS];
for i in 0..S {
for j in 0..S {
let mut sum = 0.0f32;
for k in 0..S {
sum += a[i * S + k] * b[k * S + j];
}
out[i * S + j] = sum;
}
}
out
}
fn mat_mul_rck<const R: usize, const C: usize, const K: usize>(a: &[f32], b: &[f32]) -> Vec<f32> {
let mut out = vec![0.0f32; R * K];
for i in 0..R {
for j in 0..K {
let mut sum = 0.0f32;
for k in 0..C {
sum += a[i * C + k] * b[k * K + j];
}
out[i * K + j] = sum;
}
}
out
}
fn mat_t<const R: usize, const C: usize>(a: &[f32]) -> Vec<f32> {
let mut out = vec![0.0f32; C * R];
for i in 0..R {
for j in 0..C {
out[j * R + i] = a[i * C + j];
}
}
out
}
fn mat_vec<const R: usize, const C: usize>(a: &[f32], x: &[f32; C]) -> [f32; R] {
let mut out = [0.0f32; R];
for i in 0..R {
let mut s = 0.0f32;
for j in 0..C {
s += a[i * C + j] * x[j];
}
out[i] = s;
}
out
}
fn mat_inv_small(a: &[f32], n: usize) -> Option<Vec<f32>> {
let mut aug = vec![0.0f32; n * 2 * n];
for i in 0..n {
for j in 0..n {
aug[i * 2 * n + j] = a[i * n + j];
}
aug[i * 2 * n + n + i] = 1.0;
}
for i in 0..n {
let mut max_r = i;
let mut max_v = aug[i * 2 * n + i].abs();
for k in (i + 1)..n {
let v = aug[k * 2 * n + i].abs();
if v > max_v {
max_v = v;
max_r = k;
}
}
if max_v < 1e-9 {
return None; }
if max_r != i {
for j in 0..2 * n {
aug.swap(i * 2 * n + j, max_r * 2 * n + j);
}
}
let pivot = aug[i * 2 * n + i];
for j in 0..2 * n {
aug[i * 2 * n + j] /= pivot;
}
for k in 0..n {
if k != i {
let factor = aug[k * 2 * n + i];
for j in 0..2 * n {
let v = aug[i * 2 * n + j] * factor;
aug[k * 2 * n + j] -= v;
}
}
}
}
let mut inv = vec![0.0f32; n * n];
for i in 0..n {
for j in 0..n {
inv[i * n + j] = aug[i * 2 * n + n + j];
}
}
Some(inv)
}
fn state_transition() -> [f32; STATE_DIM * STATE_DIM] {
let mut f = [0.0f32; STATE_DIM * STATE_DIM];
for i in 0..STATE_DIM {
f[i * STATE_DIM + i] = 1.0; }
for i in 0..MEAS_DIM {
f[i * STATE_DIM + i + MEAS_DIM] = 1.0;
}
f
}
fn meas_matrix() -> [f32; MEAS_DIM * STATE_DIM] {
let mut h = [0.0f32; MEAS_DIM * STATE_DIM];
for i in 0..MEAS_DIM {
h[i * STATE_DIM + i] = 1.0;
}
h
}
fn process_noise_q() -> [f32; STATE_DIM * STATE_DIM] {
let mut q = [0.0f32; STATE_DIM * STATE_DIM];
let pos_noise: [f32; 4] = [1.0, 1.0, 0.01, 1.0];
let vel_noise: [f32; 4] = [0.01, 0.01, 1e-5, 0.01];
for i in 0..MEAS_DIM {
q[i * STATE_DIM + i] = pos_noise[i];
q[(i + MEAS_DIM) * STATE_DIM + (i + MEAS_DIM)] = vel_noise[i];
}
q
}
fn meas_noise_r() -> [f32; MEAS_DIM * MEAS_DIM] {
let mut r = [0.0f32; MEAS_DIM * MEAS_DIM];
let diag: [f32; 4] = [1.0, 1.0, 0.01, 1.0];
for i in 0..MEAS_DIM {
r[i * MEAS_DIM + i] = diag[i];
}
r
}
#[derive(Debug, Clone)]
pub struct KalmanTrack {
pub track_id: u32,
pub age: u32,
pub hits: u32,
pub time_since_update: u32,
state: [f32; STATE_DIM],
covariance: [f32; STATE_DIM * STATE_DIM],
}
impl KalmanTrack {
#[must_use]
pub fn new(bbox: BBox, track_id: u32) -> Self {
let xyah = bbox.to_xyah();
let mut state = [0.0f32; STATE_DIM];
state[0] = xyah[0];
state[1] = xyah[1];
state[2] = xyah[2];
state[3] = xyah[3];
let mut cov = [0.0f32; STATE_DIM * STATE_DIM];
let pos_var: [f32; 4] = [10.0, 10.0, 0.01, 10.0];
let vel_var: [f32; 4] = [100.0, 100.0, 0.0001, 100.0];
for i in 0..MEAS_DIM {
cov[i * STATE_DIM + i] = pos_var[i];
cov[(i + MEAS_DIM) * STATE_DIM + (i + MEAS_DIM)] = vel_var[i];
}
Self {
track_id,
age: 1,
hits: 1,
time_since_update: 0,
state,
covariance: cov,
}
}
pub fn predict(&mut self) -> BBox {
let f = state_transition();
let q = process_noise_q();
let new_state = mat_vec::<STATE_DIM, STATE_DIM>(&f, &self.state);
self.state = new_state;
let fp = mat_mul_sq::<STATE_DIM, { STATE_DIM * STATE_DIM }>(&f, &self.covariance);
let ft = {
let ft_v = mat_t::<STATE_DIM, STATE_DIM>(&f);
let mut arr = [0.0f32; STATE_DIM * STATE_DIM];
arr.copy_from_slice(&ft_v);
arr
};
let fpft = mat_mul_sq::<STATE_DIM, { STATE_DIM * STATE_DIM }>(&fp, &ft);
self.covariance = mat_add_n::<{ STATE_DIM * STATE_DIM }>(&fpft, &q);
self.age += 1;
self.time_since_update += 1;
self.bbox()
}
pub fn update(&mut self, bbox: BBox) {
let z = bbox.to_xyah();
let h = meas_matrix();
let r = meas_noise_r();
let hx = mat_vec::<MEAS_DIM, STATE_DIM>(&h, &self.state);
let mut innov = [0.0f32; MEAS_DIM];
for i in 0..MEAS_DIM {
innov[i] = z[i] - hx[i];
}
let hp = mat_mul_rck::<MEAS_DIM, STATE_DIM, STATE_DIM>(&h, &self.covariance);
let ht = mat_t::<MEAS_DIM, STATE_DIM>(&h);
let hpht = mat_mul_rck::<MEAS_DIM, STATE_DIM, MEAS_DIM>(&hp, &ht);
let mut s = [0.0f32; MEAS_DIM * MEAS_DIM];
for i in 0..(MEAS_DIM * MEAS_DIM) {
s[i] = hpht[i] + r[i];
}
let pht = mat_mul_rck::<STATE_DIM, STATE_DIM, MEAS_DIM>(&self.covariance, &ht);
let s_inv = match mat_inv_small(&s, MEAS_DIM) {
Some(inv) => inv,
None => {
self.state[0] = z[0];
self.state[1] = z[1];
self.state[2] = z[2];
self.state[3] = z[3];
self.hits += 1;
self.time_since_update = 0;
return;
}
};
let k = mat_mul_rck::<STATE_DIM, MEAS_DIM, MEAS_DIM>(&pht, &s_inv);
for i in 0..STATE_DIM {
let mut sum = 0.0f32;
for j in 0..MEAS_DIM {
sum += k[i * MEAS_DIM + j] * innov[j];
}
self.state[i] += sum;
}
let kh = mat_mul_rck::<STATE_DIM, MEAS_DIM, STATE_DIM>(&k, &h);
let mut i_kh = [0.0f32; STATE_DIM * STATE_DIM];
for i in 0..STATE_DIM {
i_kh[i * STATE_DIM + i] = 1.0;
}
for i in 0..(STATE_DIM * STATE_DIM) {
i_kh[i] -= kh[i];
}
let new_p_v = mat_mul_rck::<STATE_DIM, STATE_DIM, STATE_DIM>(&i_kh, &self.covariance);
self.covariance.copy_from_slice(&new_p_v);
self.hits += 1;
self.time_since_update = 0;
}
#[must_use]
pub fn bbox(&self) -> BBox {
let xyah = [self.state[0], self.state[1], self.state[2], self.state[3]];
BBox::from_xyah(&xyah)
}
#[must_use]
pub fn is_confirmed(&self, min_hits: u32) -> bool {
self.hits >= min_hits
}
#[must_use]
pub fn is_dead(&self, max_age: u32) -> bool {
self.time_since_update > max_age
}
}
#[derive(Debug, Clone)]
pub struct TrackedObject {
pub track_id: u32,
pub bbox: BBox,
pub confidence: f32,
pub age: u32,
pub is_confirmed: bool,
}
#[derive(Debug, Clone)]
pub struct SortTrackerV2 {
tracks: Vec<KalmanTrack>,
next_id: u32,
pub max_age: u32,
pub min_hits: u32,
pub iou_threshold: f32,
}
impl SortTrackerV2 {
#[must_use]
pub fn new(max_age: u32, min_hits: u32, iou_threshold: f32) -> Self {
Self {
tracks: Vec::new(),
next_id: 1,
max_age,
min_hits,
iou_threshold,
}
}
#[must_use]
pub fn default_params() -> Self {
Self::new(1, 3, 0.3)
}
pub fn update(&mut self, detections: &[BBox]) -> Vec<TrackedObject> {
let mut predicted_bboxes: Vec<BBox> = self.tracks.iter_mut().map(|t| t.predict()).collect();
let (matched, unmatched_tracks, unmatched_dets) =
self.associate(&predicted_bboxes, detections);
for (track_idx, det_idx) in &matched {
self.tracks[*track_idx].update(detections[*det_idx]);
}
for &det_idx in &unmatched_dets {
let track = KalmanTrack::new(detections[det_idx], self.next_id);
self.next_id += 1;
self.tracks.push(track);
}
let max_age = self.max_age;
self.tracks.retain(|t| !t.is_dead(max_age));
let min_hits = self.min_hits;
self.tracks
.iter()
.filter(|t| t.is_confirmed(min_hits) || t.time_since_update == 0)
.map(|t| TrackedObject {
track_id: t.track_id,
bbox: t.bbox(),
confidence: 1.0 / (1.0 + t.time_since_update as f32),
age: t.age,
is_confirmed: t.is_confirmed(min_hits),
})
.collect()
}
fn associate(
&self,
predicted: &[BBox],
detections: &[BBox],
) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>) {
if predicted.is_empty() {
let unmatched_dets = (0..detections.len()).collect();
return (Vec::new(), Vec::new(), unmatched_dets);
}
if detections.is_empty() {
let unmatched_tracks = (0..predicted.len()).collect();
return (Vec::new(), unmatched_tracks, Vec::new());
}
let n_t = predicted.len();
let n_d = detections.len();
let mut cost_f64 = vec![vec![1.0f64; n_d]; n_t];
for (i, pb) in predicted.iter().enumerate() {
for (j, db) in detections.iter().enumerate() {
let iou = pb.iou(db);
cost_f64[i][j] = 1.0 - iou as f64;
}
}
let assignments = crate::tracking::assignment::hungarian_algorithm(&cost_f64);
let filtered = crate::tracking::assignment::filter_assignments_by_cost(
&assignments,
&cost_f64,
1.0 - self.iou_threshold as f64,
);
let mut matched = Vec::new();
let mut unmatched_tracks = Vec::new();
let mut det_used = vec![false; n_d];
for (t_idx, assignment) in filtered.iter().enumerate() {
if let Some(d_idx) = assignment {
matched.push((t_idx, *d_idx));
det_used[*d_idx] = true;
} else {
unmatched_tracks.push(t_idx);
}
}
let unmatched_dets: Vec<usize> = (0..n_d).filter(|&i| !det_used[i]).collect();
(matched, unmatched_tracks, unmatched_dets)
}
#[must_use]
pub fn active_tracks(&self) -> Vec<&KalmanTrack> {
self.tracks.iter().collect()
}
#[must_use]
pub fn track_count(&self) -> usize {
self.tracks.len()
}
pub fn reset(&mut self) {
self.tracks.clear();
self.next_id = 1;
}
}
impl Default for SortTrackerV2 {
fn default() -> Self {
Self::default_params()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bbox_area_zero_for_degenerate() {
let b = BBox::new(5.0, 5.0, 5.0, 5.0);
assert_eq!(b.area(), 0.0);
}
#[test]
fn test_bbox_area_positive() {
let b = BBox::new(0.0, 0.0, 4.0, 3.0);
assert!((b.area() - 12.0).abs() < 1e-5);
}
#[test]
fn test_bbox_center() {
let b = BBox::new(0.0, 0.0, 10.0, 10.0);
let (cx, cy) = b.center();
assert!((cx - 5.0).abs() < 1e-5);
assert!((cy - 5.0).abs() < 1e-5);
}
#[test]
fn test_bbox_iou_identical() {
let b = BBox::new(0.0, 0.0, 10.0, 10.0);
assert!((b.iou(&b) - 1.0).abs() < 1e-5);
}
#[test]
fn test_bbox_iou_no_overlap() {
let a = BBox::new(0.0, 0.0, 5.0, 5.0);
let b = BBox::new(10.0, 10.0, 15.0, 15.0);
assert_eq!(a.iou(&b), 0.0);
}
#[test]
fn test_bbox_iou_partial_overlap() {
let a = BBox::new(0.0, 0.0, 10.0, 10.0);
let b = BBox::new(5.0, 5.0, 15.0, 15.0);
let iou = a.iou(&b);
let expected = 25.0 / 175.0;
assert!((iou - expected).abs() < 1e-4);
}
#[test]
fn test_bbox_to_xyah_roundtrip() {
let b = BBox::new(10.0, 20.0, 50.0, 80.0);
let xyah = b.to_xyah();
let b2 = BBox::from_xyah(&xyah);
assert!((b.x1 - b2.x1).abs() < 1e-3);
assert!((b.y1 - b2.y1).abs() < 1e-3);
assert!((b.x2 - b2.x2).abs() < 1e-3);
assert!((b.y2 - b2.y2).abs() < 1e-3);
}
#[test]
fn test_bbox_new_clamps_order() {
let b = BBox::new(10.0, 10.0, 5.0, 5.0);
assert!(b.x1 <= b.x2);
assert!(b.y1 <= b.y2);
}
#[test]
fn test_kalman_track_new_bbox() {
let bbox = BBox::new(100.0, 100.0, 200.0, 200.0);
let track = KalmanTrack::new(bbox, 1);
let estimated = track.bbox();
let (cx, cy) = estimated.center();
assert!((cx - 150.0).abs() < 1.0);
assert!((cy - 150.0).abs() < 1.0);
}
#[test]
fn test_kalman_track_predict_increments_age() {
let bbox = BBox::new(0.0, 0.0, 50.0, 50.0);
let mut track = KalmanTrack::new(bbox, 1);
let age_before = track.age;
track.predict();
assert_eq!(track.age, age_before + 1);
}
#[test]
fn test_kalman_track_update_resets_time_since_update() {
let bbox = BBox::new(0.0, 0.0, 50.0, 50.0);
let mut track = KalmanTrack::new(bbox, 1);
track.predict(); assert_eq!(track.time_since_update, 1);
track.update(bbox); assert_eq!(track.time_since_update, 0);
}
#[test]
fn test_kalman_track_is_dead() {
let bbox = BBox::new(0.0, 0.0, 50.0, 50.0);
let mut track = KalmanTrack::new(bbox, 42);
assert!(!track.is_dead(1));
track.predict(); assert!(!track.is_dead(1)); track.predict(); assert!(track.is_dead(1)); }
#[test]
fn test_kalman_track_is_confirmed() {
let bbox = BBox::new(0.0, 0.0, 50.0, 50.0);
let mut track = KalmanTrack::new(bbox, 1); assert!(!track.is_confirmed(3));
track.update(bbox); assert!(!track.is_confirmed(3));
track.update(bbox); assert!(track.is_confirmed(3));
}
#[test]
fn test_sort_tracker_empty_detections() {
let mut tracker = SortTrackerV2::new(1, 1, 0.3);
let tracks = tracker.update(&[]);
assert!(tracks.is_empty());
}
#[test]
fn test_sort_tracker_single_detection() {
let mut tracker = SortTrackerV2::new(5, 1, 0.3);
let dets = vec![BBox::new(0.0, 0.0, 100.0, 100.0)];
let tracks = tracker.update(&dets);
assert_eq!(tracks.len(), 1);
assert_eq!(tracks[0].track_id, 1);
}
#[test]
fn test_sort_tracker_consistent_id_across_frames() {
let mut tracker = SortTrackerV2::new(5, 1, 0.3);
let bbox = BBox::new(100.0, 100.0, 150.0, 150.0);
let t1 = tracker.update(&[bbox]);
assert_eq!(t1.len(), 1);
let id = t1[0].track_id;
let t2 = tracker.update(&[bbox]);
assert_eq!(t2.len(), 1);
assert_eq!(t2[0].track_id, id);
}
#[test]
fn test_sort_tracker_new_id_for_non_overlapping() {
let mut tracker = SortTrackerV2::new(5, 1, 0.3);
let bbox1 = BBox::new(0.0, 0.0, 50.0, 50.0);
tracker.update(&[bbox1]);
let bbox2 = BBox::new(500.0, 500.0, 550.0, 550.0);
let tracks = tracker.update(&[bbox2]);
let ids: Vec<u32> = tracks.iter().map(|t| t.track_id).collect();
assert!(ids.contains(&2));
}
#[test]
fn test_sort_tracker_dead_track_removed() {
let mut tracker = SortTrackerV2::new(1, 1, 0.3); let bbox = BBox::new(0.0, 0.0, 50.0, 50.0);
tracker.update(&[bbox]);
tracker.update(&[]);
tracker.update(&[]);
let tracks = tracker.update(&[]);
assert!(tracks.is_empty());
}
#[test]
fn test_sort_tracker_track_count() {
let mut tracker = SortTrackerV2::new(5, 1, 0.3);
let dets = vec![
BBox::new(0.0, 0.0, 50.0, 50.0),
BBox::new(200.0, 200.0, 250.0, 250.0),
];
tracker.update(&dets);
assert_eq!(tracker.track_count(), 2);
}
#[test]
fn test_sort_tracker_reset() {
let mut tracker = SortTrackerV2::new(5, 1, 0.3);
tracker.update(&[BBox::new(0.0, 0.0, 50.0, 50.0)]);
tracker.reset();
assert_eq!(tracker.track_count(), 0);
}
#[test]
fn test_sort_tracker_active_tracks() {
let mut tracker = SortTrackerV2::new(5, 1, 0.3);
tracker.update(&[BBox::new(0.0, 0.0, 100.0, 100.0)]);
assert_eq!(tracker.active_tracks().len(), 1);
}
}