Skip to main content

edgefirst_tracker/
bytetrack.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    kalman::ConstantVelocityXYAHModel2, ActiveTrackInfo, DetectionBox, TrackInfo, Tracker,
6};
7use lapjv::{lapjv, Matrix};
8use log::trace;
9use nalgebra::{Dyn, OMatrix, U4};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct ByteTrackBuilder {
14    track_extra_lifespan: u64,
15    track_high_conf: f32,
16    track_iou: f32,
17    track_update: f32,
18}
19
20impl Default for ByteTrackBuilder {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl ByteTrackBuilder {
27    /// Creates a new ByteTrackBuilder with default parameters.
28    /// These defaults are:
29    /// - track_high_conf: 0.7
30    /// - track_iou: 0.25
31    /// - track_update: 0.25
32    /// - track_extra_lifespan: 500_000_000 (0.5 seconds)
33    /// # Examples
34    /// ```rust
35    /// use edgefirst_tracker::{bytetrack::ByteTrackBuilder, Tracker, MockDetection};
36    /// let mut tracker = ByteTrackBuilder::new().build();
37    /// assert_eq!(tracker.track_high_conf, 0.7);
38    /// assert_eq!(tracker.track_iou, 0.25);
39    /// assert_eq!(tracker.track_update, 0.25);
40    /// assert_eq!(tracker.track_extra_lifespan, 500_000_000);
41    /// # let boxes = Vec::<MockDetection>::new();
42    /// # tracker.update(&boxes, 0);
43    /// ```
44    pub fn new() -> Self {
45        Self {
46            track_high_conf: 0.7,
47            track_iou: 0.25,
48            track_update: 0.25,
49            track_extra_lifespan: 500_000_000,
50        }
51    }
52
53    /// Sets the extra lifespan for tracks in nanoseconds.
54    pub fn track_extra_lifespan(mut self, lifespan: u64) -> Self {
55        self.track_extra_lifespan = lifespan;
56        self
57    }
58
59    /// Sets the high confidence threshold for tracking.
60    pub fn track_high_conf(mut self, conf: f32) -> Self {
61        self.track_high_conf = conf;
62        self
63    }
64
65    /// Sets the IOU threshold for tracking.
66    pub fn track_iou(mut self, iou: f32) -> Self {
67        self.track_iou = iou;
68        self
69    }
70
71    /// Sets the update rate for the Kalman filter.
72    pub fn track_update(mut self, update: f32) -> Self {
73        self.track_update = update;
74        self
75    }
76
77    /// Builds the ByteTrack tracker with the specified parameters.
78    /// # Examples
79    /// ```rust
80    /// use edgefirst_tracker::{bytetrack::ByteTrackBuilder, Tracker, MockDetection};
81    /// let mut tracker = ByteTrackBuilder::new()
82    ///     .track_high_conf(0.8)
83    ///     .track_iou(0.3)
84    ///     .track_update(0.2)
85    ///     .track_extra_lifespan(1_000_000_000)
86    ///     .build();
87    /// assert_eq!(tracker.track_high_conf, 0.8);
88    /// assert_eq!(tracker.track_iou, 0.3);
89    /// assert_eq!(tracker.track_update, 0.2);
90    /// assert_eq!(tracker.track_extra_lifespan, 1_000_000_000);
91    /// # let boxes = Vec::<MockDetection>::new();
92    /// # tracker.update(&boxes, 0);
93    /// ```
94    pub fn build<T: DetectionBox>(self) -> ByteTrack<T> {
95        ByteTrack {
96            track_extra_lifespan: self.track_extra_lifespan,
97            track_high_conf: self.track_high_conf,
98            track_iou: self.track_iou,
99            track_update: self.track_update,
100            tracklets: Vec::new(),
101            frame_count: 0,
102        }
103    }
104}
105
106#[allow(dead_code)]
107#[derive(Default, Debug, Clone)]
108pub struct ByteTrack<T: DetectionBox> {
109    pub track_extra_lifespan: u64,
110    pub track_high_conf: f32,
111    pub track_iou: f32,
112    pub track_update: f32,
113    pub tracklets: Vec<Tracklet<T>>,
114    pub frame_count: i32,
115}
116
117#[derive(Debug, Clone)]
118pub struct Tracklet<T: DetectionBox> {
119    pub id: Uuid,
120    pub filter: ConstantVelocityXYAHModel2<f32>,
121    pub count: i32,
122    pub created: u64,
123    pub last_updated: u64,
124    pub last_box: T,
125}
126
127impl<T: DetectionBox> Tracklet<T> {
128    fn update(&mut self, detect_box: &T, ts: u64) {
129        self.count += 1;
130        self.last_updated = ts;
131        self.filter.update(&xyxy_to_xyah(&detect_box.bbox()));
132        self.last_box = detect_box.clone();
133    }
134
135    pub fn get_predicted_location(&self) -> [f32; 4] {
136        let projected = self.filter.project().0;
137        let predicted_xyah = projected.as_slice();
138        xyah_to_xyxy(predicted_xyah)
139    }
140}
141
142fn xyxy_to_xyah(vaal_box: &[f32; 4]) -> [f32; 4] {
143    let x = (vaal_box[2] + vaal_box[0]) / 2.0;
144    let y = (vaal_box[3] + vaal_box[1]) / 2.0;
145    let w = (vaal_box[2] - vaal_box[0]).max(EPSILON);
146    let h = (vaal_box[3] - vaal_box[1]).max(EPSILON);
147    let a = w / h;
148
149    [x, y, a, h]
150}
151
152fn xyah_to_xyxy(xyah: &[f32]) -> [f32; 4] {
153    assert!(xyah.len() >= 4);
154    let [x, y, a, h] = xyah[0..4] else {
155        unreachable!()
156    };
157    let w = h * a;
158    [x - w / 2.0, y - h / 2.0, x + w / 2.0, y + h / 2.0]
159}
160
161const INVALID_MATCH: f32 = 1000000.0;
162const EPSILON: f32 = 0.00001;
163
164fn iou(box1: &[f32], box2: &[f32]) -> f32 {
165    let intersection = (box1[2].min(box2[2]) - box1[0].max(box2[0])).max(0.0)
166        * (box1[3].min(box2[3]) - box1[1].max(box2[1])).max(0.0);
167
168    let union = (box1[2] - box1[0]) * (box1[3] - box1[1])
169        + (box2[2] - box2[0]) * (box2[3] - box2[1])
170        - intersection;
171
172    if union <= EPSILON {
173        return 0.0;
174    }
175
176    intersection / union
177}
178
179fn box_cost<T: DetectionBox>(
180    track: &Tracklet<T>,
181    new_box: &T,
182    distance: f32,
183    score_threshold: f32,
184    iou_threshold: f32,
185) -> f32 {
186    let _ = distance;
187
188    if new_box.score() < score_threshold {
189        return INVALID_MATCH;
190    }
191
192    // use iou between predicted box and real box:
193    let predicted_xyah = track.filter.mean.as_slice();
194    let expected = xyah_to_xyxy(predicted_xyah);
195    let iou = iou(&expected, &new_box.bbox());
196    if iou < iou_threshold {
197        return INVALID_MATCH;
198    }
199    (1.5 - new_box.score()) + (1.5 - iou)
200}
201
202impl<T: DetectionBox> ByteTrack<T> {
203    fn compute_costs(
204        &mut self,
205        boxes: &[T],
206        score_threshold: f32,
207        iou_threshold: f32,
208        box_filter: &[bool],
209        track_filter: &[bool],
210    ) -> Matrix<f32> {
211        // costs matrix must be square
212        let dims = boxes.len().max(self.tracklets.len());
213        let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(boxes.len(), 0.0);
214        for (i, mut row) in measurements.row_iter_mut().enumerate() {
215            row.copy_from_slice(&xyxy_to_xyah(&boxes[i].bbox()));
216        }
217
218        // TODO: use matrix math for IOU, should speed up computation, and store it in
219        // distances
220
221        Matrix::from_shape_fn((dims, dims), |(x, y)| {
222            if x < boxes.len() && y < self.tracklets.len() {
223                if box_filter[x] || track_filter[y] {
224                    INVALID_MATCH
225                } else {
226                    box_cost(
227                        &self.tracklets[y],
228                        &boxes[x],
229                        // distances[(x, y)],
230                        0.0,
231                        score_threshold,
232                        iou_threshold,
233                    )
234                }
235            } else {
236                0.0
237            }
238        })
239    }
240
241    /// Process assignments from linear assignment and update tracking state.
242    /// Returns true if any matches were made.
243    #[allow(clippy::too_many_arguments)]
244    fn process_assignments(
245        &mut self,
246        assignments: &[usize],
247        boxes: &[T],
248        costs: &Matrix<f32>,
249        matched: &mut [bool],
250        tracked: &mut [bool],
251        matched_info: &mut [Option<TrackInfo>],
252        timestamp: u64,
253        log_assignments: bool,
254    ) {
255        for (i, &x) in assignments.iter().enumerate() {
256            if i >= boxes.len() || x >= self.tracklets.len() {
257                continue;
258            }
259
260            // Filter out invalid assignments
261            if costs[(i, x)] >= INVALID_MATCH {
262                continue;
263            }
264
265            // Skip already matched boxes/tracklets
266            if matched[i] || tracked[x] {
267                continue;
268            }
269
270            if log_assignments {
271                trace!(
272                    "Cost: {} Box: {:#?} UUID: {} Mean: {}",
273                    costs[(i, x)],
274                    boxes[i],
275                    self.tracklets[x].id,
276                    self.tracklets[x].filter.mean
277                );
278            }
279
280            matched[i] = true;
281            matched_info[i] = Some(TrackInfo {
282                uuid: self.tracklets[x].id,
283                count: self.tracklets[x].count,
284                created: self.tracklets[x].created,
285                tracked_location: self.tracklets[x].get_predicted_location(),
286                last_updated: timestamp,
287            });
288            tracked[x] = true;
289            self.tracklets[x].update(&boxes[i], timestamp);
290        }
291    }
292
293    /// Remove expired tracklets based on timestamp.
294    fn remove_expired_tracklets(&mut self, timestamp: u64) {
295        // must iterate from the back
296        for i in (0..self.tracklets.len()).rev() {
297            let expiry = self.tracklets[i].last_updated + self.track_extra_lifespan;
298            if expiry < timestamp {
299                trace!("Tracklet removed: {:?}", self.tracklets[i].id);
300                let _ = self.tracklets.swap_remove(i);
301            }
302        }
303    }
304
305    /// Create new tracklets from unmatched high-confidence boxes.
306    fn create_new_tracklets(
307        &mut self,
308        boxes: &[T],
309        high_conf_indices: &[usize],
310        matched: &[bool],
311        matched_info: &mut [Option<TrackInfo>],
312        timestamp: u64,
313    ) {
314        for &i in high_conf_indices {
315            if matched[i] {
316                continue;
317            }
318
319            let id = Uuid::new_v4();
320            let new_tracklet = Tracklet {
321                id,
322                filter: ConstantVelocityXYAHModel2::new(
323                    &xyxy_to_xyah(&boxes[i].bbox()),
324                    self.track_update,
325                ),
326                last_updated: timestamp,
327                count: 1,
328                created: timestamp,
329                last_box: boxes[i].clone(),
330            };
331            matched_info[i] = Some(TrackInfo {
332                uuid: new_tracklet.id,
333                count: new_tracklet.count,
334                created: new_tracklet.created,
335                tracked_location: new_tracklet.get_predicted_location(),
336                last_updated: timestamp,
337            });
338            self.tracklets.push(new_tracklet);
339        }
340    }
341}
342
343impl<T> Tracker<T> for ByteTrack<T>
344where
345    T: DetectionBox,
346{
347    fn update(&mut self, boxes: &[T], timestamp: u64) -> Vec<Option<TrackInfo>> {
348        self.frame_count += 1;
349
350        // Identify high-confidence detections
351        let high_conf_ind: Vec<usize> = boxes
352            .iter()
353            .enumerate()
354            .filter(|(_, b)| b.score() >= self.track_high_conf)
355            .map(|(x, _)| x)
356            .collect();
357
358        let mut matched = vec![false; boxes.len()];
359        let mut tracked = vec![false; self.tracklets.len()];
360        let mut matched_info = vec![None; boxes.len()];
361
362        // First pass: match high-confidence detections
363        if !self.tracklets.is_empty() {
364            for track in &mut self.tracklets {
365                track.filter.predict();
366            }
367
368            let costs = self.compute_costs(
369                boxes,
370                self.track_high_conf,
371                self.track_iou,
372                &matched,
373                &tracked,
374            );
375            if let Ok(ans) = lapjv(&costs) {
376                self.process_assignments(
377                    &ans.0,
378                    boxes,
379                    &costs,
380                    &mut matched,
381                    &mut tracked,
382                    &mut matched_info,
383                    timestamp,
384                    false,
385                );
386            }
387        }
388
389        // Second pass: match remaining tracklets to low-confidence detections
390        if !self.tracklets.is_empty() {
391            let costs = self.compute_costs(boxes, 0.0, self.track_iou, &matched, &tracked);
392            if let Ok(ans) = lapjv(&costs) {
393                self.process_assignments(
394                    &ans.0,
395                    boxes,
396                    &costs,
397                    &mut matched,
398                    &mut tracked,
399                    &mut matched_info,
400                    timestamp,
401                    true,
402                );
403            }
404        }
405
406        // Remove expired tracklets
407        self.remove_expired_tracklets(timestamp);
408
409        // Create new tracklets from unmatched high-confidence boxes
410        self.create_new_tracklets(
411            boxes,
412            &high_conf_ind,
413            &matched,
414            &mut matched_info,
415            timestamp,
416        );
417
418        matched_info
419    }
420
421    fn get_active_tracks(&self) -> Vec<ActiveTrackInfo<T>> {
422        self.tracklets
423            .iter()
424            .map(|t| ActiveTrackInfo {
425                info: TrackInfo {
426                    uuid: t.id,
427                    tracked_location: t.get_predicted_location(),
428                    count: t.count,
429                    created: t.created,
430                    last_updated: t.last_updated,
431                },
432                last_box: t.last_box.clone(),
433            })
434            .collect()
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::*;
442
443    impl MockDetection {
444        fn new(x1: f32, y1: f32, x2: f32, y2: f32, score: f32) -> Self {
445            Self {
446                bbox: [x1, y1, x2, y2],
447                score,
448                label: 0,
449            }
450        }
451    }
452
453    #[test]
454    fn test_vaalbox_xyah_roundtrip() {
455        let box1 = [0.0134, 0.02135, 0.12438, 0.691];
456        let xyah = xyxy_to_xyah(&box1);
457        let box2 = xyah_to_xyxy(&xyah);
458
459        assert!((box1[0] - box2[0]).abs() < f32::EPSILON);
460        assert!((box1[1] - box2[1]).abs() < f32::EPSILON);
461        assert!((box1[2] - box2[2]).abs() < f32::EPSILON);
462        assert!((box1[3] - box2[3]).abs() < f32::EPSILON);
463    }
464
465    #[test]
466    fn test_iou_identical_boxes() {
467        let box1 = [0.1, 0.1, 0.5, 0.5];
468        let box2 = [0.1, 0.1, 0.5, 0.5];
469        let result = iou(&box1, &box2);
470        assert!(
471            (result - 1.0).abs() < 0.001,
472            "IOU of identical boxes should be 1.0"
473        );
474    }
475
476    #[test]
477    fn test_iou_no_overlap() {
478        let box1 = [0.0, 0.0, 0.2, 0.2];
479        let box2 = [0.5, 0.5, 0.7, 0.7];
480        let result = iou(&box1, &box2);
481        assert!(result < 0.001, "IOU of non-overlapping boxes should be ~0");
482    }
483
484    #[test]
485    fn test_iou_partial_overlap() {
486        let box1 = [0.0, 0.0, 0.5, 0.5];
487        let box2 = [0.25, 0.25, 0.75, 0.75];
488        let result = iou(&box1, &box2);
489        // Intersection: 0.25*0.25 = 0.0625, Union: 0.25+0.25-0.0625 = 0.4375
490        assert!(result > 0.1 && result < 0.2, "IOU should be ~0.14");
491    }
492
493    #[test]
494    fn test_bytetrack_new() {
495        let tracker: ByteTrack<MockDetection> = ByteTrackBuilder::new().build();
496        assert_eq!(tracker.frame_count, 0);
497        assert!(tracker.tracklets.is_empty());
498        assert_eq!(tracker.track_high_conf, 0.7);
499        assert_eq!(tracker.track_iou, 0.25);
500    }
501
502    #[test]
503    fn test_bytetrack_single_detection_creates_tracklet() {
504        let mut tracker = ByteTrackBuilder::new().build();
505        let detections = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.9)];
506
507        let results = tracker.update(&detections, 1000);
508
509        assert_eq!(results.len(), 1);
510        assert!(
511            results[0].is_some(),
512            "High-confidence detection should create tracklet"
513        );
514        assert_eq!(tracker.tracklets.len(), 1);
515        assert_eq!(tracker.frame_count, 1);
516    }
517
518    #[test]
519    fn test_bytetrack_low_confidence_no_tracklet() {
520        let mut tracker = ByteTrackBuilder::new().build();
521        // Score below track_high_conf (0.7)
522        let detections = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.5)];
523
524        let results = tracker.update(&detections, 1000);
525
526        assert_eq!(results.len(), 1);
527        assert!(
528            results[0].is_none(),
529            "Low-confidence detection should not create tracklet"
530        );
531        assert!(tracker.tracklets.is_empty());
532    }
533
534    #[test]
535    fn test_bytetrack_tracking_across_frames() {
536        let mut tracker = ByteTrackBuilder::new().build();
537
538        // Frame 1: Create tracklet with a larger box that's easier to track
539        let det1 = vec![MockDetection::new(0.2, 0.2, 0.4, 0.4, 0.9)];
540        let res1 = tracker.update(&det1, 1000);
541        assert!(res1[0].is_some());
542        let uuid1 = res1[0].unwrap().uuid;
543        assert_eq!(tracker.tracklets.len(), 1);
544        // After creation, tracklet count is 1
545        assert_eq!(tracker.tracklets[0].count, 1);
546
547        // Frame 2: Same location - should match existing tracklet
548        let det2 = vec![MockDetection::new(0.2, 0.2, 0.4, 0.4, 0.9)];
549        let res2 = tracker.update(&det2, 2000);
550        assert!(res2[0].is_some());
551        let info2 = res2[0].unwrap();
552
553        // Verify tracklet was matched, not a new one created
554        assert_eq!(tracker.tracklets.len(), 1, "Should still have one tracklet");
555        assert_eq!(info2.uuid, uuid1, "Should match same tracklet");
556        // After second update, the internal tracklet count should be 2
557        assert_eq!(tracker.tracklets[0].count, 2, "Internal count should be 2");
558    }
559
560    #[test]
561    fn test_bytetrack_multiple_detections() {
562        let mut tracker = ByteTrackBuilder::new().build();
563
564        let detections = vec![
565            MockDetection::new(0.1, 0.1, 0.2, 0.2, 0.9),
566            MockDetection::new(0.5, 0.5, 0.6, 0.6, 0.85),
567            MockDetection::new(0.8, 0.8, 0.9, 0.9, 0.95),
568        ];
569
570        let results = tracker.update(&detections, 1000);
571
572        assert_eq!(results.len(), 3);
573        assert!(results.iter().all(|r| r.is_some()));
574        assert_eq!(tracker.tracklets.len(), 3);
575    }
576
577    #[test]
578    fn test_bytetrack_tracklet_expiry() {
579        let mut tracker = ByteTrackBuilder::new().build();
580        tracker.track_extra_lifespan = 1000; // 1 second
581
582        // Create tracklet
583        let det1 = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.9)];
584        tracker.update(&det1, 1000);
585        assert_eq!(tracker.tracklets.len(), 1);
586
587        // Update with no detections after lifespan expires
588        let empty: Vec<MockDetection> = vec![];
589        tracker.update(&empty, 3000); // 2 seconds later
590
591        assert!(tracker.tracklets.is_empty(), "Tracklet should have expired");
592    }
593
594    #[test]
595    fn test_bytetrack_get_active_tracks() {
596        let mut tracker = ByteTrackBuilder::new().build();
597
598        let detections = vec![
599            MockDetection::new(0.1, 0.1, 0.2, 0.2, 0.9),
600            MockDetection::new(0.5, 0.5, 0.6, 0.6, 0.85),
601        ];
602        tracker.update(&detections, 1000);
603
604        let active = tracker.get_active_tracks();
605        assert_eq!(active.len(), 2);
606        assert!(active.iter().all(|t| t.info.count == 1));
607        assert!(active.iter().all(|t| t.info.created == 1000));
608    }
609
610    #[test]
611    fn test_bytetrack_empty_detections() {
612        let mut tracker = ByteTrackBuilder::new().build();
613        let empty: Vec<MockDetection> = vec![];
614
615        let results = tracker.update(&empty, 1000);
616
617        assert!(results.is_empty());
618        assert!(tracker.tracklets.is_empty());
619        assert_eq!(tracker.frame_count, 1);
620    }
621}