1use crate::prelude::{
2 NoopNotifier, ObservationBuilder, PositionalMetricType, SortTrack, TrackStoreBuilder,
3 Universal2DBox,
4};
5use crate::store::track_distance::TrackDistanceOkIterator;
6use crate::store::TrackStore;
7use crate::track::Track;
8use crate::trackers::batch::{PredictionBatchRequest, PredictionBatchResult, SceneTracks};
9use crate::trackers::epoch_db::EpochDb;
10use crate::trackers::sort::metric::SortMetric;
11use crate::trackers::sort::voting::SortVoting;
12use crate::trackers::sort::{
13 AutoWaste, SortAttributes, SortAttributesOptions, SortAttributesUpdate, SortLookup,
14 DEFAULT_AUTO_WASTE_PERIODICITY, MAHALANOBIS_NEW_TRACK_THRESHOLD,
15};
16
17use crate::trackers::spatio_temporal_constraints::SpatioTemporalConstraints;
18use crate::trackers::tracker_api::TrackerAPI;
19use crate::voting::Voting;
20use crossbeam::channel::{Receiver, Sender};
21use log::warn;
22use rand::Rng;
23use std::collections::HashMap;
24use std::mem;
25use std::sync::{Arc, Condvar, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
26use std::thread::{spawn, JoinHandle};
27
28type VotingSenderChannel = Sender<VotingCommands>;
29type VotingReceiverChannel = Receiver<VotingCommands>;
30
31type MiddlewareSortTrackStore = TrackStore<SortAttributes, SortMetric, Universal2DBox>;
32type MiddlewareSortTrack = Track<SortAttributes, SortMetric, Universal2DBox>;
33type BatchBusyMonitor = Arc<(Mutex<usize>, Condvar)>;
34
35enum VotingCommands {
36 Distances {
37 scene_id: u64,
38 distances: TrackDistanceOkIterator<Universal2DBox>,
39 channel: Sender<SceneTracks>,
40 tracks: Vec<MiddlewareSortTrack>,
41 monitor: BatchBusyMonitor,
42 },
43 Exit,
44}
45
46pub struct BatchSort {
47 monitor: Option<BatchBusyMonitor>,
48 store: Arc<RwLock<MiddlewareSortTrackStore>>,
49 wasted_store: RwLock<MiddlewareSortTrackStore>,
50 opts: Arc<SortAttributesOptions>,
51 voting_threads: Vec<(VotingSenderChannel, JoinHandle<()>)>,
52 auto_waste: AutoWaste,
53}
54
55impl Drop for BatchSort {
56 fn drop(&mut self) {
57 let voting_threads = mem::take(&mut self.voting_threads);
58 for (tx, t) in voting_threads {
59 tx.send(VotingCommands::Exit)
60 .expect("Voting thread must be alive.");
61 drop(tx);
62 t.join()
63 .expect("Voting thread is expected to shutdown successfully.");
64 }
65 }
66}
67
68fn voting_thread(
69 store: Arc<RwLock<MiddlewareSortTrackStore>>,
70 rx: VotingReceiverChannel,
71 method: PositionalMetricType,
72 track_id: Arc<RwLock<u64>>,
73) {
74 while let Ok(command) = rx.recv() {
75 match command {
76 VotingCommands::Distances {
77 scene_id,
78 distances,
79 channel,
80 tracks,
81 monitor,
82 } => {
83 let candidates_num = tracks.len();
84 let tracks_num = {
85 let store = store.read().expect("Access to store must always succeed");
86 store.shard_stats().iter().sum()
87 };
88
89 let voting = SortVoting::new(
90 match method {
91 PositionalMetricType::Mahalanobis => MAHALANOBIS_NEW_TRACK_THRESHOLD,
92 PositionalMetricType::IoU(t) => t,
93 },
94 candidates_num,
95 tracks_num,
96 );
97
98 let winners = voting.winners(distances);
99 let mut res = Vec::default();
100 for mut t in tracks {
101 let source = t.get_track_id();
102 let tid = {
103 let mut track_id = track_id.write().unwrap();
104 *track_id += 1;
105 *track_id
106 };
107 let track_id: u64 = if let Some(dest) = winners.get(&source) {
108 let dest = dest[0];
109 if dest == source {
110 t.set_track_id(tid);
111 store
112 .write()
113 .expect("Access to store must always succeed")
114 .add_track(t)
115 .unwrap();
116 tid
117 } else {
118 store
119 .write()
120 .expect("Access to store must always succeed")
121 .merge_external(dest, &t, Some(&[0]), false)
122 .unwrap();
123 dest
124 }
125 } else {
126 t.set_track_id(tid);
127 store
128 .write()
129 .expect("Access to store must always succeed")
130 .add_track(t)
131 .unwrap();
132 tid
133 };
134
135 let store = store.read().expect("Access to store must always succeed");
136 let shard = store.get_store(track_id as usize);
137 let track = shard.get(&track_id).unwrap();
138
139 res.push(SortTrack::from(track))
140 }
141 let res = channel.send((scene_id, res));
142 if let Err(e) = res {
143 warn!("Unable to send results to a caller, likely the caller already closed the channel. Error is: {:?}", e);
144 }
145 let (lock, cvar) = &*monitor;
146 let mut lock = lock.lock().unwrap();
147 *lock -= 1;
148 cvar.notify_one();
149 }
150 VotingCommands::Exit => break,
151 }
152 }
153}
154
155impl BatchSort {
156 #[allow(clippy::too_many_arguments)]
157 pub fn new(
158 distance_shards: usize,
159 voting_shards: usize,
160 bbox_history: usize,
161 max_idle_epochs: usize,
162 method: PositionalMetricType,
163 min_confidence: f32,
164 spatio_temporal_constraints: Option<SpatioTemporalConstraints>,
165 kalman_position_weight: f32,
166 kalman_velocity_weight: f32,
167 ) -> Self {
168 assert!(bbox_history > 0);
169 let epoch_db = RwLock::new(HashMap::default());
170 let opts = Arc::new(SortAttributesOptions::new(
171 Some(epoch_db),
172 max_idle_epochs,
173 bbox_history,
174 spatio_temporal_constraints.unwrap_or_default(),
175 kalman_position_weight,
176 kalman_velocity_weight,
177 ));
178
179 let store = Arc::new(RwLock::new(
180 TrackStoreBuilder::new(distance_shards)
181 .default_attributes(SortAttributes::new(opts.clone()))
182 .metric(SortMetric::new(method, min_confidence))
183 .notifier(NoopNotifier)
184 .build(),
185 ));
186
187 let wasted_store = RwLock::new(
188 TrackStoreBuilder::new(distance_shards)
189 .default_attributes(SortAttributes::new(opts.clone()))
190 .metric(SortMetric::new(method, min_confidence))
191 .notifier(NoopNotifier)
192 .build(),
193 );
194
195 let track_id = Arc::new(RwLock::new(0));
196
197 let voting_threads = (0..voting_shards)
198 .map(|_e| {
199 let (tx, rx) = crossbeam::channel::unbounded();
200 let thread_store = store.clone();
201 let thread_track_id = track_id.clone();
202 (
203 tx,
204 spawn(move || voting_thread(thread_store, rx, method, thread_track_id)),
205 )
206 })
207 .collect::<Vec<_>>();
208
209 Self {
210 monitor: None,
211 store,
212 wasted_store,
213 opts,
214 voting_threads,
215 auto_waste: AutoWaste {
216 periodicity: DEFAULT_AUTO_WASTE_PERIODICITY,
217 counter: DEFAULT_AUTO_WASTE_PERIODICITY,
218 },
219 }
220 }
221
222 pub fn predict(
223 &mut self,
224 batch_request: PredictionBatchRequest<(Universal2DBox, Option<i64>)>,
225 ) {
226 if self.auto_waste.counter == 0 {
227 self.auto_waste();
228 self.auto_waste.counter = self.auto_waste.periodicity;
229 } else {
230 self.auto_waste.counter -= 1;
231 }
232
233 if let Some(m) = &self.monitor {
234 let (lock, cvar) = &**m;
235 let _guard = cvar.wait_while(lock.lock().unwrap(), |v| *v > 0).unwrap();
236 }
237
238 self.monitor = Some(Arc::new((
239 Mutex::new(batch_request.batch_size()),
240 Condvar::new(),
241 )));
242
243 for (i, (scene_id, bboxes)) in batch_request.get_batch().iter().enumerate() {
244 let mut rng = rand::thread_rng();
245 let epoch = self.opts.next_epoch(*scene_id).unwrap();
246
247 let tracks = bboxes
248 .iter()
249 .map(|(bb, custom_object_id)| {
250 self.store
251 .read()
252 .expect("Access to store must always succeed")
253 .new_track(rng.gen())
254 .observation(
255 ObservationBuilder::new(0)
256 .observation_attributes(bb.clone())
257 .track_attributes_update(SortAttributesUpdate::new_with_scene(
258 epoch,
259 *scene_id,
260 *custom_object_id,
261 ))
262 .build(),
263 )
264 .build()
265 .expect("Track creation must always succeed!")
266 })
267 .collect::<Vec<_>>();
268
269 let (dists, errs) = {
270 let mut store = self
271 .store
272 .write()
273 .expect("Access to store must always succeed");
274 store.foreign_track_distances(tracks.clone(), 0, false)
275 };
276
277 assert!(errs.all().is_empty());
278 let thread_id = i % self.voting_threads.len();
279 self.voting_threads[thread_id]
280 .0
281 .send(VotingCommands::Distances {
282 monitor: self.monitor.as_ref().unwrap().clone(),
283 scene_id: *scene_id,
284 distances: dists.into_iter(),
285 channel: batch_request.get_sender(),
286 tracks,
287 })
288 .expect("Sending voting request to voting thread must not fail");
289 }
290 }
291
292 pub fn idle_tracks(&mut self) -> Vec<SortTrack> {
293 self.idle_tracks_with_scene(0)
294 }
295
296 pub fn idle_tracks_with_scene(&mut self, scene_id: u64) -> Vec<SortTrack> {
297 let store = self.store.read().unwrap();
298
299 store
300 .lookup(SortLookup::IdleLookup(scene_id))
301 .iter()
302 .map(|(track_id, _status)| {
303 let shard = store.get_store(*track_id as usize);
304 let track = shard.get(track_id).unwrap();
305 SortTrack::from(track)
306 })
307 .collect()
308 }
309}
310
311impl TrackerAPI<SortAttributes, SortMetric, Universal2DBox, SortAttributesOptions, NoopNotifier>
312 for BatchSort
313{
314 fn get_auto_waste_obj_mut(&mut self) -> &mut AutoWaste {
315 &mut self.auto_waste
316 }
317
318 fn get_opts(&self) -> &SortAttributesOptions {
319 &self.opts
320 }
321
322 fn get_main_store_mut(&mut self) -> RwLockWriteGuard<MiddlewareSortTrackStore> {
323 self.store.write().unwrap()
324 }
325
326 fn get_wasted_store_mut(&mut self) -> RwLockWriteGuard<MiddlewareSortTrackStore> {
327 self.wasted_store.write().unwrap()
328 }
329
330 fn get_main_store(&self) -> RwLockReadGuard<MiddlewareSortTrackStore> {
331 self.store.read().unwrap()
332 }
333
334 fn get_wasted_store(&self) -> RwLockReadGuard<MiddlewareSortTrackStore> {
335 self.wasted_store.read().unwrap()
336 }
337}
338
339#[derive(Debug, Clone)]
340pub struct SortPredictionBatchRequest {
341 pub batch: PredictionBatchRequest<(Universal2DBox, Option<i64>)>,
342 pub result: Option<PredictionBatchResult>,
343}
344
345impl SortPredictionBatchRequest {
346 pub fn new() -> Self {
347 let (batch, result) = PredictionBatchRequest::new();
348
349 Self {
350 batch,
351 result: Some(result),
352 }
353 }
354
355 pub fn add(&mut self, scene_id: u64, bbox: Universal2DBox, custom_object_id: Option<i64>) {
356 self.batch.add(scene_id, (bbox, custom_object_id))
357 }
358}
359
360impl Default for SortPredictionBatchRequest {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366#[cfg(feature = "python")]
367pub mod python {
368 use crate::{
369 trackers::{
370 batch::python::PyPredictionBatchResult,
371 sort::{
372 python::{PyPositionalMetricType, PySortTrack, PyWastedSortTrack},
373 WastedSortTrack,
374 },
375 spatio_temporal_constraints::python::PySpatioTemporalConstraints,
376 tracker_api::TrackerAPI,
377 },
378 utils::bbox::python::PyUniversal2DBox,
379 };
380
381 use super::{BatchSort, SortPredictionBatchRequest};
382 use pyo3::prelude::*;
383
384 #[pyclass]
385 #[pyo3(name = "BatchSort")]
386 pub struct PyBatchSort(pub(crate) BatchSort);
387
388 #[pymethods]
389 impl PyBatchSort {
390 #[new]
391 #[pyo3(signature = (
392 distance_shards = 4,
393 voting_shards = 4,
394 bbox_history = 1,
395 max_idle_epochs = 5,
396 method = None,
397 min_confidence = 0.05,
398 spatio_temporal_constraints = None,
399 kalman_position_weight = 1.0 / 20.0,
400 kalman_velocity_weight = 1.0 / 160.0
401 ))]
402 #[allow(clippy::too_many_arguments)]
403 pub fn new(
404 distance_shards: i64,
405 voting_shards: i64,
406 bbox_history: i64,
407 max_idle_epochs: i64,
408 method: Option<PyPositionalMetricType>,
409 min_confidence: f32,
410 spatio_temporal_constraints: Option<PySpatioTemporalConstraints>,
411 kalman_position_weight: f32,
412 kalman_velocity_weight: f32,
413 ) -> Self {
414 Self(BatchSort::new(
415 distance_shards
416 .try_into()
417 .expect("Positive number expected"),
418 voting_shards.try_into().expect("Positive number expected"),
419 bbox_history.try_into().expect("Positive number expected"),
420 max_idle_epochs
421 .try_into()
422 .expect("Positive number expected"),
423 method.unwrap_or(PyPositionalMetricType::maha()).0,
424 min_confidence,
425 spatio_temporal_constraints.map(|x| x.0),
426 kalman_position_weight,
427 kalman_velocity_weight,
428 ))
429 }
430
431 #[pyo3(signature = (n))]
432 fn skip_epochs(&mut self, n: i64) {
433 assert!(n > 0);
434 self.0.skip_epochs(n.try_into().unwrap())
435 }
436
437 #[pyo3(signature = (scene_id, n))]
438 fn skip_epochs_for_scene(&mut self, scene_id: i64, n: i64) {
439 assert!(n > 0 && scene_id >= 0);
440 self.0
441 .skip_epochs_for_scene(scene_id.try_into().unwrap(), n.try_into().unwrap())
442 }
443
444 #[pyo3(signature = ())]
447 fn shard_stats(&self) -> Vec<i64> {
448 Python::with_gil(|py| {
449 py.allow_threads(|| {
450 self.0
451 .store
452 .read()
453 .unwrap()
454 .shard_stats()
455 .into_iter()
456 .map(|e| i64::try_from(e).unwrap())
457 .collect()
458 })
459 })
460 }
461
462 #[pyo3(signature = ())]
465 fn current_epoch(&self) -> i64 {
466 self.0.current_epoch_with_scene(0).try_into().unwrap()
467 }
468
469 #[pyo3(
475 signature = (scene_id)
476 )]
477 fn current_epoch_with_scene(&self, scene_id: i64) -> isize {
478 assert!(scene_id >= 0);
479 self.0
480 .current_epoch_with_scene(scene_id.try_into().unwrap())
481 .try_into()
482 .unwrap()
483 }
484
485 #[pyo3(signature = (batch))]
491 fn predict(&mut self, mut batch: PySortPredictionBatchRequest) -> PyPredictionBatchResult {
492 self.0.predict(batch.0.batch);
493 PyPredictionBatchResult(batch.0.result.take().unwrap())
494 }
495
496 #[pyo3(signature = ())]
499 fn wasted(&mut self) -> Vec<PyWastedSortTrack> {
500 Python::with_gil(|py| {
501 py.allow_threads(|| {
502 self.0
503 .wasted()
504 .into_iter()
505 .map(WastedSortTrack::from)
506 .map(PyWastedSortTrack)
507 .collect()
508 })
509 })
510 }
511
512 #[pyo3(signature = ())]
515 pub fn clear_wasted(&mut self) {
516 Python::with_gil(|py| {
517 py.allow_threads(|| self.0.clear_wasted());
518 })
519 }
520
521 #[pyo3(signature = (scene_id))]
524 pub fn idle_tracks(&mut self, scene_id: i64) -> Vec<PySortTrack> {
525 Python::with_gil(|py| {
526 py.allow_threads(|| unsafe {
527 std::mem::transmute(self.0.idle_tracks_with_scene(scene_id.try_into().unwrap()))
528 })
529 })
530 }
531 }
532
533 #[pyclass]
534 #[pyo3(name = "SortPredictionBatchRequest")]
535 #[derive(Debug, Clone)]
536 pub struct PySortPredictionBatchRequest(pub(crate) SortPredictionBatchRequest);
537
538 #[pymethods]
539 impl PySortPredictionBatchRequest {
540 #[new]
541 fn new() -> Self {
542 Self(SortPredictionBatchRequest::new())
543 }
544
545 fn add(&mut self, scene_id: u64, bbox: PyUniversal2DBox, custom_object_id: Option<i64>) {
546 self.0.add(scene_id, bbox.0, custom_object_id)
547 }
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use crate::prelude::BoundingBox;
554 use crate::prelude::PositionalMetricType::Mahalanobis;
555 use crate::trackers::batch::PredictionBatchRequest;
556 use crate::trackers::sort::batch_api::BatchSort;
557 use crate::trackers::sort::metric::DEFAULT_MINIMAL_SORT_CONFIDENCE;
558
559 #[test]
560 fn new_drop() {
561 let mut bs = BatchSort::new(
562 1,
563 1,
564 1,
565 1,
566 Mahalanobis,
567 DEFAULT_MINIMAL_SORT_CONFIDENCE,
568 None,
569 1.0 / 20.0,
570 1.0 / 160.0,
571 );
572 let (mut batch, res) = PredictionBatchRequest::new();
573 batch.add(0, (BoundingBox::new(0.0, 0.0, 5.0, 10.0).into(), Some(1)));
574 batch.add(1, (BoundingBox::new(0.0, 0.0, 5.0, 10.0).into(), Some(2)));
575
576 bs.predict(batch);
577
578 for _ in 0..res.batch_size() {
579 let data = res.get();
580 dbg!(data);
581 }
582 }
583}