1use crate::{
5 kalman::ConstantVelocityXYAHModel2, ActiveTrackInfo, DetectionBox, TrackInfo, Tracker,
6};
7use lapjv::{lapjv, Matrix};
8use log::trace;
9use nalgebra::{Dyn, OMatrix, U4};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct ByteTrackBuilder {
14 track_extra_lifespan: u64,
15 track_high_conf: f32,
16 track_iou: f32,
17 track_update: f32,
18}
19
20impl Default for ByteTrackBuilder {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl ByteTrackBuilder {
27 pub fn new() -> Self {
45 Self {
46 track_high_conf: 0.7,
47 track_iou: 0.25,
48 track_update: 0.25,
49 track_extra_lifespan: 500_000_000,
50 }
51 }
52
53 pub fn track_extra_lifespan(mut self, lifespan: u64) -> Self {
55 self.track_extra_lifespan = lifespan;
56 self
57 }
58
59 pub fn track_high_conf(mut self, conf: f32) -> Self {
61 self.track_high_conf = conf;
62 self
63 }
64
65 pub fn track_iou(mut self, iou: f32) -> Self {
67 self.track_iou = iou;
68 self
69 }
70
71 pub fn track_update(mut self, update: f32) -> Self {
73 self.track_update = update;
74 self
75 }
76
77 pub fn build<T: DetectionBox>(self) -> ByteTrack<T> {
95 ByteTrack {
96 track_extra_lifespan: self.track_extra_lifespan,
97 track_high_conf: self.track_high_conf,
98 track_iou: self.track_iou,
99 track_update: self.track_update,
100 tracklets: Vec::new(),
101 frame_count: 0,
102 }
103 }
104}
105
106#[allow(dead_code)]
107#[derive(Default, Debug, Clone)]
108pub struct ByteTrack<T: DetectionBox> {
109 pub track_extra_lifespan: u64,
110 pub track_high_conf: f32,
111 pub track_iou: f32,
112 pub track_update: f32,
113 pub tracklets: Vec<Tracklet<T>>,
114 pub frame_count: i32,
115}
116
117#[derive(Debug, Clone)]
118pub struct Tracklet<T: DetectionBox> {
119 pub id: Uuid,
120 pub filter: ConstantVelocityXYAHModel2<f32>,
121 pub count: i32,
122 pub created: u64,
123 pub last_updated: u64,
124 pub last_box: T,
125}
126
127impl<T: DetectionBox> Tracklet<T> {
128 fn update(&mut self, detect_box: &T, ts: u64) {
129 self.count += 1;
130 self.last_updated = ts;
131 self.filter.update(&xyxy_to_xyah(&detect_box.bbox()));
132 self.last_box = detect_box.clone();
133 }
134
135 pub fn get_predicted_location(&self) -> [f32; 4] {
136 let projected = self.filter.project().0;
137 let predicted_xyah = projected.as_slice();
138 xyah_to_xyxy(predicted_xyah)
139 }
140}
141
142fn xyxy_to_xyah(vaal_box: &[f32; 4]) -> [f32; 4] {
143 let x = (vaal_box[2] + vaal_box[0]) / 2.0;
144 let y = (vaal_box[3] + vaal_box[1]) / 2.0;
145 let w = (vaal_box[2] - vaal_box[0]).max(EPSILON);
146 let h = (vaal_box[3] - vaal_box[1]).max(EPSILON);
147 let a = w / h;
148
149 [x, y, a, h]
150}
151
152fn xyah_to_xyxy(xyah: &[f32]) -> [f32; 4] {
153 assert!(xyah.len() >= 4);
154 let [x, y, a, h] = xyah[0..4] else {
155 unreachable!()
156 };
157 let w = h * a;
158 [x - w / 2.0, y - h / 2.0, x + w / 2.0, y + h / 2.0]
159}
160
161const INVALID_MATCH: f32 = 1000000.0;
162const EPSILON: f32 = 0.00001;
163
164fn iou(box1: &[f32], box2: &[f32]) -> f32 {
165 let intersection = (box1[2].min(box2[2]) - box1[0].max(box2[0])).max(0.0)
166 * (box1[3].min(box2[3]) - box1[1].max(box2[1])).max(0.0);
167
168 let union = (box1[2] - box1[0]) * (box1[3] - box1[1])
169 + (box2[2] - box2[0]) * (box2[3] - box2[1])
170 - intersection;
171
172 if union <= EPSILON {
173 return 0.0;
174 }
175
176 intersection / union
177}
178
179fn box_cost<T: DetectionBox>(
180 track: &Tracklet<T>,
181 new_box: &T,
182 distance: f32,
183 score_threshold: f32,
184 iou_threshold: f32,
185) -> f32 {
186 let _ = distance;
187
188 if new_box.score() < score_threshold {
189 return INVALID_MATCH;
190 }
191
192 let predicted_xyah = track.filter.mean.as_slice();
194 let expected = xyah_to_xyxy(predicted_xyah);
195 let iou = iou(&expected, &new_box.bbox());
196 if iou < iou_threshold {
197 return INVALID_MATCH;
198 }
199 (1.5 - new_box.score()) + (1.5 - iou)
200}
201
202impl<T: DetectionBox> ByteTrack<T> {
203 fn compute_costs(
204 &mut self,
205 boxes: &[T],
206 score_threshold: f32,
207 iou_threshold: f32,
208 box_filter: &[bool],
209 track_filter: &[bool],
210 ) -> Matrix<f32> {
211 let dims = boxes.len().max(self.tracklets.len());
213 let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(boxes.len(), 0.0);
214 for (i, mut row) in measurements.row_iter_mut().enumerate() {
215 row.copy_from_slice(&xyxy_to_xyah(&boxes[i].bbox()));
216 }
217
218 Matrix::from_shape_fn((dims, dims), |(x, y)| {
222 if x < boxes.len() && y < self.tracklets.len() {
223 if box_filter[x] || track_filter[y] {
224 INVALID_MATCH
225 } else {
226 box_cost(
227 &self.tracklets[y],
228 &boxes[x],
229 0.0,
231 score_threshold,
232 iou_threshold,
233 )
234 }
235 } else {
236 0.0
237 }
238 })
239 }
240
241 #[allow(clippy::too_many_arguments)]
244 fn process_assignments(
245 &mut self,
246 assignments: &[usize],
247 boxes: &[T],
248 costs: &Matrix<f32>,
249 matched: &mut [bool],
250 tracked: &mut [bool],
251 matched_info: &mut [Option<TrackInfo>],
252 timestamp: u64,
253 log_assignments: bool,
254 ) {
255 for (i, &x) in assignments.iter().enumerate() {
256 if i >= boxes.len() || x >= self.tracklets.len() {
257 continue;
258 }
259
260 if costs[(i, x)] >= INVALID_MATCH {
262 continue;
263 }
264
265 if matched[i] || tracked[x] {
267 continue;
268 }
269
270 if log_assignments {
271 trace!(
272 "Cost: {} Box: {:#?} UUID: {} Mean: {}",
273 costs[(i, x)],
274 boxes[i],
275 self.tracklets[x].id,
276 self.tracklets[x].filter.mean
277 );
278 }
279
280 matched[i] = true;
281 matched_info[i] = Some(TrackInfo {
282 uuid: self.tracklets[x].id,
283 count: self.tracklets[x].count,
284 created: self.tracklets[x].created,
285 tracked_location: self.tracklets[x].get_predicted_location(),
286 last_updated: timestamp,
287 });
288 tracked[x] = true;
289 self.tracklets[x].update(&boxes[i], timestamp);
290 }
291 }
292
293 fn remove_expired_tracklets(&mut self, timestamp: u64) {
295 for i in (0..self.tracklets.len()).rev() {
297 let expiry = self.tracklets[i].last_updated + self.track_extra_lifespan;
298 if expiry < timestamp {
299 trace!("Tracklet removed: {:?}", self.tracklets[i].id);
300 let _ = self.tracklets.swap_remove(i);
301 }
302 }
303 }
304
305 fn create_new_tracklets(
307 &mut self,
308 boxes: &[T],
309 high_conf_indices: &[usize],
310 matched: &[bool],
311 matched_info: &mut [Option<TrackInfo>],
312 timestamp: u64,
313 ) {
314 for &i in high_conf_indices {
315 if matched[i] {
316 continue;
317 }
318
319 let id = Uuid::new_v4();
320 let new_tracklet = Tracklet {
321 id,
322 filter: ConstantVelocityXYAHModel2::new(
323 &xyxy_to_xyah(&boxes[i].bbox()),
324 self.track_update,
325 ),
326 last_updated: timestamp,
327 count: 1,
328 created: timestamp,
329 last_box: boxes[i].clone(),
330 };
331 matched_info[i] = Some(TrackInfo {
332 uuid: new_tracklet.id,
333 count: new_tracklet.count,
334 created: new_tracklet.created,
335 tracked_location: new_tracklet.get_predicted_location(),
336 last_updated: timestamp,
337 });
338 self.tracklets.push(new_tracklet);
339 }
340 }
341}
342
343impl<T> Tracker<T> for ByteTrack<T>
344where
345 T: DetectionBox,
346{
347 fn update(&mut self, boxes: &[T], timestamp: u64) -> Vec<Option<TrackInfo>> {
348 self.frame_count += 1;
349
350 let high_conf_ind: Vec<usize> = boxes
352 .iter()
353 .enumerate()
354 .filter(|(_, b)| b.score() >= self.track_high_conf)
355 .map(|(x, _)| x)
356 .collect();
357
358 let mut matched = vec![false; boxes.len()];
359 let mut tracked = vec![false; self.tracklets.len()];
360 let mut matched_info = vec![None; boxes.len()];
361
362 if !self.tracklets.is_empty() {
364 for track in &mut self.tracklets {
365 track.filter.predict();
366 }
367
368 let costs = self.compute_costs(
369 boxes,
370 self.track_high_conf,
371 self.track_iou,
372 &matched,
373 &tracked,
374 );
375 if let Ok(ans) = lapjv(&costs) {
376 self.process_assignments(
377 &ans.0,
378 boxes,
379 &costs,
380 &mut matched,
381 &mut tracked,
382 &mut matched_info,
383 timestamp,
384 false,
385 );
386 }
387 }
388
389 if !self.tracklets.is_empty() {
391 let costs = self.compute_costs(boxes, 0.0, self.track_iou, &matched, &tracked);
392 if let Ok(ans) = lapjv(&costs) {
393 self.process_assignments(
394 &ans.0,
395 boxes,
396 &costs,
397 &mut matched,
398 &mut tracked,
399 &mut matched_info,
400 timestamp,
401 true,
402 );
403 }
404 }
405
406 self.remove_expired_tracklets(timestamp);
408
409 self.create_new_tracklets(
411 boxes,
412 &high_conf_ind,
413 &matched,
414 &mut matched_info,
415 timestamp,
416 );
417
418 matched_info
419 }
420
421 fn get_active_tracks(&self) -> Vec<ActiveTrackInfo<T>> {
422 self.tracklets
423 .iter()
424 .map(|t| ActiveTrackInfo {
425 info: TrackInfo {
426 uuid: t.id,
427 tracked_location: t.get_predicted_location(),
428 count: t.count,
429 created: t.created,
430 last_updated: t.last_updated,
431 },
432 last_box: t.last_box.clone(),
433 })
434 .collect()
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::*;
442
443 impl MockDetection {
444 fn new(x1: f32, y1: f32, x2: f32, y2: f32, score: f32) -> Self {
445 Self {
446 bbox: [x1, y1, x2, y2],
447 score,
448 label: 0,
449 }
450 }
451 }
452
453 #[test]
454 fn test_vaalbox_xyah_roundtrip() {
455 let box1 = [0.0134, 0.02135, 0.12438, 0.691];
456 let xyah = xyxy_to_xyah(&box1);
457 let box2 = xyah_to_xyxy(&xyah);
458
459 assert!((box1[0] - box2[0]).abs() < f32::EPSILON);
460 assert!((box1[1] - box2[1]).abs() < f32::EPSILON);
461 assert!((box1[2] - box2[2]).abs() < f32::EPSILON);
462 assert!((box1[3] - box2[3]).abs() < f32::EPSILON);
463 }
464
465 #[test]
466 fn test_iou_identical_boxes() {
467 let box1 = [0.1, 0.1, 0.5, 0.5];
468 let box2 = [0.1, 0.1, 0.5, 0.5];
469 let result = iou(&box1, &box2);
470 assert!(
471 (result - 1.0).abs() < 0.001,
472 "IOU of identical boxes should be 1.0"
473 );
474 }
475
476 #[test]
477 fn test_iou_no_overlap() {
478 let box1 = [0.0, 0.0, 0.2, 0.2];
479 let box2 = [0.5, 0.5, 0.7, 0.7];
480 let result = iou(&box1, &box2);
481 assert!(result < 0.001, "IOU of non-overlapping boxes should be ~0");
482 }
483
484 #[test]
485 fn test_iou_partial_overlap() {
486 let box1 = [0.0, 0.0, 0.5, 0.5];
487 let box2 = [0.25, 0.25, 0.75, 0.75];
488 let result = iou(&box1, &box2);
489 assert!(result > 0.1 && result < 0.2, "IOU should be ~0.14");
491 }
492
493 #[test]
494 fn test_bytetrack_new() {
495 let tracker: ByteTrack<MockDetection> = ByteTrackBuilder::new().build();
496 assert_eq!(tracker.frame_count, 0);
497 assert!(tracker.tracklets.is_empty());
498 assert_eq!(tracker.track_high_conf, 0.7);
499 assert_eq!(tracker.track_iou, 0.25);
500 }
501
502 #[test]
503 fn test_bytetrack_single_detection_creates_tracklet() {
504 let mut tracker = ByteTrackBuilder::new().build();
505 let detections = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.9)];
506
507 let results = tracker.update(&detections, 1000);
508
509 assert_eq!(results.len(), 1);
510 assert!(
511 results[0].is_some(),
512 "High-confidence detection should create tracklet"
513 );
514 assert_eq!(tracker.tracklets.len(), 1);
515 assert_eq!(tracker.frame_count, 1);
516 }
517
518 #[test]
519 fn test_bytetrack_low_confidence_no_tracklet() {
520 let mut tracker = ByteTrackBuilder::new().build();
521 let detections = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.5)];
523
524 let results = tracker.update(&detections, 1000);
525
526 assert_eq!(results.len(), 1);
527 assert!(
528 results[0].is_none(),
529 "Low-confidence detection should not create tracklet"
530 );
531 assert!(tracker.tracklets.is_empty());
532 }
533
534 #[test]
535 fn test_bytetrack_tracking_across_frames() {
536 let mut tracker = ByteTrackBuilder::new().build();
537
538 let det1 = vec![MockDetection::new(0.2, 0.2, 0.4, 0.4, 0.9)];
540 let res1 = tracker.update(&det1, 1000);
541 assert!(res1[0].is_some());
542 let uuid1 = res1[0].unwrap().uuid;
543 assert_eq!(tracker.tracklets.len(), 1);
544 assert_eq!(tracker.tracklets[0].count, 1);
546
547 let det2 = vec![MockDetection::new(0.2, 0.2, 0.4, 0.4, 0.9)];
549 let res2 = tracker.update(&det2, 2000);
550 assert!(res2[0].is_some());
551 let info2 = res2[0].unwrap();
552
553 assert_eq!(tracker.tracklets.len(), 1, "Should still have one tracklet");
555 assert_eq!(info2.uuid, uuid1, "Should match same tracklet");
556 assert_eq!(tracker.tracklets[0].count, 2, "Internal count should be 2");
558 }
559
560 #[test]
561 fn test_bytetrack_multiple_detections() {
562 let mut tracker = ByteTrackBuilder::new().build();
563
564 let detections = vec![
565 MockDetection::new(0.1, 0.1, 0.2, 0.2, 0.9),
566 MockDetection::new(0.5, 0.5, 0.6, 0.6, 0.85),
567 MockDetection::new(0.8, 0.8, 0.9, 0.9, 0.95),
568 ];
569
570 let results = tracker.update(&detections, 1000);
571
572 assert_eq!(results.len(), 3);
573 assert!(results.iter().all(|r| r.is_some()));
574 assert_eq!(tracker.tracklets.len(), 3);
575 }
576
577 #[test]
578 fn test_bytetrack_tracklet_expiry() {
579 let mut tracker = ByteTrackBuilder::new().build();
580 tracker.track_extra_lifespan = 1000; let det1 = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.9)];
584 tracker.update(&det1, 1000);
585 assert_eq!(tracker.tracklets.len(), 1);
586
587 let empty: Vec<MockDetection> = vec![];
589 tracker.update(&empty, 3000); assert!(tracker.tracklets.is_empty(), "Tracklet should have expired");
592 }
593
594 #[test]
595 fn test_bytetrack_get_active_tracks() {
596 let mut tracker = ByteTrackBuilder::new().build();
597
598 let detections = vec![
599 MockDetection::new(0.1, 0.1, 0.2, 0.2, 0.9),
600 MockDetection::new(0.5, 0.5, 0.6, 0.6, 0.85),
601 ];
602 tracker.update(&detections, 1000);
603
604 let active = tracker.get_active_tracks();
605 assert_eq!(active.len(), 2);
606 assert!(active.iter().all(|t| t.info.count == 1));
607 assert!(active.iter().all(|t| t.info.created == 1000));
608 }
609
610 #[test]
611 fn test_bytetrack_empty_detections() {
612 let mut tracker = ByteTrackBuilder::new().build();
613 let empty: Vec<MockDetection> = vec![];
614
615 let results = tracker.update(&empty, 1000);
616
617 assert!(results.is_empty());
618 assert!(tracker.tracklets.is_empty());
619 assert_eq!(tracker.frame_count, 1);
620 }
621}