use yscv_detect::{BoundingBox, Detection, iou};
use crate::types::TrackedDetection;
#[derive(Debug, Clone)]
struct ByteTrack {
id: usize,
bbox: BoundingBox,
score: f32,
class_id: usize,
age: usize,
hits: usize,
}
pub struct ByteTracker {
next_id: usize,
tracks: Vec<ByteTrack>,
high_threshold: f32,
low_threshold: f32,
iou_threshold: f32,
max_age: usize,
}
impl ByteTracker {
pub fn new(
high_threshold: f32,
low_threshold: f32,
iou_threshold: f32,
max_age: usize,
) -> Self {
Self {
next_id: 1,
tracks: Vec::new(),
high_threshold,
low_threshold,
iou_threshold,
max_age,
}
}
pub fn update(&mut self, detections: &[Detection]) -> Vec<TrackedDetection> {
let mut high: Vec<usize> = Vec::new();
let mut low: Vec<usize> = Vec::new();
for (i, det) in detections.iter().enumerate() {
if det.score >= self.high_threshold {
high.push(i);
} else if det.score >= self.low_threshold {
low.push(i);
}
}
let all_track_indices: Vec<usize> = (0..self.tracks.len()).collect();
let mut matched_tracks: Vec<bool> = vec![false; self.tracks.len()];
let mut matched_dets: Vec<bool> = vec![false; detections.len()];
let assignments1 = greedy_match(
&self.tracks,
detections,
&all_track_indices,
&high,
self.iou_threshold,
);
for &(ti, di) in &assignments1 {
matched_tracks[ti] = true;
matched_dets[di] = true;
self.tracks[ti].bbox = detections[di].bbox;
self.tracks[ti].score = detections[di].score;
self.tracks[ti].class_id = detections[di].class_id;
self.tracks[ti].age = 0;
self.tracks[ti].hits += 1;
}
let unmatched_track_indices: Vec<usize> = (0..self.tracks.len())
.filter(|&i| !matched_tracks[i])
.collect();
let assignments2 = greedy_match(
&self.tracks,
detections,
&unmatched_track_indices,
&low,
self.iou_threshold,
);
for &(ti, di) in &assignments2 {
matched_tracks[ti] = true;
matched_dets[di] = true;
self.tracks[ti].bbox = detections[di].bbox;
self.tracks[ti].score = detections[di].score;
self.tracks[ti].class_id = detections[di].class_id;
self.tracks[ti].age = 0;
self.tracks[ti].hits += 1;
}
for &di in &high {
if !matched_dets[di] {
let id = self.next_id;
self.next_id += 1;
self.tracks.push(ByteTrack {
id,
bbox: detections[di].bbox,
score: detections[di].score,
class_id: detections[di].class_id,
age: 0,
hits: 1,
});
}
}
for (i, track) in self.tracks.iter_mut().enumerate() {
if i < matched_tracks.len() && !matched_tracks[i] {
track.age += 1;
}
}
self.tracks.retain(|t| t.age <= self.max_age);
self.tracks
.iter()
.map(|t| TrackedDetection {
track_id: t.id as u64,
detection: Detection {
bbox: t.bbox,
score: t.score,
class_id: t.class_id,
},
})
.collect()
}
pub fn active_track_count(&self) -> usize {
self.tracks.len()
}
}
fn greedy_match(
tracks: &[ByteTrack],
detections: &[Detection],
track_indices: &[usize],
det_indices: &[usize],
iou_threshold: f32,
) -> Vec<(usize, usize)> {
let mut used_dets = vec![false; detections.len()];
let mut assignments = Vec::new();
for &ti in track_indices {
let mut best_iou = iou_threshold;
let mut best_di: Option<usize> = None;
for &di in det_indices {
if used_dets[di] {
continue;
}
let iou_val = iou(tracks[ti].bbox, detections[di].bbox);
if iou_val >= best_iou {
best_iou = iou_val;
best_di = Some(di);
}
}
if let Some(di) = best_di {
used_dets[di] = true;
assignments.push((ti, di));
}
}
assignments
}
#[cfg(test)]
mod tests {
use super::*;
fn det(x1: f32, y1: f32, x2: f32, y2: f32, score: f32) -> Detection {
Detection {
bbox: BoundingBox { x1, y1, x2, y2 },
score,
class_id: 0,
}
}
#[test]
fn byte_track_creates_tracks() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
let dets = [
det(10.0, 10.0, 50.0, 50.0, 0.9),
det(100.0, 100.0, 150.0, 150.0, 0.8),
];
let tracked = tracker.update(&dets);
assert_eq!(tracked.len(), 2);
assert_eq!(tracker.active_track_count(), 2);
}
#[test]
fn byte_track_maintains_ids() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
let dets1 = [
det(10.0, 10.0, 50.0, 50.0, 0.9),
det(100.0, 100.0, 150.0, 150.0, 0.8),
];
let tracked1 = tracker.update(&dets1);
let id0 = tracked1[0].track_id;
let id1 = tracked1[1].track_id;
let dets2 = [
det(12.0, 12.0, 52.0, 52.0, 0.9),
det(102.0, 102.0, 152.0, 152.0, 0.85),
];
let tracked2 = tracker.update(&dets2);
assert_eq!(tracked2.len(), 2);
let ids: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
assert!(ids.contains(&id0));
assert!(ids.contains(&id1));
}
#[test]
fn byte_track_removes_old_tracks() {
let max_age = 2;
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, max_age);
let dets = [det(10.0, 10.0, 50.0, 50.0, 0.9)];
tracker.update(&dets);
assert_eq!(tracker.active_track_count(), 1);
for _ in 0..=max_age {
tracker.update(&[]);
}
assert_eq!(tracker.active_track_count(), 0);
}
#[test]
fn byte_track_low_confidence_association() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
let dets1 = [det(10.0, 10.0, 50.0, 50.0, 0.9)];
let tracked1 = tracker.update(&dets1);
let id = tracked1[0].track_id;
let dets2 = [det(12.0, 12.0, 52.0, 52.0, 0.3)];
let tracked2 = tracker.update(&dets2);
assert_eq!(tracked2.len(), 1);
assert_eq!(tracked2[0].track_id, id);
}
#[test]
fn byte_track_new_track_for_new_object() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 3);
let dets1 = [det(10.0, 10.0, 50.0, 50.0, 0.9)];
let tracked1 = tracker.update(&dets1);
assert_eq!(tracked1.len(), 1);
let id1 = tracked1[0].track_id;
let dets2 = [
det(12.0, 12.0, 52.0, 52.0, 0.9),
det(200.0, 200.0, 250.0, 250.0, 0.8),
];
let tracked2 = tracker.update(&dets2);
assert_eq!(tracked2.len(), 2);
let ids: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
assert!(ids.contains(&id1));
let new_id = ids
.iter()
.find(|&&id| id != id1)
.expect("second track should exist");
assert_ne!(*new_id, id1);
}
#[test]
fn byte_track_three_objects_simultaneously() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
let dets = [
det(10.0, 10.0, 50.0, 50.0, 0.9),
det(100.0, 100.0, 140.0, 140.0, 0.8),
det(200.0, 200.0, 240.0, 240.0, 0.7),
];
let tracked = tracker.update(&dets);
assert_eq!(tracked.len(), 3);
assert_eq!(tracker.active_track_count(), 3);
let ids: Vec<u64> = tracked.iter().map(|t| t.track_id).collect();
let mut unique = ids.clone();
unique.sort();
unique.dedup();
assert_eq!(unique.len(), 3);
}
#[test]
fn byte_track_empty_detections_ages_tracks() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
assert_eq!(tracker.active_track_count(), 1);
let result = tracker.update(&[]);
assert_eq!(tracker.active_track_count(), 1);
assert_eq!(result.len(), 1); }
#[test]
fn byte_track_single_detection_stable_id() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
let d = det(20.0, 20.0, 60.0, 60.0, 0.9);
let first = tracker.update(&[d]);
let id = first[0].track_id;
for _ in 0..10 {
let tracked = tracker.update(&[d]);
assert_eq!(tracked.len(), 1);
assert_eq!(tracked[0].track_id, id);
}
}
#[test]
fn byte_track_id_stability_smooth_motion() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
let first = tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
let id = first[0].track_id;
let positions = [
(12.0, 12.0, 52.0, 52.0),
(14.0, 14.0, 54.0, 54.0),
(16.0, 16.0, 56.0, 56.0),
(18.0, 18.0, 58.0, 58.0),
];
for (x1, y1, x2, y2) in positions {
let tracked = tracker.update(&[det(x1, y1, x2, y2, 0.9)]);
assert_eq!(tracked.len(), 1);
assert_eq!(tracked[0].track_id, id);
}
}
#[test]
fn byte_track_iou_matching_overlapping_bboxes() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
let dets1 = [
det(10.0, 10.0, 50.0, 50.0, 0.9),
det(30.0, 30.0, 70.0, 70.0, 0.8),
];
let tracked1 = tracker.update(&dets1);
assert_eq!(tracked1.len(), 2);
let id_a = tracked1[0].track_id;
let id_b = tracked1[1].track_id;
assert_ne!(id_a, id_b);
let dets2 = [
det(11.0, 11.0, 51.0, 51.0, 0.9),
det(31.0, 31.0, 71.0, 71.0, 0.8),
];
let tracked2 = tracker.update(&dets2);
let ids2: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
assert!(ids2.contains(&id_a));
assert!(ids2.contains(&id_b));
}
#[test]
fn byte_track_low_vs_high_confidence_precedence() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
let dets1 = [
det(10.0, 10.0, 50.0, 50.0, 0.9),
det(100.0, 100.0, 140.0, 140.0, 0.8),
];
let tracked1 = tracker.update(&dets1);
let id_a = tracked1[0].track_id;
let id_b = tracked1[1].track_id;
let dets2 = [
det(12.0, 12.0, 52.0, 52.0, 0.9), det(102.0, 102.0, 142.0, 142.0, 0.2), ];
let tracked2 = tracker.update(&dets2);
assert_eq!(tracked2.len(), 2);
let ids2: Vec<u64> = tracked2.iter().map(|t| t.track_id).collect();
assert!(ids2.contains(&id_a));
assert!(ids2.contains(&id_b));
}
#[test]
fn byte_track_low_confidence_no_new_track() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
let dets = [det(10.0, 10.0, 50.0, 50.0, 0.2)];
let tracked = tracker.update(&dets);
assert_eq!(tracked.len(), 0);
assert_eq!(tracker.active_track_count(), 0);
}
#[test]
fn byte_track_below_low_threshold_ignored() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
let dets = [det(10.0, 10.0, 50.0, 50.0, 0.05)]; let tracked = tracker.update(&dets);
assert_eq!(tracked.len(), 0);
assert_eq!(tracker.active_track_count(), 0);
}
#[test]
fn byte_track_config_different_max_age() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 0);
tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
assert_eq!(tracker.active_track_count(), 1);
tracker.update(&[]);
assert_eq!(tracker.active_track_count(), 0);
}
#[test]
fn byte_track_config_different_iou_threshold() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.95, 5);
tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
tracker.update(&[det(15.0, 15.0, 55.0, 55.0, 0.9)]);
assert_eq!(tracker.active_track_count(), 2);
}
#[test]
fn byte_track_more_tracks_than_detections() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
let dets = [
det(10.0, 10.0, 50.0, 50.0, 0.9),
det(100.0, 100.0, 140.0, 140.0, 0.8),
det(200.0, 200.0, 240.0, 240.0, 0.7),
];
tracker.update(&dets);
assert_eq!(tracker.active_track_count(), 3);
tracker.update(&[det(12.0, 12.0, 52.0, 52.0, 0.9)]);
assert_eq!(tracker.active_track_count(), 3);
}
#[test]
fn byte_track_more_detections_than_tracks() {
let mut tracker = ByteTracker::new(0.5, 0.1, 0.3, 5);
tracker.update(&[det(10.0, 10.0, 50.0, 50.0, 0.9)]);
assert_eq!(tracker.active_track_count(), 1);
let dets = [
det(12.0, 12.0, 52.0, 52.0, 0.9),
det(100.0, 100.0, 140.0, 140.0, 0.8),
det(200.0, 200.0, 240.0, 240.0, 0.7),
];
tracker.update(&dets);
assert_eq!(tracker.active_track_count(), 3);
}
}