trackforge 0.1.9

A unified, high-performance computer vision tracking library.
Documentation
use crate::utils::kalman::{CovarianceMatrix, KalmanFilter, MeasurementVector, StateVector};

#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum TrackState {
    Tentative,
    Confirmed,
    Deleted,
}

#[derive(Debug, Clone)]
pub struct Track {
    pub track_id: u64,
    pub class_id: i64,
    pub hits: usize,
    pub age: usize,
    pub time_since_update: usize,
    pub state: TrackState,
    pub mean: StateVector,
    pub covariance: CovarianceMatrix,
    pub score: f32,

    /// Features (embeddings) collected during the current update cycle or while tentative.
    /// These are flushed to the metric gallery when appropriate.
    pub features: Vec<Vec<f32>>,

    _n_init: usize,
    _max_age: usize,
}

impl Track {
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        mean: StateVector,
        covariance: CovarianceMatrix,
        track_id: u64,
        class_id: i64,
        n_init: usize,
        max_age: usize,
        score: f32,
        feature: Vec<f32>,
    ) -> Self {
        Self {
            mean,
            covariance,
            track_id,
            class_id,
            hits: 1,
            age: 1,
            time_since_update: 0,
            state: TrackState::Tentative,
            score,
            features: vec![feature],
            _n_init: n_init,
            _max_age: max_age,
        }
    }

    /// Convert TLWH to (x, y, a, h)
    pub fn tlwh_to_xyah(tlwh: &[f32; 4]) -> MeasurementVector {
        let x = tlwh[0] + tlwh[2] / 2.0;
        let y = tlwh[1] + tlwh[3] / 2.0;
        let a = tlwh[2] / tlwh[3].max(1e-6);
        let h = tlwh[3];
        MeasurementVector::from_vec(vec![x, y, a, h])
    }

    /// Convert (x, y, a, h) to TLWH
    pub fn xyah_to_tlwh(state: &StateVector) -> [f32; 4] {
        let w = state[2] * state[3];
        let h = state[3];
        let x = state[0] - w / 2.0;
        let y = state[1] - h / 2.0;
        [x, y, w, h]
    }

    pub fn to_tlwh(&self) -> [f32; 4] {
        Self::xyah_to_tlwh(&self.mean)
    }

    pub fn predict(&mut self, kf: &KalmanFilter) {
        let (mean, covariance) = kf.predict(&self.mean, &self.covariance);
        self.mean = mean;
        self.covariance = covariance;
        self.age += 1;
        self.time_since_update += 1;
    }

    pub fn update(
        &mut self,
        kf: &KalmanFilter,
        detection: &MeasurementVector,
        score: f32,
        class_id: i64,
        feature: Vec<f32>,
    ) {
        let (mean, covariance) = kf.update(&self.mean, &self.covariance, detection);
        self.mean = mean;
        self.covariance = covariance;
        self.hits += 1;
        self.time_since_update = 0;
        self.score = score;
        self.class_id = class_id;
        self.features.push(feature);

        if self.state == TrackState::Tentative && self.hits >= self._n_init {
            self.state = TrackState::Confirmed;
        }
    }

    pub fn mark_missed(&mut self) {
        if self.state == TrackState::Tentative || self.time_since_update > self._max_age {
            self.state = TrackState::Deleted;
        }
    }

    pub fn is_confirmed(&self) -> bool {
        self.state == TrackState::Confirmed
    }

    pub fn is_tentative(&self) -> bool {
        self.state == TrackState::Tentative
    }

    pub fn is_deleted(&self) -> bool {
        self.state == TrackState::Deleted
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn create_track() -> Track {
        let mean = StateVector::from_vec(vec![100.0, 100.0, 0.5, 50.0, 0.0, 0.0, 0.0, 0.0]);
        let covariance = CovarianceMatrix::identity();
        Track::new(mean, covariance, 1, 0, 3, 30, 0.9, vec![1.0; 128])
    }

    #[test]
    fn test_track_initial_state() {
        let track = create_track();
        assert!(track.is_tentative());
        assert!(!track.is_confirmed());
        assert!(!track.is_deleted());
        assert_eq!(track.hits, 1);
        assert_eq!(track.age, 1);
        assert_eq!(track.time_since_update, 0);
    }

    #[test]
    fn test_track_tlwh_conversion() {
        // Test tlwh to xyah conversion
        let tlwh = [100.0, 100.0, 50.0, 100.0];
        let xyah = Track::tlwh_to_xyah(&tlwh);

        assert!((xyah[0] - 125.0).abs() < 0.01); // x center
        assert!((xyah[1] - 150.0).abs() < 0.01); // y center
        assert!((xyah[2] - 0.5).abs() < 0.01); // aspect ratio
        assert!((xyah[3] - 100.0).abs() < 0.01); // height
    }

    #[test]
    fn test_xyah_to_tlwh() {
        let mean = StateVector::from_vec(vec![125.0, 150.0, 0.5, 100.0, 0.0, 0.0, 0.0, 0.0]);
        let tlwh = Track::xyah_to_tlwh(&mean);

        assert!((tlwh[0] - 100.0).abs() < 0.01); // x
        assert!((tlwh[1] - 100.0).abs() < 0.01); // y
        assert!((tlwh[2] - 50.0).abs() < 0.01); // width
        assert!((tlwh[3] - 100.0).abs() < 0.01); // height
    }

    #[test]
    fn test_track_to_tlwh() {
        let track = create_track();
        let tlwh = track.to_tlwh();

        // Should convert from xyah format in mean
        assert_eq!(tlwh.len(), 4);
    }

    #[test]
    fn test_track_predict() {
        let mut track = create_track();
        let kf = KalmanFilter::default();

        let initial_age = track.age;
        track.predict(&kf);

        assert_eq!(track.age, initial_age + 1);
        assert_eq!(track.time_since_update, 1);
    }

    #[test]
    fn test_track_update_confirmation() {
        let mut track = create_track();
        let kf = KalmanFilter::default();
        let measurement = MeasurementVector::from_vec(vec![100.0, 100.0, 0.5, 50.0]);

        assert!(track.is_tentative());

        // Hit 2
        track.update(&kf, &measurement, 0.9, 0, vec![1.0; 128]);
        assert!(track.is_tentative());

        // Hit 3 - should confirm (n_init = 3)
        track.update(&kf, &measurement, 0.9, 0, vec![1.0; 128]);
        assert!(track.is_confirmed());
    }

    #[test]
    fn test_track_mark_missed_tentative() {
        let mut track = create_track();
        assert!(track.is_tentative());

        track.mark_missed();
        assert!(track.is_deleted());
    }

    #[test]
    fn test_track_mark_missed_confirmed() {
        let mut track = create_track();
        let kf = KalmanFilter::default();
        let measurement = MeasurementVector::from_vec(vec![100.0, 100.0, 0.5, 50.0]);

        // Confirm the track
        track.update(&kf, &measurement, 0.9, 0, vec![1.0; 128]);
        track.update(&kf, &measurement, 0.9, 0, vec![1.0; 128]);
        assert!(track.is_confirmed());

        // First miss shouldn't delete (time_since_update not > max_age)
        track.time_since_update = 0;
        track.mark_missed();
        assert!(track.is_confirmed()); // Still confirmed

        // Miss past max_age
        track.time_since_update = 31;
        track.mark_missed();
        assert!(track.is_deleted());
    }

    #[test]
    fn test_track_features_accumulate() {
        let mut track = create_track();
        let kf = KalmanFilter::default();
        let measurement = MeasurementVector::from_vec(vec![100.0, 100.0, 0.5, 50.0]);

        assert_eq!(track.features.len(), 1);

        track.update(&kf, &measurement, 0.9, 0, vec![2.0; 128]);
        assert_eq!(track.features.len(), 2);

        track.update(&kf, &measurement, 0.9, 0, vec![3.0; 128]);
        assert_eq!(track.features.len(), 3);
    }

    #[test]
    fn test_track_state_enum() {
        assert_eq!(TrackState::Tentative, TrackState::Tentative);
        assert_ne!(TrackState::Tentative, TrackState::Confirmed);
        assert_ne!(TrackState::Confirmed, TrackState::Deleted);
    }
}