#![allow(dead_code)]
use std::collections::VecDeque;
use crate::detect::BoundingBox;
use crate::tracking::assignment::{filter_assignments_by_cost, hungarian_algorithm};
use crate::tracking::kalman::KalmanFilter;
#[derive(Debug, Clone)]
pub struct AppearanceFeature(pub [f32; 128]);
impl AppearanceFeature {
#[must_use]
pub fn norm(&self) -> f32 {
self.0.iter().map(|v| v * v).sum::<f32>().sqrt()
}
#[must_use]
pub fn normalized(&self) -> Self {
let n = self.norm().max(1e-12);
let mut out = [0.0f32; 128];
for (o, &v) in out.iter_mut().zip(self.0.iter()) {
*o = v / n;
}
Self(out)
}
#[must_use]
pub fn cosine_distance(&self, other: &Self) -> f32 {
let dot: f32 = self.0.iter().zip(other.0.iter()).map(|(a, b)| a * b).sum();
let na = self.norm().max(1e-12);
let nb = other.norm().max(1e-12);
1.0 - dot / (na * nb)
}
}
impl Default for AppearanceFeature {
fn default() -> Self {
Self([0.0f32; 128])
}
}
#[derive(Debug, Clone)]
pub struct KalmanState {
pub mean: Vec<f64>,
pub covariance: Vec<f64>,
}
impl KalmanState {
fn new() -> Self {
Self {
mean: vec![0.0; 8],
covariance: vec![0.0; 64],
}
}
}
const MAX_FEATURE_HISTORY: usize = 100;
#[derive(Debug, Clone)]
pub struct DeepSortTrack {
pub id: u64,
pub bbox: BoundingBox,
pub kalman_state: KalmanState,
pub features: VecDeque<AppearanceFeature>,
pub age: u32,
pub hits: u32,
pub hit_streak: u32,
pub time_since_update: u32,
kalman: KalmanFilter,
}
impl DeepSortTrack {
fn new(id: u64, bbox: &BoundingBox, feature: &AppearanceFeature) -> Self {
let mut kalman = build_deep_sort_kalman();
let state = bbox_to_kalman_state(bbox);
kalman
.set_state(state)
.expect("DeepSORT state dim is always 8");
let ks = KalmanState {
mean: kalman.state().to_vec(),
covariance: kalman.covariance().to_vec(),
};
let mut features = VecDeque::with_capacity(MAX_FEATURE_HISTORY);
features.push_back(feature.clone());
Self {
id,
bbox: *bbox,
kalman_state: ks,
features,
age: 1,
hits: 1,
hit_streak: 1,
time_since_update: 0,
kalman,
}
}
fn predict(&mut self) -> BoundingBox {
let state = self.kalman.predict();
self.age += 1;
self.time_since_update += 1;
kalman_state_to_bbox(&state)
}
fn update(&mut self, bbox: &BoundingBox, feature: &AppearanceFeature) {
let meas = bbox_to_measurement(bbox);
let _ = self.kalman.update(&meas);
self.kalman_state = KalmanState {
mean: self.kalman.state().to_vec(),
covariance: self.kalman.covariance().to_vec(),
};
self.bbox = *bbox;
self.hits += 1;
self.hit_streak += 1;
self.time_since_update = 0;
if self.features.len() >= MAX_FEATURE_HISTORY {
self.features.pop_front();
}
self.features.push_back(feature.clone());
}
fn min_appearance_distance(&self, query: &AppearanceFeature) -> f32 {
self.features
.iter()
.map(|f| f.cosine_distance(query))
.fold(f32::MAX, f32::min)
}
}
#[derive(Debug, Clone)]
pub struct DeepSortConfig {
pub max_age: u32,
pub min_hits: u32,
pub max_iou_distance: f64,
pub max_appearance_distance: f32,
pub appearance_weight: f32,
}
impl Default for DeepSortConfig {
fn default() -> Self {
Self {
max_age: 30,
min_hits: 3,
max_iou_distance: 0.7,
max_appearance_distance: 0.4,
appearance_weight: 0.5,
}
}
}
#[derive(Debug)]
pub struct DeepSortTracker {
cfg: DeepSortConfig,
tracks: Vec<DeepSortTrack>,
next_id: u64,
}
impl DeepSortTracker {
#[must_use]
pub fn new(cfg: DeepSortConfig) -> Self {
Self {
cfg,
tracks: Vec::new(),
next_id: 1,
}
}
#[must_use]
pub fn update(
&mut self,
detections: &[BoundingBox],
features: &[AppearanceFeature],
) -> Vec<DeepSortTrack> {
let predicted_bboxes: Vec<BoundingBox> =
self.tracks.iter_mut().map(|t| t.predict()).collect();
let n_tracks = self.tracks.len();
let n_dets = detections.len();
let cost_matrix: Vec<Vec<f64>> = if n_tracks == 0 || n_dets == 0 {
vec![]
} else {
let lam = self.cfg.appearance_weight as f64;
let mut mat = vec![vec![0.0_f64; n_dets]; n_tracks];
for (ti, pred_bb) in predicted_bboxes.iter().enumerate() {
for (di, det_bb) in detections.iter().enumerate() {
let iou_cost = iou_distance(pred_bb, det_bb);
let app_feat = features.get(di).cloned().unwrap_or_default();
let app_cost = self.tracks[ti].min_appearance_distance(&app_feat) as f64;
mat[ti][di] = lam * app_cost + (1.0 - lam) * iou_cost;
}
}
mat
};
let max_cost = self.cfg.appearance_weight as f64 * self.cfg.max_appearance_distance as f64
+ (1.0 - self.cfg.appearance_weight as f64) * self.cfg.max_iou_distance;
let raw_assignments = if cost_matrix.is_empty() {
vec![None; n_tracks]
} else {
let raw = hungarian_algorithm(&cost_matrix);
filter_assignments_by_cost(&raw, &cost_matrix, max_cost)
};
let mut det_matched = vec![false; n_dets];
for (ti, assignment) in raw_assignments.iter().enumerate() {
if let Some(di) = assignment {
let feat = features.get(*di).cloned().unwrap_or_default();
self.tracks[ti].update(&detections[*di], &feat);
det_matched[*di] = true;
} else {
self.tracks[ti].hit_streak = 0;
}
}
for (di, det) in detections.iter().enumerate() {
if !det_matched[di] {
let feat = features.get(di).cloned().unwrap_or_default();
let new_track = DeepSortTrack::new(self.next_id, det, &feat);
self.next_id += 1;
self.tracks.push(new_track);
}
}
self.tracks
.retain(|t| t.time_since_update <= self.cfg.max_age);
self.tracks
.iter()
.filter(|t| t.hits >= self.cfg.min_hits || t.time_since_update == 0)
.cloned()
.collect()
}
#[must_use]
pub fn all_tracks(&self) -> &[DeepSortTrack] {
&self.tracks
}
#[must_use]
pub fn track_count(&self) -> usize {
self.tracks.len()
}
}
fn build_deep_sort_kalman() -> KalmanFilter {
let mut kf = KalmanFilter::new(8, 4);
let dt = 1.0_f64;
let mut f = vec![0.0_f64; 64];
for i in 0..8 {
f[i * 8 + i] = 1.0;
}
for i in 0..4 {
f[i * 8 + (i + 4)] = dt;
}
kf.transition = f;
let mut h = vec![0.0_f64; 32];
for i in 0..4 {
h[i * 8 + i] = 1.0;
}
kf.measurement = h;
let mut q = vec![0.0_f64; 64];
let pos_noise = 1.0_f64;
let vel_noise = 0.01_f64;
for i in 0..4 {
q[i * 8 + i] = pos_noise;
q[(i + 4) * 8 + (i + 4)] = vel_noise;
}
kf.process_noise = q;
let mut r = vec![0.0_f64; 16];
let meas_noise = [1.0, 1.0, 0.01, 1.0]; for i in 0..4 {
r[i * 4 + i] = meas_noise[i];
}
kf.measurement_noise = r;
let mut p = vec![0.0_f64; 64];
let p_vals = [10.0, 10.0, 1e-2, 10.0, 1e4, 1e4, 1e-5, 1e4];
for i in 0..8 {
p[i * 8 + i] = p_vals[i];
}
kf.covariance = p;
kf
}
fn bbox_to_kalman_state(bbox: &BoundingBox) -> Vec<f64> {
let cx = bbox.x as f64 + bbox.width as f64 / 2.0;
let cy = bbox.y as f64 + bbox.height as f64 / 2.0;
let h = bbox.height.max(1.0) as f64;
let a = bbox.width as f64 / h;
vec![cx, cy, a, h, 0.0, 0.0, 0.0, 0.0]
}
fn kalman_state_to_bbox(state: &[f64]) -> BoundingBox {
if state.len() < 4 {
return BoundingBox::new(0.0, 0.0, 1.0, 1.0);
}
let cx = state[0];
let cy = state[1];
let a = state[2].max(1e-6);
let h = state[3].max(1.0);
let w = a * h;
BoundingBox::new(
(cx - w / 2.0) as f32,
(cy - h / 2.0) as f32,
w as f32,
h as f32,
)
}
fn bbox_to_measurement(bbox: &BoundingBox) -> Vec<f64> {
let cx = bbox.x as f64 + bbox.width as f64 / 2.0;
let cy = bbox.y as f64 + bbox.height as f64 / 2.0;
let h = bbox.height.max(1.0) as f64;
let a = bbox.width as f64 / h;
vec![cx, cy, a, h]
}
fn iou_distance(a: &BoundingBox, b: &BoundingBox) -> f64 {
let ax2 = a.x + a.width;
let ay2 = a.y + a.height;
let bx2 = b.x + b.width;
let by2 = b.y + b.height;
let ix1 = a.x.max(b.x);
let iy1 = a.y.max(b.y);
let ix2 = ax2.min(bx2);
let iy2 = ay2.min(by2);
let iw = (ix2 - ix1).max(0.0);
let ih = (iy2 - iy1).max(0.0);
let inter = iw * ih;
let area_a = a.width * a.height;
let area_b = b.width * b.height;
let union = area_a + area_b - inter;
if union <= 0.0 {
1.0
} else {
1.0 - (inter / union) as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_bbox(x: f32, y: f32, w: f32, h: f32) -> BoundingBox {
BoundingBox::new(x, y, w, h)
}
fn zero_feat() -> AppearanceFeature {
AppearanceFeature([0.0f32; 128])
}
fn unit_feat(val: f32) -> AppearanceFeature {
AppearanceFeature([val; 128])
}
#[test]
fn test_feature_norm_zero() {
let f = zero_feat();
assert!((f.norm() - 0.0).abs() < 1e-6);
}
#[test]
fn test_feature_norm_unit() {
let mut arr = [0.0f32; 128];
arr[0] = 1.0;
let f = AppearanceFeature(arr);
assert!((f.norm() - 1.0).abs() < 1e-5);
}
#[test]
fn test_feature_normalized_unit_length() {
let f = unit_feat(1.0).normalized();
assert!((f.norm() - 1.0).abs() < 1e-4);
}
#[test]
fn test_cosine_distance_identical() {
let f = unit_feat(1.0);
let d = f.cosine_distance(&f);
assert!(d.abs() < 1e-4, "identical features should have distance ~0");
}
#[test]
fn test_cosine_distance_orthogonal() {
let mut a = [0.0f32; 128];
let mut b = [0.0f32; 128];
a[0] = 1.0;
b[1] = 1.0;
let fa = AppearanceFeature(a);
let fb = AppearanceFeature(b);
let d = fa.cosine_distance(&fb);
assert!(
(d - 1.0).abs() < 1e-4,
"orthogonal features dist should be ~1"
);
}
#[test]
fn test_cosine_distance_zero_vector() {
let fa = zero_feat();
let fb = unit_feat(1.0);
let d = fa.cosine_distance(&fb);
assert!(d.is_finite());
}
#[test]
fn test_bbox_to_kalman_round_trip() {
let bbox = make_bbox(10.0, 20.0, 60.0, 80.0);
let state = bbox_to_kalman_state(&bbox);
assert_eq!(state.len(), 8);
let back = kalman_state_to_bbox(&state);
assert!((back.x - bbox.x).abs() < 1.0);
assert!((back.y - bbox.y).abs() < 1.0);
assert!((back.width - bbox.width).abs() < 1.0);
assert!((back.height - bbox.height).abs() < 1.0);
}
#[test]
fn test_kalman_state_to_bbox_short_state() {
let state = vec![50.0, 50.0];
let bb = kalman_state_to_bbox(&state);
assert!(bb.width >= 0.0);
}
#[test]
fn test_iou_distance_full_overlap() {
let a = make_bbox(0.0, 0.0, 10.0, 10.0);
let b = make_bbox(0.0, 0.0, 10.0, 10.0);
let d = iou_distance(&a, &b);
assert!(d.abs() < 1e-4, "identical boxes distance={d}");
}
#[test]
fn test_iou_distance_no_overlap() {
let a = make_bbox(0.0, 0.0, 10.0, 10.0);
let b = make_bbox(100.0, 100.0, 10.0, 10.0);
let d = iou_distance(&a, &b);
assert!((d - 1.0).abs() < 1e-4, "non-overlapping distance={d}");
}
#[test]
fn test_track_new() {
let bbox = make_bbox(10.0, 10.0, 50.0, 50.0);
let feat = unit_feat(0.5);
let track = DeepSortTrack::new(1, &bbox, &feat);
assert_eq!(track.id, 1);
assert_eq!(track.hits, 1);
assert_eq!(track.hit_streak, 1);
assert_eq!(track.age, 1);
assert_eq!(track.features.len(), 1);
}
#[test]
fn test_track_predict_increases_age() {
let bbox = make_bbox(0.0, 0.0, 40.0, 40.0);
let mut track = DeepSortTrack::new(1, &bbox, &zero_feat());
track.predict();
assert_eq!(track.age, 2);
}
#[test]
fn test_track_update_increments_hits() {
let bbox = make_bbox(0.0, 0.0, 40.0, 40.0);
let mut track = DeepSortTrack::new(1, &bbox, &zero_feat());
track.update(&bbox, &zero_feat());
assert_eq!(track.hits, 2);
assert_eq!(track.time_since_update, 0);
}
#[test]
fn test_track_feature_history_cap() {
let bbox = make_bbox(0.0, 0.0, 20.0, 20.0);
let mut track = DeepSortTrack::new(1, &bbox, &zero_feat());
for _ in 0..MAX_FEATURE_HISTORY + 10 {
track.update(&bbox, &zero_feat());
}
assert!(track.features.len() <= MAX_FEATURE_HISTORY);
}
#[test]
fn test_track_min_appearance_distance() {
let bbox = make_bbox(0.0, 0.0, 20.0, 20.0);
let feat_a = unit_feat(1.0);
let track = DeepSortTrack::new(1, &bbox, &feat_a);
let d = track.min_appearance_distance(&feat_a);
assert!(d < 1e-3, "identical feature min dist={d}");
}
#[test]
fn test_tracker_empty_input() {
let mut t = DeepSortTracker::new(DeepSortConfig::default());
let tracks = t.update(&[], &[]);
assert!(tracks.is_empty());
}
#[test]
fn test_tracker_single_detection_confirms_after_min_hits() {
let mut cfg = DeepSortConfig::default();
cfg.min_hits = 1;
let mut t = DeepSortTracker::new(cfg);
let det = vec![make_bbox(50.0, 50.0, 60.0, 60.0)];
let feat = vec![zero_feat()];
let tracks = t.update(&det, &feat);
assert_eq!(tracks.len(), 1);
}
#[test]
fn test_tracker_track_count_grows() {
let mut cfg = DeepSortConfig::default();
cfg.min_hits = 1;
let mut t = DeepSortTracker::new(cfg);
let dets = vec![
make_bbox(10.0, 10.0, 20.0, 20.0),
make_bbox(200.0, 200.0, 20.0, 20.0),
];
let feats = vec![zero_feat(), zero_feat()];
t.update(&dets, &feats);
assert_eq!(t.track_count(), 2);
}
#[test]
fn test_tracker_persists_track_across_frames() {
let mut cfg = DeepSortConfig::default();
cfg.min_hits = 1;
cfg.max_age = 5;
let mut t = DeepSortTracker::new(cfg);
let det = vec![make_bbox(100.0, 100.0, 40.0, 40.0)];
let feat = vec![zero_feat()];
t.update(&det, &feat);
let id_first = t.all_tracks()[0].id;
let tracks2 = t.update(&det, &feat);
let id_second = tracks2[0].id;
assert_eq!(id_first, id_second, "track ID should persist");
}
#[test]
fn test_tracker_removes_stale_tracks() {
let mut cfg = DeepSortConfig::default();
cfg.min_hits = 1;
cfg.max_age = 2;
let mut t = DeepSortTracker::new(cfg);
let det = vec![make_bbox(50.0, 50.0, 30.0, 30.0)];
let feat = vec![zero_feat()];
t.update(&det, &feat);
t.update(&[], &[]);
t.update(&[], &[]);
t.update(&[], &[]);
assert_eq!(t.track_count(), 0, "stale track should be removed");
}
#[test]
fn test_tracker_new_id_per_detection() {
let mut cfg = DeepSortConfig::default();
cfg.min_hits = 1;
let mut t = DeepSortTracker::new(cfg);
let d1 = vec![make_bbox(10.0, 10.0, 20.0, 20.0)];
let d2 = vec![make_bbox(500.0, 500.0, 20.0, 20.0)];
t.update(&d1, &vec![zero_feat()]);
t.update(&d2, &vec![zero_feat()]);
let ids: Vec<u64> = t.all_tracks().iter().map(|tr| tr.id).collect();
let unique: std::collections::HashSet<u64> = ids.into_iter().collect();
assert_eq!(unique.len(), t.track_count());
}
#[test]
fn test_kalman_build_dimensions() {
let kf = build_deep_sort_kalman();
assert_eq!(kf.covariance().len(), 64);
assert_eq!(kf.state().len(), 8);
}
}