#![doc = include_str!("README.md")]
use crate::utils::kalman::{CovarianceMatrix, KalmanFilter, MeasurementVector, StateVector};
use std::collections::HashSet;
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum SortTrackState {
Tentative,
Confirmed,
Deleted,
}
#[derive(Debug, Clone)]
pub struct SortTrack {
pub tlwh: [f32; 4],
pub score: f32,
pub class_id: i64,
pub track_id: u64,
pub state: SortTrackState,
pub hits: usize,
pub time_since_update: usize,
pub age: usize,
mean: StateVector,
covariance: CovarianceMatrix,
}
impl SortTrack {
pub fn new(
tlwh: [f32; 4],
score: f32,
class_id: i64,
kf: &KalmanFilter,
track_id: u64,
) -> Self {
let measurement = Self::tlwh_to_xyah(&tlwh);
let (mean, covariance) = kf.initiate(&measurement);
Self {
tlwh,
score,
class_id,
track_id,
state: SortTrackState::Tentative,
hits: 1,
time_since_update: 0,
age: 1,
mean,
covariance,
}
}
pub fn predict(&mut self, kf: &KalmanFilter) {
let (mean, covariance) = kf.predict(&self.mean, &self.covariance);
self.mean = mean;
self.covariance = covariance;
self.age += 1;
self.time_since_update += 1;
self.tlwh = Self::xyah_to_tlwh(&self.mean);
}
fn update(&mut self, detection: &Detection, kf: &KalmanFilter) {
let measurement = Self::tlwh_to_xyah(&detection.tlwh);
let (mean, covariance) = kf.update(&self.mean, &self.covariance, &measurement);
self.mean = mean;
self.covariance = covariance;
self.tlwh = detection.tlwh;
self.score = detection.score;
self.class_id = detection.class_id;
self.hits += 1;
self.time_since_update = 0;
}
pub fn mark_deleted(&mut self) {
self.state = SortTrackState::Deleted;
}
pub fn is_confirmed(&self) -> bool {
self.state == SortTrackState::Confirmed
}
fn tlwh_to_xyah(tlwh: &[f32; 4]) -> MeasurementVector {
let x = tlwh[0] + tlwh[2] / 2.0;
let y = tlwh[1] + tlwh[3] / 2.0;
let a = tlwh[2] / tlwh[3].max(1e-6);
let h = tlwh[3];
MeasurementVector::from_vec(vec![x, y, a, h])
}
fn xyah_to_tlwh(state: &StateVector) -> [f32; 4] {
let w = state[2] * state[3];
let h = state[3];
let x = state[0] - w / 2.0;
let y = state[1] - h / 2.0;
[x, y, w, h]
}
}
#[derive(Debug, Clone)]
struct Detection {
tlwh: [f32; 4],
score: f32,
class_id: i64,
}
pub struct Sort {
tracks: Vec<SortTrack>,
max_age: usize,
min_hits: usize,
iou_threshold: f32,
kalman_filter: KalmanFilter,
next_id: u64,
}
impl Sort {
pub fn new(max_age: usize, min_hits: usize, iou_threshold: f32) -> Self {
Self {
tracks: Vec::new(),
max_age,
min_hits,
iou_threshold,
kalman_filter: KalmanFilter::default(),
next_id: 1,
}
}
pub fn update(&mut self, detections: Vec<([f32; 4], f32, i64)>) -> Vec<SortTrack> {
let detections: Vec<Detection> = detections
.into_iter()
.map(|(tlwh, score, class_id)| Detection {
tlwh,
score,
class_id,
})
.collect();
for track in &mut self.tracks {
track.predict(&self.kalman_filter);
}
let (matches, unmatched_dets, _unmatched_trks) = self.associate(&detections);
for (det_idx, trk_idx) in matches {
self.tracks[trk_idx].update(&detections[det_idx], &self.kalman_filter);
}
for det_idx in unmatched_dets {
let det = &detections[det_idx];
let new_track = SortTrack::new(
det.tlwh,
det.score,
det.class_id,
&self.kalman_filter,
self.next_id,
);
self.next_id += 1;
self.tracks.push(new_track);
}
for track in &mut self.tracks {
if track.time_since_update == 0 && track.hits >= self.min_hits {
track.state = SortTrackState::Confirmed;
}
if track.time_since_update > self.max_age {
track.mark_deleted();
}
}
self.tracks.retain(|t| t.state != SortTrackState::Deleted);
self.tracks
.iter()
.filter(|t| t.is_confirmed() && t.time_since_update == 0)
.cloned()
.collect()
}
fn associate(&self, detections: &[Detection]) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>) {
if self.tracks.is_empty() {
return (Vec::new(), (0..detections.len()).collect(), Vec::new());
}
if detections.is_empty() {
return (Vec::new(), Vec::new(), (0..self.tracks.len()).collect());
}
let track_boxes: Vec<[f32; 4]> = self.tracks.iter().map(|t| t.tlwh).collect();
let det_boxes: Vec<[f32; 4]> = detections.iter().map(|d| d.tlwh).collect();
let ious = crate::utils::geometry::iou_batch(&track_boxes, &det_boxes);
let cost_matrix: Vec<Vec<f32>> = ious
.iter()
.map(|row| row.iter().map(|&iou| 1.0 - iou).collect())
.collect();
self.linear_assignment(&cost_matrix, 1.0 - self.iou_threshold)
}
fn linear_assignment(
&self,
cost_matrix: &[Vec<f32>],
thresh: f32,
) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>) {
if cost_matrix.is_empty() {
return (Vec::new(), Vec::new(), Vec::new());
}
let rows = cost_matrix.len(); let cols = cost_matrix[0].len();
let mut matches = Vec::new();
let mut unmatched_tracks: HashSet<usize> = (0..rows).collect();
let mut unmatched_dets: HashSet<usize> = (0..cols).collect();
let mut costs: Vec<(f32, usize, usize)> = Vec::new();
for (r, row) in cost_matrix.iter().enumerate() {
for (c, &cost) in row.iter().enumerate() {
costs.push((cost, r, c));
}
}
costs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
for (cost, trk_idx, det_idx) in costs {
if cost > thresh {
continue;
}
if unmatched_tracks.contains(&trk_idx) && unmatched_dets.contains(&det_idx) {
matches.push((det_idx, trk_idx)); unmatched_tracks.remove(&trk_idx);
unmatched_dets.remove(&det_idx);
}
}
let unmatched_dets: Vec<usize> = unmatched_dets.into_iter().collect();
let unmatched_tracks: Vec<usize> = unmatched_tracks.into_iter().collect();
(matches, unmatched_dets, unmatched_tracks)
}
}
impl Default for Sort {
fn default() -> Self {
Self::new(1, 3, 0.3)
}
}
#[cfg(feature = "python")]
use pyo3::prelude::*;
#[cfg(feature = "python")]
type PyTrackingResult = (u64, [f32; 4], f32, i64);
#[cfg(feature = "python")]
#[pyclass(name = "Sort")]
pub struct PySort {
inner: Sort,
}
#[cfg(feature = "python")]
#[pymethods]
impl PySort {
#[new]
#[pyo3(signature = (max_age=1, min_hits=3, iou_threshold=0.3))]
fn new(max_age: usize, min_hits: usize, iou_threshold: f32) -> Self {
Self {
inner: Sort::new(max_age, min_hits, iou_threshold),
}
}
fn update(&mut self, detections: Vec<([f32; 4], f32, i64)>) -> PyResult<Vec<PyTrackingResult>> {
let tracks = self.inner.update(detections);
Ok(tracks
.into_iter()
.map(|t| (t.track_id, t.tlwh, t.score, t.class_id))
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sort_track_creation() {
let kf = KalmanFilter::default();
let track = SortTrack::new([10.0, 20.0, 30.0, 40.0], 0.9, 1, &kf, 1);
assert_eq!(track.tlwh, [10.0, 20.0, 30.0, 40.0]);
assert_eq!(track.score, 0.9);
assert_eq!(track.class_id, 1);
assert_eq!(track.state, SortTrackState::Tentative);
assert_eq!(track.hits, 1);
assert_eq!(track.time_since_update, 0);
}
#[test]
fn test_sort_single_detection() {
let mut tracker = Sort::new(1, 1, 0.3);
let detections = vec![([100.0, 100.0, 50.0, 100.0], 0.9, 0)];
let tracks = tracker.update(detections);
assert_eq!(tracks.len(), 1);
assert!(tracks[0].track_id > 0);
}
#[test]
fn test_sort_track_continuity() {
let mut tracker = Sort::new(1, 1, 0.3);
let det1 = vec![([100.0, 100.0, 50.0, 100.0], 0.9, 0)];
let tracks1 = tracker.update(det1);
assert_eq!(tracks1.len(), 1);
let track_id = tracks1[0].track_id;
let det2 = vec![([105.0, 105.0, 50.0, 100.0], 0.9, 0)];
let tracks2 = tracker.update(det2);
assert_eq!(tracks2.len(), 1);
assert_eq!(tracks2[0].track_id, track_id); }
#[test]
fn test_sort_min_hits() {
let mut tracker = Sort::new(1, 3, 0.3);
let det = vec![([100.0, 100.0, 50.0, 100.0], 0.9, 0)];
let tracks = tracker.update(det.clone());
assert_eq!(tracks.len(), 0);
let tracks = tracker.update(det.clone());
assert_eq!(tracks.len(), 0);
let tracks = tracker.update(det);
assert_eq!(tracks.len(), 1);
}
#[test]
fn test_sort_max_age() {
let mut tracker = Sort::new(2, 1, 0.3);
let det = vec![([100.0, 100.0, 50.0, 100.0], 0.9, 0)];
tracker.update(det);
let tracks = tracker.update(vec![]);
assert_eq!(tracks.len(), 0);
tracker.update(vec![]);
tracker.update(vec![]);
let det = vec![([100.0, 100.0, 50.0, 100.0], 0.9, 0)];
let tracks = tracker.update(det);
assert_eq!(tracks.len(), 1);
}
#[test]
fn test_sort_multiple_objects() {
let mut tracker = Sort::new(1, 1, 0.3);
let detections = vec![
([100.0, 100.0, 50.0, 100.0], 0.9, 0),
([300.0, 300.0, 50.0, 100.0], 0.85, 1),
];
let tracks = tracker.update(detections);
assert_eq!(tracks.len(), 2);
assert_ne!(tracks[0].track_id, tracks[1].track_id);
}
}