cecile_supercool_tracker/trackers/sort/
simple_api.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
3
4use rand::Rng;
5
6use crate::prelude::{NoopNotifier, ObservationBuilder, TrackStoreBuilder};
7use crate::store::TrackStore;
8use crate::track::Track;
9use crate::trackers::epoch_db::EpochDb;
10use crate::trackers::sort::{
11    metric::SortMetric, voting::SortVoting, AutoWaste, PositionalMetricType, SortAttributes,
12    SortAttributesOptions, SortAttributesUpdate, SortLookup, SortTrack, VotingType,
13    DEFAULT_AUTO_WASTE_PERIODICITY, MAHALANOBIS_NEW_TRACK_THRESHOLD,
14};
15use crate::trackers::spatio_temporal_constraints::SpatioTemporalConstraints;
16use crate::trackers::tracker_api::TrackerAPI;
17use crate::utils::bbox::Universal2DBox;
18use crate::voting::Voting;
19
20/// Easy to use SORT tracker implementation
21///
22pub struct Sort {
23    store: RwLock<TrackStore<SortAttributes, SortMetric, Universal2DBox>>,
24    wasted_store: RwLock<TrackStore<SortAttributes, SortMetric, Universal2DBox>>,
25    method: PositionalMetricType,
26    opts: Arc<SortAttributesOptions>,
27    auto_waste: AutoWaste,
28    track_id: u64,
29}
30
31impl Sort {
32    /// Creates new tracker
33    ///
34    /// # Parameters
35    /// * `shards` - amount of cpu threads to process the data, keep 1 for up to 100 simultaneously tracked objects, try it before setting high - higher numbers may lead to unexpected latencies.
36    /// * `bbox_history` - how many last bboxes are kept within stored track (valuable for offline trackers), for online - keep 1
37    /// * `max_idle_epochs` - how long track survives without being updated
38    /// * `threshold` - how low IoU must be to establish a new track (default from the authors of SORT is 0.3)
39    ///
40    #[allow(clippy::too_many_arguments)]
41    pub fn new(
42        shards: usize,
43        bbox_history: usize,
44        max_idle_epochs: usize,
45        method: PositionalMetricType,
46        min_confidence: f32,
47        spatio_temporal_constraints: Option<SpatioTemporalConstraints>,
48        kalman_position_weight: f32,
49        kalman_velocity_weight: f32,
50    ) -> Self {
51        assert!(bbox_history > 0);
52        let epoch_db = RwLock::new(HashMap::default());
53        let opts = Arc::new(SortAttributesOptions::new(
54            Some(epoch_db),
55            max_idle_epochs,
56            bbox_history,
57            spatio_temporal_constraints.unwrap_or_default(),
58            kalman_position_weight,
59            kalman_velocity_weight,
60        ));
61        let store = RwLock::new(
62            TrackStoreBuilder::new(shards)
63                .default_attributes(SortAttributes::new(opts.clone()))
64                .metric(SortMetric::new(method, min_confidence))
65                .notifier(NoopNotifier)
66                .build(),
67        );
68
69        let wasted_store = RwLock::new(
70            TrackStoreBuilder::new(shards)
71                .default_attributes(SortAttributes::new(opts.clone()))
72                .metric(SortMetric::new(method, min_confidence))
73                .notifier(NoopNotifier)
74                .build(),
75        );
76
77        Self {
78            store,
79            track_id: 0,
80            wasted_store,
81            method,
82            opts,
83            auto_waste: AutoWaste {
84                periodicity: DEFAULT_AUTO_WASTE_PERIODICITY,
85                counter: DEFAULT_AUTO_WASTE_PERIODICITY,
86            },
87        }
88    }
89
90    /// Receive tracking information for observed bboxes of `scene_id` == 0
91    ///
92    /// # Parameters
93    /// * `bboxes` - bounding boxes received from a detector
94    ///
95    pub fn predict(&mut self, bboxes: &[(Universal2DBox, Option<i64>)]) -> Vec<SortTrack> {
96        self.predict_with_scene(0, bboxes)
97    }
98
99    fn gen_track_id(&mut self) -> u64 {
100        self.track_id += 1;
101        self.track_id
102    }
103
104    /// Receive tracking information for observed bboxes of `scene_id`
105    ///
106    /// # Parameters
107    /// * `scene_id` - scene id provided by a user (class, camera id, etc...)
108    /// * `bboxes` - bounding boxes received from a detector
109    ///
110    pub fn predict_with_scene(
111        &mut self,
112        scene_id: u64,
113        bboxes: &[(Universal2DBox, Option<i64>)],
114    ) -> Vec<SortTrack> {
115        if self.auto_waste.counter == 0 {
116            self.auto_waste();
117            self.auto_waste.counter = self.auto_waste.periodicity;
118        } else {
119            self.auto_waste.counter -= 1;
120        }
121
122        let mut rng = rand::thread_rng();
123        let epoch = self.opts.next_epoch(scene_id).unwrap();
124
125        let tracks = bboxes
126            .iter()
127            .map(|(bb, custom_object_id)| {
128                self.store
129                    .read()
130                    .unwrap()
131                    .new_track(rng.gen())
132                    .observation(
133                        ObservationBuilder::new(0)
134                            .observation_attributes(bb.clone())
135                            .track_attributes_update(SortAttributesUpdate::new_with_scene(
136                                epoch,
137                                scene_id,
138                                *custom_object_id,
139                            ))
140                            .build(),
141                    )
142                    .build()
143                    .unwrap()
144            })
145            .collect::<Vec<_>>();
146        let num_candidates = tracks.len();
147        let (dists, errs) =
148            self.store
149                .write()
150                .unwrap()
151                .foreign_track_distances(tracks.clone(), 0, false);
152        assert!(errs.all().is_empty());
153        let dists = dists.all();
154        let voting = SortVoting::new(
155            match self.method {
156                PositionalMetricType::Mahalanobis => MAHALANOBIS_NEW_TRACK_THRESHOLD,
157                PositionalMetricType::IoU(t) => t,
158            },
159            num_candidates,
160            self.store.read().unwrap().shard_stats().iter().sum(),
161        );
162        let winners = voting.winners(dists);
163        let mut res = Vec::default();
164
165        for mut t in tracks {
166            let source = t.get_track_id();
167            let track_id: u64 = if let Some(dest) = winners.get(&source) {
168                let dest = dest[0];
169                if dest == source {
170                    let track_id = self.gen_track_id();
171                    t.set_track_id(track_id);
172                    self.store.write().unwrap().add_track(t).unwrap();
173                    track_id
174                } else {
175                    self.store
176                        .write()
177                        .unwrap()
178                        .merge_external(dest, &t, Some(&[0]), false)
179                        .unwrap();
180                    dest
181                }
182            } else {
183                let track_id = self.gen_track_id();
184                t.set_track_id(track_id);
185                self.store.write().unwrap().add_track(t).unwrap();
186                track_id
187            };
188
189            let lock = self.store.read().unwrap();
190            let store = lock.get_store(track_id as usize);
191            let track = store.get(&track_id).unwrap();
192            res.push(SortTrack::from(track));
193        }
194
195        res
196    }
197
198    pub fn idle_tracks(&mut self) -> Vec<SortTrack> {
199        self.idle_tracks_with_scene(0)
200    }
201
202    pub fn idle_tracks_with_scene(&mut self, scene_id: u64) -> Vec<SortTrack> {
203        let store = self.store.read().unwrap();
204
205        store
206            .lookup(SortLookup::IdleLookup(scene_id))
207            .iter()
208            .map(|(track_id, _status)| {
209                let shard = store.get_store(*track_id as usize);
210                let track = shard.get(track_id).unwrap();
211                SortTrack::from(track)
212            })
213            .collect()
214    }
215}
216
217impl TrackerAPI<SortAttributes, SortMetric, Universal2DBox, SortAttributesOptions, NoopNotifier>
218    for Sort
219{
220    fn get_auto_waste_obj_mut(&mut self) -> &mut AutoWaste {
221        &mut self.auto_waste
222    }
223
224    fn get_opts(&self) -> &SortAttributesOptions {
225        &self.opts
226    }
227
228    fn get_main_store_mut(
229        &mut self,
230    ) -> RwLockWriteGuard<TrackStore<SortAttributes, SortMetric, Universal2DBox, NoopNotifier>>
231    {
232        self.store.write().unwrap()
233    }
234
235    fn get_wasted_store_mut(
236        &mut self,
237    ) -> RwLockWriteGuard<TrackStore<SortAttributes, SortMetric, Universal2DBox, NoopNotifier>>
238    {
239        self.wasted_store.write().unwrap()
240    }
241
242    fn get_main_store(
243        &self,
244    ) -> RwLockReadGuard<TrackStore<SortAttributes, SortMetric, Universal2DBox, NoopNotifier>> {
245        self.store.read().unwrap()
246    }
247
248    fn get_wasted_store(
249        &self,
250    ) -> RwLockReadGuard<TrackStore<SortAttributes, SortMetric, Universal2DBox, NoopNotifier>> {
251        self.wasted_store.read().unwrap()
252    }
253}
254
255impl From<&Track<SortAttributes, SortMetric, Universal2DBox>> for SortTrack {
256    fn from(track: &Track<SortAttributes, SortMetric, Universal2DBox>) -> Self {
257        let attrs = track.get_attributes();
258        SortTrack {
259            id: track.get_track_id(),
260            custom_object_id: attrs.custom_object_id,
261            voting_type: VotingType::Positional,
262            epoch: attrs.last_updated_epoch,
263            scene_id: attrs.scene_id,
264            observed_bbox: attrs.observed_boxes.back().unwrap().clone(),
265            predicted_bbox: attrs.predicted_boxes.back().unwrap().clone(),
266            length: attrs.track_length,
267        }
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use crate::trackers::sort::metric::DEFAULT_MINIMAL_SORT_CONFIDENCE;
274    use crate::trackers::sort::simple_api::Sort;
275    use crate::trackers::sort::PositionalMetricType::IoU;
276    use crate::trackers::sort::DEFAULT_SORT_IOU_THRESHOLD;
277    use crate::trackers::tracker_api::TrackerAPI;
278    use crate::utils::bbox::BoundingBox;
279
280    #[test]
281    fn sort() {
282        let mut t = Sort::new(
283            1,
284            10,
285            2,
286            IoU(DEFAULT_SORT_IOU_THRESHOLD),
287            DEFAULT_MINIMAL_SORT_CONFIDENCE,
288            None,
289            1.0 / 20.0,
290            1.0 / 160.0,
291        );
292        assert_eq!(t.current_epoch(), 0);
293        let bb = BoundingBox::new(0.0, 0.0, 10.0, 20.0);
294        let v = t.predict(&[(bb.into(), None)]);
295        let wasted = t.wasted();
296        assert!(wasted.is_empty());
297        assert_eq!(v.len(), 1);
298        let v = v[0].clone();
299        let track_id = v.id;
300        assert_eq!(v.custom_object_id, None);
301        assert_eq!(v.length, 1);
302        assert_eq!(v.observed_bbox, bb.into());
303        assert_eq!(v.epoch, 1);
304        assert_eq!(t.current_epoch(), 1);
305
306        let bb = BoundingBox::new(0.1, 0.1, 10.1, 20.0);
307        let v = t.predict(&[(bb.into(), Some(2))]);
308        let wasted = t.wasted();
309        assert!(wasted.is_empty());
310        assert_eq!(v.len(), 1);
311        let v = v[0].clone();
312        assert_eq!(v.custom_object_id, Some(2));
313        assert_eq!(v.id, track_id);
314        assert_eq!(v.length, 2);
315        assert_eq!(v.observed_bbox, bb.into());
316        assert_eq!(v.epoch, 2);
317        assert_eq!(t.current_epoch(), 2);
318
319        let bb = BoundingBox::new(10.1, 10.1, 10.1, 20.0);
320        let v = t.predict(&[(bb.into(), Some(3))]);
321        assert_eq!(v.len(), 1);
322        let v = v[0].clone();
323        assert_eq!(v.custom_object_id, Some(3));
324        assert_ne!(v.id, track_id);
325        let wasted = t.wasted();
326        assert!(wasted.is_empty());
327        assert_eq!(t.current_epoch(), 3);
328
329        let bb = t.predict(&[]);
330        assert!(bb.is_empty());
331        let wasted = t.wasted();
332        assert!(wasted.is_empty());
333        assert_eq!(t.current_epoch(), 4);
334        assert_eq!(t.current_epoch(), 4);
335
336        let bb = t.predict(&[]);
337        assert!(bb.is_empty());
338        let wasted = t.wasted();
339        assert_eq!(wasted.len(), 1);
340        assert_eq!(wasted[0].get_track_id(), track_id);
341        assert_eq!(t.current_epoch(), 5);
342    }
343
344    #[test]
345    fn sort_with_scenes() {
346        let mut t = Sort::new(
347            1,
348            10,
349            2,
350            IoU(DEFAULT_SORT_IOU_THRESHOLD),
351            DEFAULT_MINIMAL_SORT_CONFIDENCE,
352            None,
353            1.0 / 20.0,
354            1.0 / 160.0,
355        );
356        let bb = BoundingBox::new(0.0, 0.0, 10.0, 20.0);
357        assert_eq!(t.current_epoch_with_scene(1), 0);
358        assert_eq!(t.current_epoch_with_scene(2), 0);
359
360        let _v = t.predict_with_scene(1, &[(bb.into(), Some(4))]);
361        let _v = t.predict_with_scene(1, &[(bb.into(), Some(5))]);
362
363        assert_eq!(t.current_epoch_with_scene(1), 2);
364        assert_eq!(t.current_epoch_with_scene(2), 0);
365
366        let _v = t.predict_with_scene(2, &[(bb.into(), Some(6))]);
367
368        assert_eq!(t.current_epoch_with_scene(1), 2);
369        assert_eq!(t.current_epoch_with_scene(2), 1);
370    }
371
372    #[test]
373    fn idle_tracks() {
374        let mut t = Sort::new(
375            1,
376            10,
377            2,
378            IoU(DEFAULT_SORT_IOU_THRESHOLD),
379            DEFAULT_MINIMAL_SORT_CONFIDENCE,
380            None,
381            1.0 / 20.0,
382            1.0 / 160.0,
383        );
384        let bb = BoundingBox::new(0.0, 0.0, 10.0, 20.0);
385
386        let _v = t.predict_with_scene(1, &[(bb.into(), Some(4))]);
387        let idle = t.idle_tracks_with_scene(1);
388        assert!(idle.is_empty());
389
390        let _v = t.predict_with_scene(1, &[]);
391
392        let idle = t.idle_tracks_with_scene(1);
393        assert_eq!(idle.len(), 1);
394        assert_eq!(idle[0].id, 1);
395    }
396
397    #[test]
398    fn clear_wasted_tracks() {
399        let mut t = Sort::new(
400            1,
401            10,
402            2,
403            IoU(DEFAULT_SORT_IOU_THRESHOLD),
404            DEFAULT_MINIMAL_SORT_CONFIDENCE,
405            None,
406            1.0 / 20.0,
407            1.0 / 160.0,
408        );
409        let bb = BoundingBox::new(0.0, 0.0, 10.0, 20.0);
410
411        let _v = t.predict_with_scene(1, &[(bb.into(), Some(4))]);
412        t.skip_epochs_for_scene(1, 3);
413        assert_eq!(
414            t.wasted_store
415                .read()
416                .unwrap()
417                .shard_stats()
418                .iter()
419                .sum::<usize>(),
420            1
421        );
422        t.clear_wasted();
423        assert_eq!(
424            t.wasted_store
425                .read()
426                .unwrap()
427                .shard_stats()
428                .iter()
429                .sum::<usize>(),
430            0
431        );
432    }
433}
434
435#[cfg(feature = "python")]
436pub mod python {
437    use pyo3::prelude::*;
438
439    use crate::{
440        prelude::Universal2DBox,
441        trackers::{
442            sort::{
443                python::{PyPositionalMetricType, PySortTrack, PyWastedSortTrack},
444                WastedSortTrack,
445            },
446            spatio_temporal_constraints::python::PySpatioTemporalConstraints,
447            tracker_api::TrackerAPI,
448        },
449        utils::bbox::python::PyUniversal2DBox,
450    };
451
452    use super::Sort;
453
454    #[pyclass]
455    #[pyo3(name = "Sort")]
456    pub struct PySort(pub Sort);
457
458    #[pymethods]
459    impl PySort {
460        #[new]
461        #[pyo3(signature = (
462            shards = 4,
463            bbox_history = 1,
464            max_idle_epochs = 5,
465            method = None,
466            min_confidence = 0.05,
467            spatio_temporal_constraints = None,
468            kalman_position_weight = 1.0 / 20.0,
469            kalman_velocity_weight = 1.0 / 160.0
470        ))]
471        #[allow(clippy::too_many_arguments)]
472        pub fn new_py(
473            shards: i64,
474            bbox_history: i64,
475            max_idle_epochs: i64,
476            method: Option<PyPositionalMetricType>,
477            min_confidence: f32,
478            spatio_temporal_constraints: Option<PySpatioTemporalConstraints>,
479            kalman_position_weight: f32,
480            kalman_velocity_weight: f32,
481        ) -> Self {
482            Self(Sort::new(
483                shards.try_into().expect("Positive number expected"),
484                bbox_history.try_into().expect("Positive number expected"),
485                max_idle_epochs
486                    .try_into()
487                    .expect("Positive number expected"),
488                method.unwrap_or(PyPositionalMetricType::maha()).0,
489                min_confidence,
490                spatio_temporal_constraints.map(|x| x.0),
491                kalman_position_weight,
492                kalman_velocity_weight,
493            ))
494        }
495
496        #[pyo3(signature = (n))]
497        pub fn skip_epochs(&mut self, n: i64) {
498            assert!(n > 0);
499            self.0.skip_epochs(n.try_into().unwrap())
500        }
501
502        #[pyo3(signature = (scene_id, n))]
503        pub fn skip_epochs_for_scene(&mut self, scene_id: i64, n: i64) {
504            assert!(n > 0 && scene_id >= 0);
505            self.0
506                .skip_epochs_for_scene(scene_id.try_into().unwrap(), n.try_into().unwrap())
507        }
508
509        /// Get the amount of stored tracks per shard
510        ///
511        #[pyo3(signature = ())]
512        pub fn shard_stats(&self) -> Vec<i64> {
513            Python::with_gil(|py| {
514                py.allow_threads(|| {
515                    self.0
516                        .store
517                        .read()
518                        .unwrap()
519                        .shard_stats()
520                        .into_iter()
521                        .map(|e| i64::try_from(e).unwrap())
522                        .collect()
523                })
524            })
525        }
526
527        /// Get the current epoch for `scene_id` == 0
528        ///
529        #[pyo3(signature = ())]
530        pub fn current_epoch(&self) -> i64 {
531            self.0.current_epoch_with_scene(0).try_into().unwrap()
532        }
533
534        /// Get the current epoch for `scene_id`
535        ///
536        /// # Parameters
537        /// * `scene_id` - scene id
538        ///
539        #[pyo3(signature = (scene_id))]
540        pub fn current_epoch_with_scene(&self, scene_id: i64) -> isize {
541            assert!(scene_id >= 0);
542            self.0
543                .current_epoch_with_scene(scene_id.try_into().unwrap())
544                .try_into()
545                .unwrap()
546        }
547
548        /// Receive tracking information for observed bboxes of `scene_id` == 0
549        ///
550        /// # Parameters
551        /// * `bboxes` - bounding boxes received from a detector
552        ///
553        #[pyo3(signature = (bboxes))]
554        pub fn predict(
555            &mut self,
556            bboxes: Vec<(PyUniversal2DBox, Option<i64>)>,
557        ) -> Vec<PySortTrack> {
558            self.predict_with_scene(0, bboxes)
559        }
560
561        /// Receive tracking information for observed bboxes of `scene_id`
562        ///
563        /// # Parameters
564        /// * `scene_id` - scene id provided by a user (class, camera id, etc...)
565        /// * `bboxes` - bounding boxes received from a detector
566        ///
567        #[pyo3(signature = (scene_id, bboxes))]
568        pub fn predict_with_scene(
569            &mut self,
570            scene_id: i64,
571            bboxes: Vec<(PyUniversal2DBox, Option<i64>)>,
572        ) -> Vec<PySortTrack> {
573            assert!(scene_id >= 0);
574            let bboxes: Vec<(Universal2DBox, Option<i64>)> = unsafe { std::mem::transmute(bboxes) };
575
576            Python::with_gil(|py| {
577                py.allow_threads(|| unsafe {
578                    std::mem::transmute(
579                        self.0
580                            .predict_with_scene(scene_id.try_into().unwrap(), &bboxes),
581                    )
582                })
583            })
584        }
585
586        /// Fetch and remove all the tracks with expired life
587        ///
588        #[pyo3(signature = ())]
589        pub fn wasted(&mut self) -> Vec<PyWastedSortTrack> {
590            Python::with_gil(|py| {
591                py.allow_threads(|| {
592                    self.0
593                        .wasted()
594                        .into_iter()
595                        .map(WastedSortTrack::from)
596                        .map(PyWastedSortTrack)
597                        .collect()
598                })
599            })
600        }
601
602        /// Clear all tracks with expired life
603        ///
604        #[pyo3(signature = ())]
605        pub fn clear_wasted(&mut self) {
606            Python::with_gil(|py| {
607                py.allow_threads(|| self.0.clear_wasted());
608            })
609        }
610
611        /// Get idle tracks with not expired life
612        ///
613        #[pyo3(signature = ())]
614        pub fn idle_tracks(&mut self) -> Vec<PySortTrack> {
615            self.idle_tracks_with_scene(0)
616        }
617
618        /// Get idle tracks with not expired life
619        ///
620        #[pyo3(signature = (scene_id))]
621        pub fn idle_tracks_with_scene(&mut self, scene_id: i64) -> Vec<PySortTrack> {
622            Python::with_gil(|py| {
623                py.allow_threads(|| unsafe {
624                    std::mem::transmute(self.0.idle_tracks_with_scene(scene_id.try_into().unwrap()))
625                })
626            })
627        }
628    }
629}