1use crate::{kalman::ConstantVelocityXYAHModel2, DetectionBox, TrackInfo, Tracker};
5use lapjv::{lapjv, Matrix};
6use log::{debug, trace};
7use nalgebra::{Dyn, OMatrix, U4};
8use uuid::Uuid;
9
10#[allow(dead_code)]
11#[derive(Default)]
12pub struct ByteTrack {
13 pub track_extra_lifespan: u64,
14 pub track_high_conf: f32,
15 pub track_iou: f32,
16 pub track_update: f32,
17
18 pub tracklets: Vec<Tracklet>,
19 pub frame_count: i32,
20}
21
22#[derive(Debug, Clone)]
23pub struct Tracklet {
24 pub id: Uuid,
25 pub filter: ConstantVelocityXYAHModel2<f32>,
26 pub count: i32,
27 pub created: u64,
28 pub last_updated: u64,
29}
30
31impl Tracklet {
32 fn update<T: DetectionBox>(&mut self, detect_box: &T, ts: u64) {
33 self.count += 1;
34 self.last_updated = ts;
35 self.filter.update(&vaalbox_to_xyah(&detect_box.bbox()));
36 }
37
38 pub fn get_predicted_location(&self) -> [f32; 4] {
39 let predicted_xyah = self.filter.mean.as_slice();
40 xyah_to_vaalbox(predicted_xyah)
41 }
42}
43
44fn vaalbox_to_xyah(vaal_box: &[f32; 4]) -> [f32; 4] {
45 let x = (vaal_box[2] + vaal_box[0]) / 2.0;
46 let y = (vaal_box[3] + vaal_box[1]) / 2.0;
47 let w = (vaal_box[2] - vaal_box[0]).max(EPSILON);
48 let h = (vaal_box[3] - vaal_box[1]).max(EPSILON);
49 let a = w / h;
50
51 [x, y, a, h]
52}
53
54fn xyah_to_vaalbox(xyah: &[f32]) -> [f32; 4] {
55 assert!(xyah.len() >= 4);
56 let [x, y, a, h] = xyah[0..4] else {
57 unreachable!()
58 };
59 let w = h * a;
60 [x - w / 2.0, y - h / 2.0, x + w / 2.0, y + h / 2.0]
61}
62
63const INVALID_MATCH: f32 = 1000000.0;
64const EPSILON: f32 = 0.00001;
65
66fn iou(box1: &[f32], box2: &[f32]) -> f32 {
67 let intersection = (box1[2].min(box2[2]) - box1[0].max(box2[0])).max(0.0)
68 * (box1[3].min(box2[3]) - box1[1].max(box2[1])).max(0.0);
69
70 let union = (box1[2] - box1[0]) * (box1[3] - box1[1])
71 + (box2[2] - box2[0]) * (box2[3] - box2[1])
72 - intersection;
73
74 if union <= EPSILON {
75 return 0.0;
76 }
77
78 intersection / union
79}
80
81fn box_cost<T: DetectionBox>(
82 track: &Tracklet,
83 new_box: &T,
84 distance: f32,
85 score_threshold: f32,
86 iou_threshold: f32,
87) -> f32 {
88 let _ = distance;
89
90 if new_box.score() < score_threshold {
91 return INVALID_MATCH;
92 }
93
94 let predicted_xyah = track.filter.mean.as_slice();
96 let expected = xyah_to_vaalbox(predicted_xyah);
97 let iou = iou(&expected, &new_box.bbox());
98 if iou < iou_threshold {
99 return INVALID_MATCH;
100 }
101 (1.5 - new_box.score()) + (1.5 - iou)
102}
103
104impl ByteTrack {
105 pub fn new() -> ByteTrack {
106 ByteTrack {
107 track_extra_lifespan: 500_000_000,
108 track_high_conf: 0.7,
109 track_iou: 0.25,
110 track_update: 0.25,
111 tracklets: Vec::new(),
112 frame_count: 0,
113 }
114 }
115
116 fn compute_costs<T: DetectionBox>(
117 &mut self,
118 boxes: &[T],
119 score_threshold: f32,
120 iou_threshold: f32,
121 box_filter: &[bool],
122 track_filter: &[bool],
123 ) -> Matrix<f32> {
124 let dims = boxes.len().max(self.tracklets.len());
126 let mut measurements = OMatrix::<f32, Dyn, U4>::from_element(boxes.len(), 0.0);
127 for (i, mut row) in measurements.row_iter_mut().enumerate() {
128 row.copy_from_slice(&vaalbox_to_xyah(&boxes[i].bbox()));
129 }
130
131 Matrix::from_shape_fn((dims, dims), |(x, y)| {
135 if x < boxes.len() && y < self.tracklets.len() {
136 if box_filter[x] || track_filter[y] {
137 INVALID_MATCH
138 } else {
139 box_cost(
140 &self.tracklets[y],
141 &boxes[x],
142 0.0,
144 score_threshold,
145 iou_threshold,
146 )
147 }
148 } else {
149 0.0
150 }
151 })
152 }
153
154 #[allow(clippy::too_many_arguments)]
157 fn process_assignments<T: DetectionBox>(
158 &mut self,
159 assignments: &[usize],
160 boxes: &[T],
161 costs: &Matrix<f32>,
162 matched: &mut [bool],
163 tracked: &mut [bool],
164 matched_info: &mut [Option<TrackInfo>],
165 timestamp: u64,
166 skip_already_matched: bool,
167 ) {
168 for (i, &x) in assignments.iter().enumerate() {
169 if i >= boxes.len() || x >= self.tracklets.len() {
170 continue;
171 }
172
173 if costs[(i, x)] >= INVALID_MATCH {
175 continue;
176 }
177
178 if skip_already_matched && (matched[i] || tracked[x]) {
180 continue;
181 }
182
183 if skip_already_matched {
184 trace!(
185 "Cost: {} Box: {:#?} UUID: {} Mean: {}",
186 costs[(i, x)],
187 boxes[i],
188 self.tracklets[x].id,
189 self.tracklets[x].filter.mean
190 );
191 }
192
193 matched[i] = true;
194 matched_info[i] = Some(TrackInfo {
195 uuid: self.tracklets[x].id,
196 count: self.tracklets[x].count,
197 created: self.tracklets[x].created,
198 tracked_location: self.tracklets[x].get_predicted_location(),
199 last_updated: timestamp,
200 });
201 assert!(!tracked[x]);
202 tracked[x] = true;
203 self.tracklets[x].update(&boxes[i], timestamp);
204 }
205 }
206
207 fn remove_expired_tracklets(&mut self, timestamp: u64) {
209 for i in (0..self.tracklets.len()).rev() {
211 let expiry = self.tracklets[i].last_updated + self.track_extra_lifespan;
212 if expiry < timestamp {
213 debug!("Tracklet removed: {:?}", self.tracklets[i].id);
214 let _ = self.tracklets.swap_remove(i);
215 }
216 }
217 }
218
219 fn create_new_tracklets<T: DetectionBox>(
221 &mut self,
222 boxes: &[T],
223 high_conf_indices: &[usize],
224 matched: &[bool],
225 matched_info: &mut [Option<TrackInfo>],
226 timestamp: u64,
227 ) {
228 for &i in high_conf_indices {
229 if matched[i] {
230 continue;
231 }
232
233 let id = Uuid::new_v4();
234 let new_tracklet = Tracklet {
235 id,
236 filter: ConstantVelocityXYAHModel2::new(
237 &vaalbox_to_xyah(&boxes[i].bbox()),
238 self.track_update,
239 ),
240 last_updated: timestamp,
241 count: 1,
242 created: timestamp,
243 };
244 matched_info[i] = Some(TrackInfo {
245 uuid: new_tracklet.id,
246 count: new_tracklet.count,
247 created: new_tracklet.created,
248 tracked_location: new_tracklet.get_predicted_location(),
249 last_updated: timestamp,
250 });
251 self.tracklets.push(new_tracklet);
252 }
253 }
254}
255
256impl<T> Tracker<T> for ByteTrack
257where
258 T: DetectionBox,
259{
260 fn update(&mut self, boxes: &[T], timestamp: u64) -> Vec<Option<TrackInfo>> {
261 self.frame_count += 1;
262
263 let high_conf_ind: Vec<usize> = boxes
265 .iter()
266 .enumerate()
267 .filter(|(_, b)| b.score() >= self.track_high_conf)
268 .map(|(x, _)| x)
269 .collect();
270
271 let mut matched = vec![false; boxes.len()];
272 let mut tracked = vec![false; self.tracklets.len()];
273 let mut matched_info = vec![None; boxes.len()];
274
275 if !self.tracklets.is_empty() {
277 for track in &mut self.tracklets {
278 track.filter.predict();
279 }
280
281 let costs = self.compute_costs(
282 boxes,
283 self.track_high_conf,
284 self.track_iou,
285 &matched,
286 &tracked,
287 );
288 let ans = lapjv(&costs).unwrap();
289 self.process_assignments(
290 &ans.0,
291 boxes,
292 &costs,
293 &mut matched,
294 &mut tracked,
295 &mut matched_info,
296 timestamp,
297 false,
298 );
299 }
300
301 if !self.tracklets.is_empty() {
303 let costs = self.compute_costs(boxes, 0.0, self.track_iou, &matched, &tracked);
304 let ans = lapjv(&costs).unwrap();
305 self.process_assignments(
306 &ans.0,
307 boxes,
308 &costs,
309 &mut matched,
310 &mut tracked,
311 &mut matched_info,
312 timestamp,
313 true,
314 );
315 }
316
317 self.remove_expired_tracklets(timestamp);
319
320 self.create_new_tracklets(
322 boxes,
323 &high_conf_ind,
324 &matched,
325 &mut matched_info,
326 timestamp,
327 );
328
329 matched_info
330 }
331
332 fn get_active_tracks(&self) -> Vec<TrackInfo> {
333 self.tracklets
334 .iter()
335 .map(|t| TrackInfo {
336 uuid: t.id,
337 tracked_location: t.get_predicted_location(),
338 count: t.count,
339 created: t.created,
340 last_updated: t.last_updated,
341 })
342 .collect()
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::{iou, vaalbox_to_xyah, xyah_to_vaalbox, ByteTrack};
349 use crate::{DetectionBox, Tracker};
350
351 #[derive(Debug, Clone)]
353 struct MockDetection {
354 bbox: [f32; 4],
355 score: f32,
356 label: usize,
357 }
358
359 impl MockDetection {
360 fn new(x1: f32, y1: f32, x2: f32, y2: f32, score: f32) -> Self {
361 Self {
362 bbox: [x1, y1, x2, y2],
363 score,
364 label: 0,
365 }
366 }
367 }
368
369 impl DetectionBox for MockDetection {
370 fn bbox(&self) -> [f32; 4] {
371 self.bbox
372 }
373
374 fn score(&self) -> f32 {
375 self.score
376 }
377
378 fn label(&self) -> usize {
379 self.label
380 }
381 }
382
383 #[test]
384 fn test_vaalbox_xyah_roundtrip() {
385 let box1 = [0.0134, 0.02135, 0.12438, 0.691];
386 let xyah = vaalbox_to_xyah(&box1);
387 let box2 = xyah_to_vaalbox(&xyah);
388
389 assert!((box1[0] - box2[0]).abs() < f32::EPSILON);
390 assert!((box1[1] - box2[1]).abs() < f32::EPSILON);
391 assert!((box1[2] - box2[2]).abs() < f32::EPSILON);
392 assert!((box1[3] - box2[3]).abs() < f32::EPSILON);
393 }
394
395 #[test]
396 fn test_iou_identical_boxes() {
397 let box1 = [0.1, 0.1, 0.5, 0.5];
398 let box2 = [0.1, 0.1, 0.5, 0.5];
399 let result = iou(&box1, &box2);
400 assert!(
401 (result - 1.0).abs() < 0.001,
402 "IOU of identical boxes should be 1.0"
403 );
404 }
405
406 #[test]
407 fn test_iou_no_overlap() {
408 let box1 = [0.0, 0.0, 0.2, 0.2];
409 let box2 = [0.5, 0.5, 0.7, 0.7];
410 let result = iou(&box1, &box2);
411 assert!(result < 0.001, "IOU of non-overlapping boxes should be ~0");
412 }
413
414 #[test]
415 fn test_iou_partial_overlap() {
416 let box1 = [0.0, 0.0, 0.5, 0.5];
417 let box2 = [0.25, 0.25, 0.75, 0.75];
418 let result = iou(&box1, &box2);
419 assert!(result > 0.1 && result < 0.2, "IOU should be ~0.14");
421 }
422
423 #[test]
424 fn test_bytetrack_new() {
425 let tracker = ByteTrack::new();
426 assert_eq!(tracker.frame_count, 0);
427 assert!(tracker.tracklets.is_empty());
428 assert_eq!(tracker.track_high_conf, 0.7);
429 assert_eq!(tracker.track_iou, 0.25);
430 }
431
432 #[test]
433 fn test_bytetrack_single_detection_creates_tracklet() {
434 let mut tracker = ByteTrack::new();
435 let detections = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.9)];
436
437 let results = tracker.update(&detections, 1000);
438
439 assert_eq!(results.len(), 1);
440 assert!(
441 results[0].is_some(),
442 "High-confidence detection should create tracklet"
443 );
444 assert_eq!(tracker.tracklets.len(), 1);
445 assert_eq!(tracker.frame_count, 1);
446 }
447
448 #[test]
449 fn test_bytetrack_low_confidence_no_tracklet() {
450 let mut tracker = ByteTrack::new();
451 let detections = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.5)];
453
454 let results = tracker.update(&detections, 1000);
455
456 assert_eq!(results.len(), 1);
457 assert!(
458 results[0].is_none(),
459 "Low-confidence detection should not create tracklet"
460 );
461 assert!(tracker.tracklets.is_empty());
462 }
463
464 #[test]
465 fn test_bytetrack_tracking_across_frames() {
466 let mut tracker = ByteTrack::new();
467
468 let det1 = vec![MockDetection::new(0.2, 0.2, 0.4, 0.4, 0.9)];
470 let res1 = tracker.update(&det1, 1000);
471 assert!(res1[0].is_some());
472 let uuid1 = res1[0].unwrap().uuid;
473 assert_eq!(tracker.tracklets.len(), 1);
474 assert_eq!(tracker.tracklets[0].count, 1);
476
477 let det2 = vec![MockDetection::new(0.2, 0.2, 0.4, 0.4, 0.9)];
479 let res2 = tracker.update(&det2, 2000);
480 assert!(res2[0].is_some());
481 let info2 = res2[0].unwrap();
482
483 assert_eq!(tracker.tracklets.len(), 1, "Should still have one tracklet");
485 assert_eq!(info2.uuid, uuid1, "Should match same tracklet");
486 assert_eq!(tracker.tracklets[0].count, 2, "Internal count should be 2");
488 }
489
490 #[test]
491 fn test_bytetrack_multiple_detections() {
492 let mut tracker = ByteTrack::new();
493
494 let detections = vec![
495 MockDetection::new(0.1, 0.1, 0.2, 0.2, 0.9),
496 MockDetection::new(0.5, 0.5, 0.6, 0.6, 0.85),
497 MockDetection::new(0.8, 0.8, 0.9, 0.9, 0.95),
498 ];
499
500 let results = tracker.update(&detections, 1000);
501
502 assert_eq!(results.len(), 3);
503 assert!(results.iter().all(|r| r.is_some()));
504 assert_eq!(tracker.tracklets.len(), 3);
505 }
506
507 #[test]
508 fn test_bytetrack_tracklet_expiry() {
509 let mut tracker = ByteTrack::new();
510 tracker.track_extra_lifespan = 1000; let det1 = vec![MockDetection::new(0.1, 0.1, 0.3, 0.3, 0.9)];
514 tracker.update(&det1, 1000);
515 assert_eq!(tracker.tracklets.len(), 1);
516
517 let empty: Vec<MockDetection> = vec![];
519 tracker.update(&empty, 3000); assert!(tracker.tracklets.is_empty(), "Tracklet should have expired");
522 }
523
524 #[test]
525 fn test_bytetrack_get_active_tracks() {
526 let mut tracker = ByteTrack::new();
527
528 let detections = vec![
529 MockDetection::new(0.1, 0.1, 0.2, 0.2, 0.9),
530 MockDetection::new(0.5, 0.5, 0.6, 0.6, 0.85),
531 ];
532 tracker.update(&detections, 1000);
533
534 let active = <ByteTrack as Tracker<MockDetection>>::get_active_tracks(&tracker);
535 assert_eq!(active.len(), 2);
536 assert!(active.iter().all(|t| t.count == 1));
537 assert!(active.iter().all(|t| t.created == 1000));
538 }
539
540 #[test]
541 fn test_bytetrack_empty_detections() {
542 let mut tracker = ByteTrack::new();
543 let empty: Vec<MockDetection> = vec![];
544
545 let results = tracker.update(&empty, 1000);
546
547 assert!(results.is_empty());
548 assert!(tracker.tracklets.is_empty());
549 assert_eq!(tracker.frame_count, 1);
550 }
551}