1use std::collections::HashMap;
2use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
3
4use rand::Rng;
5
6use crate::prelude::{NoopNotifier, ObservationBuilder, TrackStoreBuilder};
7use crate::store::TrackStore;
8use crate::track::Track;
9use crate::trackers::epoch_db::EpochDb;
10use crate::trackers::sort::{
11 metric::SortMetric, voting::SortVoting, AutoWaste, PositionalMetricType, SortAttributes,
12 SortAttributesOptions, SortAttributesUpdate, SortLookup, SortTrack, VotingType,
13 DEFAULT_AUTO_WASTE_PERIODICITY, MAHALANOBIS_NEW_TRACK_THRESHOLD,
14};
15use crate::trackers::spatio_temporal_constraints::SpatioTemporalConstraints;
16use crate::trackers::tracker_api::TrackerAPI;
17use crate::utils::bbox::Universal2DBox;
18use crate::voting::Voting;
19
20pub struct Sort {
23 store: RwLock<TrackStore<SortAttributes, SortMetric, Universal2DBox>>,
24 wasted_store: RwLock<TrackStore<SortAttributes, SortMetric, Universal2DBox>>,
25 method: PositionalMetricType,
26 opts: Arc<SortAttributesOptions>,
27 auto_waste: AutoWaste,
28 track_id: u64,
29}
30
31impl Sort {
32 #[allow(clippy::too_many_arguments)]
41 pub fn new(
42 shards: usize,
43 bbox_history: usize,
44 max_idle_epochs: usize,
45 method: PositionalMetricType,
46 min_confidence: f32,
47 spatio_temporal_constraints: Option<SpatioTemporalConstraints>,
48 kalman_position_weight: f32,
49 kalman_velocity_weight: f32,
50 ) -> Self {
51 assert!(bbox_history > 0);
52 let epoch_db = RwLock::new(HashMap::default());
53 let opts = Arc::new(SortAttributesOptions::new(
54 Some(epoch_db),
55 max_idle_epochs,
56 bbox_history,
57 spatio_temporal_constraints.unwrap_or_default(),
58 kalman_position_weight,
59 kalman_velocity_weight,
60 ));
61 let store = RwLock::new(
62 TrackStoreBuilder::new(shards)
63 .default_attributes(SortAttributes::new(opts.clone()))
64 .metric(SortMetric::new(method, min_confidence))
65 .notifier(NoopNotifier)
66 .build(),
67 );
68
69 let wasted_store = RwLock::new(
70 TrackStoreBuilder::new(shards)
71 .default_attributes(SortAttributes::new(opts.clone()))
72 .metric(SortMetric::new(method, min_confidence))
73 .notifier(NoopNotifier)
74 .build(),
75 );
76
77 Self {
78 store,
79 track_id: 0,
80 wasted_store,
81 method,
82 opts,
83 auto_waste: AutoWaste {
84 periodicity: DEFAULT_AUTO_WASTE_PERIODICITY,
85 counter: DEFAULT_AUTO_WASTE_PERIODICITY,
86 },
87 }
88 }
89
90 pub fn predict(&mut self, bboxes: &[(Universal2DBox, Option<i64>)]) -> Vec<SortTrack> {
96 self.predict_with_scene(0, bboxes)
97 }
98
99 fn gen_track_id(&mut self) -> u64 {
100 self.track_id += 1;
101 self.track_id
102 }
103
104 pub fn predict_with_scene(
111 &mut self,
112 scene_id: u64,
113 bboxes: &[(Universal2DBox, Option<i64>)],
114 ) -> Vec<SortTrack> {
115 if self.auto_waste.counter == 0 {
116 self.auto_waste();
117 self.auto_waste.counter = self.auto_waste.periodicity;
118 } else {
119 self.auto_waste.counter -= 1;
120 }
121
122 let mut rng = rand::thread_rng();
123 let epoch = self.opts.next_epoch(scene_id).unwrap();
124
125 let tracks = bboxes
126 .iter()
127 .map(|(bb, custom_object_id)| {
128 self.store
129 .read()
130 .unwrap()
131 .new_track(rng.gen())
132 .observation(
133 ObservationBuilder::new(0)
134 .observation_attributes(bb.clone())
135 .track_attributes_update(SortAttributesUpdate::new_with_scene(
136 epoch,
137 scene_id,
138 *custom_object_id,
139 ))
140 .build(),
141 )
142 .build()
143 .unwrap()
144 })
145 .collect::<Vec<_>>();
146 let num_candidates = tracks.len();
147 let (dists, errs) =
148 self.store
149 .write()
150 .unwrap()
151 .foreign_track_distances(tracks.clone(), 0, false);
152 assert!(errs.all().is_empty());
153 let dists = dists.all();
154 let voting = SortVoting::new(
155 match self.method {
156 PositionalMetricType::Mahalanobis => MAHALANOBIS_NEW_TRACK_THRESHOLD,
157 PositionalMetricType::IoU(t) => t,
158 },
159 num_candidates,
160 self.store.read().unwrap().shard_stats().iter().sum(),
161 );
162 let winners = voting.winners(dists);
163 let mut res = Vec::default();
164
165 for mut t in tracks {
166 let source = t.get_track_id();
167 let track_id: u64 = if let Some(dest) = winners.get(&source) {
168 let dest = dest[0];
169 if dest == source {
170 let track_id = self.gen_track_id();
171 t.set_track_id(track_id);
172 self.store.write().unwrap().add_track(t).unwrap();
173 track_id
174 } else {
175 self.store
176 .write()
177 .unwrap()
178 .merge_external(dest, &t, Some(&[0]), false)
179 .unwrap();
180 dest
181 }
182 } else {
183 let track_id = self.gen_track_id();
184 t.set_track_id(track_id);
185 self.store.write().unwrap().add_track(t).unwrap();
186 track_id
187 };
188
189 let lock = self.store.read().unwrap();
190 let store = lock.get_store(track_id as usize);
191 let track = store.get(&track_id).unwrap();
192 res.push(SortTrack::from(track));
193 }
194
195 res
196 }
197
198 pub fn idle_tracks(&mut self) -> Vec<SortTrack> {
199 self.idle_tracks_with_scene(0)
200 }
201
202 pub fn idle_tracks_with_scene(&mut self, scene_id: u64) -> Vec<SortTrack> {
203 let store = self.store.read().unwrap();
204
205 store
206 .lookup(SortLookup::IdleLookup(scene_id))
207 .iter()
208 .map(|(track_id, _status)| {
209 let shard = store.get_store(*track_id as usize);
210 let track = shard.get(track_id).unwrap();
211 SortTrack::from(track)
212 })
213 .collect()
214 }
215}
216
217impl TrackerAPI<SortAttributes, SortMetric, Universal2DBox, SortAttributesOptions, NoopNotifier>
218 for Sort
219{
220 fn get_auto_waste_obj_mut(&mut self) -> &mut AutoWaste {
221 &mut self.auto_waste
222 }
223
224 fn get_opts(&self) -> &SortAttributesOptions {
225 &self.opts
226 }
227
228 fn get_main_store_mut(
229 &mut self,
230 ) -> RwLockWriteGuard<TrackStore<SortAttributes, SortMetric, Universal2DBox, NoopNotifier>>
231 {
232 self.store.write().unwrap()
233 }
234
235 fn get_wasted_store_mut(
236 &mut self,
237 ) -> RwLockWriteGuard<TrackStore<SortAttributes, SortMetric, Universal2DBox, NoopNotifier>>
238 {
239 self.wasted_store.write().unwrap()
240 }
241
242 fn get_main_store(
243 &self,
244 ) -> RwLockReadGuard<TrackStore<SortAttributes, SortMetric, Universal2DBox, NoopNotifier>> {
245 self.store.read().unwrap()
246 }
247
248 fn get_wasted_store(
249 &self,
250 ) -> RwLockReadGuard<TrackStore<SortAttributes, SortMetric, Universal2DBox, NoopNotifier>> {
251 self.wasted_store.read().unwrap()
252 }
253}
254
255impl From<&Track<SortAttributes, SortMetric, Universal2DBox>> for SortTrack {
256 fn from(track: &Track<SortAttributes, SortMetric, Universal2DBox>) -> Self {
257 let attrs = track.get_attributes();
258 SortTrack {
259 id: track.get_track_id(),
260 custom_object_id: attrs.custom_object_id,
261 voting_type: VotingType::Positional,
262 epoch: attrs.last_updated_epoch,
263 scene_id: attrs.scene_id,
264 observed_bbox: attrs.observed_boxes.back().unwrap().clone(),
265 predicted_bbox: attrs.predicted_boxes.back().unwrap().clone(),
266 length: attrs.track_length,
267 }
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use crate::trackers::sort::metric::DEFAULT_MINIMAL_SORT_CONFIDENCE;
274 use crate::trackers::sort::simple_api::Sort;
275 use crate::trackers::sort::PositionalMetricType::IoU;
276 use crate::trackers::sort::DEFAULT_SORT_IOU_THRESHOLD;
277 use crate::trackers::tracker_api::TrackerAPI;
278 use crate::utils::bbox::BoundingBox;
279
280 #[test]
281 fn sort() {
282 let mut t = Sort::new(
283 1,
284 10,
285 2,
286 IoU(DEFAULT_SORT_IOU_THRESHOLD),
287 DEFAULT_MINIMAL_SORT_CONFIDENCE,
288 None,
289 1.0 / 20.0,
290 1.0 / 160.0,
291 );
292 assert_eq!(t.current_epoch(), 0);
293 let bb = BoundingBox::new(0.0, 0.0, 10.0, 20.0);
294 let v = t.predict(&[(bb.into(), None)]);
295 let wasted = t.wasted();
296 assert!(wasted.is_empty());
297 assert_eq!(v.len(), 1);
298 let v = v[0].clone();
299 let track_id = v.id;
300 assert_eq!(v.custom_object_id, None);
301 assert_eq!(v.length, 1);
302 assert_eq!(v.observed_bbox, bb.into());
303 assert_eq!(v.epoch, 1);
304 assert_eq!(t.current_epoch(), 1);
305
306 let bb = BoundingBox::new(0.1, 0.1, 10.1, 20.0);
307 let v = t.predict(&[(bb.into(), Some(2))]);
308 let wasted = t.wasted();
309 assert!(wasted.is_empty());
310 assert_eq!(v.len(), 1);
311 let v = v[0].clone();
312 assert_eq!(v.custom_object_id, Some(2));
313 assert_eq!(v.id, track_id);
314 assert_eq!(v.length, 2);
315 assert_eq!(v.observed_bbox, bb.into());
316 assert_eq!(v.epoch, 2);
317 assert_eq!(t.current_epoch(), 2);
318
319 let bb = BoundingBox::new(10.1, 10.1, 10.1, 20.0);
320 let v = t.predict(&[(bb.into(), Some(3))]);
321 assert_eq!(v.len(), 1);
322 let v = v[0].clone();
323 assert_eq!(v.custom_object_id, Some(3));
324 assert_ne!(v.id, track_id);
325 let wasted = t.wasted();
326 assert!(wasted.is_empty());
327 assert_eq!(t.current_epoch(), 3);
328
329 let bb = t.predict(&[]);
330 assert!(bb.is_empty());
331 let wasted = t.wasted();
332 assert!(wasted.is_empty());
333 assert_eq!(t.current_epoch(), 4);
334 assert_eq!(t.current_epoch(), 4);
335
336 let bb = t.predict(&[]);
337 assert!(bb.is_empty());
338 let wasted = t.wasted();
339 assert_eq!(wasted.len(), 1);
340 assert_eq!(wasted[0].get_track_id(), track_id);
341 assert_eq!(t.current_epoch(), 5);
342 }
343
344 #[test]
345 fn sort_with_scenes() {
346 let mut t = Sort::new(
347 1,
348 10,
349 2,
350 IoU(DEFAULT_SORT_IOU_THRESHOLD),
351 DEFAULT_MINIMAL_SORT_CONFIDENCE,
352 None,
353 1.0 / 20.0,
354 1.0 / 160.0,
355 );
356 let bb = BoundingBox::new(0.0, 0.0, 10.0, 20.0);
357 assert_eq!(t.current_epoch_with_scene(1), 0);
358 assert_eq!(t.current_epoch_with_scene(2), 0);
359
360 let _v = t.predict_with_scene(1, &[(bb.into(), Some(4))]);
361 let _v = t.predict_with_scene(1, &[(bb.into(), Some(5))]);
362
363 assert_eq!(t.current_epoch_with_scene(1), 2);
364 assert_eq!(t.current_epoch_with_scene(2), 0);
365
366 let _v = t.predict_with_scene(2, &[(bb.into(), Some(6))]);
367
368 assert_eq!(t.current_epoch_with_scene(1), 2);
369 assert_eq!(t.current_epoch_with_scene(2), 1);
370 }
371
372 #[test]
373 fn idle_tracks() {
374 let mut t = Sort::new(
375 1,
376 10,
377 2,
378 IoU(DEFAULT_SORT_IOU_THRESHOLD),
379 DEFAULT_MINIMAL_SORT_CONFIDENCE,
380 None,
381 1.0 / 20.0,
382 1.0 / 160.0,
383 );
384 let bb = BoundingBox::new(0.0, 0.0, 10.0, 20.0);
385
386 let _v = t.predict_with_scene(1, &[(bb.into(), Some(4))]);
387 let idle = t.idle_tracks_with_scene(1);
388 assert!(idle.is_empty());
389
390 let _v = t.predict_with_scene(1, &[]);
391
392 let idle = t.idle_tracks_with_scene(1);
393 assert_eq!(idle.len(), 1);
394 assert_eq!(idle[0].id, 1);
395 }
396
397 #[test]
398 fn clear_wasted_tracks() {
399 let mut t = Sort::new(
400 1,
401 10,
402 2,
403 IoU(DEFAULT_SORT_IOU_THRESHOLD),
404 DEFAULT_MINIMAL_SORT_CONFIDENCE,
405 None,
406 1.0 / 20.0,
407 1.0 / 160.0,
408 );
409 let bb = BoundingBox::new(0.0, 0.0, 10.0, 20.0);
410
411 let _v = t.predict_with_scene(1, &[(bb.into(), Some(4))]);
412 t.skip_epochs_for_scene(1, 3);
413 assert_eq!(
414 t.wasted_store
415 .read()
416 .unwrap()
417 .shard_stats()
418 .iter()
419 .sum::<usize>(),
420 1
421 );
422 t.clear_wasted();
423 assert_eq!(
424 t.wasted_store
425 .read()
426 .unwrap()
427 .shard_stats()
428 .iter()
429 .sum::<usize>(),
430 0
431 );
432 }
433}
434
435#[cfg(feature = "python")]
436pub mod python {
437 use pyo3::prelude::*;
438
439 use crate::{
440 prelude::Universal2DBox,
441 trackers::{
442 sort::{
443 python::{PyPositionalMetricType, PySortTrack, PyWastedSortTrack},
444 WastedSortTrack,
445 },
446 spatio_temporal_constraints::python::PySpatioTemporalConstraints,
447 tracker_api::TrackerAPI,
448 },
449 utils::bbox::python::PyUniversal2DBox,
450 };
451
452 use super::Sort;
453
454 #[pyclass]
455 #[pyo3(name = "Sort")]
456 pub struct PySort(pub Sort);
457
458 #[pymethods]
459 impl PySort {
460 #[new]
461 #[pyo3(signature = (
462 shards = 4,
463 bbox_history = 1,
464 max_idle_epochs = 5,
465 method = None,
466 min_confidence = 0.05,
467 spatio_temporal_constraints = None,
468 kalman_position_weight = 1.0 / 20.0,
469 kalman_velocity_weight = 1.0 / 160.0
470 ))]
471 #[allow(clippy::too_many_arguments)]
472 pub fn new_py(
473 shards: i64,
474 bbox_history: i64,
475 max_idle_epochs: i64,
476 method: Option<PyPositionalMetricType>,
477 min_confidence: f32,
478 spatio_temporal_constraints: Option<PySpatioTemporalConstraints>,
479 kalman_position_weight: f32,
480 kalman_velocity_weight: f32,
481 ) -> Self {
482 Self(Sort::new(
483 shards.try_into().expect("Positive number expected"),
484 bbox_history.try_into().expect("Positive number expected"),
485 max_idle_epochs
486 .try_into()
487 .expect("Positive number expected"),
488 method.unwrap_or(PyPositionalMetricType::maha()).0,
489 min_confidence,
490 spatio_temporal_constraints.map(|x| x.0),
491 kalman_position_weight,
492 kalman_velocity_weight,
493 ))
494 }
495
496 #[pyo3(signature = (n))]
497 pub fn skip_epochs(&mut self, n: i64) {
498 assert!(n > 0);
499 self.0.skip_epochs(n.try_into().unwrap())
500 }
501
502 #[pyo3(signature = (scene_id, n))]
503 pub fn skip_epochs_for_scene(&mut self, scene_id: i64, n: i64) {
504 assert!(n > 0 && scene_id >= 0);
505 self.0
506 .skip_epochs_for_scene(scene_id.try_into().unwrap(), n.try_into().unwrap())
507 }
508
509 #[pyo3(signature = ())]
512 pub fn shard_stats(&self) -> Vec<i64> {
513 Python::with_gil(|py| {
514 py.allow_threads(|| {
515 self.0
516 .store
517 .read()
518 .unwrap()
519 .shard_stats()
520 .into_iter()
521 .map(|e| i64::try_from(e).unwrap())
522 .collect()
523 })
524 })
525 }
526
527 #[pyo3(signature = ())]
530 pub fn current_epoch(&self) -> i64 {
531 self.0.current_epoch_with_scene(0).try_into().unwrap()
532 }
533
534 #[pyo3(signature = (scene_id))]
540 pub fn current_epoch_with_scene(&self, scene_id: i64) -> isize {
541 assert!(scene_id >= 0);
542 self.0
543 .current_epoch_with_scene(scene_id.try_into().unwrap())
544 .try_into()
545 .unwrap()
546 }
547
548 #[pyo3(signature = (bboxes))]
554 pub fn predict(
555 &mut self,
556 bboxes: Vec<(PyUniversal2DBox, Option<i64>)>,
557 ) -> Vec<PySortTrack> {
558 self.predict_with_scene(0, bboxes)
559 }
560
561 #[pyo3(signature = (scene_id, bboxes))]
568 pub fn predict_with_scene(
569 &mut self,
570 scene_id: i64,
571 bboxes: Vec<(PyUniversal2DBox, Option<i64>)>,
572 ) -> Vec<PySortTrack> {
573 assert!(scene_id >= 0);
574 let bboxes: Vec<(Universal2DBox, Option<i64>)> = unsafe { std::mem::transmute(bboxes) };
575
576 Python::with_gil(|py| {
577 py.allow_threads(|| unsafe {
578 std::mem::transmute(
579 self.0
580 .predict_with_scene(scene_id.try_into().unwrap(), &bboxes),
581 )
582 })
583 })
584 }
585
586 #[pyo3(signature = ())]
589 pub fn wasted(&mut self) -> Vec<PyWastedSortTrack> {
590 Python::with_gil(|py| {
591 py.allow_threads(|| {
592 self.0
593 .wasted()
594 .into_iter()
595 .map(WastedSortTrack::from)
596 .map(PyWastedSortTrack)
597 .collect()
598 })
599 })
600 }
601
602 #[pyo3(signature = ())]
605 pub fn clear_wasted(&mut self) {
606 Python::with_gil(|py| {
607 py.allow_threads(|| self.0.clear_wasted());
608 })
609 }
610
611 #[pyo3(signature = ())]
614 pub fn idle_tracks(&mut self) -> Vec<PySortTrack> {
615 self.idle_tracks_with_scene(0)
616 }
617
618 #[pyo3(signature = (scene_id))]
621 pub fn idle_tracks_with_scene(&mut self, scene_id: i64) -> Vec<PySortTrack> {
622 Python::with_gil(|py| {
623 py.allow_threads(|| unsafe {
624 std::mem::transmute(self.0.idle_tracks_with_scene(scene_id.try_into().unwrap()))
625 })
626 })
627 }
628 }
629}