use crate::detect::BoundingBox;
use std::collections::HashMap;
#[derive(Debug, Clone)]
struct CentroidTrack {
id: u64,
centroid: (f32, f32),
bbox: BoundingBox,
disappeared: usize,
}
#[derive(Debug, Clone)]
pub struct CentroidTracker {
tracks: HashMap<u64, CentroidTrack>,
next_id: u64,
max_disappeared: usize,
max_distance: f64,
}
impl CentroidTracker {
#[must_use]
pub fn new() -> Self {
Self {
tracks: HashMap::new(),
next_id: 1,
max_disappeared: 50,
max_distance: 100.0,
}
}
#[must_use]
pub const fn with_max_disappeared(mut self, frames: usize) -> Self {
self.max_disappeared = frames;
self
}
#[must_use]
pub const fn with_max_distance(mut self, distance: f64) -> Self {
self.max_distance = distance;
self
}
pub fn update(&mut self, detections: &[BoundingBox]) -> Vec<(u64, BoundingBox, f64)> {
if detections.is_empty() {
for track in self.tracks.values_mut() {
track.disappeared += 1;
}
self.tracks
.retain(|_, track| track.disappeared < self.max_disappeared);
return self.get_active_tracks();
}
let detection_centroids: Vec<(f32, f32)> = detections
.iter()
.map(|bbox| (bbox.x + bbox.width / 2.0, bbox.y + bbox.height / 2.0))
.collect();
if self.tracks.is_empty() {
for (i, bbox) in detections.iter().enumerate() {
self.register(detection_centroids[i], *bbox);
}
return self.get_active_tracks();
}
let track_ids: Vec<u64> = self.tracks.keys().copied().collect();
let track_centroids: Vec<(f32, f32)> = track_ids
.iter()
.map(|id| self.tracks[id].centroid)
.collect();
let distances = compute_distance_matrix(&track_centroids, &detection_centroids);
let (matched_tracks, matched_detections) = self.greedy_assignment(&distances, &track_ids);
for (track_id, det_idx) in matched_tracks.iter().zip(matched_detections.iter()) {
if let Some(track) = self.tracks.get_mut(track_id) {
track.centroid = detection_centroids[*det_idx];
track.bbox = detections[*det_idx];
track.disappeared = 0;
}
}
for track_id in &track_ids {
if !matched_tracks.contains(track_id) {
if let Some(track) = self.tracks.get_mut(track_id) {
track.disappeared += 1;
}
}
}
let mut detection_matched = vec![false; detections.len()];
for &det_idx in &matched_detections {
detection_matched[det_idx] = true;
}
for (i, &matched) in detection_matched.iter().enumerate() {
if !matched {
self.register(detection_centroids[i], detections[i]);
}
}
self.tracks
.retain(|_, track| track.disappeared < self.max_disappeared);
self.get_active_tracks()
}
fn register(&mut self, centroid: (f32, f32), bbox: BoundingBox) {
let track = CentroidTrack {
id: self.next_id,
centroid,
bbox,
disappeared: 0,
};
self.tracks.insert(self.next_id, track);
self.next_id += 1;
}
fn greedy_assignment(
&self,
distances: &[Vec<f64>],
track_ids: &[u64],
) -> (Vec<u64>, Vec<usize>) {
let mut matched_tracks = Vec::new();
let mut matched_detections = Vec::new();
if distances.is_empty() || distances[0].is_empty() {
return (matched_tracks, matched_detections);
}
let n_tracks = distances.len();
let n_detections = distances[0].len();
let mut track_used = vec![false; n_tracks];
let mut detection_used = vec![false; n_detections];
let mut pairs = Vec::new();
for (i, row) in distances.iter().enumerate() {
for (j, &dist) in row.iter().enumerate() {
if dist <= self.max_distance {
pairs.push((dist, i, j));
}
}
}
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
for (_, track_idx, det_idx) in pairs {
if !track_used[track_idx] && !detection_used[det_idx] {
matched_tracks.push(track_ids[track_idx]);
matched_detections.push(det_idx);
track_used[track_idx] = true;
detection_used[det_idx] = true;
}
}
(matched_tracks, matched_detections)
}
fn get_active_tracks(&self) -> Vec<(u64, BoundingBox, f64)> {
self.tracks
.values()
.map(|track| {
let confidence = 1.0 / (1.0 + track.disappeared as f64);
(track.id, track.bbox, confidence)
})
.collect()
}
pub fn get_all_tracks(&self) -> Vec<(u64, BoundingBox, usize)> {
self.tracks
.values()
.map(|track| (track.id, track.bbox, track.disappeared))
.collect()
}
pub fn reset(&mut self) {
self.tracks.clear();
self.next_id = 1;
}
pub fn num_tracks(&self) -> usize {
self.tracks.len()
}
}
impl Default for CentroidTracker {
fn default() -> Self {
Self::new()
}
}
fn compute_distance_matrix(centroids1: &[(f32, f32)], centroids2: &[(f32, f32)]) -> Vec<Vec<f64>> {
let mut distances = vec![vec![0.0; centroids2.len()]; centroids1.len()];
for (i, &(x1, y1)) in centroids1.iter().enumerate() {
for (j, &(x2, y2)) in centroids2.iter().enumerate() {
let dx = (x1 - x2) as f64;
let dy = (y1 - y2) as f64;
distances[i][j] = (dx * dx + dy * dy).sqrt();
}
}
distances
}