1use crate::track::{
2 LookupRequest, ObservationsDb, Track, TrackAttributes, TrackAttributesUpdate, TrackStatus,
3};
4use crate::trackers::epoch_db::EpochDb;
5use crate::trackers::kalman_prediction::TrackAttributesKalmanPrediction;
6use crate::trackers::spatio_temporal_constraints::SpatioTemporalConstraints;
7use crate::utils::bbox::Universal2DBox;
8use crate::utils::kalman::kalman_2d_box::DIM_2D_BOX_X2;
9use crate::utils::kalman::KalmanState;
10use anyhow::Result;
11use serde::{Serialize, Deserialize};
12
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, RwLock};
15
16use self::metric::SortMetric;
17
18pub mod metric;
20
21pub mod simple_api;
23
24pub mod voting;
27
28pub mod batch_api;
30
31pub const DEFAULT_SORT_IOU_THRESHOLD: f32 = 0.3;
33
34#[derive(Debug)]
35pub struct SortAttributesOptions {
36 epoch_db: Option<RwLock<HashMap<u64, usize>>>,
38 max_idle_epochs: usize,
40 pub history_length: usize,
42 pub spatio_temporal_constraints: SpatioTemporalConstraints,
43 pub position_weight: f32,
44 pub velocity_weight: f32,
45}
46
47impl Default for SortAttributesOptions {
48 fn default() -> Self {
49 Self {
50 epoch_db: None,
51 max_idle_epochs: 0,
52 history_length: 0,
53 spatio_temporal_constraints: SpatioTemporalConstraints::default(),
54 position_weight: 1.0 / 20.0,
55 velocity_weight: 1.0 / 160.0,
56 }
57 }
58}
59
60impl EpochDb for SortAttributesOptions {
61 fn epoch_db(&self) -> &Option<RwLock<HashMap<u64, usize>>> {
62 &self.epoch_db
63 }
64
65 fn max_idle_epochs(&self) -> usize {
66 self.max_idle_epochs
67 }
68}
69
70impl SortAttributesOptions {
71 pub fn new(
72 epoch_db: Option<RwLock<HashMap<u64, usize>>>,
73 max_idle_epochs: usize,
74 history_length: usize,
75 spatio_temporal_constraints: SpatioTemporalConstraints,
76 position_weight: f32,
77 velocity_weight: f32,
78 ) -> Self {
79 Self {
80 epoch_db,
81 max_idle_epochs,
82 history_length,
83 spatio_temporal_constraints,
84 position_weight,
85 velocity_weight,
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
93pub struct SortAttributes {
94 pub predicted_boxes: VecDeque<Universal2DBox>,
96 pub observed_boxes: VecDeque<Universal2DBox>,
98 pub last_updated_epoch: usize,
100 pub track_length: usize,
102 pub scene_id: u64,
104 pub custom_object_id: Option<i64>,
106
107 state: Option<KalmanState<{ DIM_2D_BOX_X2 }>>,
109 opts: Arc<SortAttributesOptions>,
110}
111
112impl TrackAttributesKalmanPrediction for SortAttributes {
113 fn get_state(&self) -> Option<KalmanState<{ DIM_2D_BOX_X2 }>> {
114 self.state
115 }
116
117 fn set_state(&mut self, state: KalmanState<{ DIM_2D_BOX_X2 }>) {
118 self.state = Some(state);
119 }
120
121 fn get_position_weight(&self) -> f32 {
122 self.opts.position_weight
123 }
124
125 fn get_velocity_weight(&self) -> f32 {
126 self.opts.velocity_weight
127 }
128}
129
130impl Default for SortAttributes {
131 fn default() -> Self {
132 Self {
133 predicted_boxes: VecDeque::default(),
134 observed_boxes: VecDeque::default(),
135 last_updated_epoch: 0,
136 track_length: 0,
137 scene_id: 0,
138 state: None,
139 custom_object_id: None,
140 opts: Arc::new(SortAttributesOptions::default()),
141 }
142 }
143}
144
145impl SortAttributes {
146 pub fn new(opts: Arc<SortAttributesOptions>) -> Self {
152 Self {
153 opts,
154 ..Default::default()
155 }
156 }
157
158 fn update_history(
159 &mut self,
160 observation_bbox: &Universal2DBox,
161 predicted_bbox: &Universal2DBox,
162 ) {
163 self.track_length += 1;
164
165 self.observed_boxes.push_back(observation_bbox.clone());
166 self.predicted_boxes.push_back(predicted_bbox.clone());
167
168 if self.opts.history_length > 0 && self.observed_boxes.len() > self.opts.history_length {
169 self.observed_boxes.pop_front();
170 self.predicted_boxes.pop_front();
171 }
172 }
173}
174
175#[derive(Clone, Debug, Default)]
178pub struct SortAttributesUpdate {
179 epoch: usize,
180 scene_id: u64,
181 custom_object_id: Option<i64>,
182}
183
184#[derive(Clone, Debug)]
187pub enum SortLookup {
188 IdleLookup(u64),
189}
190
191impl LookupRequest<SortAttributes, Universal2DBox> for SortLookup {
192 fn lookup(
193 &self,
194 attributes: &SortAttributes,
195 _observations: &ObservationsDb<Universal2DBox>,
196 _merge_history: &[u64],
197 ) -> bool {
198 match self {
199 SortLookup::IdleLookup(scene_id) => {
200 *scene_id == attributes.scene_id
201 && attributes.last_updated_epoch
202 != attributes
203 .opts
204 .current_epoch_with_scene(attributes.scene_id)
205 .unwrap()
206 }
207 }
208 }
209}
210
211impl SortAttributesUpdate {
212 pub fn new(epoch: usize, custom_object_id: Option<i64>) -> Self {
218 Self {
219 epoch,
220 scene_id: 0,
221 custom_object_id,
222 }
223 }
224 pub fn new_with_scene(epoch: usize, scene_id: u64, custom_object_id: Option<i64>) -> Self {
230 Self {
231 epoch,
232 scene_id,
233 custom_object_id,
234 }
235 }
236}
237
238impl TrackAttributesUpdate<SortAttributes> for SortAttributesUpdate {
239 fn apply(&self, attrs: &mut SortAttributes) -> Result<()> {
240 attrs.last_updated_epoch = self.epoch;
241 attrs.scene_id = self.scene_id;
242 attrs.custom_object_id = self.custom_object_id;
243 Ok(())
244 }
245}
246
247impl TrackAttributes<SortAttributes, Universal2DBox> for SortAttributes {
248 type Update = SortAttributesUpdate;
249 type Lookup = SortLookup;
250
251 fn compatible(&self, other: &SortAttributes) -> bool {
252 if self.scene_id == other.scene_id {
253 let o1 = self.predicted_boxes.back().unwrap();
254 let o2 = other.predicted_boxes.back().unwrap();
255
256 let epoch_delta = (self.last_updated_epoch as i128 - other.last_updated_epoch as i128)
257 .abs()
258 .try_into()
259 .unwrap();
260
261 let center_dist = Universal2DBox::dist_in_2r(o1, o2);
262
263 self.opts.max_idle_epochs() >= epoch_delta
264 && self
265 .opts
266 .spatio_temporal_constraints
267 .validate(epoch_delta, center_dist)
268 } else {
269 false
270 }
271 }
272
273 fn merge(&mut self, other: &SortAttributes) -> Result<()> {
274 self.last_updated_epoch = other.last_updated_epoch;
275 self.custom_object_id = other.custom_object_id;
276 Ok(())
277 }
278
279 fn baked(&self, _observations: &ObservationsDb<Universal2DBox>) -> Result<TrackStatus> {
280 self.opts.baked(self.scene_id, self.last_updated_epoch)
281 }
282}
283
284#[derive(Debug, Clone)]
287pub struct SortTrack {
288 pub id: u64,
291 pub epoch: usize,
294 pub predicted_bbox: Universal2DBox,
297 pub observed_bbox: Universal2DBox,
300 pub scene_id: u64,
303 pub length: usize,
306 pub voting_type: VotingType,
309 pub custom_object_id: Option<i64>,
312}
313
314#[derive(Debug, Clone)]
317pub struct WastedSortTrack {
318 pub id: u64,
321 pub epoch: usize,
324 pub predicted_bbox: Universal2DBox,
327 pub observed_bbox: Universal2DBox,
330 pub scene_id: u64,
333 pub length: usize,
336 pub predicted_boxes: Vec<Universal2DBox>,
339 pub observed_boxes: Vec<Universal2DBox>,
342}
343
344impl From<Track<SortAttributes, SortMetric, Universal2DBox>> for WastedSortTrack {
345 fn from(track: Track<SortAttributes, SortMetric, Universal2DBox>) -> Self {
346 let attrs = track.get_attributes();
347 WastedSortTrack {
348 id: track.get_track_id(),
349 epoch: attrs.last_updated_epoch,
350 scene_id: attrs.scene_id,
351 length: attrs.track_length,
352 observed_bbox: attrs.observed_boxes.back().unwrap().clone(),
353 predicted_bbox: attrs.predicted_boxes.back().unwrap().clone(),
354 predicted_boxes: attrs.predicted_boxes.clone().into_iter().collect(),
355 observed_boxes: attrs.observed_boxes.clone().into_iter().collect(),
356 }
357 }
358}
359
360#[derive(Default, Debug, Clone, Copy, Deserialize, Serialize)]
361pub enum VotingType {
362 #[default]
363 Visual,
364 Positional,
365}
366
367#[derive(Clone, Default, Copy, Debug, Deserialize, Serialize)]
368pub enum PositionalMetricType {
369 #[default]
370 Mahalanobis,
371 IoU(f32),
372}
373
374pub struct AutoWaste {
375 pub periodicity: usize,
376 pub counter: usize,
377}
378
379pub(crate) const DEFAULT_AUTO_WASTE_PERIODICITY: usize = 100;
380pub(crate) const MAHALANOBIS_NEW_TRACK_THRESHOLD: f32 = 1.0;
381
382#[cfg(feature = "python")]
383pub mod python {
384 use pyo3::prelude::*;
385
386 use crate::utils::bbox::python::PyUniversal2DBox;
387
388 use super::{PositionalMetricType, SortTrack, VotingType, WastedSortTrack};
389
390 #[pyclass]
391 #[pyo3(name = "PositionalMetricType")]
392 #[derive(Clone, Debug)]
393 pub struct PyPositionalMetricType(pub PositionalMetricType);
394
395 #[pymethods]
396 impl PyPositionalMetricType {
397 #[staticmethod]
398 pub fn maha() -> Self {
399 PyPositionalMetricType(PositionalMetricType::Mahalanobis)
400 }
401
402 #[staticmethod]
403 pub fn iou(threshold: f32) -> Self {
404 assert!(
405 threshold > 0.0 && threshold < 1.0,
406 "Threshold must lay between (0.0 and 1.0)"
407 );
408
409 PyPositionalMetricType(PositionalMetricType::IoU(threshold))
410 }
411
412 #[classattr]
413 const __hash__: Option<Py<PyAny>> = None;
414
415 fn __repr__(&self) -> String {
416 format!("{self:?}")
417 }
418
419 fn __str__(&self) -> String {
420 format!("{self:#?}")
421 }
422 }
423
424 #[pyclass]
425 #[pyo3(name = "SortTrack")]
426 #[derive(Debug, Clone)]
427 #[repr(transparent)]
428 pub struct PySortTrack(pub(crate) SortTrack);
429
430 #[pymethods]
431 impl PySortTrack {
432 #[classattr]
433 const __hash__: Option<Py<PyAny>> = None;
434
435 fn __repr__(&self) -> String {
436 format!("{self:?}")
437 }
438
439 fn __str__(&self) -> String {
440 format!("{self:#?}")
441 }
442
443 #[getter]
444 fn get_id(&self) -> u64 {
445 self.0.id
446 }
447
448 #[getter]
449 fn get_epoch(&self) -> usize {
450 self.0.epoch
451 }
452
453 #[getter]
454 fn get_predicted_bbox(&self) -> PyUniversal2DBox {
455 PyUniversal2DBox(self.0.predicted_bbox.clone())
456 }
457
458 #[getter]
459 fn get_observed_bbox(&self) -> PyUniversal2DBox {
460 PyUniversal2DBox(self.0.observed_bbox.clone())
461 }
462
463 #[getter]
464 fn get_scene_id(&self) -> u64 {
465 self.0.scene_id
466 }
467
468 #[getter]
469 fn get_length(&self) -> usize {
470 self.0.length
471 }
472
473 #[getter]
474 fn get_voting_type(&self) -> PyVotingType {
475 PyVotingType(self.0.voting_type)
476 }
477
478 #[getter]
479 fn get_custom_object_id(&self) -> Option<i64> {
480 self.0.custom_object_id
481 }
482 }
483
484 #[pyclass]
485 #[pyo3(name = "WastedSortTrack")]
486 #[derive(Debug, Clone)]
487 #[repr(transparent)]
488 pub struct PyWastedSortTrack(pub(crate) WastedSortTrack);
489
490 #[pymethods]
491 impl PyWastedSortTrack {
492 #[classattr]
493 const __hash__: Option<Py<PyAny>> = None;
494
495 fn __repr__(&self) -> String {
496 format!("{:?}", self.0)
497 }
498
499 fn __str__(&self) -> String {
500 format!("{:#?}", self.0)
501 }
502
503 #[getter]
504 fn id(&self) -> u64 {
505 self.0.id
506 }
507
508 #[getter]
509 fn epoch(&self) -> usize {
510 self.0.epoch
511 }
512
513 #[getter]
514 fn predicted_bbox(&self) -> PyUniversal2DBox {
515 PyUniversal2DBox(self.0.predicted_bbox.clone())
516 }
517
518 #[getter]
519 fn observed_bbox(&self) -> PyUniversal2DBox {
520 PyUniversal2DBox(self.0.observed_bbox.clone())
521 }
522
523 #[getter]
524 fn scene_id(&self) -> u64 {
525 self.0.scene_id
526 }
527
528 #[getter]
529 fn length(&self) -> usize {
530 self.0.length
531 }
532
533 #[getter]
534 fn predicted_boxes(&self) -> Vec<PyUniversal2DBox> {
535 unsafe { std::mem::transmute(self.0.predicted_boxes.clone()) }
536 }
537
538 #[getter]
539 fn observed_boxes(&self) -> Vec<PyUniversal2DBox> {
540 unsafe { std::mem::transmute(self.0.observed_boxes.clone()) }
541 }
542 }
543
544 #[pyclass]
545 #[pyo3(name = "VotingType")]
546 #[derive(Default, Debug, Clone, Copy)]
547 pub struct PyVotingType(pub(crate) VotingType);
548
549 #[pymethods]
550 impl PyVotingType {
551 #[classattr]
552 const __hash__: Option<Py<PyAny>> = None;
553
554 fn __repr__(&self) -> String {
555 format!("{self:?}")
556 }
557
558 fn __str__(&self) -> String {
559 format!("{self:#?}")
560 }
561 }
562}
563
564#[cfg(test)]
565mod track_tests {
566 use crate::prelude::{NoopNotifier, ObservationBuilder, TrackBuilder};
567 use crate::trackers::sort::metric::{SortMetric, DEFAULT_MINIMAL_SORT_CONFIDENCE};
568 use crate::trackers::sort::PositionalMetricType::IoU;
569 use crate::trackers::sort::{SortAttributes, DEFAULT_SORT_IOU_THRESHOLD};
570 use crate::utils::bbox::BoundingBox;
571 use crate::utils::kalman::kalman_2d_box::Universal2DBoxKalmanFilter;
572
573 #[test]
574 fn construct() {
575 let observation_bb_0 = BoundingBox::new(1.0, 1.0, 10.0, 15.0);
576 let observation_bb_1 = BoundingBox::new(1.1, 1.3, 10.0, 15.0);
577
578 let f = Universal2DBoxKalmanFilter::default();
579 let init_state = f.initiate(&observation_bb_0.into());
580
581 let mut t1 = TrackBuilder::new(1)
582 .attributes(SortAttributes::default())
583 .metric(SortMetric::new(
584 IoU(DEFAULT_SORT_IOU_THRESHOLD),
585 DEFAULT_MINIMAL_SORT_CONFIDENCE,
586 ))
587 .notifier(NoopNotifier)
588 .observation(
589 ObservationBuilder::new(0)
590 .observation_attributes(observation_bb_0.into())
591 .build(),
592 )
593 .build()
594 .unwrap();
595
596 assert!(t1.get_attributes().state.is_some());
597 assert_eq!(t1.get_attributes().predicted_boxes.len(), 1);
598 assert_eq!(t1.get_attributes().observed_boxes.len(), 1);
599 assert_eq!(t1.get_merge_history().len(), 1);
600 assert_eq!(
601 t1.get_attributes().predicted_boxes[0],
602 observation_bb_0.into()
603 );
604
605 let predicted_state = f.predict(&init_state);
606 assert_eq!(
607 BoundingBox::try_from(predicted_state).unwrap(),
608 observation_bb_0
609 );
610
611 let t2 = TrackBuilder::new(2)
612 .attributes(SortAttributes::default())
613 .metric(SortMetric::new(
614 IoU(DEFAULT_SORT_IOU_THRESHOLD),
615 DEFAULT_MINIMAL_SORT_CONFIDENCE,
616 ))
617 .notifier(NoopNotifier)
618 .observation(
619 ObservationBuilder::new(0)
620 .observation_attributes(observation_bb_1.into())
621 .build(),
622 )
623 .build()
624 .unwrap();
625
626 t1.merge(&t2, &[0], false).unwrap();
627
628 assert!(t1.get_attributes().state.is_some());
629 assert_eq!(t1.get_attributes().predicted_boxes.len(), 2);
630 assert_eq!(t1.get_attributes().observed_boxes.len(), 2);
631 }
632}