cecile_supercool_tracker/trackers/
batch.rs

1use crate::prelude::SortTrack;
2use crossbeam::channel::{Receiver, Sender};
3use log::debug;
4
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7
8pub type BatchRecords<T> = HashMap<u64, Vec<T>>;
9pub type SceneTracks = (u64, Vec<SortTrack>);
10
11#[derive(Debug, Clone)]
12pub struct PredictionBatchRequest<T> {
13    batch: BatchRecords<T>,
14    sender: Sender<SceneTracks>,
15    batch_size: Arc<Mutex<usize>>,
16}
17
18#[derive(Clone, Debug)]
19pub struct PredictionBatchResult {
20    receiver: Receiver<SceneTracks>,
21    batch_size: Arc<Mutex<usize>>,
22}
23
24impl PredictionBatchResult {
25    pub fn ready(&self) -> bool {
26        !self.receiver.is_empty()
27    }
28
29    pub fn get(&self) -> SceneTracks {
30        self.receiver
31            .recv()
32            .expect("Receiver must always receive batch computation result")
33    }
34
35    pub fn batch_size(&self) -> usize {
36        *self.batch_size.lock().unwrap()
37    }
38}
39
40#[cfg(feature = "python")]
41pub mod python {
42    use crate::trackers::sort::python::PySortTrack;
43
44    use super::PredictionBatchResult;
45    use pyo3::prelude::*;
46
47    pub type PySceneTracks = (u64, Vec<PySortTrack>);
48
49    #[pyclass]
50    #[derive(Clone, Debug)]
51    #[pyo3(name = "PredictionBatchResult")]
52    pub struct PyPredictionBatchResult(pub(crate) PredictionBatchResult);
53
54    #[pymethods]
55    impl PyPredictionBatchResult {
56        pub fn ready(&self) -> bool {
57            self.0.ready()
58        }
59
60        #[pyo3(signature = ())]
61        fn get(&self) -> PySceneTracks {
62            Python::with_gil(|py| py.allow_threads(|| unsafe { std::mem::transmute(self.0.get()) }))
63        }
64
65        pub fn batch_size(&self) -> usize {
66            self.0.batch_size()
67        }
68    }
69}
70
71impl<T> PredictionBatchRequest<T> {
72    pub fn get_sender(&self) -> Sender<SceneTracks> {
73        self.sender.clone()
74    }
75
76    #[allow(dead_code)]
77    pub(crate) fn send(&self, res: SceneTracks) -> bool {
78        let res = self.sender.send(res);
79        if let Err(e) = res {
80            debug!(
81                "Error occurred when sending results to the batch result object. Error is: {:?}",
82                e
83            );
84            false
85        } else {
86            true
87        }
88    }
89
90    pub fn batch_size(&self) -> usize {
91        *self.batch_size.lock().unwrap()
92    }
93
94    pub fn add(&mut self, scene_id: u64, elt: T) {
95        let vec = self.batch.get_mut(&scene_id);
96        if let Some(vec) = vec {
97            vec.push(elt);
98        } else {
99            self.batch.insert(scene_id, vec![elt]);
100        }
101        let mut batch_size = self.batch_size.lock().unwrap();
102        *batch_size = self.batch.len();
103    }
104
105    pub fn new() -> (Self, PredictionBatchResult) {
106        let (sender, receiver) = crossbeam::channel::bounded(1);
107        let batch_size = Arc::new(Mutex::new(0));
108        (
109            Self {
110                batch: BatchRecords::default(),
111                sender,
112                batch_size: batch_size.clone(),
113            },
114            PredictionBatchResult {
115                receiver,
116                batch_size,
117            },
118        )
119    }
120
121    pub fn get_batch(&self) -> &BatchRecords<T> {
122        &self.batch
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use crate::prelude::Universal2DBox;
129    use crate::trackers::batch::PredictionBatchRequest;
130
131    #[test]
132    fn test() {
133        let (mut request, result) = PredictionBatchRequest::<Universal2DBox>::new();
134        request.add(0, Universal2DBox::new(0.0, 0.0, Some(0.5), 1.0, 5.0));
135        request.add(0, Universal2DBox::new(5.0, 5.0, Some(0.0), 1.5, 10.0));
136        request.add(1, Universal2DBox::new(0.0, 0.0, Some(1.0), 0.7, 5.1));
137        let _batch = request.get_batch();
138        assert_eq!(result.batch_size(), 2);
139
140        assert!(request.send((0, vec![])));
141        assert_eq!(result.ready(), true);
142        let res = result.get();
143        assert_eq!(res.0, 0);
144        assert!(res.1.is_empty());
145        drop(result);
146        assert!(!request.send((0, vec![])));
147    }
148}