Skip to main content

edgefirst_tracker/
bytetrack.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    kalman::ConstantVelocityXYAHModel2, ActiveTrackInfo, DetectionBox, TrackInfo, Tracker,
6};
7use lapjv::{lapjv, Matrix};
8use log::trace;
9use nalgebra::{Dyn, OMatrix, U4};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct ByteTrackBuilder {
14    track_extra_lifespan: u64,
15    track_high_conf: f32,
16    track_iou: f32,
17    track_update: f32,
18}
19
20impl Default for ByteTrackBuilder {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl ByteTrackBuilder {
27    /// Creates a new ByteTrackBuilder with default parameters.
28    /// These defaults are:
29    /// - track_high_conf: 0.7
30    /// - track_iou: 0.25
31    /// - track_update: 0.25
32    /// - track_extra_lifespan: 500_000_000 (0.5 seconds)
33    /// # Examples
34    /// ```rust
35    /// use edgefirst_tracker::{bytetrack::ByteTrackBuilder, Tracker, MockDetection};
36    /// let mut tracker = ByteTrackBuilder::new().build();
37    /// assert_eq!(tracker.track_high_conf, 0.7);
38    /// assert_eq!(tracker.track_iou, 0.25);
39    /// assert_eq!(tracker.track_update, 0.25);
40    /// assert_eq!(tracker.track_extra_lifespan, 500_000_000);
41    /// # let boxes = Vec::<MockDetection>::new();
42    /// # tracker.update(&boxes, 0);
43    /// ```
44    pub fn new() -> Self {
45        Self {
46            track_high_conf: 0.7,
47            track_iou: 0.25,
48            track_update: 0.25,
49            track_extra_lifespan: 500_000_000,
50        }
51    }
52
53    /// Sets the extra lifespan for tracks in nanoseconds.
54    pub fn track_extra_lifespan(mut self, lifespan: u64) -> Self {
55        self.track_extra_lifespan = lifespan;
56        self
57    }
58
59    /// Sets the high confidence threshold for tracking.
60    pub fn track_high_conf(mut self, conf: f32) -> Self {
61        self.track_high_conf = conf;
62        self
63    }
64
65    /// Sets the IOU threshold for tracking.
66    pub fn track_iou(mut self, iou: f32) -> Self {
67        self.track_iou = iou;
68        self
69    }
70
71    /// Sets the update rate for the Kalman filter.
72    pub fn track_update(mut self, update: f32) -> Self {
73        self.track_update = update;
74        self
75    }
76
77    /// Builds the ByteTrack tracker with the specified parameters.
78    /// # Examples
79    /// ```rust
80    /// use edgefirst_tracker::{bytetrack::ByteTrackBuilder, Tracker, MockDetection};
81    /// let mut tracker = ByteTrackBuilder::new()
82    ///     .track_high_conf(0.8)
83    ///     .track_iou(0.3)
84    ///     .track_update(0.2)
85    ///     .track_extra_lifespan(1_000_000_000)
86    ///     .build();
87    /// assert_eq!(tracker.track_high_conf, 0.8);
88    /// assert_eq!(tracker.track_iou, 0.3);
89    /// assert_eq!(tracker.track_update, 0.2);
90    /// assert_eq!(tracker.track_extra_lifespan, 1_000_000_000);
91    /// # let boxes = Vec::<MockDetection>::new();
92    /// # tracker.update(&boxes, 0);
93    /// ```
94    pub fn build<T: DetectionBox>(self) -> ByteTrack<T> {
95        ByteTrack {
96            track_extra_lifespan: self.track_extra_lifespan,
97            track_high_conf: self.track_high_conf,
98            track_iou: self.track_iou,
99            track_update: self.track_update,
100            tracklets: Vec::new(),
101            frame_count: 0,
102        }
103    }
104}
105
106#[allow(dead_code)]
107#[derive(Default, Debug, Clone)]
108pub struct ByteTrack<T: DetectionBox> {
109    pub track_extra_lifespan: u64,
110    pub track_high_conf: f32,
111    pub track_iou: f32,
112    pub track_update: f32,
113    pub tracklets: Vec<Tracklet<T>>,
114    pub frame_count: i32,
115}
116
117#[derive(Debug, Clone)]
118pub struct Tracklet<T: DetectionBox> {
119    pub id: Uuid,
120    pub filter: ConstantVelocityXYAHModel2<f32>,
121    pub count: i32,
122    pub created: u64,
123    pub last_updated: u64,
124    pub last_box: T,
125}
126
127impl<T: DetectionBox> Tracklet<T> {
128    fn update(&mut self, detect_box: &T, ts: u64) {
129        self.count += 1;
130        self.last_updated = ts;
131        self.filter.update(&xyxy_to_xyah(&detect_box.bbox()));
132        self.last_box = detect_box.clone();
133    }
134
135    pub fn get_predicted_location(&self) -> [f32; 4] {
136        let projected = self.filter.project().0;
137        let predicted_xyah = projected.as_slice();
138        xyah_to_xyxy(predicted_xyah)
139    }
140}
141
142fn xyxy_to_xyah(vaal_box: &[f32; 4]) -> [f32; 4] {
143    let x = (vaal_box[2] + vaal_box[0]) / 2.0;
144    let y = (vaal_box[3] + vaal_box[1]) / 2.0;
145    let w = (vaal_box[2] - vaal_box[0]).max(EPSILON);
146    let h = (vaal_box[3] - vaal_box[1]).max(EPSILON);
147    let a = w / h;
148
149    [x, y, a, h]
150}
151
152fn xyah_to_xyxy(xyah: &[f32]) -> [f32; 4] {
153    assert!(xyah.len() >= 4);
154    let [x, y, a, h] = xyah[0..4] else {
155        unreachable!()
156    };
157    let w = h * a;
158    [x - w / 2.0, y - h / 2.0, x + w / 2.0, y + h / 2.0]
159}
160
161const INVALID_MATCH: f32 = 1000000.0;
162const EPSILON: f32 = 0.00001;
163
164fn iou(box1: &[f32], box2: &[f32]) -> f32 {
165    let intersection = (box1[2].min(box2[2]) - box1[0].max(box2[0])).max(0.0)
166        * (box1[3].min(box2[3]) - box1[1].max(box2[1])).max(0.0);
167
168    let union = (box1[2] - box1[0]) * (box1[3] - box1[1])
169        + (box2[2] - box2[0]) * (box2[3] - box2[1])
170        - intersection;
171
172    if union <= EPSILON {
173        return 0.0;
174    }
175
176    intersection / union
177}
178
179fn box_cost<T: DetectionBox>(
180    track: &Tracklet<T>,
181    new_box: &T,
182    distance: f32,
183    score_threshold: f32,
184    iou_threshold: f32,
185) -> f32 {
186    let _ = distance;
187
188    if new_box.score() < score_threshold {
189        return INVALID_MATCH;
190    }
191
192    // use iou between predicted box and real box:
193    let predicted_xyah = track.filter.mean.as_slice();
194    let expected = xyah_to_xyxy(predicted_xyah);
195    let iou = iou(&expected, &new_box.bbox());
196    if iou < iou_threshold {
197        return INVALID_MATCH;
198    }
199    (1.5 - new_box.score()) + (1.5 - iou)
200}
201
202impl<T: DetectionBox> ByteTrack<T> {
203    fn compute_costs(
204        &mut self,
205        boxes: &[T],
206        score_threshold: f32,
207        iou_threshold: f32,
208        box_filter: &[bool],
209        track_filter: &[bool],
210    ) -> Matrix<f32> {
211        // costs matrix must be square
212        let dims = boxes.len().max(self.tracklets.len());
213        let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(boxes.len(), 0.0);
214        for (i, mut row) in measurements.row_iter_mut().enumerate() {
215            row.copy_from_slice(&xyxy_to_xyah(&boxes[i].bbox()));
216        }
217
218        // TODO: use matrix math for IOU, should speed up computation, and store it in
219        // distances
220
221        Matrix::from_shape_fn((dims, dims), |(x, y)| {
222            if x < boxes.len() && y < self.tracklets.len() {
223                if box_filter[x] || track_filter[y] {
224                    INVALID_MATCH
225                } else {
226                    box_cost(
227                        &self.tracklets[y],
228                        &boxes[x],
229                        // distances[(x, y)],
230                        0.0,
231                        score_threshold,
232                        iou_threshold,
233                    )
234                }
235            } else {
236                0.0
237            }
238        })
239    }
240
241    /// Process assignments from linear assignment and update tracking state.
242    /// Returns true if any matches were made.
243    #[allow(clippy::too_many_arguments)]
244    fn process_assignments(
245        &mut self,
246        assignments: &[usize],
247        boxes: &[T],
248        costs: &Matrix<f32>,
249        matched: &mut [bool],
250        tracked: &mut [bool],
251        matched_info: &mut [Option<TrackInfo>],
252        timestamp: u64,
253        log_assignments: bool,
254    ) {
255        for (i, &x) in assignments.iter().enumerate() {
256            if i >= boxes.len() || x >= self.tracklets.len() {
257                continue;
258            }
259
260            // Filter out invalid assignments
261            if costs[(i, x)] >= INVALID_MATCH {
262                continue;
263            }
264
265            // Skip already matched boxes/tracklets
266            if matched[i] || tracked[x] {
267                continue;
268            }
269
270            if log_assignments {
271                trace!(
272                    "Cost: {} Box: {:#?} UUID: {} Mean: {}",
273                    costs[(i, x)],
274                    boxes[i],
275                    self.tracklets[x].id,
276                    self.tracklets[x].filter.mean
277                );
278            }
279
280            matched[i] = true;
281            matched_info[i] = Some(TrackInfo {
282                uuid: self.tracklets[x].id,
283                count: self.tracklets[x].count,
284                created: self.tracklets[x].created,
285                tracked_location: self.tracklets[x].get_predicted_location(),
286                last_updated: timestamp,
287            });
288            tracked[x] = true;
289            self.tracklets[x].update(&boxes[i], timestamp);
290        }
291    }
292
293    /// Remove expired tracklets based on timestamp.
294    fn remove_expired_tracklets(&mut self, timestamp: u64) {
295        // must iterate from the back
296        for i in (0..self.tracklets.len()).rev() {
297            let expiry = self.tracklets[i].last_updated + self.track_extra_lifespan;
298            if expiry < timestamp {
299                trace!("Tracklet removed: {:?}", self.tracklets[i].id);
300                let _ = self.tracklets.swap_remove(i);
301            }
302        }
303    }
304
305    /// Create new tracklets from unmatched high-confidence boxes.
306    fn create_new_tracklets(
307        &mut self,
308        boxes: &[T],
309        high_conf_indices: &[usize],
310        matched: &[bool],
311        matched_info: &mut [Option<TrackInfo>],
312        timestamp: u64,
313    ) {
314        for &i in high_conf_indices {
315            if matched[i] {
316                continue;
317            }
318
319            let id = Uuid::new_v4();
320            let new_tracklet = Tracklet {
321                id,
322                filter: ConstantVelocityXYAHModel2::new(
323                    &xyxy_to_xyah(&boxes[i].bbox()),
324                    self.track_update,
325                ),
326                last_updated: timestamp,
327                count: 1,
328                created: timestamp,
329                last_box: boxes[i].clone(),
330            };
331            matched_info[i] = Some(TrackInfo {
332                uuid: new_tracklet.id,
333                count: new_tracklet.count,
334                created: new_tracklet.created,
335                tracked_location: new_tracklet.get_predicted_location(),
336                last_updated: timestamp,
337            });
338            self.tracklets.push(new_tracklet);
339        }
340    }
341}
342
343impl<T> Tracker<T> for ByteTrack<T>
344where
345    T: DetectionBox,
346{
347    fn update(&mut self, boxes: &[T], timestamp: u64) -> Vec<Option<TrackInfo>> {
348        let span = tracing::trace_span!(
349            "tracker_update",
350            n_detections = boxes.len(),
351            n_tracklets = self.tracklets.len(),
352            timestamp,
353        );
354        let _enter = span.enter();
355
356        self.frame_count += 1;
357
358        // Identify high-confidence detections
359        let high_conf_ind: Vec<usize> = boxes
360            .iter()
361            .enumerate()
362            .filter(|(_, b)| b.score() >= self.track_high_conf)
363            .map(|(x, _)| x)
364            .collect();
365
366        let mut matched = vec![false; boxes.len()];
367        let mut tracked = vec![false; self.tracklets.len()];
368        let mut matched_info = vec![None; boxes.len()];
369
370        // First pass: match high-confidence detections
371        if !self.tracklets.is_empty() {
372            let _s = tracing::trace_span!("predict").entered();
373            for track in &mut self.tracklets {
374                track.filter.predict();
375            }
376        }
377
378        if !self.tracklets.is_empty() {
379            let _s = tracing::trace_span!("match_high_conf").entered();
380            let costs = self.compute_costs(
381                boxes,
382                self.track_high_conf,
383                self.track_iou,
384                &matched,
385                &tracked,
386            );
387            if let Ok(ans) = lapjv(&costs) {
388                self.process_assignments(
389                    &ans.0,
390                    boxes,
391                    &costs,
392                    &mut matched,
393                    &mut tracked,
394                    &mut matched_info,
395                    timestamp,
396                    false,
397                );
398            }
399        }
400
401        // Second pass: match remaining tracklets to low-confidence detections
402        if !self.tracklets.is_empty() {
403            let _s = tracing::trace_span!("match_low_conf").entered();
404            let costs = self.compute_costs(boxes, 0.0, self.track_iou, &matched, &tracked);
405            if let Ok(ans) = lapjv(&costs) {
406                self.process_assignments(
407                    &ans.0,
408                    boxes,
409                    &costs,
410                    &mut matched,
411                    &mut tracked,
412                    &mut matched_info,
413                    timestamp,
414                    true,
415                );
416            }
417        }
418
419        // Remove expired tracklets
420        self.remove_expired_tracklets(timestamp);
421
422        // Create new tracklets from unmatched high-confidence boxes
423        self.create_new_tracklets(
424            boxes,
425            &high_conf_ind,
426            &matched,
427            &mut matched_info,
428            timestamp,
429        );
430
431        matched_info
432    }
433
434    fn get_active_tracks(&self) -> Vec<ActiveTrackInfo<T>> {
435        self.tracklets
436            .iter()
437            .map(|t| ActiveTrackInfo {
438                info: TrackInfo {
439                    uuid: t.id,
440                    tracked_location: t.get_predicted_location(),
441                    count: t.count,
442                    created: t.created,
443                    last_updated: t.last_updated,
444                },
445                last_box: t.last_box.clone(),
446            })
447            .collect()
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use crate::*;
455
456    #[test]
457    fn test_vaalbox_xyah_roundtrip() {
458        let box1 = [0.0134, 0.02135, 0.12438, 0.691];
459        let xyah = xyxy_to_xyah(&box1);
460        let box2 = xyah_to_xyxy(&xyah);
461
462        assert!((box1[0] - box2[0]).abs() < f32::EPSILON);
463        assert!((box1[1] - box2[1]).abs() < f32::EPSILON);
464        assert!((box1[2] - box2[2]).abs() < f32::EPSILON);
465        assert!((box1[3] - box2[3]).abs() < f32::EPSILON);
466    }
467
468    #[test]
469    fn test_iou_identical_boxes() {
470        let box1 = [0.1, 0.1, 0.5, 0.5];
471        let box2 = [0.1, 0.1, 0.5, 0.5];
472        let result = iou(&box1, &box2);
473        assert!(
474            (result - 1.0).abs() < 0.001,
475            "IOU of identical boxes should be 1.0"
476        );
477    }
478
479    #[test]
480    fn test_iou_no_overlap() {
481        let box1 = [0.0, 0.0, 0.2, 0.2];
482        let box2 = [0.5, 0.5, 0.7, 0.7];
483        let result = iou(&box1, &box2);
484        assert!(result < 0.001, "IOU of non-overlapping boxes should be ~0");
485    }
486
487    #[test]
488    fn test_iou_partial_overlap() {
489        let box1 = [0.0, 0.0, 0.5, 0.5];
490        let box2 = [0.25, 0.25, 0.75, 0.75];
491        let result = iou(&box1, &box2);
492        // Intersection: 0.25*0.25 = 0.0625, Union: 0.25+0.25-0.0625 = 0.4375
493        assert!(result > 0.1 && result < 0.2, "IOU should be ~0.14");
494    }
495
496    #[test]
497    fn test_bytetrack_new() {
498        let tracker: ByteTrack<MockDetection> = ByteTrackBuilder::new().build();
499        assert_eq!(tracker.frame_count, 0);
500        assert!(tracker.tracklets.is_empty());
501        assert_eq!(tracker.track_high_conf, 0.7);
502        assert_eq!(tracker.track_iou, 0.25);
503    }
504
505    #[test]
506    fn test_bytetrack_single_detection_creates_tracklet() {
507        let mut tracker = ByteTrackBuilder::new().build();
508        let detections = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0)];
509
510        let results = tracker.update(&detections, 1000);
511
512        assert_eq!(results.len(), 1);
513        assert!(
514            results[0].is_some(),
515            "High-confidence detection should create tracklet"
516        );
517        assert_eq!(tracker.tracklets.len(), 1);
518        assert_eq!(tracker.frame_count, 1);
519    }
520
521    #[test]
522    fn test_bytetrack_low_confidence_no_tracklet() {
523        let mut tracker = ByteTrackBuilder::new().build();
524        // Score below track_high_conf (0.7)
525        let detections = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.5, 0)];
526
527        let results = tracker.update(&detections, 1000);
528
529        assert_eq!(results.len(), 1);
530        assert!(
531            results[0].is_none(),
532            "Low-confidence detection should not create tracklet"
533        );
534        assert!(tracker.tracklets.is_empty());
535    }
536
537    #[test]
538    fn test_bytetrack_tracking_across_frames() {
539        let mut tracker = ByteTrackBuilder::new().build();
540
541        // Frame 1: Create tracklet with a larger box that's easier to track
542        let det1 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
543        let res1 = tracker.update(&det1, 1000);
544        assert!(res1[0].is_some());
545        let uuid1 = res1[0].unwrap().uuid;
546        assert_eq!(tracker.tracklets.len(), 1);
547        // After creation, tracklet count is 1
548        assert_eq!(tracker.tracklets[0].count, 1);
549
550        // Frame 2: Same location - should match existing tracklet
551        let det2 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
552        let res2 = tracker.update(&det2, 2000);
553        assert!(res2[0].is_some());
554        let info2 = res2[0].unwrap();
555
556        // Verify tracklet was matched, not a new one created
557        assert_eq!(tracker.tracklets.len(), 1, "Should still have one tracklet");
558        assert_eq!(info2.uuid, uuid1, "Should match same tracklet");
559        // After second update, the internal tracklet count should be 2
560        assert_eq!(tracker.tracklets[0].count, 2, "Internal count should be 2");
561    }
562
563    #[test]
564    fn test_bytetrack_multiple_detections() {
565        let mut tracker = ByteTrackBuilder::new().build();
566
567        let detections = vec![
568            MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0),
569            MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0),
570            MockDetection::new([0.8, 0.8, 0.9, 0.9], 0.95, 0),
571        ];
572
573        let results = tracker.update(&detections, 1000);
574
575        assert_eq!(results.len(), 3);
576        assert!(results.iter().all(|r| r.is_some()));
577        assert_eq!(tracker.tracklets.len(), 3);
578    }
579
580    #[test]
581    fn test_bytetrack_tracklet_expiry() {
582        let mut tracker = ByteTrackBuilder::new().build();
583        tracker.track_extra_lifespan = 1000; // 1 second
584
585        // Create tracklet
586        let det1 = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0)];
587        tracker.update(&det1, 1000);
588        assert_eq!(tracker.tracklets.len(), 1);
589
590        // Update with no detections after lifespan expires
591        let empty: Vec<MockDetection> = vec![];
592        tracker.update(&empty, 3000); // 2 seconds later
593
594        assert!(tracker.tracklets.is_empty(), "Tracklet should have expired");
595    }
596
597    #[test]
598    fn test_bytetrack_get_active_tracks() {
599        let mut tracker = ByteTrackBuilder::new().build();
600
601        let detections = vec![
602            MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0),
603            MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0),
604        ];
605        tracker.update(&detections, 1000);
606
607        let active = tracker.get_active_tracks();
608        assert_eq!(active.len(), 2);
609        assert!(active.iter().all(|t| t.info.count == 1));
610        assert!(active.iter().all(|t| t.info.created == 1000));
611    }
612
613    #[test]
614    fn test_bytetrack_empty_detections() {
615        let mut tracker = ByteTrackBuilder::new().build();
616        let empty: Vec<MockDetection> = vec![];
617
618        let results = tracker.update(&empty, 1000);
619
620        assert!(results.is_empty());
621        assert!(tracker.tracklets.is_empty());
622        assert_eq!(tracker.frame_count, 1);
623    }
624
625    #[test]
626    fn test_two_stage_matching() {
627        // The core ByteTrack innovation: low-confidence detections are matched
628        // to existing tracklets in a second stage.
629        let mut tracker = ByteTrackBuilder::new().build();
630
631        // Frame 1: high-confidence detection creates a tracklet
632        let det1 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
633        let res1 = tracker.update(&det1, 1_000_000);
634        assert!(res1[0].is_some());
635        let uuid1 = res1[0].unwrap().uuid;
636        assert_eq!(tracker.tracklets.len(), 1);
637
638        // Frame 2: same location but low confidence (0.3, below track_high_conf=0.7).
639        // Second-stage matching should still associate it with the existing tracklet.
640        let det2 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.3, 0)];
641        let res2 = tracker.update(&det2, 2_000_000);
642        assert!(
643            res2[0].is_some(),
644            "Low-conf detection should match existing tracklet via second stage"
645        );
646        assert_eq!(
647            res2[0].unwrap().uuid,
648            uuid1,
649            "Should match the same tracklet"
650        );
651        assert_eq!(
652            tracker.tracklets.len(),
653            1,
654            "No new tracklet should be created"
655        );
656        assert_eq!(
657            tracker.tracklets[0].count, 2,
658            "Tracklet count should increment"
659        );
660    }
661
662    #[test]
663    fn test_builder_track_extra_lifespan() {
664        let lifespan_default = 500_000_000; // 0.5 seconds (default)
665        let lifespan_extended = 2_000_000_000; // 2 seconds
666
667        let mut tracker_default: ByteTrack<MockDetection> = ByteTrackBuilder::new().build();
668        let mut tracker_extended: ByteTrack<MockDetection> = ByteTrackBuilder::new()
669            .track_extra_lifespan(lifespan_extended)
670            .build();
671
672        assert_eq!(tracker_default.track_extra_lifespan, lifespan_default);
673        assert_eq!(tracker_extended.track_extra_lifespan, lifespan_extended);
674
675        let ts_start = 1_000_000_000u64; // 1 second
676        let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
677
678        tracker_default.update(&det, ts_start);
679        tracker_extended.update(&det, ts_start);
680        assert_eq!(tracker_default.tracklets.len(), 1);
681        assert_eq!(tracker_extended.tracklets.len(), 1);
682
683        // Advance to 1s + 1s = 2s. Default lifespan (0.5s) should have expired,
684        // extended lifespan (2s) should still be active.
685        let ts_after = ts_start + 1_000_000_000;
686        let empty: Vec<MockDetection> = vec![];
687        tracker_default.update(&empty, ts_after);
688        tracker_extended.update(&empty, ts_after);
689
690        assert!(
691            tracker_default.tracklets.is_empty(),
692            "Default tracker should have expired the tracklet"
693        );
694        assert_eq!(
695            tracker_extended.tracklets.len(),
696            1,
697            "Extended tracker should still have the tracklet"
698        );
699    }
700
701    #[test]
702    fn test_builder_track_high_conf() {
703        let mut tracker: ByteTrack<MockDetection> =
704            ByteTrackBuilder::new().track_high_conf(0.9).build();
705        assert_eq!(tracker.track_high_conf, 0.9);
706
707        // Detection with score 0.8 is below the 0.9 threshold
708        let det_low = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.8, 0)];
709        let res = tracker.update(&det_low, 1000);
710        assert!(
711            res[0].is_none(),
712            "Score 0.8 should not create a tracklet with threshold 0.9"
713        );
714        assert!(tracker.tracklets.is_empty());
715
716        // Detection with score 0.95 is above the 0.9 threshold
717        let det_high = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.95, 0)];
718        let res = tracker.update(&det_high, 2000);
719        assert!(
720            res[0].is_some(),
721            "Score 0.95 should create a tracklet with threshold 0.9"
722        );
723        assert_eq!(tracker.tracklets.len(), 1);
724    }
725
726    #[test]
727    fn test_builder_track_iou() {
728        // Tight IOU threshold: shifted detection should NOT match
729        let mut tracker: ByteTrack<MockDetection> = ByteTrackBuilder::new().track_iou(0.8).build();
730
731        // Frame 1: two well-separated detections
732        let det1 = vec![
733            MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0),
734            MockDetection::new([0.5, 0.5, 0.7, 0.7], 0.9, 0),
735        ];
736        tracker.update(&det1, 1000);
737        assert_eq!(tracker.tracklets.len(), 2);
738
739        // Frame 2: shift the first detection slightly. With IOU threshold 0.8
740        // the overlap won't be enough for a match, so it creates a new tracklet.
741        let det2 = vec![
742            MockDetection::new([0.15, 0.15, 0.35, 0.35], 0.9, 0),
743            MockDetection::new([0.5, 0.5, 0.7, 0.7], 0.9, 0),
744        ];
745        let res2 = tracker.update(&det2, 2000);
746        assert_eq!(res2.len(), 2);
747
748        // The second detection (unchanged) should still match. The first (shifted)
749        // should fail the tight IOU threshold and create a new tracklet.
750        assert!(
751            tracker.tracklets.len() >= 3,
752            "Shifted detection should create a new tracklet with tight IOU threshold, got {} tracklets",
753            tracker.tracklets.len()
754        );
755    }
756
757    #[test]
758    fn test_degenerate_zero_area_box() {
759        // A zero-area box (xmin == xmax) should not panic
760        let mut tracker = ByteTrackBuilder::new().build();
761        let det = vec![
762            MockDetection::new([0.5, 0.1, 0.5, 0.3], 0.9, 0), // zero width
763            MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0), // normal box
764        ];
765        let results = tracker.update(&det, 1000);
766        assert_eq!(results.len(), 2);
767
768        // IOU between a zero-area box and a normal box should be 0
769        let zero_box = [0.5, 0.1, 0.5, 0.3];
770        let normal_box = [0.1, 0.1, 0.3, 0.3];
771        let iou_val = iou(&zero_box, &normal_box);
772        assert!(
773            iou_val < EPSILON,
774            "IOU with a zero-area box should be ~0, got {iou_val}"
775        );
776    }
777
778    #[test]
779    fn test_degenerate_high_velocity() {
780        let mut tracker = ByteTrackBuilder::new().build();
781
782        // Frame 1: detection at top-left
783        let det1 = vec![MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0)];
784        let res1 = tracker.update(&det1, 1_000_000);
785        assert!(res1[0].is_some());
786        let uuid1 = res1[0].unwrap().uuid;
787        assert_eq!(tracker.tracklets.len(), 1);
788
789        // Frame 2: detection at bottom-right (huge displacement)
790        let det2 = vec![MockDetection::new([0.8, 0.8, 0.9, 0.9], 0.9, 0)];
791        let res2 = tracker.update(&det2, 2_000_000);
792        assert!(res2[0].is_some());
793
794        // With default IOU threshold the far-away detection should not match;
795        // a new tracklet is created instead.
796        assert_eq!(
797            tracker.tracklets.len(),
798            2,
799            "Far-displaced detection should create a new tracklet"
800        );
801        assert_ne!(
802            res2[0].unwrap().uuid,
803            uuid1,
804            "New detection should have a different UUID"
805        );
806    }
807
808    #[test]
809    fn test_many_detections_100() {
810        let mut tracker = ByteTrackBuilder::new().build();
811
812        // Generate 100 non-overlapping small boxes spread across [0, 1]
813        let detections: Vec<MockDetection> = (0..100)
814            .map(|i| {
815                let x = (i % 10) as f32 * 0.1;
816                let y = (i / 10) as f32 * 0.1;
817                MockDetection::new([x, y, x + 0.05, y + 0.05], 0.9, 0)
818            })
819            .collect();
820
821        let results = tracker.update(&detections, 1000);
822        assert_eq!(results.len(), 100);
823        assert!(
824            results.iter().all(|r| r.is_some()),
825            "All 100 high-confidence detections should create tracklets"
826        );
827        assert_eq!(
828            tracker.tracklets.len(),
829            100,
830            "Should have 100 active tracklets"
831        );
832    }
833
834    #[test]
835    fn test_tracklet_count_increments_each_frame() {
836        let mut tracker = ByteTrackBuilder::new().build();
837        let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
838
839        for frame in 1..=5 {
840            tracker.update(&det, frame * 1000);
841        }
842
843        assert_eq!(tracker.tracklets.len(), 1);
844        assert_eq!(
845            tracker.tracklets[0].count, 5,
846            "Tracklet count should equal number of frames it was matched"
847        );
848    }
849
850    #[test]
851    fn test_tracklet_created_timestamp_preserved() {
852        let mut tracker = ByteTrackBuilder::new().build();
853        let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
854
855        tracker.update(&det, 1000);
856        tracker.update(&det, 2000);
857        tracker.update(&det, 3000);
858
859        let active = tracker.get_active_tracks();
860        assert_eq!(active.len(), 1);
861        assert_eq!(
862            active[0].info.created, 1000,
863            "Created timestamp should remain at the first frame"
864        );
865        assert_eq!(
866            active[0].info.last_updated, 3000,
867            "Last updated should be the most recent frame"
868        );
869    }
870
871    #[test]
872    fn test_mixed_confidence_detections() {
873        // Mix of high and low confidence detections in a single frame
874        let mut tracker = ByteTrackBuilder::new().build();
875        let det = vec![
876            MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0), // high
877            MockDetection::new([0.3, 0.3, 0.4, 0.4], 0.3, 0), // low
878            MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0), // high
879            MockDetection::new([0.7, 0.7, 0.8, 0.8], 0.1, 0), // low
880        ];
881
882        let results = tracker.update(&det, 1000);
883        assert_eq!(results.len(), 4);
884
885        // Only the high-confidence ones should create tracklets
886        assert!(
887            results[0].is_some(),
888            "High-conf detection should create tracklet"
889        );
890        assert!(
891            results[1].is_none(),
892            "Low-conf detection should not create tracklet"
893        );
894        assert!(
895            results[2].is_some(),
896            "High-conf detection should create tracklet"
897        );
898        assert!(
899            results[3].is_none(),
900            "Low-conf detection should not create tracklet"
901        );
902        assert_eq!(tracker.tracklets.len(), 2);
903    }
904
905    #[test]
906    fn test_iou_contained_box() {
907        // One box fully contains the other
908        let outer = [0.0, 0.0, 1.0, 1.0];
909        let inner = [0.25, 0.25, 0.75, 0.75];
910        let result = iou(&outer, &inner);
911        // inner area = 0.25, outer area = 1.0, intersection = 0.25, union = 1.0
912        assert!(
913            (result - 0.25).abs() < 0.01,
914            "IOU of contained box should be inner_area/outer_area = 0.25, got {result}"
915        );
916    }
917
918    #[test]
919    fn test_xyxy_to_xyah_square_box() {
920        // A square box should have aspect ratio 1.0
921        let square = [0.1, 0.2, 0.3, 0.4];
922        let xyah = xyxy_to_xyah(&square);
923        assert!((xyah[0] - 0.2).abs() < 1e-5, "Center x should be 0.2");
924        assert!((xyah[1] - 0.3).abs() < 1e-5, "Center y should be 0.3");
925        assert!(
926            (xyah[2] - 1.0).abs() < 1e-5,
927            "Aspect ratio of square should be 1.0"
928        );
929        assert!((xyah[3] - 0.2).abs() < 1e-5, "Height should be 0.2");
930    }
931
932    #[test]
933    fn test_frame_count_increments() {
934        let mut tracker = ByteTrackBuilder::new().build();
935        let empty: Vec<MockDetection> = vec![];
936
937        for _ in 0..10 {
938            tracker.update(&empty, 0);
939        }
940
941        assert_eq!(
942            tracker.frame_count, 10,
943            "Frame count should increment each update"
944        );
945    }
946
947    #[test]
948    fn test_tracklet_predicted_location_near_detection() {
949        let mut tracker = ByteTrackBuilder::new().build();
950        let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
951        tracker.update(&det, 1000);
952
953        let pred = tracker.tracklets[0].get_predicted_location();
954        // The predicted location should be close to the original detection
955        assert!(
956            (pred[0] - 0.2).abs() < 0.1,
957            "Predicted xmin should be near 0.2, got {}",
958            pred[0]
959        );
960        assert!(
961            (pred[1] - 0.2).abs() < 0.1,
962            "Predicted ymin should be near 0.2, got {}",
963            pred[1]
964        );
965        assert!(
966            (pred[2] - 0.4).abs() < 0.1,
967            "Predicted xmax should be near 0.4, got {}",
968            pred[2]
969        );
970        assert!(
971            (pred[3] - 0.4).abs() < 0.1,
972            "Predicted ymax should be near 0.4, got {}",
973            pred[3]
974        );
975    }
976}