cecile_supercool_tracker/trackers/
sort.rs

1use crate::track::{
2    LookupRequest, ObservationsDb, Track, TrackAttributes, TrackAttributesUpdate, TrackStatus,
3};
4use crate::trackers::epoch_db::EpochDb;
5use crate::trackers::kalman_prediction::TrackAttributesKalmanPrediction;
6use crate::trackers::spatio_temporal_constraints::SpatioTemporalConstraints;
7use crate::utils::bbox::Universal2DBox;
8use crate::utils::kalman::kalman_2d_box::DIM_2D_BOX_X2;
9use crate::utils::kalman::KalmanState;
10use anyhow::Result;
11use serde::{Serialize, Deserialize};
12
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, RwLock};
15
16use self::metric::SortMetric;
17
18/// SORT metric implementation with IoU and Mahalanobis distances
19pub mod metric;
20
21/// SORT implementation with a very tiny interface
22pub mod simple_api;
23
24/// Voting engine with Hungarian algorithm
25///
26pub mod voting;
27
28/// SORT tracker with Batch API
29pub mod batch_api;
30
31/// Default IoU threshold that is defined by SORT author in the original repo
32pub const DEFAULT_SORT_IOU_THRESHOLD: f32 = 0.3;
33
34#[derive(Debug)]
35pub struct SortAttributesOptions {
36    /// The map that stores current epochs for the scene_id
37    epoch_db: Option<RwLock<HashMap<u64, usize>>>,
38    /// The maximum number of epochs without update while the track is alive
39    max_idle_epochs: usize,
40    /// The maximum length of collected objects for the track
41    pub history_length: usize,
42    pub spatio_temporal_constraints: SpatioTemporalConstraints,
43    pub position_weight: f32,
44    pub velocity_weight: f32,
45}
46
47impl Default for SortAttributesOptions {
48    fn default() -> Self {
49        Self {
50            epoch_db: None,
51            max_idle_epochs: 0,
52            history_length: 0,
53            spatio_temporal_constraints: SpatioTemporalConstraints::default(),
54            position_weight: 1.0 / 20.0,
55            velocity_weight: 1.0 / 160.0,
56        }
57    }
58}
59
60impl EpochDb for SortAttributesOptions {
61    fn epoch_db(&self) -> &Option<RwLock<HashMap<u64, usize>>> {
62        &self.epoch_db
63    }
64
65    fn max_idle_epochs(&self) -> usize {
66        self.max_idle_epochs
67    }
68}
69
70impl SortAttributesOptions {
71    pub fn new(
72        epoch_db: Option<RwLock<HashMap<u64, usize>>>,
73        max_idle_epochs: usize,
74        history_length: usize,
75        spatio_temporal_constraints: SpatioTemporalConstraints,
76        position_weight: f32,
77        velocity_weight: f32,
78    ) -> Self {
79        Self {
80            epoch_db,
81            max_idle_epochs,
82            history_length,
83            spatio_temporal_constraints,
84            position_weight,
85            velocity_weight,
86        }
87    }
88}
89
90/// Attributes associated with SORT track
91///
92#[derive(Debug, Clone)]
93pub struct SortAttributes {
94    /// The lastly predicted boxes
95    pub predicted_boxes: VecDeque<Universal2DBox>,
96    /// The lastly observed boxes
97    pub observed_boxes: VecDeque<Universal2DBox>,
98    /// The epoch when the track was lastly updated
99    pub last_updated_epoch: usize,
100    /// The length of the track
101    pub track_length: usize,
102    /// Customer-specific scene identifier that splits the objects by classes, realms, etc.
103    pub scene_id: u64,
104    /// Custom object id
105    pub custom_object_id: Option<i64>,
106
107    /// Kalman filter predicted state
108    state: Option<KalmanState<{ DIM_2D_BOX_X2 }>>,
109    opts: Arc<SortAttributesOptions>,
110}
111
112impl TrackAttributesKalmanPrediction for SortAttributes {
113    fn get_state(&self) -> Option<KalmanState<{ DIM_2D_BOX_X2 }>> {
114        self.state
115    }
116
117    fn set_state(&mut self, state: KalmanState<{ DIM_2D_BOX_X2 }>) {
118        self.state = Some(state);
119    }
120
121    fn get_position_weight(&self) -> f32 {
122        self.opts.position_weight
123    }
124
125    fn get_velocity_weight(&self) -> f32 {
126        self.opts.velocity_weight
127    }
128}
129
130impl Default for SortAttributes {
131    fn default() -> Self {
132        Self {
133            predicted_boxes: VecDeque::default(),
134            observed_boxes: VecDeque::default(),
135            last_updated_epoch: 0,
136            track_length: 0,
137            scene_id: 0,
138            state: None,
139            custom_object_id: None,
140            opts: Arc::new(SortAttributesOptions::default()),
141        }
142    }
143}
144
145impl SortAttributes {
146    /// Creates new attributes with limited history
147    ///
148    /// # Parameters
149    /// * `opts` - options
150    ///
151    pub fn new(opts: Arc<SortAttributesOptions>) -> Self {
152        Self {
153            opts,
154            ..Default::default()
155        }
156    }
157
158    fn update_history(
159        &mut self,
160        observation_bbox: &Universal2DBox,
161        predicted_bbox: &Universal2DBox,
162    ) {
163        self.track_length += 1;
164
165        self.observed_boxes.push_back(observation_bbox.clone());
166        self.predicted_boxes.push_back(predicted_bbox.clone());
167
168        if self.opts.history_length > 0 && self.observed_boxes.len() > self.opts.history_length {
169            self.observed_boxes.pop_front();
170            self.predicted_boxes.pop_front();
171        }
172    }
173}
174
175/// Update object for SortAttributes
176///
177#[derive(Clone, Debug, Default)]
178pub struct SortAttributesUpdate {
179    epoch: usize,
180    scene_id: u64,
181    custom_object_id: Option<i64>,
182}
183
184/// Lookup object for SortAttributes
185///
186#[derive(Clone, Debug)]
187pub enum SortLookup {
188    IdleLookup(u64),
189}
190
191impl LookupRequest<SortAttributes, Universal2DBox> for SortLookup {
192    fn lookup(
193        &self,
194        attributes: &SortAttributes,
195        _observations: &ObservationsDb<Universal2DBox>,
196        _merge_history: &[u64],
197    ) -> bool {
198        match self {
199            SortLookup::IdleLookup(scene_id) => {
200                *scene_id == attributes.scene_id
201                    && attributes.last_updated_epoch
202                        != attributes
203                            .opts
204                            .current_epoch_with_scene(attributes.scene_id)
205                            .unwrap()
206            }
207        }
208    }
209}
210
211impl SortAttributesUpdate {
212    /// update epoch with scene_id == 0
213    ///
214    /// # Parameters
215    /// * `epoch` - epoch update
216    ///
217    pub fn new(epoch: usize, custom_object_id: Option<i64>) -> Self {
218        Self {
219            epoch,
220            scene_id: 0,
221            custom_object_id,
222        }
223    }
224    /// update epoch for a specific scene_id
225    ///
226    /// # Parameters
227    /// * `epoch` - epoch
228    /// * `scene_id` - scene_id
229    pub fn new_with_scene(epoch: usize, scene_id: u64, custom_object_id: Option<i64>) -> Self {
230        Self {
231            epoch,
232            scene_id,
233            custom_object_id,
234        }
235    }
236}
237
238impl TrackAttributesUpdate<SortAttributes> for SortAttributesUpdate {
239    fn apply(&self, attrs: &mut SortAttributes) -> Result<()> {
240        attrs.last_updated_epoch = self.epoch;
241        attrs.scene_id = self.scene_id;
242        attrs.custom_object_id = self.custom_object_id;
243        Ok(())
244    }
245}
246
247impl TrackAttributes<SortAttributes, Universal2DBox> for SortAttributes {
248    type Update = SortAttributesUpdate;
249    type Lookup = SortLookup;
250
251    fn compatible(&self, other: &SortAttributes) -> bool {
252        if self.scene_id == other.scene_id {
253            let o1 = self.predicted_boxes.back().unwrap();
254            let o2 = other.predicted_boxes.back().unwrap();
255
256            let epoch_delta = (self.last_updated_epoch as i128 - other.last_updated_epoch as i128)
257                .abs()
258                .try_into()
259                .unwrap();
260
261            let center_dist = Universal2DBox::dist_in_2r(o1, o2);
262
263            self.opts.max_idle_epochs() >= epoch_delta
264                && self
265                    .opts
266                    .spatio_temporal_constraints
267                    .validate(epoch_delta, center_dist)
268        } else {
269            false
270        }
271    }
272
273    fn merge(&mut self, other: &SortAttributes) -> Result<()> {
274        self.last_updated_epoch = other.last_updated_epoch;
275        self.custom_object_id = other.custom_object_id;
276        Ok(())
277    }
278
279    fn baked(&self, _observations: &ObservationsDb<Universal2DBox>) -> Result<TrackStatus> {
280        self.opts.baked(self.scene_id, self.last_updated_epoch)
281    }
282}
283
284/// Online track structure that contains tracking information for the last tracker epoch
285///
286#[derive(Debug, Clone)]
287pub struct SortTrack {
288    /// id of the track
289    ///
290    pub id: u64,
291    /// when the track was lastly updated
292    ///
293    pub epoch: usize,
294    /// the bbox predicted by KF
295    ///
296    pub predicted_bbox: Universal2DBox,
297    /// the bbox passed by detector
298    ///
299    pub observed_bbox: Universal2DBox,
300    /// user-defined scene id that splits tracking space on isolated realms
301    ///
302    pub scene_id: u64,
303    /// current track length
304    ///
305    pub length: usize,
306    /// what kind of voting was led to the current merge
307    ///
308    pub voting_type: VotingType,
309    /// custom object id passed by the user to find the track easily
310    ///
311    pub custom_object_id: Option<i64>,
312}
313
314/// Online track structure that contains tracking information for the last tracker epoch
315///
316#[derive(Debug, Clone)]
317pub struct WastedSortTrack {
318    /// id of the track
319    ///
320    pub id: u64,
321    /// when the track was lastly updated
322    ///
323    pub epoch: usize,
324    /// the bbox predicted by KF
325    ///
326    pub predicted_bbox: Universal2DBox,
327    /// the bbox passed by detector
328    ///
329    pub observed_bbox: Universal2DBox,
330    /// user-defined scene id that splits tracking space on isolated realms
331    ///
332    pub scene_id: u64,
333    /// current track length
334    ///
335    pub length: usize,
336    /// history of predicted boxes
337    ///
338    pub predicted_boxes: Vec<Universal2DBox>,
339    /// history of observed boxes
340    ///
341    pub observed_boxes: Vec<Universal2DBox>,
342}
343
344impl From<Track<SortAttributes, SortMetric, Universal2DBox>> for WastedSortTrack {
345    fn from(track: Track<SortAttributes, SortMetric, Universal2DBox>) -> Self {
346        let attrs = track.get_attributes();
347        WastedSortTrack {
348            id: track.get_track_id(),
349            epoch: attrs.last_updated_epoch,
350            scene_id: attrs.scene_id,
351            length: attrs.track_length,
352            observed_bbox: attrs.observed_boxes.back().unwrap().clone(),
353            predicted_bbox: attrs.predicted_boxes.back().unwrap().clone(),
354            predicted_boxes: attrs.predicted_boxes.clone().into_iter().collect(),
355            observed_boxes: attrs.observed_boxes.clone().into_iter().collect(),
356        }
357    }
358}
359
360#[derive(Default, Debug, Clone, Copy, Deserialize, Serialize)]
361pub enum VotingType {
362    #[default]
363    Visual,
364    Positional,
365}
366
367#[derive(Clone, Default, Copy, Debug, Deserialize, Serialize)]
368pub enum PositionalMetricType {
369    #[default]
370    Mahalanobis,
371    IoU(f32),
372}
373
374pub struct AutoWaste {
375    pub periodicity: usize,
376    pub counter: usize,
377}
378
379pub(crate) const DEFAULT_AUTO_WASTE_PERIODICITY: usize = 100;
380pub(crate) const MAHALANOBIS_NEW_TRACK_THRESHOLD: f32 = 1.0;
381
382#[cfg(feature = "python")]
383pub mod python {
384    use pyo3::prelude::*;
385
386    use crate::utils::bbox::python::PyUniversal2DBox;
387
388    use super::{PositionalMetricType, SortTrack, VotingType, WastedSortTrack};
389
390    #[pyclass]
391    #[pyo3(name = "PositionalMetricType")]
392    #[derive(Clone, Debug)]
393    pub struct PyPositionalMetricType(pub PositionalMetricType);
394
395    #[pymethods]
396    impl PyPositionalMetricType {
397        #[staticmethod]
398        pub fn maha() -> Self {
399            PyPositionalMetricType(PositionalMetricType::Mahalanobis)
400        }
401
402        #[staticmethod]
403        pub fn iou(threshold: f32) -> Self {
404            assert!(
405                threshold > 0.0 && threshold < 1.0,
406                "Threshold must lay between (0.0 and 1.0)"
407            );
408
409            PyPositionalMetricType(PositionalMetricType::IoU(threshold))
410        }
411
412        #[classattr]
413        const __hash__: Option<Py<PyAny>> = None;
414
415        fn __repr__(&self) -> String {
416            format!("{self:?}")
417        }
418
419        fn __str__(&self) -> String {
420            format!("{self:#?}")
421        }
422    }
423
424    #[pyclass]
425    #[pyo3(name = "SortTrack")]
426    #[derive(Debug, Clone)]
427    #[repr(transparent)]
428    pub struct PySortTrack(pub(crate) SortTrack);
429
430    #[pymethods]
431    impl PySortTrack {
432        #[classattr]
433        const __hash__: Option<Py<PyAny>> = None;
434
435        fn __repr__(&self) -> String {
436            format!("{self:?}")
437        }
438
439        fn __str__(&self) -> String {
440            format!("{self:#?}")
441        }
442
443        #[getter]
444        fn get_id(&self) -> u64 {
445            self.0.id
446        }
447
448        #[getter]
449        fn get_epoch(&self) -> usize {
450            self.0.epoch
451        }
452
453        #[getter]
454        fn get_predicted_bbox(&self) -> PyUniversal2DBox {
455            PyUniversal2DBox(self.0.predicted_bbox.clone())
456        }
457
458        #[getter]
459        fn get_observed_bbox(&self) -> PyUniversal2DBox {
460            PyUniversal2DBox(self.0.observed_bbox.clone())
461        }
462
463        #[getter]
464        fn get_scene_id(&self) -> u64 {
465            self.0.scene_id
466        }
467
468        #[getter]
469        fn get_length(&self) -> usize {
470            self.0.length
471        }
472
473        #[getter]
474        fn get_voting_type(&self) -> PyVotingType {
475            PyVotingType(self.0.voting_type)
476        }
477
478        #[getter]
479        fn get_custom_object_id(&self) -> Option<i64> {
480            self.0.custom_object_id
481        }
482    }
483
484    #[pyclass]
485    #[pyo3(name = "WastedSortTrack")]
486    #[derive(Debug, Clone)]
487    #[repr(transparent)]
488    pub struct PyWastedSortTrack(pub(crate) WastedSortTrack);
489
490    #[pymethods]
491    impl PyWastedSortTrack {
492        #[classattr]
493        const __hash__: Option<Py<PyAny>> = None;
494
495        fn __repr__(&self) -> String {
496            format!("{:?}", self.0)
497        }
498
499        fn __str__(&self) -> String {
500            format!("{:#?}", self.0)
501        }
502
503        #[getter]
504        fn id(&self) -> u64 {
505            self.0.id
506        }
507
508        #[getter]
509        fn epoch(&self) -> usize {
510            self.0.epoch
511        }
512
513        #[getter]
514        fn predicted_bbox(&self) -> PyUniversal2DBox {
515            PyUniversal2DBox(self.0.predicted_bbox.clone())
516        }
517
518        #[getter]
519        fn observed_bbox(&self) -> PyUniversal2DBox {
520            PyUniversal2DBox(self.0.observed_bbox.clone())
521        }
522
523        #[getter]
524        fn scene_id(&self) -> u64 {
525            self.0.scene_id
526        }
527
528        #[getter]
529        fn length(&self) -> usize {
530            self.0.length
531        }
532
533        #[getter]
534        fn predicted_boxes(&self) -> Vec<PyUniversal2DBox> {
535            unsafe { std::mem::transmute(self.0.predicted_boxes.clone()) }
536        }
537
538        #[getter]
539        fn observed_boxes(&self) -> Vec<PyUniversal2DBox> {
540            unsafe { std::mem::transmute(self.0.observed_boxes.clone()) }
541        }
542    }
543
544    #[pyclass]
545    #[pyo3(name = "VotingType")]
546    #[derive(Default, Debug, Clone, Copy)]
547    pub struct PyVotingType(pub(crate) VotingType);
548
549    #[pymethods]
550    impl PyVotingType {
551        #[classattr]
552        const __hash__: Option<Py<PyAny>> = None;
553
554        fn __repr__(&self) -> String {
555            format!("{self:?}")
556        }
557
558        fn __str__(&self) -> String {
559            format!("{self:#?}")
560        }
561    }
562}
563
564#[cfg(test)]
565mod track_tests {
566    use crate::prelude::{NoopNotifier, ObservationBuilder, TrackBuilder};
567    use crate::trackers::sort::metric::{SortMetric, DEFAULT_MINIMAL_SORT_CONFIDENCE};
568    use crate::trackers::sort::PositionalMetricType::IoU;
569    use crate::trackers::sort::{SortAttributes, DEFAULT_SORT_IOU_THRESHOLD};
570    use crate::utils::bbox::BoundingBox;
571    use crate::utils::kalman::kalman_2d_box::Universal2DBoxKalmanFilter;
572
573    #[test]
574    fn construct() {
575        let observation_bb_0 = BoundingBox::new(1.0, 1.0, 10.0, 15.0);
576        let observation_bb_1 = BoundingBox::new(1.1, 1.3, 10.0, 15.0);
577
578        let f = Universal2DBoxKalmanFilter::default();
579        let init_state = f.initiate(&observation_bb_0.into());
580
581        let mut t1 = TrackBuilder::new(1)
582            .attributes(SortAttributes::default())
583            .metric(SortMetric::new(
584                IoU(DEFAULT_SORT_IOU_THRESHOLD),
585                DEFAULT_MINIMAL_SORT_CONFIDENCE,
586            ))
587            .notifier(NoopNotifier)
588            .observation(
589                ObservationBuilder::new(0)
590                    .observation_attributes(observation_bb_0.into())
591                    .build(),
592            )
593            .build()
594            .unwrap();
595
596        assert!(t1.get_attributes().state.is_some());
597        assert_eq!(t1.get_attributes().predicted_boxes.len(), 1);
598        assert_eq!(t1.get_attributes().observed_boxes.len(), 1);
599        assert_eq!(t1.get_merge_history().len(), 1);
600        assert_eq!(
601            t1.get_attributes().predicted_boxes[0],
602            observation_bb_0.into()
603        );
604
605        let predicted_state = f.predict(&init_state);
606        assert_eq!(
607            BoundingBox::try_from(predicted_state).unwrap(),
608            observation_bb_0
609        );
610
611        let t2 = TrackBuilder::new(2)
612            .attributes(SortAttributes::default())
613            .metric(SortMetric::new(
614                IoU(DEFAULT_SORT_IOU_THRESHOLD),
615                DEFAULT_MINIMAL_SORT_CONFIDENCE,
616            ))
617            .notifier(NoopNotifier)
618            .observation(
619                ObservationBuilder::new(0)
620                    .observation_attributes(observation_bb_1.into())
621                    .build(),
622            )
623            .build()
624            .unwrap();
625
626        t1.merge(&t2, &[0], false).unwrap();
627
628        assert!(t1.get_attributes().state.is_some());
629        assert_eq!(t1.get_attributes().predicted_boxes.len(), 2);
630        assert_eq!(t1.get_attributes().observed_boxes.len(), 2);
631    }
632}