use super::assignment::{
compute_iou, create_iou_cost_matrix, filter_assignments_by_cost, hungarian_algorithm,
};
use crate::detect::BoundingBox;
use crate::error::{CvError, CvResult};
use crate::tracking::kalman::KalmanFilter;
#[derive(Debug, Clone)]
pub struct Track {
pub id: u64,
kalman: KalmanFilter,
pub bbox: BoundingBox,
pub frames_since_update: usize,
pub hits: usize,
pub hit_streak: usize,
pub age: usize,
pub confidence: f64,
}
impl Track {
fn new(id: u64, bbox: BoundingBox) -> CvResult<Self> {
let mut kalman = KalmanFilter::new(7, 4);
let dt = 1.0;
kalman.transition = vec![
1.0, 0.0, 0.0, 0.0, dt, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, dt, 0.0, 0.0, 0.0, 1.0, 0.0,
0.0, 0.0, dt, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
];
kalman.measurement = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
];
kalman.set_process_noise(0.01);
kalman.set_measurement_noise(1.0);
let cx = (bbox.x + bbox.width / 2.0) as f64;
let cy = (bbox.y + bbox.height / 2.0) as f64;
let s = (bbox.width * bbox.height) as f64;
let r = (bbox.width / bbox.height) as f64;
kalman.set_state(vec![cx, cy, s, r, 0.0, 0.0, 0.0])?;
Ok(Self {
id,
kalman,
bbox,
frames_since_update: 0,
hits: 1,
hit_streak: 1,
age: 1,
confidence: 1.0,
})
}
fn predict(&mut self) -> BoundingBox {
let state = self.kalman.predict();
self.bbox = state_to_bbox(&state);
self.age += 1;
self.bbox
}
fn update(&mut self, bbox: &BoundingBox) {
self.frames_since_update = 0;
self.hits += 1;
self.hit_streak += 1;
let measurement = bbox_to_measurement(bbox);
if let Ok(state) = self.kalman.update(&measurement) {
self.bbox = state_to_bbox(&state);
} else {
self.bbox = *bbox;
}
self.confidence = 1.0;
}
fn mark_missed(&mut self) {
self.frames_since_update += 1;
self.hit_streak = 0;
self.confidence *= 0.8;
}
fn get_state(&self) -> BoundingBox {
self.bbox
}
}
#[derive(Debug, Clone)]
pub struct SortTracker {
tracks: Vec<Track>,
next_id: u64,
max_age: usize,
min_hits: usize,
iou_threshold: f64,
}
impl SortTracker {
#[must_use]
pub fn new() -> Self {
Self {
tracks: Vec::new(),
next_id: 1,
max_age: 30,
min_hits: 3,
iou_threshold: 0.3,
}
}
#[must_use]
pub const fn with_max_age(mut self, age: usize) -> Self {
self.max_age = age;
self
}
#[must_use]
pub const fn with_min_hits(mut self, hits: usize) -> Self {
self.min_hits = hits;
self
}
#[must_use]
pub const fn with_iou_threshold(mut self, threshold: f64) -> Self {
self.iou_threshold = threshold;
self
}
pub fn update(&mut self, detections: &[BoundingBox]) -> Vec<(u64, BoundingBox, f64)> {
let mut predicted = Vec::new();
for track in &mut self.tracks {
predicted.push(track.predict());
}
let (matched, unmatched_tracks, unmatched_detections) =
self.associate_detections_to_tracks(&predicted, detections);
for (track_idx, det_idx) in matched {
self.tracks[track_idx].update(&detections[det_idx]);
}
for track_idx in unmatched_tracks {
if track_idx < self.tracks.len() {
self.tracks[track_idx].mark_missed();
}
}
for det_idx in unmatched_detections {
if let Ok(track) = Track::new(self.next_id, detections[det_idx]) {
self.tracks.push(track);
self.next_id += 1;
}
}
self.tracks
.retain(|track| track.frames_since_update < self.max_age);
self.tracks
.iter()
.filter(|track| track.hit_streak >= self.min_hits || track.frames_since_update == 0)
.map(|track| (track.id, track.get_state(), track.confidence))
.collect()
}
pub fn get_tracks(&self) -> Vec<(u64, BoundingBox, f64)> {
self.tracks
.iter()
.map(|track| (track.id, track.bbox, track.confidence))
.collect()
}
pub fn reset(&mut self) {
self.tracks.clear();
self.next_id = 1;
}
fn associate_detections_to_tracks(
&self,
tracks: &[BoundingBox],
detections: &[BoundingBox],
) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>) {
if tracks.is_empty() {
return (Vec::new(), Vec::new(), (0..detections.len()).collect());
}
if detections.is_empty() {
return (Vec::new(), (0..tracks.len()).collect(), Vec::new());
}
let cost_matrix = create_iou_cost_matrix(tracks, detections);
let assignments = hungarian_algorithm(&cost_matrix);
let max_cost = 1.0 - self.iou_threshold;
let filtered = filter_assignments_by_cost(&assignments, &cost_matrix, max_cost);
let mut matched = Vec::new();
let mut unmatched_tracks = Vec::new();
let mut detection_used = vec![false; detections.len()];
for (track_idx, assignment) in filtered.iter().enumerate() {
if let Some(det_idx) = assignment {
matched.push((track_idx, *det_idx));
detection_used[*det_idx] = true;
} else {
unmatched_tracks.push(track_idx);
}
}
let unmatched_detections: Vec<usize> = (0..detections.len())
.filter(|&i| !detection_used[i])
.collect();
(matched, unmatched_tracks, unmatched_detections)
}
}
impl Default for SortTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct DeepSortTracker {
sort: SortTracker,
feature_history: Vec<Vec<Vec<f32>>>,
max_feature_history: usize,
feature_threshold: f64,
}
impl DeepSortTracker {
#[must_use]
pub fn new() -> Self {
Self {
sort: SortTracker::new(),
feature_history: Vec::new(),
max_feature_history: 100,
feature_threshold: 0.7,
}
}
pub fn update_with_features(
&mut self,
detections: &[BoundingBox],
features: &[Vec<f32>],
) -> Vec<(u64, BoundingBox, f64)> {
self.sort.update(detections)
}
pub fn reset(&mut self) {
self.sort.reset();
self.feature_history.clear();
}
}
impl Default for DeepSortTracker {
fn default() -> Self {
Self::new()
}
}
fn bbox_to_measurement(bbox: &BoundingBox) -> Vec<f64> {
let cx = (bbox.x + bbox.width / 2.0) as f64;
let cy = (bbox.y + bbox.height / 2.0) as f64;
let s = (bbox.width * bbox.height) as f64;
let r = (bbox.width / bbox.height) as f64;
vec![cx, cy, s, r]
}
fn state_to_bbox(state: &[f64]) -> BoundingBox {
if state.len() < 4 {
return BoundingBox::new(0.0, 0.0, 1.0, 1.0);
}
let cx = state[0];
let cy = state[1];
let s = state[2].max(1.0);
let r = state[3].max(0.1);
let w = (s * r).sqrt();
let h = s / w;
BoundingBox::new(
(cx - w / 2.0) as f32,
(cy - h / 2.0) as f32,
w as f32,
h as f32,
)
}