Skip to main content

edgefirst_tracker/
bytetrack.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{kalman::ConstantVelocityXYAHModel2, DetectionBox, TrackInfo, Tracker};
5use lapjv::{lapjv, Matrix};
6use log::{debug, trace};
7use nalgebra::{Dyn, OMatrix, U4};
8use uuid::Uuid;
9
10#[allow(dead_code)]
11#[derive(Default)]
12pub struct ByteTrack {
13    pub track_extra_lifespan: u64,
14    pub track_high_conf: f32,
15    pub track_iou: f32,
16    pub track_update: f32,
17
18    pub tracklets: Vec<Tracklet>,
19    pub frame_count: i32,
20}
21
22#[derive(Debug, Clone)]
23pub struct Tracklet {
24    pub id: Uuid,
25    pub filter: ConstantVelocityXYAHModel2<f32>,
26    pub count: i32,
27    pub created: u64,
28    pub last_updated: u64,
29}
30
31impl Tracklet {
32    fn update<T: DetectionBox>(&mut self, detect_box: &T, ts: u64) {
33        self.count += 1;
34        self.last_updated = ts;
35        self.filter.update(&vaalbox_to_xyah(&detect_box.bbox()));
36    }
37
38    pub fn get_predicted_location(&self) -> [f32; 4] {
39        let predicted_xyah = self.filter.mean.as_slice();
40        xyah_to_vaalbox(predicted_xyah)
41    }
42}
43
44fn vaalbox_to_xyah(vaal_box: &[f32; 4]) -> [f32; 4] {
45    let x = (vaal_box[2] + vaal_box[0]) / 2.0;
46    let y = (vaal_box[3] + vaal_box[1]) / 2.0;
47    let w = (vaal_box[2] - vaal_box[0]).max(EPSILON);
48    let h = (vaal_box[3] - vaal_box[1]).max(EPSILON);
49    let a = w / h;
50
51    [x, y, a, h]
52}
53
54fn xyah_to_vaalbox(xyah: &[f32]) -> [f32; 4] {
55    assert!(xyah.len() >= 4);
56    let [x, y, a, h] = xyah[0..4] else {
57        unreachable!()
58    };
59    let w = h * a;
60    [x - w / 2.0, y - h / 2.0, x + w / 2.0, y + h / 2.0]
61}
62
63const INVALID_MATCH: f32 = 1000000.0;
64const EPSILON: f32 = 0.00001;
65
66fn iou(box1: &[f32], box2: &[f32]) -> f32 {
67    let intersection = (box1[2].min(box2[2]) - box1[0].max(box2[0])).max(0.0)
68        * (box1[3].min(box2[3]) - box1[1].max(box2[1])).max(0.0);
69
70    let union = (box1[2] - box1[0]) * (box1[3] - box1[1])
71        + (box2[2] - box2[0]) * (box2[3] - box2[1])
72        - intersection;
73
74    if union <= EPSILON {
75        return 0.0;
76    }
77
78    intersection / union
79}
80
81fn box_cost<T: DetectionBox>(
82    track: &Tracklet,
83    new_box: &T,
84    distance: f32,
85    score_threshold: f32,
86    iou_threshold: f32,
87) -> f32 {
88    let _ = distance;
89
90    if new_box.score() < score_threshold {
91        return INVALID_MATCH;
92    }
93
94    // use iou between predicted box and real box:
95    let predicted_xyah = track.filter.mean.as_slice();
96    let expected = xyah_to_vaalbox(predicted_xyah);
97    let iou = iou(&expected, &new_box.bbox());
98    if iou < iou_threshold {
99        return INVALID_MATCH;
100    }
101    (1.5 - new_box.score()) + (1.5 - iou)
102}
103
104impl ByteTrack {
105    pub fn new() -> ByteTrack {
106        ByteTrack {
107            track_extra_lifespan: 500_000_000,
108            track_high_conf: 0.7,
109            track_iou: 0.25,
110            track_update: 0.25,
111            tracklets: Vec::new(),
112            frame_count: 0,
113        }
114    }
115
116    fn compute_costs<T: DetectionBox>(
117        &mut self,
118        boxes: &[T],
119        score_threshold: f32,
120        iou_threshold: f32,
121        box_filter: &[bool],
122        track_filter: &[bool],
123    ) -> Matrix<f32> {
124        // costs matrix must be square
125        let dims = boxes.len().max(self.tracklets.len());
126        let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(boxes.len(), 0.0);
127        for (i, mut row) in measurements.row_iter_mut().enumerate() {
128            row.copy_from_slice(&vaalbox_to_xyah(&boxes[i].bbox()));
129        }
130
131        // TODO: use matrix math for IOU, should speed up computation, and store it in
132        // distances
133
134        Matrix::from_shape_fn((dims, dims), |(x, y)| {
135            if x < boxes.len() && y < self.tracklets.len() {
136                if box_filter[x] || track_filter[y] {
137                    INVALID_MATCH
138                } else {
139                    box_cost(
140                        &self.tracklets[y],
141                        &boxes[x],
142                        // distances[(x, y)],
143                        0.0,
144                        score_threshold,
145                        iou_threshold,
146                    )
147                }
148            } else {
149                0.0
150            }
151        })
152    }
153
154    /// Process assignments from linear assignment and update tracking state.
155    /// Returns true if any matches were made.
156    #[allow(clippy::too_many_arguments)]
157    fn process_assignments<T: DetectionBox>(
158        &mut self,
159        assignments: &[usize],
160        boxes: &[T],
161        costs: &Matrix<f32>,
162        matched: &mut [bool],
163        tracked: &mut [bool],
164        matched_info: &mut [Option<TrackInfo>],
165        timestamp: u64,
166        skip_already_matched: bool,
167    ) {
168        for (i, &x) in assignments.iter().enumerate() {
169            if i >= boxes.len() || x >= self.tracklets.len() {
170                continue;
171            }
172
173            // Filter out invalid assignments
174            if costs[(i, x)] >= INVALID_MATCH {
175                continue;
176            }
177
178            // For second pass, skip already matched boxes/tracklets
179            if skip_already_matched && (matched[i] || tracked[x]) {
180                continue;
181            }
182
183            if skip_already_matched {
184                trace!(
185                    "Cost: {} Box: {:#?} UUID: {} Mean: {}",
186                    costs[(i, x)],
187                    boxes[i],
188                    self.tracklets[x].id,
189                    self.tracklets[x].filter.mean
190                );
191            }
192
193            matched[i] = true;
194            matched_info[i] = Some(TrackInfo {
195                uuid: self.tracklets[x].id,
196                count: self.tracklets[x].count,
197                created: self.tracklets[x].created,
198                tracked_location: self.tracklets[x].get_predicted_location(),
199                last_updated: timestamp,
200            });
201            assert!(!tracked[x]);
202            tracked[x] = true;
203            self.tracklets[x].update(&boxes[i], timestamp);
204        }
205    }
206
207    /// Remove expired tracklets based on timestamp.
208    fn remove_expired_tracklets(&mut self, timestamp: u64) {
209        // must iterate from the back
210        for i in (0..self.tracklets.len()).rev() {
211            let expiry = self.tracklets[i].last_updated + self.track_extra_lifespan;
212            if expiry < timestamp {
213                debug!("Tracklet removed: {:?}", self.tracklets[i].id);
214                let _ = self.tracklets.swap_remove(i);
215            }
216        }
217    }
218
219    /// Create new tracklets from unmatched high-confidence boxes.
220    fn create_new_tracklets<T: DetectionBox>(
221        &mut self,
222        boxes: &[T],
223        high_conf_indices: &[usize],
224        matched: &[bool],
225        matched_info: &mut [Option<TrackInfo>],
226        timestamp: u64,
227    ) {
228        for &i in high_conf_indices {
229            if matched[i] {
230                continue;
231            }
232
233            let id = Uuid::new_v4();
234            let new_tracklet = Tracklet {
235                id,
236                filter: ConstantVelocityXYAHModel2::new(
237                    &vaalbox_to_xyah(&boxes[i].bbox()),
238                    self.track_update,
239                ),
240                last_updated: timestamp,
241                count: 1,
242                created: timestamp,
243            };
244            matched_info[i] = Some(TrackInfo {
245                uuid: new_tracklet.id,
246                count: new_tracklet.count,
247                created: new_tracklet.created,
248                tracked_location: new_tracklet.get_predicted_location(),
249                last_updated: timestamp,
250            });
251            self.tracklets.push(new_tracklet);
252        }
253    }
254}
255
256impl<T> Tracker<T> for ByteTrack
257where
258    T: DetectionBox,
259{
260    fn update(&mut self, boxes: &[T], timestamp: u64) -> Vec<Option<TrackInfo>> {
261        self.frame_count += 1;
262
263        // Identify high-confidence detections
264        let high_conf_ind: Vec<usize> = boxes
265            .iter()
266            .enumerate()
267            .filter(|(_, b)| b.score() >= self.track_high_conf)
268            .map(|(x, _)| x)
269            .collect();
270
271        let mut matched = vec![false; boxes.len()];
272        let mut tracked = vec![false; self.tracklets.len()];
273        let mut matched_info = vec![None; boxes.len()];
274
275        // First pass: match high-confidence detections
276        if !self.tracklets.is_empty() {
277            for track in &mut self.tracklets {
278                track.filter.predict();
279            }
280
281            let costs = self.compute_costs(
282                boxes,
283                self.track_high_conf,
284                self.track_iou,
285                &matched,
286                &tracked,
287            );
288            let ans = lapjv(&costs).unwrap();
289            self.process_assignments(
290                &ans.0,
291                boxes,
292                &costs,
293                &mut matched,
294                &mut tracked,
295                &mut matched_info,
296                timestamp,
297                false,
298            );
299        }
300
301        // Second pass: match remaining tracklets to low-confidence detections
302        if !self.tracklets.is_empty() {
303            let costs = self.compute_costs(boxes, 0.0, self.track_iou, &matched, &tracked);
304            let ans = lapjv(&costs).unwrap();
305            self.process_assignments(
306                &ans.0,
307                boxes,
308                &costs,
309                &mut matched,
310                &mut tracked,
311                &mut matched_info,
312                timestamp,
313                true,
314            );
315        }
316
317        // Remove expired tracklets
318        self.remove_expired_tracklets(timestamp);
319
320        // Create new tracklets from unmatched high-confidence boxes
321        self.create_new_tracklets(
322            boxes,
323            &high_conf_ind,
324            &matched,
325            &mut matched_info,
326            timestamp,
327        );
328
329        matched_info
330    }
331
332    fn get_active_tracks(&self) -> Vec<TrackInfo> {
333        self.tracklets
334            .iter()
335            .map(|t| TrackInfo {
336                uuid: t.id,
337                tracked_location: t.get_predicted_location(),
338                count: t.count,
339                created: t.created,
340                last_updated: t.last_updated,
341            })
342            .collect()
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::{iou, vaalbox_to_xyah, xyah_to_vaalbox, ByteTrack};
349    use crate::{DetectionBox, Tracker};
350
351    /// Mock detection for testing
352    #[derive(Debug, Clone)]
353    struct MockDetection {
354        bbox: [f32; 4],
355        score: f32,
356        label: usize,
357    }
358
359    impl MockDetection {
360        fn new(x1: f32, y1: f32, x2: f32, y2: f32, score: f32) -> Self {
361            Self {
362                bbox: [x1, y1, x2, y2],
363                score,
364                label: 0,
365            }
366        }
367    }
368
369    impl DetectionBox for MockDetection {
370        fn bbox(&self) -> [f32; 4] {
371            self.bbox
372        }
373
374        fn score(&self) -> f32 {
375            self.score
376        }
377
378        fn label(&self) -> usize {
379            self.label
380        }
381    }
382
383    #[test]
384    fn test_vaalbox_xyah_roundtrip() {
385        let box1 = [0.0134, 0.02135, 0.12438, 0.691];
386        let xyah = vaalbox_to_xyah(&box1);
387        let box2 = xyah_to_vaalbox(&xyah);
388
389        assert!((box1[0] - box2[0]).abs() < f32::EPSILON);
390        assert!((box1[1] - box2[1]).abs() < f32::EPSILON);
391        assert!((box1[2] - box2[2]).abs() < f32::EPSILON);
392        assert!((box1[3] - box2[3]).abs() < f32::EPSILON);
393    }
394
395    #[test]
396    fn test_iou_identical_boxes() {
397        let box1 = [0.1, 0.1, 0.5, 0.5];
398        let box2 = [0.1, 0.1, 0.5, 0.5];
399        let result = iou(&box1, &box2);
400        assert!(
401            (result - 1.0).abs() < 0.001,
402            "IOU of identical boxes should be 1.0"
403        );
404    }
405
406    #[test]
407    fn test_iou_no_overlap() {
408        let box1 = [0.0, 0.0, 0.2, 0.2];
409        let box2 = [0.5, 0.5, 0.7, 0.7];
410        let result = iou(&box1, &box2);
411        assert!(result < 0.001, "IOU of non-overlapping boxes should be ~0");
412    }
413
414    #[test]
415    fn test_iou_partial_overlap() {
416        let box1 = [0.0, 0.0, 0.5, 0.5];
417        let box2 = [0.25, 0.25, 0.75, 0.75];
418        let result = iou(&box1, &box2);
419        // Intersection: 0.25*0.25 = 0.0625, Union: 0.25+0.25-0.0625 = 0.4375
420        assert!(result > 0.1 && result < 0.2, "IOU should be ~0.14");
421    }
422
423    #[test]
424    fn test_bytetrack_new() {
425        let tracker = ByteTrack::new();
426        assert_eq!(tracker.frame_count, 0);
427        assert!(tracker.tracklets.is_empty());
428        assert_eq!(tracker.track_high_conf, 0.7);
429        assert_eq!(tracker.track_iou, 0.25);
430    }
431
432    #[test]
433    fn test_bytetrack_single_detection_creates_tracklet() {
434        let mut tracker = ByteTrack::new();
435        let detections = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.9)];
436
437        let results = tracker.update(&detections, 1000);
438
439        assert_eq!(results.len(), 1);
440        assert!(
441            results[0].is_some(),
442            "High-confidence detection should create tracklet"
443        );
444        assert_eq!(tracker.tracklets.len(), 1);
445        assert_eq!(tracker.frame_count, 1);
446    }
447
448    #[test]
449    fn test_bytetrack_low_confidence_no_tracklet() {
450        let mut tracker = ByteTrack::new();
451        // Score below track_high_conf (0.7)
452        let detections = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.5)];
453
454        let results = tracker.update(&detections, 1000);
455
456        assert_eq!(results.len(), 1);
457        assert!(
458            results[0].is_none(),
459            "Low-confidence detection should not create tracklet"
460        );
461        assert!(tracker.tracklets.is_empty());
462    }
463
464    #[test]
465    fn test_bytetrack_tracking_across_frames() {
466        let mut tracker = ByteTrack::new();
467
468        // Frame 1: Create tracklet with a larger box that's easier to track
469        let det1 = vec![MockDetection::new(0.2, 0.2, 0.4, 0.4, 0.9)];
470        let res1 = tracker.update(&det1, 1000);
471        assert!(res1[0].is_some());
472        let uuid1 = res1[0].unwrap().uuid;
473        assert_eq!(tracker.tracklets.len(), 1);
474        // After creation, tracklet count is 1
475        assert_eq!(tracker.tracklets[0].count, 1);
476
477        // Frame 2: Same location - should match existing tracklet
478        let det2 = vec![MockDetection::new(0.2, 0.2, 0.4, 0.4, 0.9)];
479        let res2 = tracker.update(&det2, 2000);
480        assert!(res2[0].is_some());
481        let info2 = res2[0].unwrap();
482
483        // Verify tracklet was matched, not a new one created
484        assert_eq!(tracker.tracklets.len(), 1, "Should still have one tracklet");
485        assert_eq!(info2.uuid, uuid1, "Should match same tracklet");
486        // After second update, the internal tracklet count should be 2
487        assert_eq!(tracker.tracklets[0].count, 2, "Internal count should be 2");
488    }
489
490    #[test]
491    fn test_bytetrack_multiple_detections() {
492        let mut tracker = ByteTrack::new();
493
494        let detections = vec![
495            MockDetection::new(0.1, 0.1, 0.2, 0.2, 0.9),
496            MockDetection::new(0.5, 0.5, 0.6, 0.6, 0.85),
497            MockDetection::new(0.8, 0.8, 0.9, 0.9, 0.95),
498        ];
499
500        let results = tracker.update(&detections, 1000);
501
502        assert_eq!(results.len(), 3);
503        assert!(results.iter().all(|r| r.is_some()));
504        assert_eq!(tracker.tracklets.len(), 3);
505    }
506
507    #[test]
508    fn test_bytetrack_tracklet_expiry() {
509        let mut tracker = ByteTrack::new();
510        tracker.track_extra_lifespan = 1000; // 1 second
511
512        // Create tracklet
513        let det1 = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.9)];
514        tracker.update(&det1, 1000);
515        assert_eq!(tracker.tracklets.len(), 1);
516
517        // Update with no detections after lifespan expires
518        let empty: Vec<MockDetection> = vec![];
519        tracker.update(&empty, 3000); // 2 seconds later
520
521        assert!(tracker.tracklets.is_empty(), "Tracklet should have expired");
522    }
523
524    #[test]
525    fn test_bytetrack_get_active_tracks() {
526        let mut tracker = ByteTrack::new();
527
528        let detections = vec![
529            MockDetection::new(0.1, 0.1, 0.2, 0.2, 0.9),
530            MockDetection::new(0.5, 0.5, 0.6, 0.6, 0.85),
531        ];
532        tracker.update(&detections, 1000);
533
534        let active = <ByteTrack as Tracker<MockDetection>>::get_active_tracks(&tracker);
535        assert_eq!(active.len(), 2);
536        assert!(active.iter().all(|t| t.count == 1));
537        assert!(active.iter().all(|t| t.created == 1000));
538    }
539
540    #[test]
541    fn test_bytetrack_empty_detections() {
542        let mut tracker = ByteTrack::new();
543        let empty: Vec<MockDetection> = vec![];
544
545        let results = tracker.update(&empty, 1000);
546
547        assert!(results.is_empty());
548        assert!(tracker.tracklets.is_empty());
549        assert_eq!(tracker.frame_count, 1);
550    }
551}