cecile_supercool_tracker/trackers/sort/
batch_api.rs

1use crate::prelude::{
2    NoopNotifier, ObservationBuilder, PositionalMetricType, SortTrack, TrackStoreBuilder,
3    Universal2DBox,
4};
5use crate::store::track_distance::TrackDistanceOkIterator;
6use crate::store::TrackStore;
7use crate::track::Track;
8use crate::trackers::batch::{PredictionBatchRequest, PredictionBatchResult, SceneTracks};
9use crate::trackers::epoch_db::EpochDb;
10use crate::trackers::sort::metric::SortMetric;
11use crate::trackers::sort::voting::SortVoting;
12use crate::trackers::sort::{
13    AutoWaste, SortAttributes, SortAttributesOptions, SortAttributesUpdate, SortLookup,
14    DEFAULT_AUTO_WASTE_PERIODICITY, MAHALANOBIS_NEW_TRACK_THRESHOLD,
15};
16
17use crate::trackers::spatio_temporal_constraints::SpatioTemporalConstraints;
18use crate::trackers::tracker_api::TrackerAPI;
19use crate::voting::Voting;
20use crossbeam::channel::{Receiver, Sender};
21use log::warn;
22use rand::Rng;
23use std::collections::HashMap;
24use std::mem;
25use std::sync::{Arc, Condvar, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
26use std::thread::{spawn, JoinHandle};
27
28type VotingSenderChannel = Sender<VotingCommands>;
29type VotingReceiverChannel = Receiver<VotingCommands>;
30
31type MiddlewareSortTrackStore = TrackStore<SortAttributes, SortMetric, Universal2DBox>;
32type MiddlewareSortTrack = Track<SortAttributes, SortMetric, Universal2DBox>;
33type BatchBusyMonitor = Arc<(Mutex<usize>, Condvar)>;
34
35enum VotingCommands {
36    Distances {
37        scene_id: u64,
38        distances: TrackDistanceOkIterator<Universal2DBox>,
39        channel: Sender<SceneTracks>,
40        tracks: Vec<MiddlewareSortTrack>,
41        monitor: BatchBusyMonitor,
42    },
43    Exit,
44}
45
46pub struct BatchSort {
47    monitor: Option<BatchBusyMonitor>,
48    store: Arc<RwLock<MiddlewareSortTrackStore>>,
49    wasted_store: RwLock<MiddlewareSortTrackStore>,
50    opts: Arc<SortAttributesOptions>,
51    voting_threads: Vec<(VotingSenderChannel, JoinHandle<()>)>,
52    auto_waste: AutoWaste,
53}
54
55impl Drop for BatchSort {
56    fn drop(&mut self) {
57        let voting_threads = mem::take(&mut self.voting_threads);
58        for (tx, t) in voting_threads {
59            tx.send(VotingCommands::Exit)
60                .expect("Voting thread must be alive.");
61            drop(tx);
62            t.join()
63                .expect("Voting thread is expected to shutdown successfully.");
64        }
65    }
66}
67
68fn voting_thread(
69    store: Arc<RwLock<MiddlewareSortTrackStore>>,
70    rx: VotingReceiverChannel,
71    method: PositionalMetricType,
72    track_id: Arc<RwLock<u64>>,
73) {
74    while let Ok(command) = rx.recv() {
75        match command {
76            VotingCommands::Distances {
77                scene_id,
78                distances,
79                channel,
80                tracks,
81                monitor,
82            } => {
83                let candidates_num = tracks.len();
84                let tracks_num = {
85                    let store = store.read().expect("Access to store must always succeed");
86                    store.shard_stats().iter().sum()
87                };
88
89                let voting = SortVoting::new(
90                    match method {
91                        PositionalMetricType::Mahalanobis => MAHALANOBIS_NEW_TRACK_THRESHOLD,
92                        PositionalMetricType::IoU(t) => t,
93                    },
94                    candidates_num,
95                    tracks_num,
96                );
97
98                let winners = voting.winners(distances);
99                let mut res = Vec::default();
100                for mut t in tracks {
101                    let source = t.get_track_id();
102                    let tid = {
103                        let mut track_id = track_id.write().unwrap();
104                        *track_id += 1;
105                        *track_id
106                    };
107                    let track_id: u64 = if let Some(dest) = winners.get(&source) {
108                        let dest = dest[0];
109                        if dest == source {
110                            t.set_track_id(tid);
111                            store
112                                .write()
113                                .expect("Access to store must always succeed")
114                                .add_track(t)
115                                .unwrap();
116                            tid
117                        } else {
118                            store
119                                .write()
120                                .expect("Access to store must always succeed")
121                                .merge_external(dest, &t, Some(&[0]), false)
122                                .unwrap();
123                            dest
124                        }
125                    } else {
126                        t.set_track_id(tid);
127                        store
128                            .write()
129                            .expect("Access to store must always succeed")
130                            .add_track(t)
131                            .unwrap();
132                        tid
133                    };
134
135                    let store = store.read().expect("Access to store must always succeed");
136                    let shard = store.get_store(track_id as usize);
137                    let track = shard.get(&track_id).unwrap();
138
139                    res.push(SortTrack::from(track))
140                }
141                let res = channel.send((scene_id, res));
142                if let Err(e) = res {
143                    warn!("Unable to send results to a caller, likely the caller already closed the channel. Error is: {:?}", e);
144                }
145                let (lock, cvar) = &*monitor;
146                let mut lock = lock.lock().unwrap();
147                *lock -= 1;
148                cvar.notify_one();
149            }
150            VotingCommands::Exit => break,
151        }
152    }
153}
154
155impl BatchSort {
156    #[allow(clippy::too_many_arguments)]
157    pub fn new(
158        distance_shards: usize,
159        voting_shards: usize,
160        bbox_history: usize,
161        max_idle_epochs: usize,
162        method: PositionalMetricType,
163        min_confidence: f32,
164        spatio_temporal_constraints: Option<SpatioTemporalConstraints>,
165        kalman_position_weight: f32,
166        kalman_velocity_weight: f32,
167    ) -> Self {
168        assert!(bbox_history > 0);
169        let epoch_db = RwLock::new(HashMap::default());
170        let opts = Arc::new(SortAttributesOptions::new(
171            Some(epoch_db),
172            max_idle_epochs,
173            bbox_history,
174            spatio_temporal_constraints.unwrap_or_default(),
175            kalman_position_weight,
176            kalman_velocity_weight,
177        ));
178
179        let store = Arc::new(RwLock::new(
180            TrackStoreBuilder::new(distance_shards)
181                .default_attributes(SortAttributes::new(opts.clone()))
182                .metric(SortMetric::new(method, min_confidence))
183                .notifier(NoopNotifier)
184                .build(),
185        ));
186
187        let wasted_store = RwLock::new(
188            TrackStoreBuilder::new(distance_shards)
189                .default_attributes(SortAttributes::new(opts.clone()))
190                .metric(SortMetric::new(method, min_confidence))
191                .notifier(NoopNotifier)
192                .build(),
193        );
194
195        let track_id = Arc::new(RwLock::new(0));
196
197        let voting_threads = (0..voting_shards)
198            .map(|_e| {
199                let (tx, rx) = crossbeam::channel::unbounded();
200                let thread_store = store.clone();
201                let thread_track_id = track_id.clone();
202                (
203                    tx,
204                    spawn(move || voting_thread(thread_store, rx, method, thread_track_id)),
205                )
206            })
207            .collect::<Vec<_>>();
208
209        Self {
210            monitor: None,
211            store,
212            wasted_store,
213            opts,
214            voting_threads,
215            auto_waste: AutoWaste {
216                periodicity: DEFAULT_AUTO_WASTE_PERIODICITY,
217                counter: DEFAULT_AUTO_WASTE_PERIODICITY,
218            },
219        }
220    }
221
222    pub fn predict(
223        &mut self,
224        batch_request: PredictionBatchRequest<(Universal2DBox, Option<i64>)>,
225    ) {
226        if self.auto_waste.counter == 0 {
227            self.auto_waste();
228            self.auto_waste.counter = self.auto_waste.periodicity;
229        } else {
230            self.auto_waste.counter -= 1;
231        }
232
233        if let Some(m) = &self.monitor {
234            let (lock, cvar) = &**m;
235            let _guard = cvar.wait_while(lock.lock().unwrap(), |v| *v > 0).unwrap();
236        }
237
238        self.monitor = Some(Arc::new((
239            Mutex::new(batch_request.batch_size()),
240            Condvar::new(),
241        )));
242
243        for (i, (scene_id, bboxes)) in batch_request.get_batch().iter().enumerate() {
244            let mut rng = rand::thread_rng();
245            let epoch = self.opts.next_epoch(*scene_id).unwrap();
246
247            let tracks = bboxes
248                .iter()
249                .map(|(bb, custom_object_id)| {
250                    self.store
251                        .read()
252                        .expect("Access to store must always succeed")
253                        .new_track(rng.gen())
254                        .observation(
255                            ObservationBuilder::new(0)
256                                .observation_attributes(bb.clone())
257                                .track_attributes_update(SortAttributesUpdate::new_with_scene(
258                                    epoch,
259                                    *scene_id,
260                                    *custom_object_id,
261                                ))
262                                .build(),
263                        )
264                        .build()
265                        .expect("Track creation must always succeed!")
266                })
267                .collect::<Vec<_>>();
268
269            let (dists, errs) = {
270                let mut store = self
271                    .store
272                    .write()
273                    .expect("Access to store must always succeed");
274                store.foreign_track_distances(tracks.clone(), 0, false)
275            };
276
277            assert!(errs.all().is_empty());
278            let thread_id = i % self.voting_threads.len();
279            self.voting_threads[thread_id]
280                .0
281                .send(VotingCommands::Distances {
282                    monitor: self.monitor.as_ref().unwrap().clone(),
283                    scene_id: *scene_id,
284                    distances: dists.into_iter(),
285                    channel: batch_request.get_sender(),
286                    tracks,
287                })
288                .expect("Sending voting request to voting thread must not fail");
289        }
290    }
291
292    pub fn idle_tracks(&mut self) -> Vec<SortTrack> {
293        self.idle_tracks_with_scene(0)
294    }
295
296    pub fn idle_tracks_with_scene(&mut self, scene_id: u64) -> Vec<SortTrack> {
297        let store = self.store.read().unwrap();
298
299        store
300            .lookup(SortLookup::IdleLookup(scene_id))
301            .iter()
302            .map(|(track_id, _status)| {
303                let shard = store.get_store(*track_id as usize);
304                let track = shard.get(track_id).unwrap();
305                SortTrack::from(track)
306            })
307            .collect()
308    }
309}
310
311impl TrackerAPI<SortAttributes, SortMetric, Universal2DBox, SortAttributesOptions, NoopNotifier>
312    for BatchSort
313{
314    fn get_auto_waste_obj_mut(&mut self) -> &mut AutoWaste {
315        &mut self.auto_waste
316    }
317
318    fn get_opts(&self) -> &SortAttributesOptions {
319        &self.opts
320    }
321
322    fn get_main_store_mut(&mut self) -> RwLockWriteGuard<MiddlewareSortTrackStore> {
323        self.store.write().unwrap()
324    }
325
326    fn get_wasted_store_mut(&mut self) -> RwLockWriteGuard<MiddlewareSortTrackStore> {
327        self.wasted_store.write().unwrap()
328    }
329
330    fn get_main_store(&self) -> RwLockReadGuard<MiddlewareSortTrackStore> {
331        self.store.read().unwrap()
332    }
333
334    fn get_wasted_store(&self) -> RwLockReadGuard<MiddlewareSortTrackStore> {
335        self.wasted_store.read().unwrap()
336    }
337}
338
339#[derive(Debug, Clone)]
340pub struct SortPredictionBatchRequest {
341    pub batch: PredictionBatchRequest<(Universal2DBox, Option<i64>)>,
342    pub result: Option<PredictionBatchResult>,
343}
344
345impl SortPredictionBatchRequest {
346    pub fn new() -> Self {
347        let (batch, result) = PredictionBatchRequest::new();
348
349        Self {
350            batch,
351            result: Some(result),
352        }
353    }
354
355    pub fn add(&mut self, scene_id: u64, bbox: Universal2DBox, custom_object_id: Option<i64>) {
356        self.batch.add(scene_id, (bbox, custom_object_id))
357    }
358}
359
360impl Default for SortPredictionBatchRequest {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366#[cfg(feature = "python")]
367pub mod python {
368    use crate::{
369        trackers::{
370            batch::python::PyPredictionBatchResult,
371            sort::{
372                python::{PyPositionalMetricType, PySortTrack, PyWastedSortTrack},
373                WastedSortTrack,
374            },
375            spatio_temporal_constraints::python::PySpatioTemporalConstraints,
376            tracker_api::TrackerAPI,
377        },
378        utils::bbox::python::PyUniversal2DBox,
379    };
380
381    use super::{BatchSort, SortPredictionBatchRequest};
382    use pyo3::prelude::*;
383
384    #[pyclass]
385    #[pyo3(name = "BatchSort")]
386    pub struct PyBatchSort(pub(crate) BatchSort);
387
388    #[pymethods]
389    impl PyBatchSort {
390        #[new]
391        #[pyo3(signature = (
392        distance_shards = 4,
393        voting_shards = 4,
394        bbox_history = 1,
395        max_idle_epochs = 5,
396        method = None,
397        min_confidence = 0.05,
398        spatio_temporal_constraints = None,
399        kalman_position_weight = 1.0 / 20.0,
400        kalman_velocity_weight = 1.0 / 160.0
401    ))]
402        #[allow(clippy::too_many_arguments)]
403        pub fn new(
404            distance_shards: i64,
405            voting_shards: i64,
406            bbox_history: i64,
407            max_idle_epochs: i64,
408            method: Option<PyPositionalMetricType>,
409            min_confidence: f32,
410            spatio_temporal_constraints: Option<PySpatioTemporalConstraints>,
411            kalman_position_weight: f32,
412            kalman_velocity_weight: f32,
413        ) -> Self {
414            Self(BatchSort::new(
415                distance_shards
416                    .try_into()
417                    .expect("Positive number expected"),
418                voting_shards.try_into().expect("Positive number expected"),
419                bbox_history.try_into().expect("Positive number expected"),
420                max_idle_epochs
421                    .try_into()
422                    .expect("Positive number expected"),
423                method.unwrap_or(PyPositionalMetricType::maha()).0,
424                min_confidence,
425                spatio_temporal_constraints.map(|x| x.0),
426                kalman_position_weight,
427                kalman_velocity_weight,
428            ))
429        }
430
431        #[pyo3(signature = (n))]
432        fn skip_epochs(&mut self, n: i64) {
433            assert!(n > 0);
434            self.0.skip_epochs(n.try_into().unwrap())
435        }
436
437        #[pyo3(signature = (scene_id, n))]
438        fn skip_epochs_for_scene(&mut self, scene_id: i64, n: i64) {
439            assert!(n > 0 && scene_id >= 0);
440            self.0
441                .skip_epochs_for_scene(scene_id.try_into().unwrap(), n.try_into().unwrap())
442        }
443
444        /// Get the amount of stored tracks per shard
445        ///
446        #[pyo3(signature = ())]
447        fn shard_stats(&self) -> Vec<i64> {
448            Python::with_gil(|py| {
449                py.allow_threads(|| {
450                    self.0
451                        .store
452                        .read()
453                        .unwrap()
454                        .shard_stats()
455                        .into_iter()
456                        .map(|e| i64::try_from(e).unwrap())
457                        .collect()
458                })
459            })
460        }
461
462        /// Get the current epoch for `scene_id` == 0
463        ///
464        #[pyo3(signature = ())]
465        fn current_epoch(&self) -> i64 {
466            self.0.current_epoch_with_scene(0).try_into().unwrap()
467        }
468
469        /// Get the current epoch for `scene_id`
470        ///
471        /// # Parameters
472        /// * `scene_id` - scene id
473        ///
474        #[pyo3(
475        signature = (scene_id)
476    )]
477        fn current_epoch_with_scene(&self, scene_id: i64) -> isize {
478            assert!(scene_id >= 0);
479            self.0
480                .current_epoch_with_scene(scene_id.try_into().unwrap())
481                .try_into()
482                .unwrap()
483        }
484
485        /// Receive tracking information for observed bboxes of `scene_id` == 0
486        ///
487        /// # Parameters
488        /// * `bboxes` - bounding boxes received from a detector
489        ///
490        #[pyo3(signature = (batch))]
491        fn predict(&mut self, mut batch: PySortPredictionBatchRequest) -> PyPredictionBatchResult {
492            self.0.predict(batch.0.batch);
493            PyPredictionBatchResult(batch.0.result.take().unwrap())
494        }
495
496        /// Remove all the tracks with expired life
497        ///
498        #[pyo3(signature = ())]
499        fn wasted(&mut self) -> Vec<PyWastedSortTrack> {
500            Python::with_gil(|py| {
501                py.allow_threads(|| {
502                    self.0
503                        .wasted()
504                        .into_iter()
505                        .map(WastedSortTrack::from)
506                        .map(PyWastedSortTrack)
507                        .collect()
508                })
509            })
510        }
511
512        /// Clear all tracks with expired life
513        ///
514        #[pyo3(signature = ())]
515        pub fn clear_wasted(&mut self) {
516            Python::with_gil(|py| {
517                py.allow_threads(|| self.0.clear_wasted());
518            })
519        }
520
521        /// Get idle tracks with not expired life
522        ///
523        #[pyo3(signature = (scene_id))]
524        pub fn idle_tracks(&mut self, scene_id: i64) -> Vec<PySortTrack> {
525            Python::with_gil(|py| {
526                py.allow_threads(|| unsafe {
527                    std::mem::transmute(self.0.idle_tracks_with_scene(scene_id.try_into().unwrap()))
528                })
529            })
530        }
531    }
532
533    #[pyclass]
534    #[pyo3(name = "SortPredictionBatchRequest")]
535    #[derive(Debug, Clone)]
536    pub struct PySortPredictionBatchRequest(pub(crate) SortPredictionBatchRequest);
537
538    #[pymethods]
539    impl PySortPredictionBatchRequest {
540        #[new]
541        fn new() -> Self {
542            Self(SortPredictionBatchRequest::new())
543        }
544
545        fn add(&mut self, scene_id: u64, bbox: PyUniversal2DBox, custom_object_id: Option<i64>) {
546            self.0.add(scene_id, bbox.0, custom_object_id)
547        }
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use crate::prelude::BoundingBox;
554    use crate::prelude::PositionalMetricType::Mahalanobis;
555    use crate::trackers::batch::PredictionBatchRequest;
556    use crate::trackers::sort::batch_api::BatchSort;
557    use crate::trackers::sort::metric::DEFAULT_MINIMAL_SORT_CONFIDENCE;
558
559    #[test]
560    fn new_drop() {
561        let mut bs = BatchSort::new(
562            1,
563            1,
564            1,
565            1,
566            Mahalanobis,
567            DEFAULT_MINIMAL_SORT_CONFIDENCE,
568            None,
569            1.0 / 20.0,
570            1.0 / 160.0,
571        );
572        let (mut batch, res) = PredictionBatchRequest::new();
573        batch.add(0, (BoundingBox::new(0.0, 0.0, 5.0, 10.0).into(), Some(1)));
574        batch.add(1, (BoundingBox::new(0.0, 0.0, 5.0, 10.0).into(), Some(2)));
575
576        bs.predict(batch);
577
578        for _ in 0..res.batch_size() {
579            let data = res.get();
580            dbg!(data);
581        }
582    }
583}