cecile_supercool_tracker/trackers/visual_sort/
simple_api.rs

1use crate::prelude::{NoopNotifier, ObservationBuilder, SortTrack, TrackStoreBuilder};
2use crate::store::TrackStore;
3use crate::track::utils::FromVec;
4use crate::track::{Feature, Track};
5use crate::trackers::epoch_db::EpochDb;
6use crate::trackers::sort::VotingType::Positional;
7use crate::trackers::sort::{
8    AutoWaste, PositionalMetricType, SortAttributesOptions, DEFAULT_AUTO_WASTE_PERIODICITY,
9    MAHALANOBIS_NEW_TRACK_THRESHOLD,
10};
11use crate::trackers::tracker_api::TrackerAPI;
12use crate::trackers::visual_sort::metric::{VisualMetric, VisualMetricOptions};
13use crate::trackers::visual_sort::observation_attributes::VisualObservationAttributes;
14use crate::trackers::visual_sort::options::VisualSortOptions;
15use crate::trackers::visual_sort::track_attributes::{
16    VisualAttributes, VisualAttributesUpdate, VisualSortLookup,
17};
18use crate::trackers::visual_sort::voting::VisualVoting;
19use crate::trackers::visual_sort::VisualSortObservation;
20use crate::utils::clipping::bbox_own_areas::{
21    exclusively_owned_areas, exclusively_owned_areas_normalized_shares,
22};
23use crate::voting::Voting;
24use rand::Rng;
25use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
26
27// /// Easy to use Visual SORT tracker implementation
28// ///
29pub struct VisualSort {
30    store: RwLock<TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes>>,
31    wasted_store: RwLock<TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes>>,
32    metric_opts: Arc<VisualMetricOptions>,
33    track_opts: Arc<SortAttributesOptions>,
34    auto_waste: AutoWaste,
35    track_id: u64,
36}
37
38impl VisualSort {
39    /// Creates new tracker
40    ///
41    /// # Parameters
42    /// * `shards` - amount of cpu threads to process the data, keep 1 for up to 100 simultaneously tracked objects, try it before setting high - higher numbers may lead to unexpected latencies.
43    /// * `opts` - tracker options
44    ///
45    pub fn new(shards: usize, opts: &VisualSortOptions) -> Self {
46        let (track_opts, metric) = opts.clone().build();
47        let track_opts = Arc::new(track_opts);
48        let metric_opts = metric.opts.clone();
49        let store = RwLock::new(
50            TrackStoreBuilder::new(shards)
51                .default_attributes(VisualAttributes::new(track_opts.clone()))
52                .metric(metric.clone())
53                .notifier(NoopNotifier)
54                .build(),
55        );
56
57        let wasted_store = RwLock::new(
58            TrackStoreBuilder::new(shards)
59                .default_attributes(VisualAttributes::new(track_opts.clone()))
60                .metric(metric)
61                .notifier(NoopNotifier)
62                .build(),
63        );
64
65        Self {
66            store,
67            wasted_store,
68            track_opts,
69            track_id: 0,
70            metric_opts,
71            auto_waste: AutoWaste {
72                periodicity: DEFAULT_AUTO_WASTE_PERIODICITY,
73                counter: DEFAULT_AUTO_WASTE_PERIODICITY,
74            },
75        }
76    }
77
78    /// Receive tracking information for observed bboxes of `scene_id == 0`
79    ///
80    /// # Parameters
81    /// * `scene_id` - custom identifier for the group of observed objects;
82    /// * `observations` - object observations with (feature, feature_quality and bounding box).
83    ///
84    pub fn predict(&mut self, observations: &[VisualSortObservation]) -> Vec<SortTrack> {
85        self.predict_with_scene(0, observations)
86    }
87
88    fn gen_track_id(&mut self) -> u64 {
89        self.track_id += 1;
90        self.track_id
91    }
92
93    /// Receive tracking information for observed bboxes of `scene_id`
94    ///
95    /// # Parameters
96    /// * `scene_id` - custom identifier for the group of observed objects;
97    /// * `observations` - object observations with (feature, feature_quality and bounding box).
98    ///
99    pub fn predict_with_scene(
100        &mut self,
101        scene_id: u64,
102        observations: &[VisualSortObservation],
103    ) -> Vec<SortTrack> {
104        if self.auto_waste.counter == 0 {
105            self.auto_waste();
106            self.auto_waste.counter = self.auto_waste.periodicity;
107        } else {
108            self.auto_waste.counter -= 1;
109        }
110
111        let mut percentages = Vec::default();
112        let use_own_area_percentage = self.metric_opts.visual_minimal_own_area_percentage_collect
113            + self.metric_opts.visual_minimal_own_area_percentage_use
114            > 0.0;
115
116        if use_own_area_percentage {
117            percentages.reserve(observations.len());
118            let boxes = observations
119                .iter()
120                .map(|e| &e.bounding_box)
121                .collect::<Vec<_>>();
122
123            percentages = exclusively_owned_areas_normalized_shares(
124                boxes.as_ref(),
125                exclusively_owned_areas(boxes.as_ref()).as_ref(),
126            );
127        }
128
129        let mut rng = rand::thread_rng();
130        let epoch = self.track_opts.next_epoch(scene_id).unwrap();
131
132        let mut tracks = observations
133            .iter()
134            .enumerate()
135            .map(|(i, o)| {
136                self.store
137                    .read()
138                    .unwrap()
139                    .new_track(rng.gen())
140                    .observation({
141                        let mut obs = ObservationBuilder::new(0).observation_attributes(
142                            if use_own_area_percentage {
143                                VisualObservationAttributes::with_own_area_percentage(
144                                    o.feature_quality.unwrap_or(1.0),
145                                    o.bounding_box.clone(),
146                                    percentages[i],
147                                )
148                            } else {
149                                VisualObservationAttributes::new(
150                                    o.feature_quality.unwrap_or(1.0),
151                                    o.bounding_box.clone(),
152                                )
153                            },
154                        );
155
156                        if let Some(feature) = &o.feature {
157                            obs = obs.observation(Feature::from_vec(feature.to_vec()));
158                        }
159
160                        obs.track_attributes_update(VisualAttributesUpdate::new_init_with_scene(
161                            epoch,
162                            scene_id,
163                            o.custom_object_id,
164                        ))
165                        .build()
166                    })
167                    .build()
168                    .unwrap()
169            })
170            .collect::<Vec<_>>();
171
172        let (dists, errs) =
173            self.store
174                .write()
175                .unwrap()
176                .foreign_track_distances(tracks.clone(), 0, false);
177
178        assert!(errs.all().is_empty());
179        let voting = VisualVoting::new(
180            match self.metric_opts.positional_kind {
181                PositionalMetricType::Mahalanobis => MAHALANOBIS_NEW_TRACK_THRESHOLD,
182                PositionalMetricType::IoU(t) => t,
183            },
184            f32::MAX,
185            self.metric_opts.visual_min_votes,
186        );
187        let winners = voting.winners(dists);
188        let mut res = Vec::default();
189        for t in &mut tracks {
190            let source = t.get_track_id();
191            let track_id: u64 = if let Some(dest) = winners.get(&source) {
192                let (dest, vt) = dest[0];
193                if dest == source {
194                    let mut t = t.clone();
195                    let track_id = self.gen_track_id();
196                    t.set_track_id(track_id);
197                    self.store.write().unwrap().add_track(t).unwrap();
198                    track_id
199                } else {
200                    t.add_observation(
201                        0,
202                        None,
203                        None,
204                        Some(VisualAttributesUpdate::new_voting_type(vt)),
205                    )
206                    .unwrap();
207                    self.store
208                        .write()
209                        .unwrap()
210                        .merge_external(dest, t, Some(&[0]), false)
211                        .unwrap();
212                    dest
213                }
214            } else {
215                let mut t = t.clone();
216                let track_id = self.gen_track_id();
217                t.set_track_id(track_id);
218                self.store.write().unwrap().add_track(t).unwrap();
219                track_id
220            };
221
222            let lock = self.store.read().unwrap();
223            let store = lock.get_store(track_id as usize);
224            let track = store.get(&track_id).unwrap();
225
226            res.push(SortTrack::from(track))
227        }
228
229        res
230    }
231
232    pub fn idle_tracks(&mut self) -> Vec<SortTrack> {
233        self.idle_tracks_with_scene(0)
234    }
235
236    pub fn idle_tracks_with_scene(&mut self, scene_id: u64) -> Vec<SortTrack> {
237        let store = self.store.read().unwrap();
238        store
239            .lookup(VisualSortLookup::IdleLookup(scene_id))
240            .iter()
241            .map(|(track_id, _status)| {
242                let shard = store.get_store(*track_id as usize);
243                let track = shard.get(track_id).unwrap();
244                SortTrack::from(track)
245            })
246            .collect()
247    }
248}
249
250impl
251    TrackerAPI<
252        VisualAttributes,
253        VisualMetric,
254        VisualObservationAttributes,
255        SortAttributesOptions,
256        NoopNotifier,
257    > for VisualSort
258{
259    fn get_auto_waste_obj_mut(&mut self) -> &mut AutoWaste {
260        &mut self.auto_waste
261    }
262
263    fn get_opts(&self) -> &SortAttributesOptions {
264        &self.track_opts
265    }
266
267    fn get_main_store_mut(
268        &mut self,
269    ) -> RwLockWriteGuard<
270        TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes, NoopNotifier>,
271    > {
272        self.store.write().unwrap()
273    }
274
275    fn get_wasted_store_mut(
276        &mut self,
277    ) -> RwLockWriteGuard<
278        TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes, NoopNotifier>,
279    > {
280        self.wasted_store.write().unwrap()
281    }
282
283    fn get_main_store(
284        &self,
285    ) -> RwLockReadGuard<
286        TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes, NoopNotifier>,
287    > {
288        self.store.read().unwrap()
289    }
290
291    fn get_wasted_store(
292        &self,
293    ) -> RwLockReadGuard<
294        TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes, NoopNotifier>,
295    > {
296        self.wasted_store.read().unwrap()
297    }
298}
299
300impl From<&Track<VisualAttributes, VisualMetric, VisualObservationAttributes>> for SortTrack {
301    fn from(track: &Track<VisualAttributes, VisualMetric, VisualObservationAttributes>) -> Self {
302        let attrs = track.get_attributes();
303        SortTrack {
304            id: track.get_track_id(),
305            custom_object_id: attrs.custom_object_id,
306            voting_type: attrs.voting_type.unwrap_or(Positional),
307            epoch: attrs.last_updated_epoch,
308            scene_id: attrs.scene_id,
309            observed_bbox: attrs.observed_boxes.back().unwrap().clone(),
310            predicted_bbox: attrs.predicted_boxes.back().unwrap().clone(),
311            length: attrs.track_length,
312        }
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use crate::track::Observation;
319    use crate::trackers::sort::{PositionalMetricType, VotingType};
320    use crate::trackers::tracker_api::TrackerAPI;
321    use crate::trackers::visual_sort::metric::VisualSortMetricType;
322    use crate::trackers::visual_sort::observation_attributes::VisualObservationAttributes;
323    use crate::trackers::visual_sort::options::VisualSortOptions;
324    use crate::trackers::visual_sort::simple_api::VisualSort;
325    use crate::trackers::visual_sort::{VisualSortObservation, WastedVisualSortTrack};
326    use crate::utils::bbox::BoundingBox;
327
328    #[test]
329    fn visual_sort() {
330        let opts = VisualSortOptions::default()
331            .max_idle_epochs(3)
332            .kept_history_length(3)
333            .visual_metric(VisualSortMetricType::Euclidean(1.0))
334            .positional_metric(PositionalMetricType::Mahalanobis)
335            .visual_minimal_track_length(2)
336            .visual_minimal_area(5.0)
337            .visual_minimal_quality_use(0.45)
338            .visual_minimal_quality_collect(0.7)
339            .visual_max_observations(3)
340            .visual_min_votes(2);
341
342        let mut tracker = VisualSort::new(1, &opts);
343
344        // new track to be initialized
345        //
346        let tracks = tracker.predict_with_scene(
347            10,
348            &[VisualSortObservation::new(
349                Some(&vec![1.0, 1.0]),
350                Some(0.9),
351                BoundingBox::new(1.0, 1.0, 3.0, 5.0).as_xyaah(),
352                Some(13),
353            )],
354        );
355        let t = &tracks[0];
356        assert_eq!(t.custom_object_id, Some(13));
357        assert_eq!(t.scene_id, 10);
358        assert!(matches!(t.voting_type, VotingType::Positional));
359        assert!(matches!(t.epoch, 1));
360        let attrs = {
361            let lock = tracker.store.read().unwrap();
362            let store = lock.get_store(t.id as usize);
363            let track = store.get(&t.id).unwrap();
364            track.get_attributes().clone()
365        };
366        assert_eq!(attrs.visual_features_collected_count, 1);
367        assert_eq!(attrs.track_length, 1);
368        assert_eq!(attrs.observed_boxes.len(), 1);
369        assert_eq!(attrs.predicted_boxes.len(), 1);
370        assert_eq!(attrs.observed_features.len(), 1);
371        let first_track_id = t.id;
372
373        {
374            // another scene - new track
375            let tracks = tracker.predict_with_scene(
376                1,
377                &[VisualSortObservation::new(
378                    Some(&vec![1.0, 1.0]),
379                    Some(0.9),
380                    BoundingBox::new(1.0, 1.0, 3.0, 5.0).as_xyaah(),
381                    Some(133),
382                )],
383            );
384            let t = &tracks[0];
385            assert_eq!(t.custom_object_id, Some(133));
386            assert_eq!(t.scene_id, 1);
387            assert!(matches!(t.voting_type, VotingType::Positional));
388            assert!(matches!(t.epoch, 1));
389            let attrs = {
390                let lock = tracker.store.read().unwrap();
391                let store = lock.get_store(t.id as usize);
392                let track = store.get(&t.id).unwrap();
393                track.get_attributes().clone()
394            };
395            assert_eq!(attrs.visual_features_collected_count, 1);
396            assert_eq!(attrs.track_length, 1);
397            assert_eq!(attrs.observed_boxes.len(), 1);
398            assert_eq!(attrs.predicted_boxes.len(), 1);
399            assert_eq!(attrs.observed_features.len(), 1);
400        }
401
402        // add the segment to the track (merge by bbox pos)
403        //
404        let tracks = tracker.predict_with_scene(
405            10,
406            &[VisualSortObservation::new(
407                Some(&vec![0.95, 0.95]),
408                Some(0.93),
409                BoundingBox::new(1.1, 1.1, 3.05, 5.01).as_xyaah(),
410                Some(15),
411            )],
412        );
413        let t = &tracks[0];
414        assert_eq!(t.id, first_track_id);
415        assert_eq!(t.custom_object_id, Some(15));
416        assert_eq!(t.scene_id, 10);
417        assert!(matches!(t.voting_type, VotingType::Positional));
418        assert!(matches!(t.epoch, 2));
419        let attrs = {
420            let lock = tracker.store.read().unwrap();
421            let store = lock.get_store(t.id as usize);
422            let track = store.get(&t.id).unwrap();
423            track.get_attributes().clone()
424        };
425        assert_eq!(attrs.visual_features_collected_count, 2);
426        assert_eq!(attrs.track_length, 2);
427        assert_eq!(attrs.observed_boxes.len(), 2);
428        assert_eq!(attrs.predicted_boxes.len(), 2);
429        assert_eq!(attrs.observed_features.len(), 2);
430
431        // add the segment to the track (no visual_sort feature)
432        //
433        let tracks = tracker.predict_with_scene(
434            10,
435            &[VisualSortObservation::new(
436                None,
437                Some(0.93),
438                BoundingBox::new(1.11, 1.15, 3.15, 5.05).as_xyaah(),
439                Some(25),
440            )],
441        );
442        let t = &tracks[0];
443        assert_eq!(t.id, first_track_id);
444        assert_eq!(t.custom_object_id, Some(25));
445        assert_eq!(t.scene_id, 10);
446        assert!(matches!(t.voting_type, VotingType::Positional));
447        assert!(matches!(t.epoch, 3));
448        let attrs = {
449            let lock = tracker.store.read().unwrap();
450            let store = lock.get_store(t.id as usize);
451            let track = store.get(&t.id).unwrap();
452            track.get_attributes().clone()
453        };
454        assert_eq!(attrs.visual_features_collected_count, 2);
455        assert_eq!(attrs.track_length, 3);
456        assert_eq!(attrs.observed_boxes.len(), 3);
457        assert_eq!(attrs.predicted_boxes.len(), 3);
458        assert_eq!(attrs.observed_features.len(), 3);
459        assert!(attrs.observed_features.back().unwrap().is_none());
460
461        // add the segment to the track (no visual_sort feature)
462        //
463        let tracks = tracker.predict_with_scene(
464            10,
465            &[VisualSortObservation::new(
466                None,
467                Some(0.93),
468                BoundingBox::new(1.15, 1.25, 3.10, 5.05).as_xyaah(),
469                Some(2),
470            )],
471        );
472        let t = &tracks[0];
473        assert_eq!(t.id, first_track_id);
474        assert!(matches!(t.voting_type, VotingType::Positional));
475        assert!(matches!(t.epoch, 4));
476        let attrs = {
477            let lock = tracker.store.read().unwrap();
478            let store = lock.get_store(t.id as usize);
479            let track = store.get(&t.id).unwrap();
480            track.get_attributes().clone()
481        };
482        assert_eq!(attrs.visual_features_collected_count, 2);
483        assert_eq!(attrs.track_length, 4);
484        assert_eq!(attrs.observed_boxes.len(), 3);
485        assert_eq!(attrs.predicted_boxes.len(), 3);
486        assert_eq!(attrs.observed_features.len(), 3);
487        assert!(attrs.observed_features.back().unwrap().is_none());
488
489        // add the segment to the track (with visual_sort feature but low quality - no use, no collect)
490        //
491        let tracks = tracker.predict_with_scene(
492            10,
493            &[VisualSortObservation::new(
494                Some(&vec![0.97, 0.97]),
495                Some(0.44),
496                BoundingBox::new(1.15, 1.25, 3.10, 5.05).as_xyaah(),
497                Some(2),
498            )],
499        );
500        let t = &tracks[0];
501        assert_eq!(t.id, first_track_id);
502        assert!(matches!(t.voting_type, VotingType::Positional));
503        let attrs = {
504            let lock = tracker.store.read().unwrap();
505            let store = lock.get_store(t.id as usize);
506            let track = store.get(&t.id).unwrap();
507            track.get_attributes().clone()
508        };
509        assert_eq!(attrs.visual_features_collected_count, 2);
510        assert_eq!(attrs.track_length, 5);
511        assert!(attrs.observed_features.back().unwrap().is_some());
512
513        // add the segment to the track (with visual_sort feature but low quality - use, but no collect)
514        //
515        let tracks = tracker.predict_with_scene(
516            10,
517            &[VisualSortObservation::new(
518                Some(&vec![0.97, 0.97]),
519                Some(0.6),
520                BoundingBox::new(1.15, 1.25, 3.10, 5.05).as_xyaah(),
521                Some(2),
522            )],
523        );
524        let t = &tracks[0];
525        assert_eq!(t.id, first_track_id);
526        assert!(matches!(t.voting_type, VotingType::Visual));
527        let attrs = {
528            let lock = tracker.store.read().unwrap();
529            let store = lock.get_store(t.id as usize);
530            let track = store.get(&t.id).unwrap();
531            track.get_attributes().clone()
532        };
533        assert_eq!(attrs.visual_features_collected_count, 2);
534        assert_eq!(attrs.track_length, 6);
535        assert!(attrs.observed_features.back().unwrap().is_some());
536
537        // add the segment to the track (with visual_sort feature of normal quality - use, collect)
538        //
539        let tracks = tracker.predict_with_scene(
540            10,
541            &[VisualSortObservation::new(
542                Some(&vec![0.97, 0.97]),
543                Some(0.8),
544                BoundingBox::new(1.15, 1.25, 3.10, 5.05).as_xyaah(),
545                Some(2),
546            )],
547        );
548        let t = &tracks[0];
549        assert_eq!(t.id, first_track_id);
550        assert!(matches!(t.voting_type, VotingType::Visual));
551        let attrs = {
552            let lock = tracker.store.read().unwrap();
553            let store = lock.get_store(t.id as usize);
554            let track = store.get(&t.id).unwrap();
555            let observations = track.get_observations(0).unwrap();
556
557            fn bbox_is(b: &Observation<VisualObservationAttributes>) -> bool {
558                b.attr().as_ref().unwrap().bbox_opt().is_some()
559            }
560
561            assert!(bbox_is(&observations[0]) && observations[0].feature().is_some());
562            assert!(!bbox_is(&observations[1]) && observations[1].feature().is_some());
563            assert!(!bbox_is(&observations[2]) && observations[2].feature().is_some());
564
565            track.get_attributes().clone()
566        };
567        assert_eq!(attrs.visual_features_collected_count, 3);
568        assert_eq!(attrs.track_length, 7);
569        assert!(attrs.observed_features.back().unwrap().is_some());
570
571        // new track to be initialized
572        //
573        let tracks = tracker.predict_with_scene(
574            10,
575            &[VisualSortObservation::new(
576                Some(&vec![0.1, 0.1]),
577                Some(0.9),
578                BoundingBox::new(10.0, 10.0, 3.0, 5.0).as_xyaah(),
579                Some(33),
580            )],
581        );
582        let t = &tracks[0];
583        assert_eq!(t.custom_object_id, Some(33));
584        assert_eq!(t.scene_id, 10);
585        assert!(matches!(t.voting_type, VotingType::Positional));
586        assert!(matches!(t.epoch, 8));
587        assert_ne!(t.id, first_track_id);
588        let attrs = {
589            let lock = tracker.store.read().unwrap();
590            let store = lock.get_store(t.id as usize);
591            let track = store.get(&t.id).unwrap();
592            track.get_attributes().clone()
593        };
594        assert_eq!(attrs.visual_features_collected_count, 1);
595        assert_eq!(attrs.track_length, 1);
596        assert_eq!(attrs.observed_boxes.len(), 1);
597        assert_eq!(attrs.predicted_boxes.len(), 1);
598        assert_eq!(attrs.observed_features.len(), 1);
599        let other_track_id = t.id;
600
601        // add segment to be initialized
602        //
603        let tracks = tracker.predict_with_scene(
604            10,
605            &[VisualSortObservation::new(
606                Some(&vec![0.12, 0.15]),
607                Some(0.88),
608                BoundingBox::new(10.1, 10.1, 3.0, 5.0).as_xyaah(),
609                Some(35),
610            )],
611        );
612        let t = &tracks[0];
613        assert_eq!(t.custom_object_id, Some(35));
614        assert_eq!(t.scene_id, 10);
615        assert!(matches!(t.voting_type, VotingType::Positional));
616        assert!(matches!(t.epoch, 9));
617        assert_eq!(t.id, other_track_id);
618        let attrs = {
619            let lock = tracker.store.read().unwrap();
620            let store = lock.get_store(t.id as usize);
621            let track = store.get(&t.id).unwrap();
622            track.get_attributes().clone()
623        };
624        assert_eq!(attrs.visual_features_collected_count, 2);
625        assert_eq!(attrs.track_length, 2);
626        assert_eq!(attrs.observed_boxes.len(), 2);
627        assert_eq!(attrs.predicted_boxes.len(), 2);
628        assert_eq!(attrs.observed_features.len(), 2);
629
630        // add segment to be initialized
631        //
632        let tracks = tracker.predict_with_scene(
633            10,
634            &[VisualSortObservation::new(
635                Some(&vec![0.12, 0.14]),
636                Some(0.87),
637                BoundingBox::new(10.1, 10.1, 3.0, 5.0).as_xyaah(),
638                Some(31),
639            )],
640        );
641        let t = &tracks[0];
642        assert_eq!(t.custom_object_id, Some(31));
643        assert_eq!(t.scene_id, 10);
644        assert!(matches!(t.voting_type, VotingType::Visual));
645        assert!(matches!(t.epoch, 10));
646        assert_eq!(t.id, other_track_id);
647        let attrs = {
648            let lock = tracker.store.read().unwrap();
649            let store = lock.get_store(t.id as usize);
650            let track = store.get(&t.id).unwrap();
651            track.get_attributes().clone()
652        };
653        assert_eq!(attrs.visual_features_collected_count, 3);
654        assert_eq!(attrs.track_length, 3);
655        assert_eq!(attrs.observed_boxes.len(), 3);
656        assert_eq!(attrs.predicted_boxes.len(), 3);
657        assert_eq!(attrs.observed_features.len(), 3);
658
659        tracker.skip_epochs_for_scene(10, 5);
660        let tracks = tracker
661            .wasted()
662            .into_iter()
663            .map(WastedVisualSortTrack::from)
664            .collect::<Vec<_>>();
665        dbg!(&tracks);
666    }
667}
668
669#[cfg(feature = "python")]
670pub mod python {
671    use pyo3::prelude::*;
672
673    use crate::{
674        prelude::VisualSortObservation,
675        trackers::{
676            sort::python::PySortTrack,
677            tracker_api::TrackerAPI,
678            visual_sort::{
679                options::python::PyVisualSortOptions,
680                python::{PyVisualSortObservationSet, PyWastedVisualSortTrack},
681                WastedVisualSortTrack,
682            },
683        },
684    };
685
686    use super::VisualSort;
687
688    #[pyclass]
689    #[pyo3(name = "VisualSort")]
690    pub struct PyVisualSort(pub(crate) VisualSort);
691
692    #[pymethods]
693    impl PyVisualSort {
694        #[new]
695        pub fn new(shards: i64, opts: &PyVisualSortOptions) -> Self {
696            assert!(shards > 0);
697            Self(VisualSort::new(shards.try_into().unwrap(), &opts.0))
698        }
699
700        #[pyo3(signature = (n))]
701        pub fn skip_epochs(&mut self, n: i64) {
702            assert!(n > 0);
703            self.0.skip_epochs(n.try_into().unwrap())
704        }
705
706        #[pyo3(signature = (scene_id, n))]
707        pub fn skip_epochs_for_scene(&mut self, scene_id: i64, n: i64) {
708            assert!(n > 0 && scene_id >= 0);
709            self.0
710                .skip_epochs_for_scene(scene_id.try_into().unwrap(), n.try_into().unwrap())
711        }
712
713        /// Get the amount of stored tracks per shard
714        ///
715        #[pyo3(signature = ())]
716        pub fn shard_stats(&self) -> Vec<i64> {
717            Python::with_gil(|py| {
718                py.allow_threads(|| {
719                    self.0
720                        .active_shard_stats()
721                        .into_iter()
722                        .map(|e| i64::try_from(e).unwrap())
723                        .collect()
724                })
725            })
726        }
727
728        /// Get the current epoch for `scene_id` == 0
729        ///
730        #[pyo3(signature = ())]
731        pub fn current_epoch(&self) -> i64 {
732            self.0.current_epoch_with_scene(0).try_into().unwrap()
733        }
734
735        /// Get the current epoch for `scene_id`
736        ///
737        /// # Parameters
738        /// * `scene_id` - scene id
739        ///
740        #[pyo3(signature = (scene_id))]
741        pub fn current_epoch_with_scene(&self, scene_id: i64) -> isize {
742            assert!(scene_id >= 0);
743            self.0
744                .current_epoch_with_scene(scene_id.try_into().unwrap())
745                .try_into()
746                .unwrap()
747        }
748
749        /// Receive tracking information for observed bboxes of `scene_id` == 0
750        ///
751        /// # Parameters
752        /// * `bboxes` - bounding boxes received from a detector
753        ///
754        #[pyo3(signature = (observation_set))]
755        pub fn predict(
756            &mut self,
757            observation_set: &PyVisualSortObservationSet,
758        ) -> Vec<PySortTrack> {
759            unsafe { std::mem::transmute(self.0.predict_with_scene(0, &observation_set.0.inner)) }
760        }
761
762        /// Receive tracking information for observed bboxes of `scene_id`
763        ///
764        /// # Parameters
765        /// * `scene_id` - scene id provided by a user (class, camera id, etc...)
766        /// * `observation_set` - observation set
767        ///
768        #[pyo3(signature = (scene_id, observation_set))]
769        pub fn predict_with_scene(
770            &mut self,
771            scene_id: i64,
772            observation_set: &PyVisualSortObservationSet,
773        ) -> Vec<PySortTrack> {
774            assert!(scene_id >= 0);
775            let observations = observation_set
776                .0
777                .inner
778                .iter()
779                .map(|e| {
780                    VisualSortObservation::new(
781                        e.feature.as_deref(),
782                        e.feature_quality,
783                        e.bounding_box.clone(),
784                        e.custom_object_id,
785                    )
786                })
787                .collect::<Vec<_>>();
788
789            Python::with_gil(|py| {
790                py.allow_threads(|| unsafe {
791                    std::mem::transmute(
792                        self.0
793                            .predict_with_scene(scene_id.try_into().unwrap(), &observations),
794                    )
795                })
796            })
797        }
798
799        /// Remove all the tracks with expired life
800        ///
801        #[pyo3(signature = ())]
802        pub fn wasted(&mut self) -> Vec<PyWastedVisualSortTrack> {
803            Python::with_gil(|py| {
804                py.allow_threads(|| {
805                    self.0
806                        .wasted()
807                        .into_iter()
808                        .map(WastedVisualSortTrack::from)
809                        .map(PyWastedVisualSortTrack)
810                        .collect()
811                })
812            })
813        }
814
815        /// Clear all tracks with expired life
816        ///
817        #[pyo3(signature = ())]
818        pub fn clear_wasted(&mut self) {
819            Python::with_gil(|py| py.allow_threads(|| self.0.clear_wasted()));
820        }
821
822        /// Get idle tracks with not expired life
823        ///
824        #[pyo3(signature = ())]
825        pub fn idle_tracks(&mut self) -> Vec<PySortTrack> {
826            unsafe { std::mem::transmute(self.0.idle_tracks_with_scene(0)) }
827        }
828
829        /// Get idle tracks with not expired life
830        ///
831        #[pyo3(signature = (scene_id))]
832        pub fn idle_tracks_with_scene_py(&mut self, scene_id: i64) -> Vec<PySortTrack> {
833            Python::with_gil(|py| {
834                py.allow_threads(|| unsafe {
835                    std::mem::transmute(self.0.idle_tracks_with_scene(scene_id.try_into().unwrap()))
836                })
837            })
838        }
839    }
840}