Skip to main content

edgefirst_tracker/
bytetrack.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    kalman::ConstantVelocityXYAHModel2, ActiveTrackInfo, DetectionBox, TrackInfo, Tracker,
6};
7use lapjv::{lapjv, Matrix};
8use log::trace;
9use nalgebra::{Dyn, OMatrix, U4};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct ByteTrackBuilder {
14    track_extra_lifespan: u64,
15    track_high_conf: f32,
16    track_iou: f32,
17    track_update: f32,
18}
19
20impl Default for ByteTrackBuilder {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl ByteTrackBuilder {
27    /// Creates a new ByteTrackBuilder with default parameters.
28    /// These defaults are:
29    /// - track_high_conf: 0.7
30    /// - track_iou: 0.25
31    /// - track_update: 0.25
32    /// - track_extra_lifespan: 500_000_000 (0.5 seconds)
33    /// # Examples
34    /// ```rust
35    /// use edgefirst_tracker::{bytetrack::ByteTrackBuilder, Tracker, MockDetection};
36    /// let mut tracker = ByteTrackBuilder::new().build();
37    /// assert_eq!(tracker.track_high_conf, 0.7);
38    /// assert_eq!(tracker.track_iou, 0.25);
39    /// assert_eq!(tracker.track_update, 0.25);
40    /// assert_eq!(tracker.track_extra_lifespan, 500_000_000);
41    /// # let boxes = Vec::<MockDetection>::new();
42    /// # tracker.update(&boxes, 0);
43    /// ```
44    pub fn new() -> Self {
45        Self {
46            track_high_conf: 0.7,
47            track_iou: 0.25,
48            track_update: 0.25,
49            track_extra_lifespan: 500_000_000,
50        }
51    }
52
53    /// Sets the extra lifespan for tracks in nanoseconds.
54    pub fn track_extra_lifespan(mut self, lifespan: u64) -> Self {
55        self.track_extra_lifespan = lifespan;
56        self
57    }
58
59    /// Sets the high confidence threshold for tracking.
60    pub fn track_high_conf(mut self, conf: f32) -> Self {
61        self.track_high_conf = conf;
62        self
63    }
64
65    /// Sets the IOU threshold for tracking.
66    pub fn track_iou(mut self, iou: f32) -> Self {
67        self.track_iou = iou;
68        self
69    }
70
71    /// Sets the update rate for the Kalman filter.
72    pub fn track_update(mut self, update: f32) -> Self {
73        self.track_update = update;
74        self
75    }
76
77    /// Builds the ByteTrack tracker with the specified parameters.
78    /// # Examples
79    /// ```rust
80    /// use edgefirst_tracker::{bytetrack::ByteTrackBuilder, Tracker, MockDetection};
81    /// let mut tracker = ByteTrackBuilder::new()
82    ///     .track_high_conf(0.8)
83    ///     .track_iou(0.3)
84    ///     .track_update(0.2)
85    ///     .track_extra_lifespan(1_000_000_000)
86    ///     .build();
87    /// assert_eq!(tracker.track_high_conf, 0.8);
88    /// assert_eq!(tracker.track_iou, 0.3);
89    /// assert_eq!(tracker.track_update, 0.2);
90    /// assert_eq!(tracker.track_extra_lifespan, 1_000_000_000);
91    /// # let boxes = Vec::<MockDetection>::new();
92    /// # tracker.update(&boxes, 0);
93    /// ```
94    pub fn build<T: DetectionBox>(self) -> ByteTrack<T> {
95        ByteTrack {
96            track_extra_lifespan: self.track_extra_lifespan,
97            track_high_conf: self.track_high_conf,
98            track_iou: self.track_iou,
99            track_update: self.track_update,
100            tracklets: Vec::new(),
101            frame_count: 0,
102        }
103    }
104}
105
106#[allow(dead_code)]
107#[derive(Default, Debug, Clone)]
108pub struct ByteTrack<T: DetectionBox> {
109    pub track_extra_lifespan: u64,
110    pub track_high_conf: f32,
111    pub track_iou: f32,
112    pub track_update: f32,
113    pub tracklets: Vec<Tracklet<T>>,
114    pub frame_count: i32,
115}
116
117#[derive(Debug, Clone)]
118pub struct Tracklet<T: DetectionBox> {
119    pub id: Uuid,
120    pub filter: ConstantVelocityXYAHModel2<f32>,
121    pub count: i32,
122    pub created: u64,
123    pub last_updated: u64,
124    pub last_box: T,
125}
126
127impl<T: DetectionBox> Tracklet<T> {
128    fn update(&mut self, detect_box: &T, ts: u64) {
129        self.count += 1;
130        self.last_updated = ts;
131        self.filter.update(&xyxy_to_xyah(&detect_box.bbox()));
132        self.last_box = detect_box.clone();
133    }
134
135    pub fn get_predicted_location(&self) -> [f32; 4] {
136        let projected = self.filter.project().0;
137        let predicted_xyah = projected.as_slice();
138        xyah_to_xyxy(predicted_xyah)
139    }
140}
141
142fn xyxy_to_xyah(vaal_box: &[f32; 4]) -> [f32; 4] {
143    let x = (vaal_box[2] + vaal_box[0]) / 2.0;
144    let y = (vaal_box[3] + vaal_box[1]) / 2.0;
145    let w = (vaal_box[2] - vaal_box[0]).max(EPSILON);
146    let h = (vaal_box[3] - vaal_box[1]).max(EPSILON);
147    let a = w / h;
148
149    [x, y, a, h]
150}
151
152fn xyah_to_xyxy(xyah: &[f32]) -> [f32; 4] {
153    assert!(xyah.len() >= 4);
154    let [x, y, a, h] = xyah[0..4] else {
155        unreachable!()
156    };
157    let w = h * a;
158    [x - w / 2.0, y - h / 2.0, x + w / 2.0, y + h / 2.0]
159}
160
161const INVALID_MATCH: f32 = 1000000.0;
162const EPSILON: f32 = 0.00001;
163
164fn iou(box1: &[f32], box2: &[f32]) -> f32 {
165    let intersection = (box1[2].min(box2[2]) - box1[0].max(box2[0])).max(0.0)
166        * (box1[3].min(box2[3]) - box1[1].max(box2[1])).max(0.0);
167
168    let union = (box1[2] - box1[0]) * (box1[3] - box1[1])
169        + (box2[2] - box2[0]) * (box2[3] - box2[1])
170        - intersection;
171
172    if union <= EPSILON {
173        return 0.0;
174    }
175
176    intersection / union
177}
178
179fn box_cost<T: DetectionBox>(
180    track: &Tracklet<T>,
181    new_box: &T,
182    distance: f32,
183    score_threshold: f32,
184    iou_threshold: f32,
185) -> f32 {
186    let _ = distance;
187
188    if new_box.score() < score_threshold {
189        return INVALID_MATCH;
190    }
191
192    // use iou between predicted box and real box:
193    let predicted_xyah = track.filter.mean.as_slice();
194    let expected = xyah_to_xyxy(predicted_xyah);
195    let iou = iou(&expected, &new_box.bbox());
196    if iou < iou_threshold {
197        return INVALID_MATCH;
198    }
199    (1.5 - new_box.score()) + (1.5 - iou)
200}
201
202impl<T: DetectionBox> ByteTrack<T> {
203    fn compute_costs(
204        &mut self,
205        boxes: &[T],
206        score_threshold: f32,
207        iou_threshold: f32,
208        box_filter: &[bool],
209        track_filter: &[bool],
210    ) -> Matrix<f32> {
211        // costs matrix must be square
212        let dims = boxes.len().max(self.tracklets.len());
213        let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(boxes.len(), 0.0);
214        for (i, mut row) in measurements.row_iter_mut().enumerate() {
215            row.copy_from_slice(&xyxy_to_xyah(&boxes[i].bbox()));
216        }
217
218        // TODO: use matrix math for IOU, should speed up computation, and store it in
219        // distances
220
221        Matrix::from_shape_fn((dims, dims), |(x, y)| {
222            if x < boxes.len() && y < self.tracklets.len() {
223                if box_filter[x] || track_filter[y] {
224                    INVALID_MATCH
225                } else {
226                    box_cost(
227                        &self.tracklets[y],
228                        &boxes[x],
229                        // distances[(x, y)],
230                        0.0,
231                        score_threshold,
232                        iou_threshold,
233                    )
234                }
235            } else {
236                0.0
237            }
238        })
239    }
240
241    /// Process assignments from linear assignment and update tracking state.
242    /// Returns true if any matches were made.
243    #[allow(clippy::too_many_arguments)]
244    fn process_assignments(
245        &mut self,
246        assignments: &[usize],
247        boxes: &[T],
248        costs: &Matrix<f32>,
249        matched: &mut [bool],
250        tracked: &mut [bool],
251        matched_info: &mut [Option<TrackInfo>],
252        timestamp: u64,
253        log_assignments: bool,
254    ) {
255        for (i, &x) in assignments.iter().enumerate() {
256            if i >= boxes.len() || x >= self.tracklets.len() {
257                continue;
258            }
259
260            // Filter out invalid assignments
261            if costs[(i, x)] >= INVALID_MATCH {
262                continue;
263            }
264
265            // Skip already matched boxes/tracklets
266            if matched[i] || tracked[x] {
267                continue;
268            }
269
270            if log_assignments {
271                trace!(
272                    "Cost: {} Box: {:#?} UUID: {} Mean: {}",
273                    costs[(i, x)],
274                    boxes[i],
275                    self.tracklets[x].id,
276                    self.tracklets[x].filter.mean
277                );
278            }
279
280            matched[i] = true;
281            matched_info[i] = Some(TrackInfo {
282                uuid: self.tracklets[x].id,
283                count: self.tracklets[x].count,
284                created: self.tracklets[x].created,
285                tracked_location: self.tracklets[x].get_predicted_location(),
286                last_updated: timestamp,
287            });
288            tracked[x] = true;
289            self.tracklets[x].update(&boxes[i], timestamp);
290        }
291    }
292
293    /// Remove expired tracklets based on timestamp.
294    fn remove_expired_tracklets(&mut self, timestamp: u64) {
295        // must iterate from the back
296        for i in (0..self.tracklets.len()).rev() {
297            let expiry = self.tracklets[i].last_updated + self.track_extra_lifespan;
298            if expiry < timestamp {
299                trace!("Tracklet removed: {:?}", self.tracklets[i].id);
300                let _ = self.tracklets.swap_remove(i);
301            }
302        }
303    }
304
305    /// Create new tracklets from unmatched high-confidence boxes.
306    fn create_new_tracklets(
307        &mut self,
308        boxes: &[T],
309        high_conf_indices: &[usize],
310        matched: &[bool],
311        matched_info: &mut [Option<TrackInfo>],
312        timestamp: u64,
313    ) {
314        for &i in high_conf_indices {
315            if matched[i] {
316                continue;
317            }
318
319            let id = Uuid::new_v4();
320            let new_tracklet = Tracklet {
321                id,
322                filter: ConstantVelocityXYAHModel2::new(
323                    &xyxy_to_xyah(&boxes[i].bbox()),
324                    self.track_update,
325                ),
326                last_updated: timestamp,
327                count: 1,
328                created: timestamp,
329                last_box: boxes[i].clone(),
330            };
331            matched_info[i] = Some(TrackInfo {
332                uuid: new_tracklet.id,
333                count: new_tracklet.count,
334                created: new_tracklet.created,
335                tracked_location: new_tracklet.get_predicted_location(),
336                last_updated: timestamp,
337            });
338            self.tracklets.push(new_tracklet);
339        }
340    }
341}
342
343impl<T> Tracker<T> for ByteTrack<T>
344where
345    T: DetectionBox,
346{
347    fn update(&mut self, boxes: &[T], timestamp: u64) -> Vec<Option<TrackInfo>> {
348        self.frame_count += 1;
349
350        // Identify high-confidence detections
351        let high_conf_ind: Vec<usize> = boxes
352            .iter()
353            .enumerate()
354            .filter(|(_, b)| b.score() >= self.track_high_conf)
355            .map(|(x, _)| x)
356            .collect();
357
358        let mut matched = vec![false; boxes.len()];
359        let mut tracked = vec![false; self.tracklets.len()];
360        let mut matched_info = vec![None; boxes.len()];
361
362        // First pass: match high-confidence detections
363        if !self.tracklets.is_empty() {
364            for track in &mut self.tracklets {
365                track.filter.predict();
366            }
367
368            let costs = self.compute_costs(
369                boxes,
370                self.track_high_conf,
371                self.track_iou,
372                &matched,
373                &tracked,
374            );
375            if let Ok(ans) = lapjv(&costs) {
376                self.process_assignments(
377                    &ans.0,
378                    boxes,
379                    &costs,
380                    &mut matched,
381                    &mut tracked,
382                    &mut matched_info,
383                    timestamp,
384                    false,
385                );
386            }
387        }
388
389        // Second pass: match remaining tracklets to low-confidence detections
390        if !self.tracklets.is_empty() {
391            let costs = self.compute_costs(boxes, 0.0, self.track_iou, &matched, &tracked);
392            if let Ok(ans) = lapjv(&costs) {
393                self.process_assignments(
394                    &ans.0,
395                    boxes,
396                    &costs,
397                    &mut matched,
398                    &mut tracked,
399                    &mut matched_info,
400                    timestamp,
401                    true,
402                );
403            }
404        }
405
406        // Remove expired tracklets
407        self.remove_expired_tracklets(timestamp);
408
409        // Create new tracklets from unmatched high-confidence boxes
410        self.create_new_tracklets(
411            boxes,
412            &high_conf_ind,
413            &matched,
414            &mut matched_info,
415            timestamp,
416        );
417
418        matched_info
419    }
420
421    fn get_active_tracks(&self) -> Vec<ActiveTrackInfo<T>> {
422        self.tracklets
423            .iter()
424            .map(|t| ActiveTrackInfo {
425                info: TrackInfo {
426                    uuid: t.id,
427                    tracked_location: t.get_predicted_location(),
428                    count: t.count,
429                    created: t.created,
430                    last_updated: t.last_updated,
431                },
432                last_box: t.last_box.clone(),
433            })
434            .collect()
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::*;
442
443    #[test]
444    fn test_vaalbox_xyah_roundtrip() {
445        let box1 = [0.0134, 0.02135, 0.12438, 0.691];
446        let xyah = xyxy_to_xyah(&box1);
447        let box2 = xyah_to_xyxy(&xyah);
448
449        assert!((box1[0] - box2[0]).abs() < f32::EPSILON);
450        assert!((box1[1] - box2[1]).abs() < f32::EPSILON);
451        assert!((box1[2] - box2[2]).abs() < f32::EPSILON);
452        assert!((box1[3] - box2[3]).abs() < f32::EPSILON);
453    }
454
455    #[test]
456    fn test_iou_identical_boxes() {
457        let box1 = [0.1, 0.1, 0.5, 0.5];
458        let box2 = [0.1, 0.1, 0.5, 0.5];
459        let result = iou(&box1, &box2);
460        assert!(
461            (result - 1.0).abs() < 0.001,
462            "IOU of identical boxes should be 1.0"
463        );
464    }
465
466    #[test]
467    fn test_iou_no_overlap() {
468        let box1 = [0.0, 0.0, 0.2, 0.2];
469        let box2 = [0.5, 0.5, 0.7, 0.7];
470        let result = iou(&box1, &box2);
471        assert!(result < 0.001, "IOU of non-overlapping boxes should be ~0");
472    }
473
474    #[test]
475    fn test_iou_partial_overlap() {
476        let box1 = [0.0, 0.0, 0.5, 0.5];
477        let box2 = [0.25, 0.25, 0.75, 0.75];
478        let result = iou(&box1, &box2);
479        // Intersection: 0.25*0.25 = 0.0625, Union: 0.25+0.25-0.0625 = 0.4375
480        assert!(result > 0.1 && result < 0.2, "IOU should be ~0.14");
481    }
482
483    #[test]
484    fn test_bytetrack_new() {
485        let tracker: ByteTrack<MockDetection> = ByteTrackBuilder::new().build();
486        assert_eq!(tracker.frame_count, 0);
487        assert!(tracker.tracklets.is_empty());
488        assert_eq!(tracker.track_high_conf, 0.7);
489        assert_eq!(tracker.track_iou, 0.25);
490    }
491
492    #[test]
493    fn test_bytetrack_single_detection_creates_tracklet() {
494        let mut tracker = ByteTrackBuilder::new().build();
495        let detections = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0)];
496
497        let results = tracker.update(&detections, 1000);
498
499        assert_eq!(results.len(), 1);
500        assert!(
501            results[0].is_some(),
502            "High-confidence detection should create tracklet"
503        );
504        assert_eq!(tracker.tracklets.len(), 1);
505        assert_eq!(tracker.frame_count, 1);
506    }
507
508    #[test]
509    fn test_bytetrack_low_confidence_no_tracklet() {
510        let mut tracker = ByteTrackBuilder::new().build();
511        // Score below track_high_conf (0.7)
512        let detections = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.5, 0)];
513
514        let results = tracker.update(&detections, 1000);
515
516        assert_eq!(results.len(), 1);
517        assert!(
518            results[0].is_none(),
519            "Low-confidence detection should not create tracklet"
520        );
521        assert!(tracker.tracklets.is_empty());
522    }
523
524    #[test]
525    fn test_bytetrack_tracking_across_frames() {
526        let mut tracker = ByteTrackBuilder::new().build();
527
528        // Frame 1: Create tracklet with a larger box that's easier to track
529        let det1 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
530        let res1 = tracker.update(&det1, 1000);
531        assert!(res1[0].is_some());
532        let uuid1 = res1[0].unwrap().uuid;
533        assert_eq!(tracker.tracklets.len(), 1);
534        // After creation, tracklet count is 1
535        assert_eq!(tracker.tracklets[0].count, 1);
536
537        // Frame 2: Same location - should match existing tracklet
538        let det2 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
539        let res2 = tracker.update(&det2, 2000);
540        assert!(res2[0].is_some());
541        let info2 = res2[0].unwrap();
542
543        // Verify tracklet was matched, not a new one created
544        assert_eq!(tracker.tracklets.len(), 1, "Should still have one tracklet");
545        assert_eq!(info2.uuid, uuid1, "Should match same tracklet");
546        // After second update, the internal tracklet count should be 2
547        assert_eq!(tracker.tracklets[0].count, 2, "Internal count should be 2");
548    }
549
550    #[test]
551    fn test_bytetrack_multiple_detections() {
552        let mut tracker = ByteTrackBuilder::new().build();
553
554        let detections = vec![
555            MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0),
556            MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0),
557            MockDetection::new([0.8, 0.8, 0.9, 0.9], 0.95, 0),
558        ];
559
560        let results = tracker.update(&detections, 1000);
561
562        assert_eq!(results.len(), 3);
563        assert!(results.iter().all(|r| r.is_some()));
564        assert_eq!(tracker.tracklets.len(), 3);
565    }
566
567    #[test]
568    fn test_bytetrack_tracklet_expiry() {
569        let mut tracker = ByteTrackBuilder::new().build();
570        tracker.track_extra_lifespan = 1000; // 1 second
571
572        // Create tracklet
573        let det1 = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0)];
574        tracker.update(&det1, 1000);
575        assert_eq!(tracker.tracklets.len(), 1);
576
577        // Update with no detections after lifespan expires
578        let empty: Vec<MockDetection> = vec![];
579        tracker.update(&empty, 3000); // 2 seconds later
580
581        assert!(tracker.tracklets.is_empty(), "Tracklet should have expired");
582    }
583
584    #[test]
585    fn test_bytetrack_get_active_tracks() {
586        let mut tracker = ByteTrackBuilder::new().build();
587
588        let detections = vec![
589            MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0),
590            MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0),
591        ];
592        tracker.update(&detections, 1000);
593
594        let active = tracker.get_active_tracks();
595        assert_eq!(active.len(), 2);
596        assert!(active.iter().all(|t| t.info.count == 1));
597        assert!(active.iter().all(|t| t.info.created == 1000));
598    }
599
600    #[test]
601    fn test_bytetrack_empty_detections() {
602        let mut tracker = ByteTrackBuilder::new().build();
603        let empty: Vec<MockDetection> = vec![];
604
605        let results = tracker.update(&empty, 1000);
606
607        assert!(results.is_empty());
608        assert!(tracker.tracklets.is_empty());
609        assert_eq!(tracker.frame_count, 1);
610    }
611
612    #[test]
613    fn test_two_stage_matching() {
614        // The core ByteTrack innovation: low-confidence detections are matched
615        // to existing tracklets in a second stage.
616        let mut tracker = ByteTrackBuilder::new().build();
617
618        // Frame 1: high-confidence detection creates a tracklet
619        let det1 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
620        let res1 = tracker.update(&det1, 1_000_000);
621        assert!(res1[0].is_some());
622        let uuid1 = res1[0].unwrap().uuid;
623        assert_eq!(tracker.tracklets.len(), 1);
624
625        // Frame 2: same location but low confidence (0.3, below track_high_conf=0.7).
626        // Second-stage matching should still associate it with the existing tracklet.
627        let det2 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.3, 0)];
628        let res2 = tracker.update(&det2, 2_000_000);
629        assert!(
630            res2[0].is_some(),
631            "Low-conf detection should match existing tracklet via second stage"
632        );
633        assert_eq!(
634            res2[0].unwrap().uuid,
635            uuid1,
636            "Should match the same tracklet"
637        );
638        assert_eq!(
639            tracker.tracklets.len(),
640            1,
641            "No new tracklet should be created"
642        );
643        assert_eq!(
644            tracker.tracklets[0].count, 2,
645            "Tracklet count should increment"
646        );
647    }
648
649    #[test]
650    fn test_builder_track_extra_lifespan() {
651        let lifespan_default = 500_000_000; // 0.5 seconds (default)
652        let lifespan_extended = 2_000_000_000; // 2 seconds
653
654        let mut tracker_default: ByteTrack<MockDetection> = ByteTrackBuilder::new().build();
655        let mut tracker_extended: ByteTrack<MockDetection> = ByteTrackBuilder::new()
656            .track_extra_lifespan(lifespan_extended)
657            .build();
658
659        assert_eq!(tracker_default.track_extra_lifespan, lifespan_default);
660        assert_eq!(tracker_extended.track_extra_lifespan, lifespan_extended);
661
662        let ts_start = 1_000_000_000u64; // 1 second
663        let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
664
665        tracker_default.update(&det, ts_start);
666        tracker_extended.update(&det, ts_start);
667        assert_eq!(tracker_default.tracklets.len(), 1);
668        assert_eq!(tracker_extended.tracklets.len(), 1);
669
670        // Advance to 1s + 1s = 2s. Default lifespan (0.5s) should have expired,
671        // extended lifespan (2s) should still be active.
672        let ts_after = ts_start + 1_000_000_000;
673        let empty: Vec<MockDetection> = vec![];
674        tracker_default.update(&empty, ts_after);
675        tracker_extended.update(&empty, ts_after);
676
677        assert!(
678            tracker_default.tracklets.is_empty(),
679            "Default tracker should have expired the tracklet"
680        );
681        assert_eq!(
682            tracker_extended.tracklets.len(),
683            1,
684            "Extended tracker should still have the tracklet"
685        );
686    }
687
688    #[test]
689    fn test_builder_track_high_conf() {
690        let mut tracker: ByteTrack<MockDetection> =
691            ByteTrackBuilder::new().track_high_conf(0.9).build();
692        assert_eq!(tracker.track_high_conf, 0.9);
693
694        // Detection with score 0.8 is below the 0.9 threshold
695        let det_low = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.8, 0)];
696        let res = tracker.update(&det_low, 1000);
697        assert!(
698            res[0].is_none(),
699            "Score 0.8 should not create a tracklet with threshold 0.9"
700        );
701        assert!(tracker.tracklets.is_empty());
702
703        // Detection with score 0.95 is above the 0.9 threshold
704        let det_high = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.95, 0)];
705        let res = tracker.update(&det_high, 2000);
706        assert!(
707            res[0].is_some(),
708            "Score 0.95 should create a tracklet with threshold 0.9"
709        );
710        assert_eq!(tracker.tracklets.len(), 1);
711    }
712
713    #[test]
714    fn test_builder_track_iou() {
715        // Tight IOU threshold: shifted detection should NOT match
716        let mut tracker: ByteTrack<MockDetection> = ByteTrackBuilder::new().track_iou(0.8).build();
717
718        // Frame 1: two well-separated detections
719        let det1 = vec![
720            MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0),
721            MockDetection::new([0.5, 0.5, 0.7, 0.7], 0.9, 0),
722        ];
723        tracker.update(&det1, 1000);
724        assert_eq!(tracker.tracklets.len(), 2);
725
726        // Frame 2: shift the first detection slightly. With IOU threshold 0.8
727        // the overlap won't be enough for a match, so it creates a new tracklet.
728        let det2 = vec![
729            MockDetection::new([0.15, 0.15, 0.35, 0.35], 0.9, 0),
730            MockDetection::new([0.5, 0.5, 0.7, 0.7], 0.9, 0),
731        ];
732        let res2 = tracker.update(&det2, 2000);
733        assert_eq!(res2.len(), 2);
734
735        // The second detection (unchanged) should still match. The first (shifted)
736        // should fail the tight IOU threshold and create a new tracklet.
737        assert!(
738            tracker.tracklets.len() >= 3,
739            "Shifted detection should create a new tracklet with tight IOU threshold, got {} tracklets",
740            tracker.tracklets.len()
741        );
742    }
743
744    #[test]
745    fn test_degenerate_zero_area_box() {
746        // A zero-area box (xmin == xmax) should not panic
747        let mut tracker = ByteTrackBuilder::new().build();
748        let det = vec![
749            MockDetection::new([0.5, 0.1, 0.5, 0.3], 0.9, 0), // zero width
750            MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0), // normal box
751        ];
752        let results = tracker.update(&det, 1000);
753        assert_eq!(results.len(), 2);
754
755        // IOU between a zero-area box and a normal box should be 0
756        let zero_box = [0.5, 0.1, 0.5, 0.3];
757        let normal_box = [0.1, 0.1, 0.3, 0.3];
758        let iou_val = iou(&zero_box, &normal_box);
759        assert!(
760            iou_val < EPSILON,
761            "IOU with a zero-area box should be ~0, got {iou_val}"
762        );
763    }
764
765    #[test]
766    fn test_degenerate_high_velocity() {
767        let mut tracker = ByteTrackBuilder::new().build();
768
769        // Frame 1: detection at top-left
770        let det1 = vec![MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0)];
771        let res1 = tracker.update(&det1, 1_000_000);
772        assert!(res1[0].is_some());
773        let uuid1 = res1[0].unwrap().uuid;
774        assert_eq!(tracker.tracklets.len(), 1);
775
776        // Frame 2: detection at bottom-right (huge displacement)
777        let det2 = vec![MockDetection::new([0.8, 0.8, 0.9, 0.9], 0.9, 0)];
778        let res2 = tracker.update(&det2, 2_000_000);
779        assert!(res2[0].is_some());
780
781        // With default IOU threshold the far-away detection should not match;
782        // a new tracklet is created instead.
783        assert_eq!(
784            tracker.tracklets.len(),
785            2,
786            "Far-displaced detection should create a new tracklet"
787        );
788        assert_ne!(
789            res2[0].unwrap().uuid,
790            uuid1,
791            "New detection should have a different UUID"
792        );
793    }
794
795    #[test]
796    fn test_many_detections_100() {
797        let mut tracker = ByteTrackBuilder::new().build();
798
799        // Generate 100 non-overlapping small boxes spread across [0, 1]
800        let detections: Vec<MockDetection> = (0..100)
801            .map(|i| {
802                let x = (i % 10) as f32 * 0.1;
803                let y = (i / 10) as f32 * 0.1;
804                MockDetection::new([x, y, x + 0.05, y + 0.05], 0.9, 0)
805            })
806            .collect();
807
808        let results = tracker.update(&detections, 1000);
809        assert_eq!(results.len(), 100);
810        assert!(
811            results.iter().all(|r| r.is_some()),
812            "All 100 high-confidence detections should create tracklets"
813        );
814        assert_eq!(
815            tracker.tracklets.len(),
816            100,
817            "Should have 100 active tracklets"
818        );
819    }
820
821    #[test]
822    fn test_tracklet_count_increments_each_frame() {
823        let mut tracker = ByteTrackBuilder::new().build();
824        let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
825
826        for frame in 1..=5 {
827            tracker.update(&det, frame * 1000);
828        }
829
830        assert_eq!(tracker.tracklets.len(), 1);
831        assert_eq!(
832            tracker.tracklets[0].count, 5,
833            "Tracklet count should equal number of frames it was matched"
834        );
835    }
836
837    #[test]
838    fn test_tracklet_created_timestamp_preserved() {
839        let mut tracker = ByteTrackBuilder::new().build();
840        let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
841
842        tracker.update(&det, 1000);
843        tracker.update(&det, 2000);
844        tracker.update(&det, 3000);
845
846        let active = tracker.get_active_tracks();
847        assert_eq!(active.len(), 1);
848        assert_eq!(
849            active[0].info.created, 1000,
850            "Created timestamp should remain at the first frame"
851        );
852        assert_eq!(
853            active[0].info.last_updated, 3000,
854            "Last updated should be the most recent frame"
855        );
856    }
857
858    #[test]
859    fn test_mixed_confidence_detections() {
860        // Mix of high and low confidence detections in a single frame
861        let mut tracker = ByteTrackBuilder::new().build();
862        let det = vec![
863            MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0), // high
864            MockDetection::new([0.3, 0.3, 0.4, 0.4], 0.3, 0), // low
865            MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0), // high
866            MockDetection::new([0.7, 0.7, 0.8, 0.8], 0.1, 0), // low
867        ];
868
869        let results = tracker.update(&det, 1000);
870        assert_eq!(results.len(), 4);
871
872        // Only the high-confidence ones should create tracklets
873        assert!(
874            results[0].is_some(),
875            "High-conf detection should create tracklet"
876        );
877        assert!(
878            results[1].is_none(),
879            "Low-conf detection should not create tracklet"
880        );
881        assert!(
882            results[2].is_some(),
883            "High-conf detection should create tracklet"
884        );
885        assert!(
886            results[3].is_none(),
887            "Low-conf detection should not create tracklet"
888        );
889        assert_eq!(tracker.tracklets.len(), 2);
890    }
891
892    #[test]
893    fn test_iou_contained_box() {
894        // One box fully contains the other
895        let outer = [0.0, 0.0, 1.0, 1.0];
896        let inner = [0.25, 0.25, 0.75, 0.75];
897        let result = iou(&outer, &inner);
898        // inner area = 0.25, outer area = 1.0, intersection = 0.25, union = 1.0
899        assert!(
900            (result - 0.25).abs() < 0.01,
901            "IOU of contained box should be inner_area/outer_area = 0.25, got {result}"
902        );
903    }
904
905    #[test]
906    fn test_xyxy_to_xyah_square_box() {
907        // A square box should have aspect ratio 1.0
908        let square = [0.1, 0.2, 0.3, 0.4];
909        let xyah = xyxy_to_xyah(&square);
910        assert!((xyah[0] - 0.2).abs() < 1e-5, "Center x should be 0.2");
911        assert!((xyah[1] - 0.3).abs() < 1e-5, "Center y should be 0.3");
912        assert!(
913            (xyah[2] - 1.0).abs() < 1e-5,
914            "Aspect ratio of square should be 1.0"
915        );
916        assert!((xyah[3] - 0.2).abs() < 1e-5, "Height should be 0.2");
917    }
918
919    #[test]
920    fn test_frame_count_increments() {
921        let mut tracker = ByteTrackBuilder::new().build();
922        let empty: Vec<MockDetection> = vec![];
923
924        for _ in 0..10 {
925            tracker.update(&empty, 0);
926        }
927
928        assert_eq!(
929            tracker.frame_count, 10,
930            "Frame count should increment each update"
931        );
932    }
933
934    #[test]
935    fn test_tracklet_predicted_location_near_detection() {
936        let mut tracker = ByteTrackBuilder::new().build();
937        let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
938        tracker.update(&det, 1000);
939
940        let pred = tracker.tracklets[0].get_predicted_location();
941        // The predicted location should be close to the original detection
942        assert!(
943            (pred[0] - 0.2).abs() < 0.1,
944            "Predicted xmin should be near 0.2, got {}",
945            pred[0]
946        );
947        assert!(
948            (pred[1] - 0.2).abs() < 0.1,
949            "Predicted ymin should be near 0.2, got {}",
950            pred[1]
951        );
952        assert!(
953            (pred[2] - 0.4).abs() < 0.1,
954            "Predicted xmax should be near 0.4, got {}",
955            pred[2]
956        );
957        assert!(
958            (pred[3] - 0.4).abs() < 0.1,
959            "Predicted ymax should be near 0.4, got {}",
960            pred[3]
961        );
962    }
963}