cecile_supercool_tracker/trackers/
batch.rs1use crate::prelude::SortTrack;
2use crossbeam::channel::{Receiver, Sender};
3use log::debug;
4
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7
8pub type BatchRecords<T> = HashMap<u64, Vec<T>>;
9pub type SceneTracks = (u64, Vec<SortTrack>);
10
11#[derive(Debug, Clone)]
12pub struct PredictionBatchRequest<T> {
13 batch: BatchRecords<T>,
14 sender: Sender<SceneTracks>,
15 batch_size: Arc<Mutex<usize>>,
16}
17
18#[derive(Clone, Debug)]
19pub struct PredictionBatchResult {
20 receiver: Receiver<SceneTracks>,
21 batch_size: Arc<Mutex<usize>>,
22}
23
24impl PredictionBatchResult {
25 pub fn ready(&self) -> bool {
26 !self.receiver.is_empty()
27 }
28
29 pub fn get(&self) -> SceneTracks {
30 self.receiver
31 .recv()
32 .expect("Receiver must always receive batch computation result")
33 }
34
35 pub fn batch_size(&self) -> usize {
36 *self.batch_size.lock().unwrap()
37 }
38}
39
40#[cfg(feature = "python")]
41pub mod python {
42 use crate::trackers::sort::python::PySortTrack;
43
44 use super::PredictionBatchResult;
45 use pyo3::prelude::*;
46
47 pub type PySceneTracks = (u64, Vec<PySortTrack>);
48
49 #[pyclass]
50 #[derive(Clone, Debug)]
51 #[pyo3(name = "PredictionBatchResult")]
52 pub struct PyPredictionBatchResult(pub(crate) PredictionBatchResult);
53
54 #[pymethods]
55 impl PyPredictionBatchResult {
56 pub fn ready(&self) -> bool {
57 self.0.ready()
58 }
59
60 #[pyo3(signature = ())]
61 fn get(&self) -> PySceneTracks {
62 Python::with_gil(|py| py.allow_threads(|| unsafe { std::mem::transmute(self.0.get()) }))
63 }
64
65 pub fn batch_size(&self) -> usize {
66 self.0.batch_size()
67 }
68 }
69}
70
71impl<T> PredictionBatchRequest<T> {
72 pub fn get_sender(&self) -> Sender<SceneTracks> {
73 self.sender.clone()
74 }
75
76 #[allow(dead_code)]
77 pub(crate) fn send(&self, res: SceneTracks) -> bool {
78 let res = self.sender.send(res);
79 if let Err(e) = res {
80 debug!(
81 "Error occurred when sending results to the batch result object. Error is: {:?}",
82 e
83 );
84 false
85 } else {
86 true
87 }
88 }
89
90 pub fn batch_size(&self) -> usize {
91 *self.batch_size.lock().unwrap()
92 }
93
94 pub fn add(&mut self, scene_id: u64, elt: T) {
95 let vec = self.batch.get_mut(&scene_id);
96 if let Some(vec) = vec {
97 vec.push(elt);
98 } else {
99 self.batch.insert(scene_id, vec![elt]);
100 }
101 let mut batch_size = self.batch_size.lock().unwrap();
102 *batch_size = self.batch.len();
103 }
104
105 pub fn new() -> (Self, PredictionBatchResult) {
106 let (sender, receiver) = crossbeam::channel::bounded(1);
107 let batch_size = Arc::new(Mutex::new(0));
108 (
109 Self {
110 batch: BatchRecords::default(),
111 sender,
112 batch_size: batch_size.clone(),
113 },
114 PredictionBatchResult {
115 receiver,
116 batch_size,
117 },
118 )
119 }
120
121 pub fn get_batch(&self) -> &BatchRecords<T> {
122 &self.batch
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use crate::prelude::Universal2DBox;
129 use crate::trackers::batch::PredictionBatchRequest;
130
131 #[test]
132 fn test() {
133 let (mut request, result) = PredictionBatchRequest::<Universal2DBox>::new();
134 request.add(0, Universal2DBox::new(0.0, 0.0, Some(0.5), 1.0, 5.0));
135 request.add(0, Universal2DBox::new(5.0, 5.0, Some(0.0), 1.5, 10.0));
136 request.add(1, Universal2DBox::new(0.0, 0.0, Some(1.0), 0.7, 5.1));
137 let _batch = request.get_batch();
138 assert_eq!(result.batch_size(), 2);
139
140 assert!(request.send((0, vec![])));
141 assert_eq!(result.ready(), true);
142 let res = result.get();
143 assert_eq!(res.0, 0);
144 assert!(res.1.is_empty());
145 drop(result);
146 assert!(!request.send((0, vec![])));
147 }
148}