use super::assignment::{compute_iou, greedy_assignment};
use crate::detect::BoundingBox;
use std::collections::HashMap;
#[derive(Debug, Clone)]
struct IouTrack {
id: u64,
bbox: BoundingBox,
disappeared: usize,
age: usize,
hits: usize,
confidence: f64,
}
#[derive(Debug, Clone)]
pub struct IouTracker {
tracks: HashMap<u64, IouTrack>,
next_id: u64,
max_disappeared: usize,
min_iou: f64,
min_hits: usize,
}
impl IouTracker {
#[must_use]
pub fn new() -> Self {
Self {
tracks: HashMap::new(),
next_id: 1,
max_disappeared: 5,
min_iou: 0.3,
min_hits: 1,
}
}
#[must_use]
pub const fn with_max_disappeared(mut self, frames: usize) -> Self {
self.max_disappeared = frames;
self
}
#[must_use]
pub const fn with_min_iou(mut self, iou: f64) -> Self {
self.min_iou = iou;
self
}
#[must_use]
pub const fn with_min_hits(mut self, hits: usize) -> Self {
self.min_hits = hits;
self
}
pub fn update(&mut self, detections: &[BoundingBox]) -> Vec<(u64, BoundingBox, f64)> {
if detections.is_empty() {
for track in self.tracks.values_mut() {
track.disappeared += 1;
track.age += 1;
}
self.tracks
.retain(|_, track| track.disappeared < self.max_disappeared);
return self.get_confirmed_tracks();
}
if self.tracks.is_empty() {
for bbox in detections {
self.register(*bbox, 1.0);
}
return self.get_confirmed_tracks();
}
let track_ids: Vec<u64> = self.tracks.keys().copied().collect();
let track_bboxes: Vec<BoundingBox> =
track_ids.iter().map(|id| self.tracks[id].bbox).collect();
let mut cost_matrix = vec![vec![0.0; detections.len()]; track_bboxes.len()];
for (i, track_bbox) in track_bboxes.iter().enumerate() {
for (j, det_bbox) in detections.iter().enumerate() {
let iou = compute_iou(track_bbox, det_bbox);
cost_matrix[i][j] = 1.0 - iou;
}
}
let max_cost = 1.0 - self.min_iou;
let assignments = greedy_assignment(&cost_matrix, max_cost);
let mut matched_detections = vec![false; detections.len()];
for (track_idx, assignment) in assignments.iter().enumerate() {
if let Some(det_idx) = assignment {
let track_id = track_ids[track_idx];
if let Some(track) = self.tracks.get_mut(&track_id) {
track.bbox = detections[*det_idx];
track.disappeared = 0;
track.age += 1;
track.hits += 1;
track.confidence = 1.0;
matched_detections[*det_idx] = true;
}
}
}
for (track_idx, assignment) in assignments.iter().enumerate() {
if assignment.is_none() {
let track_id = track_ids[track_idx];
if let Some(track) = self.tracks.get_mut(&track_id) {
track.disappeared += 1;
track.age += 1;
track.confidence *= 0.8;
}
}
}
for (i, &matched) in matched_detections.iter().enumerate() {
if !matched {
self.register(detections[i], 1.0);
}
}
self.tracks
.retain(|_, track| track.disappeared < self.max_disappeared);
self.get_confirmed_tracks()
}
pub fn update_with_confidence(
&mut self,
detections: &[BoundingBox],
confidences: &[f64],
) -> Vec<(u64, BoundingBox, f64)> {
if detections.len() != confidences.len() {
return self.update(detections);
}
let result = self.update(detections);
let track_ids: Vec<u64> = self.tracks.keys().copied().collect();
for (i, &track_id) in track_ids.iter().enumerate() {
if let Some(track) = self.tracks.get_mut(&track_id) {
if track.age == 1 && i < confidences.len() {
track.confidence = confidences[i];
}
}
}
result
}
fn register(&mut self, bbox: BoundingBox, confidence: f64) {
let track = IouTrack {
id: self.next_id,
bbox,
disappeared: 0,
age: 1,
hits: 1,
confidence,
};
self.tracks.insert(self.next_id, track);
self.next_id += 1;
}
fn get_confirmed_tracks(&self) -> Vec<(u64, BoundingBox, f64)> {
self.tracks
.values()
.filter(|track| track.hits >= self.min_hits)
.map(|track| (track.id, track.bbox, track.confidence))
.collect()
}
pub fn get_all_tracks(&self) -> Vec<(u64, BoundingBox, f64)> {
self.tracks
.values()
.map(|track| (track.id, track.bbox, track.confidence))
.collect()
}
pub fn reset(&mut self) {
self.tracks.clear();
self.next_id = 1;
}
pub fn num_tracks(&self) -> usize {
self.tracks.len()
}
pub fn get_track(&self, id: u64) -> Option<(BoundingBox, f64)> {
self.tracks
.get(&id)
.map(|track| (track.bbox, track.confidence))
}
}
impl Default for IouTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct IouTrackerAdvanced {
base: IouTracker,
occlusion_threshold: f64,
velocities: HashMap<u64, (f32, f32)>,
}
impl IouTrackerAdvanced {
#[must_use]
pub fn new() -> Self {
Self {
base: IouTracker::new(),
occlusion_threshold: 0.8,
velocities: HashMap::new(),
}
}
pub fn update(&mut self, detections: &[BoundingBox]) -> Vec<(u64, BoundingBox, f64)> {
let tracks = self.base.update(detections);
self.update_velocities(&tracks);
tracks
}
fn update_velocities(&mut self, tracks: &[(u64, BoundingBox, f64)]) {
for &(id, bbox, _) in tracks {
self.velocities.entry(id).or_insert((0.0, 0.0));
}
}
pub fn reset(&mut self) {
self.base.reset();
self.velocities.clear();
}
}
impl Default for IouTrackerAdvanced {
fn default() -> Self {
Self::new()
}
}