1use crate::{
5 kalman::ConstantVelocityXYAHModel2, ActiveTrackInfo, DetectionBox, TrackInfo, Tracker,
6};
7use lapjv::{lapjv, Matrix};
8use log::trace;
9use nalgebra::{Dyn, OMatrix, U4};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct ByteTrackBuilder {
14 track_extra_lifespan: u64,
15 track_high_conf: f32,
16 track_iou: f32,
17 track_update: f32,
18}
19
20impl Default for ByteTrackBuilder {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl ByteTrackBuilder {
27 pub fn new() -> Self {
45 Self {
46 track_high_conf: 0.7,
47 track_iou: 0.25,
48 track_update: 0.25,
49 track_extra_lifespan: 500_000_000,
50 }
51 }
52
53 pub fn track_extra_lifespan(mut self, lifespan: u64) -> Self {
55 self.track_extra_lifespan = lifespan;
56 self
57 }
58
59 pub fn track_high_conf(mut self, conf: f32) -> Self {
61 self.track_high_conf = conf;
62 self
63 }
64
65 pub fn track_iou(mut self, iou: f32) -> Self {
67 self.track_iou = iou;
68 self
69 }
70
71 pub fn track_update(mut self, update: f32) -> Self {
73 self.track_update = update;
74 self
75 }
76
77 pub fn build<T: DetectionBox>(self) -> ByteTrack<T> {
95 ByteTrack {
96 track_extra_lifespan: self.track_extra_lifespan,
97 track_high_conf: self.track_high_conf,
98 track_iou: self.track_iou,
99 track_update: self.track_update,
100 tracklets: Vec::new(),
101 frame_count: 0,
102 }
103 }
104}
105
106#[allow(dead_code)]
107#[derive(Default, Debug, Clone)]
108pub struct ByteTrack<T: DetectionBox> {
109 pub track_extra_lifespan: u64,
110 pub track_high_conf: f32,
111 pub track_iou: f32,
112 pub track_update: f32,
113 pub tracklets: Vec<Tracklet<T>>,
114 pub frame_count: i32,
115}
116
117#[derive(Debug, Clone)]
118pub struct Tracklet<T: DetectionBox> {
119 pub id: Uuid,
120 pub filter: ConstantVelocityXYAHModel2<f32>,
121 pub count: i32,
122 pub created: u64,
123 pub last_updated: u64,
124 pub last_box: T,
125}
126
127impl<T: DetectionBox> Tracklet<T> {
128 fn update(&mut self, detect_box: &T, ts: u64) {
129 self.count += 1;
130 self.last_updated = ts;
131 self.filter.update(&xyxy_to_xyah(&detect_box.bbox()));
132 self.last_box = detect_box.clone();
133 }
134
135 pub fn get_predicted_location(&self) -> [f32; 4] {
136 let projected = self.filter.project().0;
137 let predicted_xyah = projected.as_slice();
138 xyah_to_xyxy(predicted_xyah)
139 }
140}
141
142fn xyxy_to_xyah(vaal_box: &[f32; 4]) -> [f32; 4] {
143 let x = (vaal_box[2] + vaal_box[0]) / 2.0;
144 let y = (vaal_box[3] + vaal_box[1]) / 2.0;
145 let w = (vaal_box[2] - vaal_box[0]).max(EPSILON);
146 let h = (vaal_box[3] - vaal_box[1]).max(EPSILON);
147 let a = w / h;
148
149 [x, y, a, h]
150}
151
152fn xyah_to_xyxy(xyah: &[f32]) -> [f32; 4] {
153 assert!(xyah.len() >= 4);
154 let [x, y, a, h] = xyah[0..4] else {
155 unreachable!()
156 };
157 let w = h * a;
158 [x - w / 2.0, y - h / 2.0, x + w / 2.0, y + h / 2.0]
159}
160
161const INVALID_MATCH: f32 = 1000000.0;
162const EPSILON: f32 = 0.00001;
163
164fn iou(box1: &[f32], box2: &[f32]) -> f32 {
165 let intersection = (box1[2].min(box2[2]) - box1[0].max(box2[0])).max(0.0)
166 * (box1[3].min(box2[3]) - box1[1].max(box2[1])).max(0.0);
167
168 let union = (box1[2] - box1[0]) * (box1[3] - box1[1])
169 + (box2[2] - box2[0]) * (box2[3] - box2[1])
170 - intersection;
171
172 if union <= EPSILON {
173 return 0.0;
174 }
175
176 intersection / union
177}
178
179fn box_cost<T: DetectionBox>(
180 track: &Tracklet<T>,
181 new_box: &T,
182 distance: f32,
183 score_threshold: f32,
184 iou_threshold: f32,
185) -> f32 {
186 let _ = distance;
187
188 if new_box.score() < score_threshold {
189 return INVALID_MATCH;
190 }
191
192 let predicted_xyah = track.filter.mean.as_slice();
194 let expected = xyah_to_xyxy(predicted_xyah);
195 let iou = iou(&expected, &new_box.bbox());
196 if iou < iou_threshold {
197 return INVALID_MATCH;
198 }
199 (1.5 - new_box.score()) + (1.5 - iou)
200}
201
202impl<T: DetectionBox> ByteTrack<T> {
203 fn compute_costs(
204 &mut self,
205 boxes: &[T],
206 score_threshold: f32,
207 iou_threshold: f32,
208 box_filter: &[bool],
209 track_filter: &[bool],
210 ) -> Matrix<f32> {
211 let dims = boxes.len().max(self.tracklets.len());
213 let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(boxes.len(), 0.0);
214 for (i, mut row) in measurements.row_iter_mut().enumerate() {
215 row.copy_from_slice(&xyxy_to_xyah(&boxes[i].bbox()));
216 }
217
218 Matrix::from_shape_fn((dims, dims), |(x, y)| {
222 if x < boxes.len() && y < self.tracklets.len() {
223 if box_filter[x] || track_filter[y] {
224 INVALID_MATCH
225 } else {
226 box_cost(
227 &self.tracklets[y],
228 &boxes[x],
229 0.0,
231 score_threshold,
232 iou_threshold,
233 )
234 }
235 } else {
236 0.0
237 }
238 })
239 }
240
241 #[allow(clippy::too_many_arguments)]
244 fn process_assignments(
245 &mut self,
246 assignments: &[usize],
247 boxes: &[T],
248 costs: &Matrix<f32>,
249 matched: &mut [bool],
250 tracked: &mut [bool],
251 matched_info: &mut [Option<TrackInfo>],
252 timestamp: u64,
253 log_assignments: bool,
254 ) {
255 for (i, &x) in assignments.iter().enumerate() {
256 if i >= boxes.len() || x >= self.tracklets.len() {
257 continue;
258 }
259
260 if costs[(i, x)] >= INVALID_MATCH {
262 continue;
263 }
264
265 if matched[i] || tracked[x] {
267 continue;
268 }
269
270 if log_assignments {
271 trace!(
272 "Cost: {} Box: {:#?} UUID: {} Mean: {}",
273 costs[(i, x)],
274 boxes[i],
275 self.tracklets[x].id,
276 self.tracklets[x].filter.mean
277 );
278 }
279
280 matched[i] = true;
281 matched_info[i] = Some(TrackInfo {
282 uuid: self.tracklets[x].id,
283 count: self.tracklets[x].count,
284 created: self.tracklets[x].created,
285 tracked_location: self.tracklets[x].get_predicted_location(),
286 last_updated: timestamp,
287 });
288 tracked[x] = true;
289 self.tracklets[x].update(&boxes[i], timestamp);
290 }
291 }
292
293 fn remove_expired_tracklets(&mut self, timestamp: u64) {
295 for i in (0..self.tracklets.len()).rev() {
297 let expiry = self.tracklets[i].last_updated + self.track_extra_lifespan;
298 if expiry < timestamp {
299 trace!("Tracklet removed: {:?}", self.tracklets[i].id);
300 let _ = self.tracklets.swap_remove(i);
301 }
302 }
303 }
304
305 fn create_new_tracklets(
307 &mut self,
308 boxes: &[T],
309 high_conf_indices: &[usize],
310 matched: &[bool],
311 matched_info: &mut [Option<TrackInfo>],
312 timestamp: u64,
313 ) {
314 for &i in high_conf_indices {
315 if matched[i] {
316 continue;
317 }
318
319 let id = Uuid::new_v4();
320 let new_tracklet = Tracklet {
321 id,
322 filter: ConstantVelocityXYAHModel2::new(
323 &xyxy_to_xyah(&boxes[i].bbox()),
324 self.track_update,
325 ),
326 last_updated: timestamp,
327 count: 1,
328 created: timestamp,
329 last_box: boxes[i].clone(),
330 };
331 matched_info[i] = Some(TrackInfo {
332 uuid: new_tracklet.id,
333 count: new_tracklet.count,
334 created: new_tracklet.created,
335 tracked_location: new_tracklet.get_predicted_location(),
336 last_updated: timestamp,
337 });
338 self.tracklets.push(new_tracklet);
339 }
340 }
341}
342
343impl<T> Tracker<T> for ByteTrack<T>
344where
345 T: DetectionBox,
346{
347 fn update(&mut self, boxes: &[T], timestamp: u64) -> Vec<Option<TrackInfo>> {
348 let span = tracing::trace_span!(
349 "tracker_update",
350 n_detections = boxes.len(),
351 n_tracklets = self.tracklets.len(),
352 timestamp,
353 );
354 let _enter = span.enter();
355
356 self.frame_count += 1;
357
358 let high_conf_ind: Vec<usize> = boxes
360 .iter()
361 .enumerate()
362 .filter(|(_, b)| b.score() >= self.track_high_conf)
363 .map(|(x, _)| x)
364 .collect();
365
366 let mut matched = vec![false; boxes.len()];
367 let mut tracked = vec![false; self.tracklets.len()];
368 let mut matched_info = vec![None; boxes.len()];
369
370 if !self.tracklets.is_empty() {
372 let _s = tracing::trace_span!("predict").entered();
373 for track in &mut self.tracklets {
374 track.filter.predict();
375 }
376 }
377
378 if !self.tracklets.is_empty() {
379 let _s = tracing::trace_span!("match_high_conf").entered();
380 let costs = self.compute_costs(
381 boxes,
382 self.track_high_conf,
383 self.track_iou,
384 &matched,
385 &tracked,
386 );
387 if let Ok(ans) = lapjv(&costs) {
388 self.process_assignments(
389 &ans.0,
390 boxes,
391 &costs,
392 &mut matched,
393 &mut tracked,
394 &mut matched_info,
395 timestamp,
396 false,
397 );
398 }
399 }
400
401 if !self.tracklets.is_empty() {
403 let _s = tracing::trace_span!("match_low_conf").entered();
404 let costs = self.compute_costs(boxes, 0.0, self.track_iou, &matched, &tracked);
405 if let Ok(ans) = lapjv(&costs) {
406 self.process_assignments(
407 &ans.0,
408 boxes,
409 &costs,
410 &mut matched,
411 &mut tracked,
412 &mut matched_info,
413 timestamp,
414 true,
415 );
416 }
417 }
418
419 self.remove_expired_tracklets(timestamp);
421
422 self.create_new_tracklets(
424 boxes,
425 &high_conf_ind,
426 &matched,
427 &mut matched_info,
428 timestamp,
429 );
430
431 matched_info
432 }
433
434 fn get_active_tracks(&self) -> Vec<ActiveTrackInfo<T>> {
435 self.tracklets
436 .iter()
437 .map(|t| ActiveTrackInfo {
438 info: TrackInfo {
439 uuid: t.id,
440 tracked_location: t.get_predicted_location(),
441 count: t.count,
442 created: t.created,
443 last_updated: t.last_updated,
444 },
445 last_box: t.last_box.clone(),
446 })
447 .collect()
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454 use crate::*;
455
456 #[test]
457 fn test_vaalbox_xyah_roundtrip() {
458 let box1 = [0.0134, 0.02135, 0.12438, 0.691];
459 let xyah = xyxy_to_xyah(&box1);
460 let box2 = xyah_to_xyxy(&xyah);
461
462 assert!((box1[0] - box2[0]).abs() < f32::EPSILON);
463 assert!((box1[1] - box2[1]).abs() < f32::EPSILON);
464 assert!((box1[2] - box2[2]).abs() < f32::EPSILON);
465 assert!((box1[3] - box2[3]).abs() < f32::EPSILON);
466 }
467
468 #[test]
469 fn test_iou_identical_boxes() {
470 let box1 = [0.1, 0.1, 0.5, 0.5];
471 let box2 = [0.1, 0.1, 0.5, 0.5];
472 let result = iou(&box1, &box2);
473 assert!(
474 (result - 1.0).abs() < 0.001,
475 "IOU of identical boxes should be 1.0"
476 );
477 }
478
479 #[test]
480 fn test_iou_no_overlap() {
481 let box1 = [0.0, 0.0, 0.2, 0.2];
482 let box2 = [0.5, 0.5, 0.7, 0.7];
483 let result = iou(&box1, &box2);
484 assert!(result < 0.001, "IOU of non-overlapping boxes should be ~0");
485 }
486
487 #[test]
488 fn test_iou_partial_overlap() {
489 let box1 = [0.0, 0.0, 0.5, 0.5];
490 let box2 = [0.25, 0.25, 0.75, 0.75];
491 let result = iou(&box1, &box2);
492 assert!(result > 0.1 && result < 0.2, "IOU should be ~0.14");
494 }
495
496 #[test]
497 fn test_bytetrack_new() {
498 let tracker: ByteTrack<MockDetection> = ByteTrackBuilder::new().build();
499 assert_eq!(tracker.frame_count, 0);
500 assert!(tracker.tracklets.is_empty());
501 assert_eq!(tracker.track_high_conf, 0.7);
502 assert_eq!(tracker.track_iou, 0.25);
503 }
504
505 #[test]
506 fn test_bytetrack_single_detection_creates_tracklet() {
507 let mut tracker = ByteTrackBuilder::new().build();
508 let detections = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0)];
509
510 let results = tracker.update(&detections, 1000);
511
512 assert_eq!(results.len(), 1);
513 assert!(
514 results[0].is_some(),
515 "High-confidence detection should create tracklet"
516 );
517 assert_eq!(tracker.tracklets.len(), 1);
518 assert_eq!(tracker.frame_count, 1);
519 }
520
521 #[test]
522 fn test_bytetrack_low_confidence_no_tracklet() {
523 let mut tracker = ByteTrackBuilder::new().build();
524 let detections = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.5, 0)];
526
527 let results = tracker.update(&detections, 1000);
528
529 assert_eq!(results.len(), 1);
530 assert!(
531 results[0].is_none(),
532 "Low-confidence detection should not create tracklet"
533 );
534 assert!(tracker.tracklets.is_empty());
535 }
536
537 #[test]
538 fn test_bytetrack_tracking_across_frames() {
539 let mut tracker = ByteTrackBuilder::new().build();
540
541 let det1 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
543 let res1 = tracker.update(&det1, 1000);
544 assert!(res1[0].is_some());
545 let uuid1 = res1[0].unwrap().uuid;
546 assert_eq!(tracker.tracklets.len(), 1);
547 assert_eq!(tracker.tracklets[0].count, 1);
549
550 let det2 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
552 let res2 = tracker.update(&det2, 2000);
553 assert!(res2[0].is_some());
554 let info2 = res2[0].unwrap();
555
556 assert_eq!(tracker.tracklets.len(), 1, "Should still have one tracklet");
558 assert_eq!(info2.uuid, uuid1, "Should match same tracklet");
559 assert_eq!(tracker.tracklets[0].count, 2, "Internal count should be 2");
561 }
562
563 #[test]
564 fn test_bytetrack_multiple_detections() {
565 let mut tracker = ByteTrackBuilder::new().build();
566
567 let detections = vec![
568 MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0),
569 MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0),
570 MockDetection::new([0.8, 0.8, 0.9, 0.9], 0.95, 0),
571 ];
572
573 let results = tracker.update(&detections, 1000);
574
575 assert_eq!(results.len(), 3);
576 assert!(results.iter().all(|r| r.is_some()));
577 assert_eq!(tracker.tracklets.len(), 3);
578 }
579
580 #[test]
581 fn test_bytetrack_tracklet_expiry() {
582 let mut tracker = ByteTrackBuilder::new().build();
583 tracker.track_extra_lifespan = 1000; let det1 = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0)];
587 tracker.update(&det1, 1000);
588 assert_eq!(tracker.tracklets.len(), 1);
589
590 let empty: Vec<MockDetection> = vec![];
592 tracker.update(&empty, 3000); assert!(tracker.tracklets.is_empty(), "Tracklet should have expired");
595 }
596
597 #[test]
598 fn test_bytetrack_get_active_tracks() {
599 let mut tracker = ByteTrackBuilder::new().build();
600
601 let detections = vec![
602 MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0),
603 MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0),
604 ];
605 tracker.update(&detections, 1000);
606
607 let active = tracker.get_active_tracks();
608 assert_eq!(active.len(), 2);
609 assert!(active.iter().all(|t| t.info.count == 1));
610 assert!(active.iter().all(|t| t.info.created == 1000));
611 }
612
613 #[test]
614 fn test_bytetrack_empty_detections() {
615 let mut tracker = ByteTrackBuilder::new().build();
616 let empty: Vec<MockDetection> = vec![];
617
618 let results = tracker.update(&empty, 1000);
619
620 assert!(results.is_empty());
621 assert!(tracker.tracklets.is_empty());
622 assert_eq!(tracker.frame_count, 1);
623 }
624
625 #[test]
626 fn test_two_stage_matching() {
627 let mut tracker = ByteTrackBuilder::new().build();
630
631 let det1 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
633 let res1 = tracker.update(&det1, 1_000_000);
634 assert!(res1[0].is_some());
635 let uuid1 = res1[0].unwrap().uuid;
636 assert_eq!(tracker.tracklets.len(), 1);
637
638 let det2 = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.3, 0)];
641 let res2 = tracker.update(&det2, 2_000_000);
642 assert!(
643 res2[0].is_some(),
644 "Low-conf detection should match existing tracklet via second stage"
645 );
646 assert_eq!(
647 res2[0].unwrap().uuid,
648 uuid1,
649 "Should match the same tracklet"
650 );
651 assert_eq!(
652 tracker.tracklets.len(),
653 1,
654 "No new tracklet should be created"
655 );
656 assert_eq!(
657 tracker.tracklets[0].count, 2,
658 "Tracklet count should increment"
659 );
660 }
661
662 #[test]
663 fn test_builder_track_extra_lifespan() {
664 let lifespan_default = 500_000_000; let lifespan_extended = 2_000_000_000; let mut tracker_default: ByteTrack<MockDetection> = ByteTrackBuilder::new().build();
668 let mut tracker_extended: ByteTrack<MockDetection> = ByteTrackBuilder::new()
669 .track_extra_lifespan(lifespan_extended)
670 .build();
671
672 assert_eq!(tracker_default.track_extra_lifespan, lifespan_default);
673 assert_eq!(tracker_extended.track_extra_lifespan, lifespan_extended);
674
675 let ts_start = 1_000_000_000u64; let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
677
678 tracker_default.update(&det, ts_start);
679 tracker_extended.update(&det, ts_start);
680 assert_eq!(tracker_default.tracklets.len(), 1);
681 assert_eq!(tracker_extended.tracklets.len(), 1);
682
683 let ts_after = ts_start + 1_000_000_000;
686 let empty: Vec<MockDetection> = vec![];
687 tracker_default.update(&empty, ts_after);
688 tracker_extended.update(&empty, ts_after);
689
690 assert!(
691 tracker_default.tracklets.is_empty(),
692 "Default tracker should have expired the tracklet"
693 );
694 assert_eq!(
695 tracker_extended.tracklets.len(),
696 1,
697 "Extended tracker should still have the tracklet"
698 );
699 }
700
701 #[test]
702 fn test_builder_track_high_conf() {
703 let mut tracker: ByteTrack<MockDetection> =
704 ByteTrackBuilder::new().track_high_conf(0.9).build();
705 assert_eq!(tracker.track_high_conf, 0.9);
706
707 let det_low = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.8, 0)];
709 let res = tracker.update(&det_low, 1000);
710 assert!(
711 res[0].is_none(),
712 "Score 0.8 should not create a tracklet with threshold 0.9"
713 );
714 assert!(tracker.tracklets.is_empty());
715
716 let det_high = vec![MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.95, 0)];
718 let res = tracker.update(&det_high, 2000);
719 assert!(
720 res[0].is_some(),
721 "Score 0.95 should create a tracklet with threshold 0.9"
722 );
723 assert_eq!(tracker.tracklets.len(), 1);
724 }
725
726 #[test]
727 fn test_builder_track_iou() {
728 let mut tracker: ByteTrack<MockDetection> = ByteTrackBuilder::new().track_iou(0.8).build();
730
731 let det1 = vec![
733 MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0),
734 MockDetection::new([0.5, 0.5, 0.7, 0.7], 0.9, 0),
735 ];
736 tracker.update(&det1, 1000);
737 assert_eq!(tracker.tracklets.len(), 2);
738
739 let det2 = vec![
742 MockDetection::new([0.15, 0.15, 0.35, 0.35], 0.9, 0),
743 MockDetection::new([0.5, 0.5, 0.7, 0.7], 0.9, 0),
744 ];
745 let res2 = tracker.update(&det2, 2000);
746 assert_eq!(res2.len(), 2);
747
748 assert!(
751 tracker.tracklets.len() >= 3,
752 "Shifted detection should create a new tracklet with tight IOU threshold, got {} tracklets",
753 tracker.tracklets.len()
754 );
755 }
756
757 #[test]
758 fn test_degenerate_zero_area_box() {
759 let mut tracker = ByteTrackBuilder::new().build();
761 let det = vec![
762 MockDetection::new([0.5, 0.1, 0.5, 0.3], 0.9, 0), MockDetection::new([0.1, 0.1, 0.3, 0.3], 0.9, 0), ];
765 let results = tracker.update(&det, 1000);
766 assert_eq!(results.len(), 2);
767
768 let zero_box = [0.5, 0.1, 0.5, 0.3];
770 let normal_box = [0.1, 0.1, 0.3, 0.3];
771 let iou_val = iou(&zero_box, &normal_box);
772 assert!(
773 iou_val < EPSILON,
774 "IOU with a zero-area box should be ~0, got {iou_val}"
775 );
776 }
777
778 #[test]
779 fn test_degenerate_high_velocity() {
780 let mut tracker = ByteTrackBuilder::new().build();
781
782 let det1 = vec![MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0)];
784 let res1 = tracker.update(&det1, 1_000_000);
785 assert!(res1[0].is_some());
786 let uuid1 = res1[0].unwrap().uuid;
787 assert_eq!(tracker.tracklets.len(), 1);
788
789 let det2 = vec![MockDetection::new([0.8, 0.8, 0.9, 0.9], 0.9, 0)];
791 let res2 = tracker.update(&det2, 2_000_000);
792 assert!(res2[0].is_some());
793
794 assert_eq!(
797 tracker.tracklets.len(),
798 2,
799 "Far-displaced detection should create a new tracklet"
800 );
801 assert_ne!(
802 res2[0].unwrap().uuid,
803 uuid1,
804 "New detection should have a different UUID"
805 );
806 }
807
808 #[test]
809 fn test_many_detections_100() {
810 let mut tracker = ByteTrackBuilder::new().build();
811
812 let detections: Vec<MockDetection> = (0..100)
814 .map(|i| {
815 let x = (i % 10) as f32 * 0.1;
816 let y = (i / 10) as f32 * 0.1;
817 MockDetection::new([x, y, x + 0.05, y + 0.05], 0.9, 0)
818 })
819 .collect();
820
821 let results = tracker.update(&detections, 1000);
822 assert_eq!(results.len(), 100);
823 assert!(
824 results.iter().all(|r| r.is_some()),
825 "All 100 high-confidence detections should create tracklets"
826 );
827 assert_eq!(
828 tracker.tracklets.len(),
829 100,
830 "Should have 100 active tracklets"
831 );
832 }
833
834 #[test]
835 fn test_tracklet_count_increments_each_frame() {
836 let mut tracker = ByteTrackBuilder::new().build();
837 let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
838
839 for frame in 1..=5 {
840 tracker.update(&det, frame * 1000);
841 }
842
843 assert_eq!(tracker.tracklets.len(), 1);
844 assert_eq!(
845 tracker.tracklets[0].count, 5,
846 "Tracklet count should equal number of frames it was matched"
847 );
848 }
849
850 #[test]
851 fn test_tracklet_created_timestamp_preserved() {
852 let mut tracker = ByteTrackBuilder::new().build();
853 let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
854
855 tracker.update(&det, 1000);
856 tracker.update(&det, 2000);
857 tracker.update(&det, 3000);
858
859 let active = tracker.get_active_tracks();
860 assert_eq!(active.len(), 1);
861 assert_eq!(
862 active[0].info.created, 1000,
863 "Created timestamp should remain at the first frame"
864 );
865 assert_eq!(
866 active[0].info.last_updated, 3000,
867 "Last updated should be the most recent frame"
868 );
869 }
870
871 #[test]
872 fn test_mixed_confidence_detections() {
873 let mut tracker = ByteTrackBuilder::new().build();
875 let det = vec![
876 MockDetection::new([0.1, 0.1, 0.2, 0.2], 0.9, 0), MockDetection::new([0.3, 0.3, 0.4, 0.4], 0.3, 0), MockDetection::new([0.5, 0.5, 0.6, 0.6], 0.85, 0), MockDetection::new([0.7, 0.7, 0.8, 0.8], 0.1, 0), ];
881
882 let results = tracker.update(&det, 1000);
883 assert_eq!(results.len(), 4);
884
885 assert!(
887 results[0].is_some(),
888 "High-conf detection should create tracklet"
889 );
890 assert!(
891 results[1].is_none(),
892 "Low-conf detection should not create tracklet"
893 );
894 assert!(
895 results[2].is_some(),
896 "High-conf detection should create tracklet"
897 );
898 assert!(
899 results[3].is_none(),
900 "Low-conf detection should not create tracklet"
901 );
902 assert_eq!(tracker.tracklets.len(), 2);
903 }
904
905 #[test]
906 fn test_iou_contained_box() {
907 let outer = [0.0, 0.0, 1.0, 1.0];
909 let inner = [0.25, 0.25, 0.75, 0.75];
910 let result = iou(&outer, &inner);
911 assert!(
913 (result - 0.25).abs() < 0.01,
914 "IOU of contained box should be inner_area/outer_area = 0.25, got {result}"
915 );
916 }
917
918 #[test]
919 fn test_xyxy_to_xyah_square_box() {
920 let square = [0.1, 0.2, 0.3, 0.4];
922 let xyah = xyxy_to_xyah(&square);
923 assert!((xyah[0] - 0.2).abs() < 1e-5, "Center x should be 0.2");
924 assert!((xyah[1] - 0.3).abs() < 1e-5, "Center y should be 0.3");
925 assert!(
926 (xyah[2] - 1.0).abs() < 1e-5,
927 "Aspect ratio of square should be 1.0"
928 );
929 assert!((xyah[3] - 0.2).abs() < 1e-5, "Height should be 0.2");
930 }
931
932 #[test]
933 fn test_frame_count_increments() {
934 let mut tracker = ByteTrackBuilder::new().build();
935 let empty: Vec<MockDetection> = vec![];
936
937 for _ in 0..10 {
938 tracker.update(&empty, 0);
939 }
940
941 assert_eq!(
942 tracker.frame_count, 10,
943 "Frame count should increment each update"
944 );
945 }
946
947 #[test]
948 fn test_tracklet_predicted_location_near_detection() {
949 let mut tracker = ByteTrackBuilder::new().build();
950 let det = vec![MockDetection::new([0.2, 0.2, 0.4, 0.4], 0.9, 0)];
951 tracker.update(&det, 1000);
952
953 let pred = tracker.tracklets[0].get_predicted_location();
954 assert!(
956 (pred[0] - 0.2).abs() < 0.1,
957 "Predicted xmin should be near 0.2, got {}",
958 pred[0]
959 );
960 assert!(
961 (pred[1] - 0.2).abs() < 0.1,
962 "Predicted ymin should be near 0.2, got {}",
963 pred[1]
964 );
965 assert!(
966 (pred[2] - 0.4).abs() < 0.1,
967 "Predicted xmax should be near 0.4, got {}",
968 pred[2]
969 );
970 assert!(
971 (pred[3] - 0.4).abs() < 0.1,
972 "Predicted ymax should be near 0.4, got {}",
973 pred[3]
974 );
975 }
976}