#![doc = include_str!("README.md")]
use crate::utils::kalman::{CovarianceMatrix, KalmanFilter, MeasurementVector, StateVector};
#[derive(Debug, Clone)]
pub struct STrack {
pub tlwh: [f32; 4],
pub score: f32,
pub class_id: i64,
pub track_id: u64,
pub state: TrackState,
pub is_activated: bool,
pub frame_id: usize,
pub start_frame: usize,
pub tracklet_len: usize,
pub mean: StateVector,
pub covariance: CovarianceMatrix,
}
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum TrackState {
New,
Tracked,
Lost,
Removed,
}
impl STrack {
pub fn new(tlwh: [f32; 4], score: f32, class_id: i64) -> Self {
Self {
tlwh,
score,
class_id,
track_id: 0,
state: TrackState::New,
is_activated: false,
frame_id: 0,
start_frame: 0,
tracklet_len: 0,
mean: StateVector::zeros(),
covariance: CovarianceMatrix::identity(),
}
}
pub fn activate(&mut self, kf: &KalmanFilter, frame_id: usize, track_id: u64) {
self.frame_id = frame_id;
self.start_frame = frame_id;
self.state = TrackState::Tracked;
self.is_activated = true;
self.track_id = track_id;
self.tracklet_len = 0;
let measurement = self.tlwh_to_xyah(self.tlwh);
let (mean, covariance) = kf.initiate(&measurement);
self.mean = mean;
self.covariance = covariance;
}
pub fn re_activate(
&mut self,
new_track: STrack,
frame_id: usize,
new_track_id: Option<u64>,
kf: &KalmanFilter,
) {
let measurement = self.tlwh_to_xyah(new_track.tlwh);
let (mean, covariance) = kf.update(&self.mean, &self.covariance, &measurement);
self.mean = mean;
self.covariance = covariance;
self.state = TrackState::Tracked;
self.is_activated = true;
self.frame_id = frame_id;
self.tracklet_len = 0;
self.score = new_track.score;
self.tlwh = new_track.tlwh;
if let Some(id) = new_track_id {
self.track_id = id;
}
}
pub fn update(&mut self, new_track: STrack, frame_id: usize, kf: &KalmanFilter) {
self.frame_id = frame_id;
self.tracklet_len += 1;
self.state = TrackState::Tracked;
self.is_activated = true;
self.score = new_track.score;
self.tlwh = new_track.tlwh;
let measurement = self.tlwh_to_xyah(new_track.tlwh);
let (mean, covariance) = kf.update(&self.mean, &self.covariance, &measurement);
self.mean = mean;
self.covariance = covariance;
}
pub fn predict(&mut self, kf: &KalmanFilter) {
if self.state != TrackState::Tracked {
self.mean[7] = 0.0; }
let (mean, covariance) = kf.predict(&self.mean, &self.covariance);
self.mean = mean;
self.covariance = covariance;
let tlwh = self.tlwh_from_xyah(&self.mean);
self.tlwh = tlwh; }
fn tlwh_to_xyah(&self, tlwh: [f32; 4]) -> MeasurementVector {
let x = tlwh[0] + tlwh[2] / 2.0;
let y = tlwh[1] + tlwh[3] / 2.0;
let a = tlwh[2] / tlwh[3];
let h = tlwh[3];
MeasurementVector::from_vec(vec![x, y, a, h])
}
fn tlwh_from_xyah(&self, xyah: &StateVector) -> [f32; 4] {
let w = xyah[2] * xyah[3];
let h = xyah[3];
let x = xyah[0] - w / 2.0;
let y = xyah[1] - h / 2.0;
[x, y, w, h]
}
}
pub struct ByteTrack {
tracked_stracks: Vec<STrack>,
lost_stracks: Vec<STrack>,
frame_id: usize,
buffer_size: usize,
track_thresh: f32,
match_thresh: f32,
det_thresh: f32, kalman_filter: KalmanFilter,
next_id: u64,
}
impl ByteTrack {
pub fn new(track_thresh: f32, track_buffer: usize, match_thresh: f32, det_thresh: f32) -> Self {
Self {
tracked_stracks: Vec::new(),
lost_stracks: Vec::new(),
frame_id: 0,
buffer_size: track_buffer, track_thresh,
match_thresh,
det_thresh,
kalman_filter: KalmanFilter::default(),
next_id: 1,
}
}
pub fn update(&mut self, output_results: Vec<([f32; 4], f32, i64)>) -> Vec<STrack> {
self.frame_id += 1;
let mut activated_stracks = Vec::new();
let mut refind_stracks = Vec::new();
let mut lost_stracks = Vec::new();
let detections: Vec<STrack> = output_results
.iter()
.map(|(tlwh, score, cls)| STrack::new(*tlwh, *score, *cls))
.collect();
let mut detections_high = Vec::new();
let mut detections_low = Vec::new();
for track in detections {
if track.score >= self.track_thresh {
detections_high.push(track);
} else {
detections_low.push(track);
}
}
for track in &mut self.tracked_stracks {
track.predict(&self.kalman_filter);
}
for track in &mut self.lost_stracks {
track.predict(&self.kalman_filter);
}
let mut unconfirmed = Vec::new();
let mut tracked_stracks = Vec::new();
for track in self.tracked_stracks.drain(..) {
if !track.is_activated {
unconfirmed.push(track);
} else {
tracked_stracks.push(track);
}
}
let mut strack_pool = Vec::new();
strack_pool.extend_from_slice(&tracked_stracks);
strack_pool.extend_from_slice(&self.lost_stracks);
let (matches, u_track, u_detection) =
if strack_pool.is_empty() || detections_high.is_empty() {
(
Vec::new(),
(0..strack_pool.len()).collect(),
(0..detections_high.len()).collect(),
)
} else {
let (dists, _, _) = self.iou_distance(&strack_pool, &detections_high);
self.linear_assignment(&dists, self.match_thresh) };
for (itrack, idet) in matches {
let track = &mut strack_pool[itrack];
let det = &detections_high[idet];
if track.state == TrackState::Tracked {
track.update(det.clone(), self.frame_id, &self.kalman_filter);
activated_stracks.push(track.clone());
} else {
track.re_activate(det.clone(), self.frame_id, None, &self.kalman_filter);
refind_stracks.push(track.clone());
}
}
let mut detections_second = Vec::new();
for &i in &u_detection {
detections_second.push(detections_high[i].clone());
}
let mut r_tracked_stracks = Vec::new(); for &i in &u_track {
let track = &strack_pool[i];
if track.state == TrackState::Tracked {
r_tracked_stracks.push(track.clone()); }
}
let (matches, u_track_second, _) =
if r_tracked_stracks.is_empty() || detections_low.is_empty() {
(
Vec::new(),
(0..r_tracked_stracks.len()).collect(),
(0..detections_low.len()).collect(),
)
} else {
let (dists, _, _) = self.iou_distance(&r_tracked_stracks, &detections_low);
self.linear_assignment(&dists, 0.5) };
for (itrack, idet) in matches {
let track = &mut r_tracked_stracks[itrack];
let det = &detections_low[idet];
if track.state == TrackState::Tracked {
track.update(det.clone(), self.frame_id, &self.kalman_filter);
activated_stracks.push(track.clone());
} else {
track.re_activate(det.clone(), self.frame_id, None, &self.kalman_filter);
refind_stracks.push(track.clone());
}
}
for &it in &u_track_second {
let track = &mut r_tracked_stracks[it];
if track.state != TrackState::Lost {
track.state = TrackState::Lost;
lost_stracks.push(track.clone());
}
}
for &i in &u_detection {
let det = &detections_high[i];
if det.score < self.det_thresh {
continue;
}
let mut new_track = det.clone();
let id = self.next_id;
self.next_id += 1;
new_track.activate(&self.kalman_filter, self.frame_id, id);
activated_stracks.push(new_track);
}
for &i in &u_track {
let track = &strack_pool[i];
if track.state == TrackState::Lost {
if self.frame_id - track.frame_id <= self.buffer_size {
lost_stracks.push(track.clone());
}
}
}
self.tracked_stracks = activated_stracks;
self.tracked_stracks.extend(refind_stracks);
self.lost_stracks = lost_stracks;
let mut output_stracks = Vec::new();
for track in &self.tracked_stracks {
if track.is_activated {
output_stracks.push(track.clone());
}
}
output_stracks
}
fn iou_distance(
&self,
stracks: &[STrack],
detections: &[STrack],
) -> (Vec<Vec<f32>>, Vec<usize>, Vec<usize>) {
let strack_boxes: Vec<[f32; 4]> = stracks.iter().map(|s| s.tlwh).collect();
let det_boxes: Vec<[f32; 4]> = detections.iter().map(|s| s.tlwh).collect();
let mut cost_matrix = Vec::new();
if strack_boxes.is_empty() || det_boxes.is_empty() {
return (cost_matrix, vec![], vec![]);
}
let ious = crate::utils::geometry::iou_batch(&strack_boxes, &det_boxes);
for iou_row in ious {
let mut row = Vec::new();
for iou in iou_row {
row.push(1.0 - iou);
}
cost_matrix.push(row);
}
(
cost_matrix,
(0..strack_boxes.len()).collect(),
(0..det_boxes.len()).collect(),
)
}
fn linear_assignment(
&self,
cost_matrix: &[Vec<f32>],
thresh: f32,
) -> (Vec<(usize, usize)>, Vec<usize>, Vec<usize>) {
use std::collections::HashSet;
if cost_matrix.is_empty() {
return (Vec::new(), Vec::new(), Vec::new());
}
let rows = cost_matrix.len();
let cols = cost_matrix[0].len();
let mut matches = Vec::new();
let mut unmatched_tracks = (0..rows).collect::<HashSet<_>>();
let mut unmatched_detections = (0..cols).collect::<HashSet<_>>();
let mut costs = Vec::new();
for (r, row) in cost_matrix.iter().enumerate() {
for (c, &cost) in row.iter().enumerate() {
costs.push((cost, r, c));
}
}
costs.sort_by(|a, b| a.0.total_cmp(&b.0));
for (cost, r, c) in costs {
if cost > thresh {
continue;
}
if unmatched_tracks.contains(&r) && unmatched_detections.contains(&c) {
matches.push((r, c));
unmatched_tracks.remove(&r);
unmatched_detections.remove(&c);
}
}
let u_track: Vec<usize> = unmatched_tracks.into_iter().collect();
let u_det: Vec<usize> = unmatched_detections.into_iter().collect();
(matches, u_track, u_det)
}
}
#[cfg(feature = "python")]
use pyo3::prelude::*;
#[cfg(feature = "python")]
type PyTrackingResult = (u64, [f32; 4], f32, i64);
#[cfg(feature = "python")]
#[pyclass(name = "BYTETRACK")]
pub struct PyByteTrack {
inner: ByteTrack,
}
#[cfg(feature = "python")]
#[pymethods]
impl PyByteTrack {
#[new]
#[pyo3(signature = (track_thresh=0.5, track_buffer=30, match_thresh=0.8, det_thresh=0.6))]
fn new(track_thresh: f32, track_buffer: usize, match_thresh: f32, det_thresh: f32) -> Self {
Self {
inner: ByteTrack::new(track_thresh, track_buffer, match_thresh, det_thresh),
}
}
fn update(
&mut self,
output_results: Vec<([f32; 4], f32, i64)>,
) -> PyResult<Vec<PyTrackingResult>> {
let tracks = self.inner.update(output_results);
Ok(tracks
.into_iter()
.map(|t| (t.track_id, t.tlwh, t.score, t.class_id))
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strack_init() {
let tlwh = [10.0, 10.0, 50.0, 100.0];
let score = 0.9;
let class_id = 1;
let strack = STrack::new(tlwh, score, class_id);
assert_eq!(strack.tlwh, tlwh);
assert_eq!(strack.score, score);
assert_eq!(strack.class_id, class_id);
assert_eq!(strack.state, TrackState::New);
assert!(!strack.is_activated);
}
#[test]
fn test_bytetrack_update_simple() {
let mut tracker = ByteTrack::new(0.5, 30, 0.8, 0.6);
let detection = ([10.0, 10.0, 50.0, 100.0], 0.9_f32, 0_i64);
let output = tracker.update(vec![detection]);
assert_eq!(output.len(), 1);
let track = &output[0];
let first_id = track.track_id;
assert_eq!(track.state, TrackState::Tracked);
let detection2 = ([15.0, 15.0, 50.0, 100.0], 0.9_f32, 0_i64);
let output2 = tracker.update(vec![detection2]);
assert_eq!(output2.len(), 1);
assert_eq!(output2[0].track_id, first_id); }
#[test]
fn test_bytetrack_low_conf_match() {
let mut tracker = ByteTrack::new(0.6, 30, 0.8, 0.6);
let d1 = ([10.0, 10.0, 50.0, 50.0], 0.9, 0);
let out1 = tracker.update(vec![d1]);
assert_eq!(out1.len(), 1);
let id = out1[0].track_id;
let d2 = ([12.0, 12.0, 50.0, 50.0], 0.4, 0);
let output2 = tracker.update(vec![d2]);
assert_eq!(output2.len(), 1, "Expected 1 track, got {}", output2.len());
assert_eq!(output2[0].track_id, id);
}
#[test]
fn test_bytetrack_instance_isolation() {
let mut tracker1 = ByteTrack::new(0.5, 30, 0.8, 0.6);
let mut tracker2 = ByteTrack::new(0.5, 30, 0.8, 0.6);
let det1 = vec![([100.0, 100.0, 50.0, 100.0], 0.9_f32, 0_i64)];
let tracks1 = tracker1.update(det1);
assert_eq!(tracks1.len(), 1);
assert_eq!(tracks1[0].track_id, 1);
let det2 = vec![([100.0, 100.0, 50.0, 100.0], 0.9_f32, 0_i64)];
let tracks2 = tracker2.update(det2);
assert_eq!(tracks2.len(), 1);
assert_eq!(tracks2[0].track_id, 1);
}
#[test]
fn test_bytetrack_id_sequential() {
let mut tracker = ByteTrack::new(0.5, 30, 0.8, 0.6);
let det1 = vec![([100.0, 100.0, 50.0, 100.0], 0.9_f32, 0_i64)];
let tracks1 = tracker.update(det1);
assert_eq!(tracks1[0].track_id, 1);
let det2 = vec![([200.0, 200.0, 50.0, 100.0], 0.9_f32, 1_i64)];
let tracks2 = tracker.update(det2);
assert_eq!(tracks2[0].track_id, 2);
}
#[test]
fn test_bytetrack_re_activate_lost_track() {
let mut tracker = ByteTrack::new(0.5, 30, 0.8, 0.6);
let d = ([10.0, 10.0, 50.0, 100.0], 0.9_f32, 0_i64);
let out1 = tracker.update(vec![d]);
assert_eq!(out1.len(), 1);
let id = out1[0].track_id;
let out2 = tracker.update(vec![]);
assert_eq!(out2.len(), 0, "track should not appear while lost");
let out3 = tracker.update(vec![d]);
assert_eq!(out3.len(), 1);
assert_eq!(
out3[0].track_id, id,
"re-activated track retains its original ID"
);
}
}