1use crate::{
5 kalman::ConstantVelocityXYAHModel2, ActiveTrackInfo, DetectionBox, TrackInfo, Tracker,
6};
7use lapjv::{lapjv, Matrix};
8use log::trace;
9use nalgebra::{Dyn, OMatrix, U4};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct ByteTrackBuilder {
14 track_extra_lifespan: u64,
15 track_high_conf: f32,
16 track_iou: f32,
17 track_update: f32,
18}
19
20impl Default for ByteTrackBuilder {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl ByteTrackBuilder {
27 pub fn new() -> Self {
45 Self {
46 track_high_conf: 0.7,
47 track_iou: 0.25,
48 track_update: 0.25,
49 track_extra_lifespan: 500_000_000,
50 }
51 }
52
53 pub fn track_extra_lifespan(mut self, lifespan: u64) -> Self {
55 self.track_extra_lifespan = lifespan;
56 self
57 }
58
59 pub fn track_high_conf(mut self, conf: f32) -> Self {
61 self.track_high_conf = conf;
62 self
63 }
64
65 pub fn track_iou(mut self, iou: f32) -> Self {
67 self.track_iou = iou;
68 self
69 }
70
71 pub fn track_update(mut self, update: f32) -> Self {
73 self.track_update = update;
74 self
75 }
76
77 pub fn build<T: DetectionBox>(self) -> ByteTrack<T> {
95 ByteTrack {
96 track_extra_lifespan: self.track_extra_lifespan,
97 track_high_conf: self.track_high_conf,
98 track_iou: self.track_iou,
99 track_update: self.track_update,
100 tracklets: Vec::new(),
101 frame_count: 0,
102 }
103 }
104}
105
106#[allow(dead_code)]
107#[derive(Default, Debug, Clone)]
108pub struct ByteTrack<T: DetectionBox> {
109 pub track_extra_lifespan: u64,
110 pub track_high_conf: f32,
111 pub track_iou: f32,
112 pub track_update: f32,
113 pub tracklets: Vec<Tracklet<T>>,
114 pub frame_count: i32,
115}
116
117#[derive(Debug, Clone)]
118pub struct Tracklet<T: DetectionBox> {
119 pub id: Uuid,
120 pub filter: ConstantVelocityXYAHModel2<f32>,
121 pub count: i32,
122 pub created: u64,
123 pub last_updated: u64,
124 pub last_box: T,
125}
126
127impl<T: DetectionBox> Tracklet<T> {
128 fn update(&mut self, detect_box: &T, ts: u64) {
129 self.count += 1;
130 self.last_updated = ts;
131 self.filter.update(&xyxy_to_xyah(&detect_box.bbox()));
132 self.last_box = detect_box.clone();
133 }
134
135 pub fn get_predicted_location(&self) -> [f32; 4] {
136 let projected = self.filter.project().0;
137 let predicted_xyah = projected.as_slice();
138 xyah_to_xyxy(predicted_xyah)
139 }
140}
141
142fn xyxy_to_xyah(vaal_box: &[f32; 4]) -> [f32; 4] {
143 let x = (vaal_box[2] + vaal_box[0]) / 2.0;
144 let y = (vaal_box[3] + vaal_box[1]) / 2.0;
145 let w = (vaal_box[2] - vaal_box[0]).max(EPSILON);
146 let h = (vaal_box[3] - vaal_box[1]).max(EPSILON);
147 let a = w / h;
148
149 [x, y, a, h]
150}
151
152fn xyah_to_xyxy(xyah: &[f32]) -> [f32; 4] {
153 assert!(xyah.len() >= 4);
154 let [x, y, a, h] = xyah[0..4] else {
155 unreachable!()
156 };
157 let w = h * a;
158 [x - w / 2.0, y - h / 2.0, x + w / 2.0, y + h / 2.0]
159}
160
161const INVALID_MATCH: f32 = 1000000.0;
162const EPSILON: f32 = 0.00001;
163
164fn iou(box1: &[f32], box2: &[f32]) -> f32 {
165 let intersection = (box1[2].min(box2[2]) - box1[0].max(box2[0])).max(0.0)
166 * (box1[3].min(box2[3]) - box1[1].max(box2[1])).max(0.0);
167
168 let union = (box1[2] - box1[0]) * (box1[3] - box1[1])
169 + (box2[2] - box2[0]) * (box2[3] - box2[1])
170 - intersection;
171
172 if union <= EPSILON {
173 return 0.0;
174 }
175
176 intersection / union
177}
178
179fn box_cost<T: DetectionBox>(
180 track: &Tracklet<T>,
181 new_box: &T,
182 distance: f32,
183 score_threshold: f32,
184 iou_threshold: f32,
185) -> f32 {
186 let _ = distance;
187
188 if new_box.score() < score_threshold {
189 return INVALID_MATCH;
190 }
191
192 let predicted_xyah = track.filter.mean.as_slice();
194 let expected = xyah_to_xyxy(predicted_xyah);
195 let iou = iou(&expected, &new_box.bbox());
196 if iou < iou_threshold {
197 return INVALID_MATCH;
198 }
199 (1.5 - new_box.score()) + (1.5 - iou)
200}
201
202impl<T: DetectionBox> ByteTrack<T> {
203 fn compute_costs(
204 &mut self,
205 boxes: &[T],
206 score_threshold: f32,
207 iou_threshold: f32,
208 box_filter: &[bool],
209 track_filter: &[bool],
210 ) -> Matrix<f32> {
211 let dims = boxes.len().max(self.tracklets.len());
213 let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(boxes.len(), 0.0);
214 for (i, mut row) in measurements.row_iter_mut().enumerate() {
215 row.copy_from_slice(&xyxy_to_xyah(&boxes[i].bbox()));
216 }
217
218 Matrix::from_shape_fn((dims, dims), |(x, y)| {
222 if x < boxes.len() && y < self.tracklets.len() {
223 if box_filter[x] || track_filter[y] {
224 INVALID_MATCH
225 } else {
226 box_cost(
227 &self.tracklets[y],
228 &boxes[x],
229 0.0,
231 score_threshold,
232 iou_threshold,
233 )
234 }
235 } else {
236 0.0
237 }
238 })
239 }
240
241 #[allow(clippy::too_many_arguments)]
244 fn process_assignments(
245 &mut self,
246 assignments: &[usize],
247 boxes: &[T],
248 costs: &Matrix<f32>,
249 matched: &mut [bool],
250 tracked: &mut [bool],
251 matched_info: &mut [Option<TrackInfo>],
252 timestamp: u64,
253 log_assignments: bool,
254 ) {
255 for (i, &x) in assignments.iter().enumerate() {
256 if i >= boxes.len() || x >= self.tracklets.len() {
257 continue;
258 }
259
260 if costs[(i, x)] >= INVALID_MATCH {
262 continue;
263 }
264
265 if matched[i] || tracked[x] {
267 continue;
268 }
269
270 if log_assignments {
271 trace!(
272 "Cost: {} Box: {:#?} UUID: {} Mean: {}",
273 costs[(i, x)],
274 boxes[i],
275 self.tracklets[x].id,
276 self.tracklets[x].filter.mean
277 );
278 }
279
280 matched[i] = true;
281 matched_info[i] = Some(TrackInfo {
282 uuid: self.tracklets[x].id,
283 count: self.tracklets[x].count,
284 created: self.tracklets[x].created,
285 tracked_location: self.tracklets[x].get_predicted_location(),
286 last_updated: timestamp,
287 });
288 tracked[x] = true;
289 self.tracklets[x].update(&boxes[i], timestamp);
290 }
291 }
292
293 fn remove_expired_tracklets(&mut self, timestamp: u64) {
295 for i in (0..self.tracklets.len()).rev() {
297 let expiry = self.tracklets[i].last_updated + self.track_extra_lifespan;
298 if expiry < timestamp {
299 trace!("Tracklet removed: {:?}", self.tracklets[i].id);
300 let _ = self.tracklets.swap_remove(i);
301 }
302 }
303 }
304
305 fn create_new_tracklets(
307 &mut self,
308 boxes: &[T],
309 high_conf_indices: &[usize],
310 matched: &[bool],
311 matched_info: &mut [Option<TrackInfo>],
312 timestamp: u64,
313 ) {
314 for &i in high_conf_indices {
315 if matched[i] {
316 continue;
317 }
318
319 let id = Uuid::new_v4();
320 let new_tracklet = Tracklet {
321 id,
322 filter: ConstantVelocityXYAHModel2::new(
323 &xyxy_to_xyah(&boxes[i].bbox()),
324 self.track_update,
325 ),
326 last_updated: timestamp,
327 count: 1,
328 created: timestamp,
329 last_box: boxes[i].clone(),
330 };
331 matched_info[i] = Some(TrackInfo {
332 uuid: new_tracklet.id,
333 count: new_tracklet.count,
334 created: new_tracklet.created,
335 tracked_location: new_tracklet.get_predicted_location(),
336 last_updated: timestamp,
337 });
338 self.tracklets.push(new_tracklet);
339 }
340 }
341}
342
343impl<T> Tracker<T> for ByteTrack<T>
344where
345 T: DetectionBox,
346{
347 fn update(&mut self, boxes: &[T], timestamp: u64) -> Vec<Option<TrackInfo>> {
348 self.frame_count += 1;
349
350 let high_conf_ind: Vec<usize> = boxes
352 .iter()
353 .enumerate()
354 .filter(|(_, b)| b.score() >= self.track_high_conf)
355 .map(|(x, _)| x)
356 .collect();
357
358 let mut matched = vec![false; boxes.len()];
359 let mut tracked = vec![false; self.tracklets.len()];
360 let mut matched_info = vec![None; boxes.len()];
361
362 if !self.tracklets.is_empty() {
364 for track in &mut self.tracklets {
365 track.filter.predict();
366 }
367
368 let costs = self.compute_costs(
369 boxes,
370 self.track_high_conf,
371 self.track_iou,
372 &matched,
373 &tracked,
374 );
375 if let Ok(ans) = lapjv(&costs) {
376 self.process_assignments(
377 &ans.0,
378 boxes,
379 &costs,
380 &mut matched,
381 &mut tracked,
382 &mut matched_info,
383 timestamp,
384 false,
385 );
386 }
387 }
388
389 if !self.tracklets.is_empty() {
391 let costs = self.compute_costs(boxes, 0.0, self.track_iou, &matched, &tracked);
392 if let Ok(ans) = lapjv(&costs) {
393 self.process_assignments(
394 &ans.0,
395 boxes,
396 &costs,
397 &mut matched,
398 &mut tracked,
399 &mut matched_info,
400 timestamp,
401 true,
402 );
403 }
404 }
405
406 self.remove_expired_tracklets(timestamp);
408
409 self.create_new_tracklets(
411 boxes,
412 &high_conf_ind,
413 &matched,
414 &mut matched_info,
415 timestamp,
416 );
417
418 matched_info
419 }
420
421 fn get_active_tracks(&self) -> Vec<ActiveTrackInfo<T>> {
422 self.tracklets
423 .iter()
424 .map(|t| ActiveTrackInfo {
425 info: TrackInfo {
426 uuid: t.id,
427 tracked_location: t.get_predicted_location(),
428 count: t.count,
429 created: t.created,
430 last_updated: t.last_updated,
431 },
432 last_box: t.last_box.clone(),
433 })
434 .collect()
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::*;
442
443 #[test]
444 fn test_vaalbox_xyah_roundtrip() {
445 let box1 = [0.0134, 0.02135, 0.12438, 0.691];
446 let xyah = xyxy_to_xyah(&box1);
447 let box2 = xyah_to_xyxy(&xyah);
448
449 assert!((box1[0] - box2[0]).abs() < f32::EPSILON);
450 assert!((box1[1] - box2[1]).abs() < f32::EPSILON);
451 assert!((box1[2] - box2[2]).abs() < f32::EPSILON);
452 assert!((box1[3] - box2[3]).abs() < f32::EPSILON);
453 }
454
455 #[test]
456 fn test_iou_identical_boxes() {
457 let box1 = [0.1, 0.1, 0.5, 0.5];
458 let box2 = [0.1, 0.1, 0.5, 0.5];
459 let result = iou(&box1, &box2);
460 assert!(
461 (result - 1.0).abs() < 0.001,
462 "IOU of identical boxes should be 1.0"
463 );
464 }
465
466 #[test]
467 fn test_iou_no_overlap() {
468 let box1 = [0.0, 0.0, 0.2, 0.2];
469 let box2 = [0.5, 0.5, 0.7, 0.7];
470 let result = iou(&box1, &box2);
471 assert!(result < 0.001, "IOU of non-overlapping boxes should be ~0");
472 }
473
474 #[test]
475 fn test_iou_partial_overlap() {
476 let box1 = [0.0, 0.0, 0.5, 0.5];
477 let box2 = [0.25, 0.25, 0.75, 0.75];
478 let result = iou(&box1, &box2);
479 assert!(result > 0.1 && result < 0.2, "IOU should be ~0.14");
481 }
482
483 #[test]
484 fn test_bytetrack_new() {
485 let tracker: ByteTrack<MockDetection> = ByteTrackBuilder::new().build();
486 assert_eq!(tracker.frame_count, 0);
487 assert!(tracker.tracklets.is_empty());
488 assert_eq!(tracker.track_high_conf, 0.7);
489 assert_eq!(tracker.track_iou, 0.25);
490 }
491
492 #[test]
493 fn test_bytetrack_single_detection_creates_tracklet() {
494 let mut tracker = ByteTrackBuilder::new().build();
495 let detections = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0)];
496
497 let results = tracker.update(&detections, 1000);
498
499 assert_eq!(results.len(), 1);
500 assert!(
501 results[0].is_some(),
502 "High-confidence detection should create tracklet"
503 );
504 assert_eq!(tracker.tracklets.len(), 1);
505 assert_eq!(tracker.frame_count, 1);
506 }
507
508 #[test]
509 fn test_bytetrack_low_confidence_no_tracklet() {
510 let mut tracker = ByteTrackBuilder::new().build();
511 let detections = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.5, 0)];
513
514 let results = tracker.update(&detections, 1000);
515
516 assert_eq!(results.len(), 1);
517 assert!(
518 results[0].is_none(),
519 "Low-confidence detection should not create tracklet"
520 );
521 assert!(tracker.tracklets.is_empty());
522 }
523
524 #[test]
525 fn test_bytetrack_tracking_across_frames() {
526 let mut tracker = ByteTrackBuilder::new().build();
527
528 let det1 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
530 let res1 = tracker.update(&det1, 1000);
531 assert!(res1[0].is_some());
532 let uuid1 = res1[0].unwrap().uuid;
533 assert_eq!(tracker.tracklets.len(), 1);
534 assert_eq!(tracker.tracklets[0].count, 1);
536
537 let det2 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
539 let res2 = tracker.update(&det2, 2000);
540 assert!(res2[0].is_some());
541 let info2 = res2[0].unwrap();
542
543 assert_eq!(tracker.tracklets.len(), 1, "Should still have one tracklet");
545 assert_eq!(info2.uuid, uuid1, "Should match same tracklet");
546 assert_eq!(tracker.tracklets[0].count, 2, "Internal count should be 2");
548 }
549
550 #[test]
551 fn test_bytetrack_multiple_detections() {
552 let mut tracker = ByteTrackBuilder::new().build();
553
554 let detections = vec![
555 MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0),
556 MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0),
557 MockDetection::new([0.8, 0.8, 0.9, 0.9], 0.95, 0),
558 ];
559
560 let results = tracker.update(&detections, 1000);
561
562 assert_eq!(results.len(), 3);
563 assert!(results.iter().all(|r| r.is_some()));
564 assert_eq!(tracker.tracklets.len(), 3);
565 }
566
567 #[test]
568 fn test_bytetrack_tracklet_expiry() {
569 let mut tracker = ByteTrackBuilder::new().build();
570 tracker.track_extra_lifespan = 1000; let det1 = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0)];
574 tracker.update(&det1, 1000);
575 assert_eq!(tracker.tracklets.len(), 1);
576
577 let empty: Vec<MockDetection> = vec![];
579 tracker.update(&empty, 3000); assert!(tracker.tracklets.is_empty(), "Tracklet should have expired");
582 }
583
584 #[test]
585 fn test_bytetrack_get_active_tracks() {
586 let mut tracker = ByteTrackBuilder::new().build();
587
588 let detections = vec![
589 MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0),
590 MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0),
591 ];
592 tracker.update(&detections, 1000);
593
594 let active = tracker.get_active_tracks();
595 assert_eq!(active.len(), 2);
596 assert!(active.iter().all(|t| t.info.count == 1));
597 assert!(active.iter().all(|t| t.info.created == 1000));
598 }
599
600 #[test]
601 fn test_bytetrack_empty_detections() {
602 let mut tracker = ByteTrackBuilder::new().build();
603 let empty: Vec<MockDetection> = vec![];
604
605 let results = tracker.update(&empty, 1000);
606
607 assert!(results.is_empty());
608 assert!(tracker.tracklets.is_empty());
609 assert_eq!(tracker.frame_count, 1);
610 }
611
612 #[test]
613 fn test_two_stage_matching() {
614 let mut tracker = ByteTrackBuilder::new().build();
617
618 let det1 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
620 let res1 = tracker.update(&det1, 1_000_000);
621 assert!(res1[0].is_some());
622 let uuid1 = res1[0].unwrap().uuid;
623 assert_eq!(tracker.tracklets.len(), 1);
624
625 let det2 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.3, 0)];
628 let res2 = tracker.update(&det2, 2_000_000);
629 assert!(
630 res2[0].is_some(),
631 "Low-conf detection should match existing tracklet via second stage"
632 );
633 assert_eq!(
634 res2[0].unwrap().uuid,
635 uuid1,
636 "Should match the same tracklet"
637 );
638 assert_eq!(
639 tracker.tracklets.len(),
640 1,
641 "No new tracklet should be created"
642 );
643 assert_eq!(
644 tracker.tracklets[0].count, 2,
645 "Tracklet count should increment"
646 );
647 }
648
649 #[test]
650 fn test_builder_track_extra_lifespan() {
651 let lifespan_default = 500_000_000; let lifespan_extended = 2_000_000_000; let mut tracker_default: ByteTrack<MockDetection> = ByteTrackBuilder::new().build();
655 let mut tracker_extended: ByteTrack<MockDetection> = ByteTrackBuilder::new()
656 .track_extra_lifespan(lifespan_extended)
657 .build();
658
659 assert_eq!(tracker_default.track_extra_lifespan, lifespan_default);
660 assert_eq!(tracker_extended.track_extra_lifespan, lifespan_extended);
661
662 let ts_start = 1_000_000_000u64; let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
664
665 tracker_default.update(&det, ts_start);
666 tracker_extended.update(&det, ts_start);
667 assert_eq!(tracker_default.tracklets.len(), 1);
668 assert_eq!(tracker_extended.tracklets.len(), 1);
669
670 let ts_after = ts_start + 1_000_000_000;
673 let empty: Vec<MockDetection> = vec![];
674 tracker_default.update(&empty, ts_after);
675 tracker_extended.update(&empty, ts_after);
676
677 assert!(
678 tracker_default.tracklets.is_empty(),
679 "Default tracker should have expired the tracklet"
680 );
681 assert_eq!(
682 tracker_extended.tracklets.len(),
683 1,
684 "Extended tracker should still have the tracklet"
685 );
686 }
687
688 #[test]
689 fn test_builder_track_high_conf() {
690 let mut tracker: ByteTrack<MockDetection> =
691 ByteTrackBuilder::new().track_high_conf(0.9).build();
692 assert_eq!(tracker.track_high_conf, 0.9);
693
694 let det_low = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.8, 0)];
696 let res = tracker.update(&det_low, 1000);
697 assert!(
698 res[0].is_none(),
699 "Score 0.8 should not create a tracklet with threshold 0.9"
700 );
701 assert!(tracker.tracklets.is_empty());
702
703 let det_high = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.95, 0)];
705 let res = tracker.update(&det_high, 2000);
706 assert!(
707 res[0].is_some(),
708 "Score 0.95 should create a tracklet with threshold 0.9"
709 );
710 assert_eq!(tracker.tracklets.len(), 1);
711 }
712
713 #[test]
714 fn test_builder_track_iou() {
715 let mut tracker: ByteTrack<MockDetection> = ByteTrackBuilder::new().track_iou(0.8).build();
717
718 let det1 = vec![
720 MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0),
721 MockDetection::new([0.5, 0.5, 0.7, 0.7], 0.9, 0),
722 ];
723 tracker.update(&det1, 1000);
724 assert_eq!(tracker.tracklets.len(), 2);
725
726 let det2 = vec![
729 MockDetection::new([0.15, 0.15, 0.35, 0.35], 0.9, 0),
730 MockDetection::new([0.5, 0.5, 0.7, 0.7], 0.9, 0),
731 ];
732 let res2 = tracker.update(&det2, 2000);
733 assert_eq!(res2.len(), 2);
734
735 assert!(
738 tracker.tracklets.len() >= 3,
739 "Shifted detection should create a new tracklet with tight IOU threshold, got {} tracklets",
740 tracker.tracklets.len()
741 );
742 }
743
744 #[test]
745 fn test_degenerate_zero_area_box() {
746 let mut tracker = ByteTrackBuilder::new().build();
748 let det = vec![
749 MockDetection::new([0.5, 0.1, 0.5, 0.3], 0.9, 0), MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0), ];
752 let results = tracker.update(&det, 1000);
753 assert_eq!(results.len(), 2);
754
755 let zero_box = [0.5, 0.1, 0.5, 0.3];
757 let normal_box = [0.1, 0.1, 0.3, 0.3];
758 let iou_val = iou(&zero_box, &normal_box);
759 assert!(
760 iou_val < EPSILON,
761 "IOU with a zero-area box should be ~0, got {iou_val}"
762 );
763 }
764
765 #[test]
766 fn test_degenerate_high_velocity() {
767 let mut tracker = ByteTrackBuilder::new().build();
768
769 let det1 = vec![MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0)];
771 let res1 = tracker.update(&det1, 1_000_000);
772 assert!(res1[0].is_some());
773 let uuid1 = res1[0].unwrap().uuid;
774 assert_eq!(tracker.tracklets.len(), 1);
775
776 let det2 = vec![MockDetection::new([0.8, 0.8, 0.9, 0.9], 0.9, 0)];
778 let res2 = tracker.update(&det2, 2_000_000);
779 assert!(res2[0].is_some());
780
781 assert_eq!(
784 tracker.tracklets.len(),
785 2,
786 "Far-displaced detection should create a new tracklet"
787 );
788 assert_ne!(
789 res2[0].unwrap().uuid,
790 uuid1,
791 "New detection should have a different UUID"
792 );
793 }
794
795 #[test]
796 fn test_many_detections_100() {
797 let mut tracker = ByteTrackBuilder::new().build();
798
799 let detections: Vec<MockDetection> = (0..100)
801 .map(|i| {
802 let x = (i % 10) as f32 * 0.1;
803 let y = (i / 10) as f32 * 0.1;
804 MockDetection::new([x, y, x + 0.05, y + 0.05], 0.9, 0)
805 })
806 .collect();
807
808 let results = tracker.update(&detections, 1000);
809 assert_eq!(results.len(), 100);
810 assert!(
811 results.iter().all(|r| r.is_some()),
812 "All 100 high-confidence detections should create tracklets"
813 );
814 assert_eq!(
815 tracker.tracklets.len(),
816 100,
817 "Should have 100 active tracklets"
818 );
819 }
820
821 #[test]
822 fn test_tracklet_count_increments_each_frame() {
823 let mut tracker = ByteTrackBuilder::new().build();
824 let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
825
826 for frame in 1..=5 {
827 tracker.update(&det, frame * 1000);
828 }
829
830 assert_eq!(tracker.tracklets.len(), 1);
831 assert_eq!(
832 tracker.tracklets[0].count, 5,
833 "Tracklet count should equal number of frames it was matched"
834 );
835 }
836
837 #[test]
838 fn test_tracklet_created_timestamp_preserved() {
839 let mut tracker = ByteTrackBuilder::new().build();
840 let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
841
842 tracker.update(&det, 1000);
843 tracker.update(&det, 2000);
844 tracker.update(&det, 3000);
845
846 let active = tracker.get_active_tracks();
847 assert_eq!(active.len(), 1);
848 assert_eq!(
849 active[0].info.created, 1000,
850 "Created timestamp should remain at the first frame"
851 );
852 assert_eq!(
853 active[0].info.last_updated, 3000,
854 "Last updated should be the most recent frame"
855 );
856 }
857
858 #[test]
859 fn test_mixed_confidence_detections() {
860 let mut tracker = ByteTrackBuilder::new().build();
862 let det = vec![
863 MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0), MockDetection::new([0.3, 0.3, 0.4, 0.4], 0.3, 0), MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0), MockDetection::new([0.7, 0.7, 0.8, 0.8], 0.1, 0), ];
868
869 let results = tracker.update(&det, 1000);
870 assert_eq!(results.len(), 4);
871
872 assert!(
874 results[0].is_some(),
875 "High-conf detection should create tracklet"
876 );
877 assert!(
878 results[1].is_none(),
879 "Low-conf detection should not create tracklet"
880 );
881 assert!(
882 results[2].is_some(),
883 "High-conf detection should create tracklet"
884 );
885 assert!(
886 results[3].is_none(),
887 "Low-conf detection should not create tracklet"
888 );
889 assert_eq!(tracker.tracklets.len(), 2);
890 }
891
892 #[test]
893 fn test_iou_contained_box() {
894 let outer = [0.0, 0.0, 1.0, 1.0];
896 let inner = [0.25, 0.25, 0.75, 0.75];
897 let result = iou(&outer, &inner);
898 assert!(
900 (result - 0.25).abs() < 0.01,
901 "IOU of contained box should be inner_area/outer_area = 0.25, got {result}"
902 );
903 }
904
905 #[test]
906 fn test_xyxy_to_xyah_square_box() {
907 let square = [0.1, 0.2, 0.3, 0.4];
909 let xyah = xyxy_to_xyah(&square);
910 assert!((xyah[0] - 0.2).abs() < 1e-5, "Center x should be 0.2");
911 assert!((xyah[1] - 0.3).abs() < 1e-5, "Center y should be 0.3");
912 assert!(
913 (xyah[2] - 1.0).abs() < 1e-5,
914 "Aspect ratio of square should be 1.0"
915 );
916 assert!((xyah[3] - 0.2).abs() < 1e-5, "Height should be 0.2");
917 }
918
919 #[test]
920 fn test_frame_count_increments() {
921 let mut tracker = ByteTrackBuilder::new().build();
922 let empty: Vec<MockDetection> = vec![];
923
924 for _ in 0..10 {
925 tracker.update(&empty, 0);
926 }
927
928 assert_eq!(
929 tracker.frame_count, 10,
930 "Frame count should increment each update"
931 );
932 }
933
934 #[test]
935 fn test_tracklet_predicted_location_near_detection() {
936 let mut tracker = ByteTrackBuilder::new().build();
937 let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
938 tracker.update(&det, 1000);
939
940 let pred = tracker.tracklets[0].get_predicted_location();
941 assert!(
943 (pred[0] - 0.2).abs() < 0.1,
944 "Predicted xmin should be near 0.2, got {}",
945 pred[0]
946 );
947 assert!(
948 (pred[1] - 0.2).abs() < 0.1,
949 "Predicted ymin should be near 0.2, got {}",
950 pred[1]
951 );
952 assert!(
953 (pred[2] - 0.4).abs() < 0.1,
954 "Predicted xmax should be near 0.4, got {}",
955 pred[2]
956 );
957 assert!(
958 (pred[3] - 0.4).abs() < 0.1,
959 "Predicted ymax should be near 0.4, got {}",
960 pred[3]
961 );
962 }
963}