use crate::utils::kalman::{CovarianceMatrix, KalmanFilter, MeasurementVector, StateVector};
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum TrackState {
Tentative,
Confirmed,
Deleted,
}
#[derive(Debug, Clone)]
pub struct Track {
pub track_id: u64,
pub class_id: i64,
pub hits: usize,
pub age: usize,
pub time_since_update: usize,
pub state: TrackState,
pub mean: StateVector,
pub covariance: CovarianceMatrix,
pub score: f32,
pub features: Vec<Vec<f32>>,
_n_init: usize,
_max_age: usize,
}
impl Track {
#[allow(clippy::too_many_arguments)]
pub fn new(
mean: StateVector,
covariance: CovarianceMatrix,
track_id: u64,
class_id: i64,
n_init: usize,
max_age: usize,
score: f32,
feature: Vec<f32>,
) -> Self {
Self {
mean,
covariance,
track_id,
class_id,
hits: 1,
age: 1,
time_since_update: 0,
state: TrackState::Tentative,
score,
features: vec![feature],
_n_init: n_init,
_max_age: max_age,
}
}
pub fn tlwh_to_xyah(tlwh: &[f32; 4]) -> MeasurementVector {
let x = tlwh[0] + tlwh[2] / 2.0;
let y = tlwh[1] + tlwh[3] / 2.0;
let a = tlwh[2] / tlwh[3].max(1e-6);
let h = tlwh[3];
MeasurementVector::from_vec(vec![x, y, a, h])
}
pub fn xyah_to_tlwh(state: &StateVector) -> [f32; 4] {
let w = state[2] * state[3];
let h = state[3];
let x = state[0] - w / 2.0;
let y = state[1] - h / 2.0;
[x, y, w, h]
}
pub fn to_tlwh(&self) -> [f32; 4] {
Self::xyah_to_tlwh(&self.mean)
}
pub fn predict(&mut self, kf: &KalmanFilter) {
let (mean, covariance) = kf.predict(&self.mean, &self.covariance);
self.mean = mean;
self.covariance = covariance;
self.age += 1;
self.time_since_update += 1;
}
pub fn update(
&mut self,
kf: &KalmanFilter,
detection: &MeasurementVector,
score: f32,
class_id: i64,
feature: Vec<f32>,
) {
let (mean, covariance) = kf.update(&self.mean, &self.covariance, detection);
self.mean = mean;
self.covariance = covariance;
self.hits += 1;
self.time_since_update = 0;
self.score = score;
self.class_id = class_id;
self.features.push(feature);
if self.state == TrackState::Tentative && self.hits >= self._n_init {
self.state = TrackState::Confirmed;
}
}
pub fn mark_missed(&mut self) {
if self.state == TrackState::Tentative || self.time_since_update > self._max_age {
self.state = TrackState::Deleted;
}
}
pub fn is_confirmed(&self) -> bool {
self.state == TrackState::Confirmed
}
pub fn is_tentative(&self) -> bool {
self.state == TrackState::Tentative
}
pub fn is_deleted(&self) -> bool {
self.state == TrackState::Deleted
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_track() -> Track {
let mean = StateVector::from_vec(vec![100.0, 100.0, 0.5, 50.0, 0.0, 0.0, 0.0, 0.0]);
let covariance = CovarianceMatrix::identity();
Track::new(mean, covariance, 1, 0, 3, 30, 0.9, vec![1.0; 128])
}
#[test]
fn test_track_initial_state() {
let track = create_track();
assert!(track.is_tentative());
assert!(!track.is_confirmed());
assert!(!track.is_deleted());
assert_eq!(track.hits, 1);
assert_eq!(track.age, 1);
assert_eq!(track.time_since_update, 0);
}
#[test]
fn test_track_tlwh_conversion() {
let tlwh = [100.0, 100.0, 50.0, 100.0];
let xyah = Track::tlwh_to_xyah(&tlwh);
assert!((xyah[0] - 125.0).abs() < 0.01); assert!((xyah[1] - 150.0).abs() < 0.01); assert!((xyah[2] - 0.5).abs() < 0.01); assert!((xyah[3] - 100.0).abs() < 0.01); }
#[test]
fn test_xyah_to_tlwh() {
let mean = StateVector::from_vec(vec![125.0, 150.0, 0.5, 100.0, 0.0, 0.0, 0.0, 0.0]);
let tlwh = Track::xyah_to_tlwh(&mean);
assert!((tlwh[0] - 100.0).abs() < 0.01); assert!((tlwh[1] - 100.0).abs() < 0.01); assert!((tlwh[2] - 50.0).abs() < 0.01); assert!((tlwh[3] - 100.0).abs() < 0.01); }
#[test]
fn test_track_to_tlwh() {
let track = create_track();
let tlwh = track.to_tlwh();
assert_eq!(tlwh.len(), 4);
}
#[test]
fn test_track_predict() {
let mut track = create_track();
let kf = KalmanFilter::default();
let initial_age = track.age;
track.predict(&kf);
assert_eq!(track.age, initial_age + 1);
assert_eq!(track.time_since_update, 1);
}
#[test]
fn test_track_update_confirmation() {
let mut track = create_track();
let kf = KalmanFilter::default();
let measurement = MeasurementVector::from_vec(vec![100.0, 100.0, 0.5, 50.0]);
assert!(track.is_tentative());
track.update(&kf, &measurement, 0.9, 0, vec![1.0; 128]);
assert!(track.is_tentative());
track.update(&kf, &measurement, 0.9, 0, vec![1.0; 128]);
assert!(track.is_confirmed());
}
#[test]
fn test_track_mark_missed_tentative() {
let mut track = create_track();
assert!(track.is_tentative());
track.mark_missed();
assert!(track.is_deleted());
}
#[test]
fn test_track_mark_missed_confirmed() {
let mut track = create_track();
let kf = KalmanFilter::default();
let measurement = MeasurementVector::from_vec(vec![100.0, 100.0, 0.5, 50.0]);
track.update(&kf, &measurement, 0.9, 0, vec![1.0; 128]);
track.update(&kf, &measurement, 0.9, 0, vec![1.0; 128]);
assert!(track.is_confirmed());
track.time_since_update = 0;
track.mark_missed();
assert!(track.is_confirmed());
track.time_since_update = 31;
track.mark_missed();
assert!(track.is_deleted());
}
#[test]
fn test_track_features_accumulate() {
let mut track = create_track();
let kf = KalmanFilter::default();
let measurement = MeasurementVector::from_vec(vec![100.0, 100.0, 0.5, 50.0]);
assert_eq!(track.features.len(), 1);
track.update(&kf, &measurement, 0.9, 0, vec![2.0; 128]);
assert_eq!(track.features.len(), 2);
track.update(&kf, &measurement, 0.9, 0, vec![3.0; 128]);
assert_eq!(track.features.len(), 3);
}
#[test]
fn test_track_state_enum() {
assert_eq!(TrackState::Tentative, TrackState::Tentative);
assert_ne!(TrackState::Tentative, TrackState::Confirmed);
assert_ne!(TrackState::Confirmed, TrackState::Deleted);
}
}