use crate::{Hbb, KalmanFilterXYAH, StateCov, StateMean};
use std::sync::atomic::{AtomicUsize, Ordering};
pub(crate) static GLOBAL_TRACK_ID: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum TrackState {
New,
Tracked,
Lost,
Removed,
}
#[derive(Debug, Clone)]
pub struct STrack {
pub(crate) kalman_filter: KalmanFilterXYAH,
pub(crate) mean: StateMean,
pub(crate) covariance: StateCov,
pub(crate) hbb: Hbb,
pub(crate) state: TrackState,
pub(crate) is_activated: bool,
pub(crate) track_id: usize,
pub(crate) frame_id: usize,
pub(crate) start_frame_id: usize,
pub(crate) tracklet_len: usize,
}
impl PartialEq for STrack {
fn eq(&self, other: &Self) -> bool {
self.track_id == other.track_id
}
}
impl STrack {
pub fn new(hbb: Hbb) -> Self {
Self {
kalman_filter: KalmanFilterXYAH::default(),
mean: [0.0; 8],
covariance: [[0.0; 8]; 8],
hbb,
state: TrackState::New,
is_activated: false,
track_id: 0,
frame_id: 0,
start_frame_id: 0,
tracklet_len: 0,
}
}
#[inline(always)]
pub fn score(&self) -> f32 {
self.hbb.confidence().unwrap_or(1.0)
}
fn compute_hbb_from_mean(&self) -> Hbb {
let (cx, cy, a, h) = (self.mean[0], self.mean[1], self.mean[2], self.mean[3]);
let w = a * h;
let x = (cx - w / 2.0).max(0.0);
let y = (cy - h / 2.0).max(0.0);
Hbb::from_xywh(x, y, w, h)
}
pub fn current_hbb(&self) -> Hbb {
let mut hbb = self.compute_hbb_from_mean();
if let Some(conf) = self.hbb.confidence() {
hbb = hbb.with_confidence(conf);
}
if let Some(id) = self.hbb.id() {
hbb = hbb.with_id(id);
}
if let Some(name) = self.hbb.name() {
hbb = hbb.with_name(name);
}
if self.track_id > 0 {
hbb = hbb.with_track_id(self.track_id);
}
hbb
}
pub fn predict(&mut self) {
if self.state != TrackState::Tracked {
self.mean[7] = 0.0; }
let (new_mean, new_covariance) = self.kalman_filter.predict(&self.mean, &self.covariance);
self.mean = new_mean;
self.covariance = new_covariance;
}
pub fn update(&mut self, new_track: &STrack, frame_id: usize) {
let (cx, cy, a, h) = new_track.hbb.cxcyah();
let measurement = [cx, cy, a, h];
let (new_mean, new_covariance) =
self.kalman_filter
.update(&self.mean, &self.covariance, &measurement);
self.mean = new_mean;
self.covariance = new_covariance;
self.frame_id = frame_id;
self.tracklet_len += 1;
self.state = TrackState::Tracked;
self.is_activated = true;
self.update_hbb_from_detection(new_track);
}
fn update_hbb_from_prediction(&mut self) {
let mut hbb = self.compute_hbb_from_mean();
if let Some(conf) = self.hbb.confidence() {
hbb = hbb.with_confidence(conf);
}
if let Some(id) = self.hbb.id() {
hbb = hbb.with_id(id);
}
if let Some(name) = self.hbb.name() {
hbb = hbb.with_name(name);
}
if self.track_id > 0 {
hbb = hbb.with_track_id(self.track_id);
}
self.hbb = hbb;
}
fn update_hbb_from_detection(&mut self, detection: &STrack) {
let mut hbb = self.compute_hbb_from_mean();
if let Some(conf) = detection.hbb.confidence() {
hbb = hbb.with_confidence(conf);
}
if let Some(name) = detection.hbb.name() {
hbb = hbb.with_name(name);
}
if let Some(id) = detection.hbb.id() {
hbb = hbb.with_id(id);
}
if self.track_id > 0 {
hbb = hbb.with_track_id(self.track_id);
}
self.hbb = hbb;
}
pub fn activate(&mut self, kalman_filter: KalmanFilterXYAH, frame_id: usize) {
self.kalman_filter = kalman_filter;
self.track_id = Self::next_id();
let (cx, cy, a, h) = self.hbb.cxcyah();
let measurement = [cx, cy, a, h];
let (new_mean, new_covariance) = self.kalman_filter.initiate(&measurement);
self.mean = new_mean;
self.covariance = new_covariance;
self.tracklet_len = 0;
self.state = TrackState::Tracked;
self.is_activated = frame_id == 1;
self.frame_id = frame_id;
self.start_frame_id = frame_id;
self.update_hbb_from_prediction();
self.hbb = self.hbb.clone().with_track_id(self.track_id);
}
pub fn re_activate(&mut self, new_track: &STrack, frame_id: usize, new_id: bool) {
let (cx, cy, a, h) = new_track.hbb.cxcyah();
let measurement = [cx, cy, a, h];
let (new_mean, new_covariance) =
self.kalman_filter
.update(&self.mean, &self.covariance, &measurement);
self.mean = new_mean;
self.covariance = new_covariance;
self.tracklet_len = 0;
self.state = TrackState::Tracked;
self.is_activated = true;
self.frame_id = frame_id;
if new_id {
self.track_id = Self::next_id();
}
self.update_hbb_from_detection(new_track);
}
pub fn next_id() -> usize {
GLOBAL_TRACK_ID.fetch_add(1, Ordering::SeqCst) + 1
}
pub fn reset_id() {
GLOBAL_TRACK_ID.store(0, Ordering::SeqCst);
}
pub fn get_current_count() -> usize {
GLOBAL_TRACK_ID.load(Ordering::SeqCst)
}
pub fn end_frame(&self) -> usize {
self.frame_id
}
}