1use crate::prelude::{NoopNotifier, ObservationBuilder, SortTrack, TrackStoreBuilder};
2use crate::store::TrackStore;
3use crate::track::utils::FromVec;
4use crate::track::{Feature, Track};
5use crate::trackers::epoch_db::EpochDb;
6use crate::trackers::sort::VotingType::Positional;
7use crate::trackers::sort::{
8 AutoWaste, PositionalMetricType, SortAttributesOptions, DEFAULT_AUTO_WASTE_PERIODICITY,
9 MAHALANOBIS_NEW_TRACK_THRESHOLD,
10};
11use crate::trackers::tracker_api::TrackerAPI;
12use crate::trackers::visual_sort::metric::{VisualMetric, VisualMetricOptions};
13use crate::trackers::visual_sort::observation_attributes::VisualObservationAttributes;
14use crate::trackers::visual_sort::options::VisualSortOptions;
15use crate::trackers::visual_sort::track_attributes::{
16 VisualAttributes, VisualAttributesUpdate, VisualSortLookup,
17};
18use crate::trackers::visual_sort::voting::VisualVoting;
19use crate::trackers::visual_sort::VisualSortObservation;
20use crate::utils::clipping::bbox_own_areas::{
21 exclusively_owned_areas, exclusively_owned_areas_normalized_shares,
22};
23use crate::voting::Voting;
24use rand::Rng;
25use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
26
27pub struct VisualSort {
30 store: RwLock<TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes>>,
31 wasted_store: RwLock<TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes>>,
32 metric_opts: Arc<VisualMetricOptions>,
33 track_opts: Arc<SortAttributesOptions>,
34 auto_waste: AutoWaste,
35 track_id: u64,
36}
37
38impl VisualSort {
39 pub fn new(shards: usize, opts: &VisualSortOptions) -> Self {
46 let (track_opts, metric) = opts.clone().build();
47 let track_opts = Arc::new(track_opts);
48 let metric_opts = metric.opts.clone();
49 let store = RwLock::new(
50 TrackStoreBuilder::new(shards)
51 .default_attributes(VisualAttributes::new(track_opts.clone()))
52 .metric(metric.clone())
53 .notifier(NoopNotifier)
54 .build(),
55 );
56
57 let wasted_store = RwLock::new(
58 TrackStoreBuilder::new(shards)
59 .default_attributes(VisualAttributes::new(track_opts.clone()))
60 .metric(metric)
61 .notifier(NoopNotifier)
62 .build(),
63 );
64
65 Self {
66 store,
67 wasted_store,
68 track_opts,
69 track_id: 0,
70 metric_opts,
71 auto_waste: AutoWaste {
72 periodicity: DEFAULT_AUTO_WASTE_PERIODICITY,
73 counter: DEFAULT_AUTO_WASTE_PERIODICITY,
74 },
75 }
76 }
77
78 pub fn predict(&mut self, observations: &[VisualSortObservation]) -> Vec<SortTrack> {
85 self.predict_with_scene(0, observations)
86 }
87
88 fn gen_track_id(&mut self) -> u64 {
89 self.track_id += 1;
90 self.track_id
91 }
92
93 pub fn predict_with_scene(
100 &mut self,
101 scene_id: u64,
102 observations: &[VisualSortObservation],
103 ) -> Vec<SortTrack> {
104 if self.auto_waste.counter == 0 {
105 self.auto_waste();
106 self.auto_waste.counter = self.auto_waste.periodicity;
107 } else {
108 self.auto_waste.counter -= 1;
109 }
110
111 let mut percentages = Vec::default();
112 let use_own_area_percentage = self.metric_opts.visual_minimal_own_area_percentage_collect
113 + self.metric_opts.visual_minimal_own_area_percentage_use
114 > 0.0;
115
116 if use_own_area_percentage {
117 percentages.reserve(observations.len());
118 let boxes = observations
119 .iter()
120 .map(|e| &e.bounding_box)
121 .collect::<Vec<_>>();
122
123 percentages = exclusively_owned_areas_normalized_shares(
124 boxes.as_ref(),
125 exclusively_owned_areas(boxes.as_ref()).as_ref(),
126 );
127 }
128
129 let mut rng = rand::thread_rng();
130 let epoch = self.track_opts.next_epoch(scene_id).unwrap();
131
132 let mut tracks = observations
133 .iter()
134 .enumerate()
135 .map(|(i, o)| {
136 self.store
137 .read()
138 .unwrap()
139 .new_track(rng.gen())
140 .observation({
141 let mut obs = ObservationBuilder::new(0).observation_attributes(
142 if use_own_area_percentage {
143 VisualObservationAttributes::with_own_area_percentage(
144 o.feature_quality.unwrap_or(1.0),
145 o.bounding_box.clone(),
146 percentages[i],
147 )
148 } else {
149 VisualObservationAttributes::new(
150 o.feature_quality.unwrap_or(1.0),
151 o.bounding_box.clone(),
152 )
153 },
154 );
155
156 if let Some(feature) = &o.feature {
157 obs = obs.observation(Feature::from_vec(feature.to_vec()));
158 }
159
160 obs.track_attributes_update(VisualAttributesUpdate::new_init_with_scene(
161 epoch,
162 scene_id,
163 o.custom_object_id,
164 ))
165 .build()
166 })
167 .build()
168 .unwrap()
169 })
170 .collect::<Vec<_>>();
171
172 let (dists, errs) =
173 self.store
174 .write()
175 .unwrap()
176 .foreign_track_distances(tracks.clone(), 0, false);
177
178 assert!(errs.all().is_empty());
179 let voting = VisualVoting::new(
180 match self.metric_opts.positional_kind {
181 PositionalMetricType::Mahalanobis => MAHALANOBIS_NEW_TRACK_THRESHOLD,
182 PositionalMetricType::IoU(t) => t,
183 },
184 f32::MAX,
185 self.metric_opts.visual_min_votes,
186 );
187 let winners = voting.winners(dists);
188 let mut res = Vec::default();
189 for t in &mut tracks {
190 let source = t.get_track_id();
191 let track_id: u64 = if let Some(dest) = winners.get(&source) {
192 let (dest, vt) = dest[0];
193 if dest == source {
194 let mut t = t.clone();
195 let track_id = self.gen_track_id();
196 t.set_track_id(track_id);
197 self.store.write().unwrap().add_track(t).unwrap();
198 track_id
199 } else {
200 t.add_observation(
201 0,
202 None,
203 None,
204 Some(VisualAttributesUpdate::new_voting_type(vt)),
205 )
206 .unwrap();
207 self.store
208 .write()
209 .unwrap()
210 .merge_external(dest, t, Some(&[0]), false)
211 .unwrap();
212 dest
213 }
214 } else {
215 let mut t = t.clone();
216 let track_id = self.gen_track_id();
217 t.set_track_id(track_id);
218 self.store.write().unwrap().add_track(t).unwrap();
219 track_id
220 };
221
222 let lock = self.store.read().unwrap();
223 let store = lock.get_store(track_id as usize);
224 let track = store.get(&track_id).unwrap();
225
226 res.push(SortTrack::from(track))
227 }
228
229 res
230 }
231
232 pub fn idle_tracks(&mut self) -> Vec<SortTrack> {
233 self.idle_tracks_with_scene(0)
234 }
235
236 pub fn idle_tracks_with_scene(&mut self, scene_id: u64) -> Vec<SortTrack> {
237 let store = self.store.read().unwrap();
238 store
239 .lookup(VisualSortLookup::IdleLookup(scene_id))
240 .iter()
241 .map(|(track_id, _status)| {
242 let shard = store.get_store(*track_id as usize);
243 let track = shard.get(track_id).unwrap();
244 SortTrack::from(track)
245 })
246 .collect()
247 }
248}
249
250impl
251 TrackerAPI<
252 VisualAttributes,
253 VisualMetric,
254 VisualObservationAttributes,
255 SortAttributesOptions,
256 NoopNotifier,
257 > for VisualSort
258{
259 fn get_auto_waste_obj_mut(&mut self) -> &mut AutoWaste {
260 &mut self.auto_waste
261 }
262
263 fn get_opts(&self) -> &SortAttributesOptions {
264 &self.track_opts
265 }
266
267 fn get_main_store_mut(
268 &mut self,
269 ) -> RwLockWriteGuard<
270 TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes, NoopNotifier>,
271 > {
272 self.store.write().unwrap()
273 }
274
275 fn get_wasted_store_mut(
276 &mut self,
277 ) -> RwLockWriteGuard<
278 TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes, NoopNotifier>,
279 > {
280 self.wasted_store.write().unwrap()
281 }
282
283 fn get_main_store(
284 &self,
285 ) -> RwLockReadGuard<
286 TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes, NoopNotifier>,
287 > {
288 self.store.read().unwrap()
289 }
290
291 fn get_wasted_store(
292 &self,
293 ) -> RwLockReadGuard<
294 TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes, NoopNotifier>,
295 > {
296 self.wasted_store.read().unwrap()
297 }
298}
299
300impl From<&Track<VisualAttributes, VisualMetric, VisualObservationAttributes>> for SortTrack {
301 fn from(track: &Track<VisualAttributes, VisualMetric, VisualObservationAttributes>) -> Self {
302 let attrs = track.get_attributes();
303 SortTrack {
304 id: track.get_track_id(),
305 custom_object_id: attrs.custom_object_id,
306 voting_type: attrs.voting_type.unwrap_or(Positional),
307 epoch: attrs.last_updated_epoch,
308 scene_id: attrs.scene_id,
309 observed_bbox: attrs.observed_boxes.back().unwrap().clone(),
310 predicted_bbox: attrs.predicted_boxes.back().unwrap().clone(),
311 length: attrs.track_length,
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use crate::track::Observation;
319 use crate::trackers::sort::{PositionalMetricType, VotingType};
320 use crate::trackers::tracker_api::TrackerAPI;
321 use crate::trackers::visual_sort::metric::VisualSortMetricType;
322 use crate::trackers::visual_sort::observation_attributes::VisualObservationAttributes;
323 use crate::trackers::visual_sort::options::VisualSortOptions;
324 use crate::trackers::visual_sort::simple_api::VisualSort;
325 use crate::trackers::visual_sort::{VisualSortObservation, WastedVisualSortTrack};
326 use crate::utils::bbox::BoundingBox;
327
328 #[test]
329 fn visual_sort() {
330 let opts = VisualSortOptions::default()
331 .max_idle_epochs(3)
332 .kept_history_length(3)
333 .visual_metric(VisualSortMetricType::Euclidean(1.0))
334 .positional_metric(PositionalMetricType::Mahalanobis)
335 .visual_minimal_track_length(2)
336 .visual_minimal_area(5.0)
337 .visual_minimal_quality_use(0.45)
338 .visual_minimal_quality_collect(0.7)
339 .visual_max_observations(3)
340 .visual_min_votes(2);
341
342 let mut tracker = VisualSort::new(1, &opts);
343
344 let tracks = tracker.predict_with_scene(
347 10,
348 &[VisualSortObservation::new(
349 Some(&vec![1.0, 1.0]),
350 Some(0.9),
351 BoundingBox::new(1.0, 1.0, 3.0, 5.0).as_xyaah(),
352 Some(13),
353 )],
354 );
355 let t = &tracks[0];
356 assert_eq!(t.custom_object_id, Some(13));
357 assert_eq!(t.scene_id, 10);
358 assert!(matches!(t.voting_type, VotingType::Positional));
359 assert!(matches!(t.epoch, 1));
360 let attrs = {
361 let lock = tracker.store.read().unwrap();
362 let store = lock.get_store(t.id as usize);
363 let track = store.get(&t.id).unwrap();
364 track.get_attributes().clone()
365 };
366 assert_eq!(attrs.visual_features_collected_count, 1);
367 assert_eq!(attrs.track_length, 1);
368 assert_eq!(attrs.observed_boxes.len(), 1);
369 assert_eq!(attrs.predicted_boxes.len(), 1);
370 assert_eq!(attrs.observed_features.len(), 1);
371 let first_track_id = t.id;
372
373 {
374 let tracks = tracker.predict_with_scene(
376 1,
377 &[VisualSortObservation::new(
378 Some(&vec![1.0, 1.0]),
379 Some(0.9),
380 BoundingBox::new(1.0, 1.0, 3.0, 5.0).as_xyaah(),
381 Some(133),
382 )],
383 );
384 let t = &tracks[0];
385 assert_eq!(t.custom_object_id, Some(133));
386 assert_eq!(t.scene_id, 1);
387 assert!(matches!(t.voting_type, VotingType::Positional));
388 assert!(matches!(t.epoch, 1));
389 let attrs = {
390 let lock = tracker.store.read().unwrap();
391 let store = lock.get_store(t.id as usize);
392 let track = store.get(&t.id).unwrap();
393 track.get_attributes().clone()
394 };
395 assert_eq!(attrs.visual_features_collected_count, 1);
396 assert_eq!(attrs.track_length, 1);
397 assert_eq!(attrs.observed_boxes.len(), 1);
398 assert_eq!(attrs.predicted_boxes.len(), 1);
399 assert_eq!(attrs.observed_features.len(), 1);
400 }
401
402 let tracks = tracker.predict_with_scene(
405 10,
406 &[VisualSortObservation::new(
407 Some(&vec![0.95, 0.95]),
408 Some(0.93),
409 BoundingBox::new(1.1, 1.1, 3.05, 5.01).as_xyaah(),
410 Some(15),
411 )],
412 );
413 let t = &tracks[0];
414 assert_eq!(t.id, first_track_id);
415 assert_eq!(t.custom_object_id, Some(15));
416 assert_eq!(t.scene_id, 10);
417 assert!(matches!(t.voting_type, VotingType::Positional));
418 assert!(matches!(t.epoch, 2));
419 let attrs = {
420 let lock = tracker.store.read().unwrap();
421 let store = lock.get_store(t.id as usize);
422 let track = store.get(&t.id).unwrap();
423 track.get_attributes().clone()
424 };
425 assert_eq!(attrs.visual_features_collected_count, 2);
426 assert_eq!(attrs.track_length, 2);
427 assert_eq!(attrs.observed_boxes.len(), 2);
428 assert_eq!(attrs.predicted_boxes.len(), 2);
429 assert_eq!(attrs.observed_features.len(), 2);
430
431 let tracks = tracker.predict_with_scene(
434 10,
435 &[VisualSortObservation::new(
436 None,
437 Some(0.93),
438 BoundingBox::new(1.11, 1.15, 3.15, 5.05).as_xyaah(),
439 Some(25),
440 )],
441 );
442 let t = &tracks[0];
443 assert_eq!(t.id, first_track_id);
444 assert_eq!(t.custom_object_id, Some(25));
445 assert_eq!(t.scene_id, 10);
446 assert!(matches!(t.voting_type, VotingType::Positional));
447 assert!(matches!(t.epoch, 3));
448 let attrs = {
449 let lock = tracker.store.read().unwrap();
450 let store = lock.get_store(t.id as usize);
451 let track = store.get(&t.id).unwrap();
452 track.get_attributes().clone()
453 };
454 assert_eq!(attrs.visual_features_collected_count, 2);
455 assert_eq!(attrs.track_length, 3);
456 assert_eq!(attrs.observed_boxes.len(), 3);
457 assert_eq!(attrs.predicted_boxes.len(), 3);
458 assert_eq!(attrs.observed_features.len(), 3);
459 assert!(attrs.observed_features.back().unwrap().is_none());
460
461 let tracks = tracker.predict_with_scene(
464 10,
465 &[VisualSortObservation::new(
466 None,
467 Some(0.93),
468 BoundingBox::new(1.15, 1.25, 3.10, 5.05).as_xyaah(),
469 Some(2),
470 )],
471 );
472 let t = &tracks[0];
473 assert_eq!(t.id, first_track_id);
474 assert!(matches!(t.voting_type, VotingType::Positional));
475 assert!(matches!(t.epoch, 4));
476 let attrs = {
477 let lock = tracker.store.read().unwrap();
478 let store = lock.get_store(t.id as usize);
479 let track = store.get(&t.id).unwrap();
480 track.get_attributes().clone()
481 };
482 assert_eq!(attrs.visual_features_collected_count, 2);
483 assert_eq!(attrs.track_length, 4);
484 assert_eq!(attrs.observed_boxes.len(), 3);
485 assert_eq!(attrs.predicted_boxes.len(), 3);
486 assert_eq!(attrs.observed_features.len(), 3);
487 assert!(attrs.observed_features.back().unwrap().is_none());
488
489 let tracks = tracker.predict_with_scene(
492 10,
493 &[VisualSortObservation::new(
494 Some(&vec![0.97, 0.97]),
495 Some(0.44),
496 BoundingBox::new(1.15, 1.25, 3.10, 5.05).as_xyaah(),
497 Some(2),
498 )],
499 );
500 let t = &tracks[0];
501 assert_eq!(t.id, first_track_id);
502 assert!(matches!(t.voting_type, VotingType::Positional));
503 let attrs = {
504 let lock = tracker.store.read().unwrap();
505 let store = lock.get_store(t.id as usize);
506 let track = store.get(&t.id).unwrap();
507 track.get_attributes().clone()
508 };
509 assert_eq!(attrs.visual_features_collected_count, 2);
510 assert_eq!(attrs.track_length, 5);
511 assert!(attrs.observed_features.back().unwrap().is_some());
512
513 let tracks = tracker.predict_with_scene(
516 10,
517 &[VisualSortObservation::new(
518 Some(&vec![0.97, 0.97]),
519 Some(0.6),
520 BoundingBox::new(1.15, 1.25, 3.10, 5.05).as_xyaah(),
521 Some(2),
522 )],
523 );
524 let t = &tracks[0];
525 assert_eq!(t.id, first_track_id);
526 assert!(matches!(t.voting_type, VotingType::Visual));
527 let attrs = {
528 let lock = tracker.store.read().unwrap();
529 let store = lock.get_store(t.id as usize);
530 let track = store.get(&t.id).unwrap();
531 track.get_attributes().clone()
532 };
533 assert_eq!(attrs.visual_features_collected_count, 2);
534 assert_eq!(attrs.track_length, 6);
535 assert!(attrs.observed_features.back().unwrap().is_some());
536
537 let tracks = tracker.predict_with_scene(
540 10,
541 &[VisualSortObservation::new(
542 Some(&vec![0.97, 0.97]),
543 Some(0.8),
544 BoundingBox::new(1.15, 1.25, 3.10, 5.05).as_xyaah(),
545 Some(2),
546 )],
547 );
548 let t = &tracks[0];
549 assert_eq!(t.id, first_track_id);
550 assert!(matches!(t.voting_type, VotingType::Visual));
551 let attrs = {
552 let lock = tracker.store.read().unwrap();
553 let store = lock.get_store(t.id as usize);
554 let track = store.get(&t.id).unwrap();
555 let observations = track.get_observations(0).unwrap();
556
557 fn bbox_is(b: &Observation<VisualObservationAttributes>) -> bool {
558 b.attr().as_ref().unwrap().bbox_opt().is_some()
559 }
560
561 assert!(bbox_is(&observations[0]) && observations[0].feature().is_some());
562 assert!(!bbox_is(&observations[1]) && observations[1].feature().is_some());
563 assert!(!bbox_is(&observations[2]) && observations[2].feature().is_some());
564
565 track.get_attributes().clone()
566 };
567 assert_eq!(attrs.visual_features_collected_count, 3);
568 assert_eq!(attrs.track_length, 7);
569 assert!(attrs.observed_features.back().unwrap().is_some());
570
571 let tracks = tracker.predict_with_scene(
574 10,
575 &[VisualSortObservation::new(
576 Some(&vec![0.1, 0.1]),
577 Some(0.9),
578 BoundingBox::new(10.0, 10.0, 3.0, 5.0).as_xyaah(),
579 Some(33),
580 )],
581 );
582 let t = &tracks[0];
583 assert_eq!(t.custom_object_id, Some(33));
584 assert_eq!(t.scene_id, 10);
585 assert!(matches!(t.voting_type, VotingType::Positional));
586 assert!(matches!(t.epoch, 8));
587 assert_ne!(t.id, first_track_id);
588 let attrs = {
589 let lock = tracker.store.read().unwrap();
590 let store = lock.get_store(t.id as usize);
591 let track = store.get(&t.id).unwrap();
592 track.get_attributes().clone()
593 };
594 assert_eq!(attrs.visual_features_collected_count, 1);
595 assert_eq!(attrs.track_length, 1);
596 assert_eq!(attrs.observed_boxes.len(), 1);
597 assert_eq!(attrs.predicted_boxes.len(), 1);
598 assert_eq!(attrs.observed_features.len(), 1);
599 let other_track_id = t.id;
600
601 let tracks = tracker.predict_with_scene(
604 10,
605 &[VisualSortObservation::new(
606 Some(&vec![0.12, 0.15]),
607 Some(0.88),
608 BoundingBox::new(10.1, 10.1, 3.0, 5.0).as_xyaah(),
609 Some(35),
610 )],
611 );
612 let t = &tracks[0];
613 assert_eq!(t.custom_object_id, Some(35));
614 assert_eq!(t.scene_id, 10);
615 assert!(matches!(t.voting_type, VotingType::Positional));
616 assert!(matches!(t.epoch, 9));
617 assert_eq!(t.id, other_track_id);
618 let attrs = {
619 let lock = tracker.store.read().unwrap();
620 let store = lock.get_store(t.id as usize);
621 let track = store.get(&t.id).unwrap();
622 track.get_attributes().clone()
623 };
624 assert_eq!(attrs.visual_features_collected_count, 2);
625 assert_eq!(attrs.track_length, 2);
626 assert_eq!(attrs.observed_boxes.len(), 2);
627 assert_eq!(attrs.predicted_boxes.len(), 2);
628 assert_eq!(attrs.observed_features.len(), 2);
629
630 let tracks = tracker.predict_with_scene(
633 10,
634 &[VisualSortObservation::new(
635 Some(&vec![0.12, 0.14]),
636 Some(0.87),
637 BoundingBox::new(10.1, 10.1, 3.0, 5.0).as_xyaah(),
638 Some(31),
639 )],
640 );
641 let t = &tracks[0];
642 assert_eq!(t.custom_object_id, Some(31));
643 assert_eq!(t.scene_id, 10);
644 assert!(matches!(t.voting_type, VotingType::Visual));
645 assert!(matches!(t.epoch, 10));
646 assert_eq!(t.id, other_track_id);
647 let attrs = {
648 let lock = tracker.store.read().unwrap();
649 let store = lock.get_store(t.id as usize);
650 let track = store.get(&t.id).unwrap();
651 track.get_attributes().clone()
652 };
653 assert_eq!(attrs.visual_features_collected_count, 3);
654 assert_eq!(attrs.track_length, 3);
655 assert_eq!(attrs.observed_boxes.len(), 3);
656 assert_eq!(attrs.predicted_boxes.len(), 3);
657 assert_eq!(attrs.observed_features.len(), 3);
658
659 tracker.skip_epochs_for_scene(10, 5);
660 let tracks = tracker
661 .wasted()
662 .into_iter()
663 .map(WastedVisualSortTrack::from)
664 .collect::<Vec<_>>();
665 dbg!(&tracks);
666 }
667}
668
669#[cfg(feature = "python")]
670pub mod python {
671 use pyo3::prelude::*;
672
673 use crate::{
674 prelude::VisualSortObservation,
675 trackers::{
676 sort::python::PySortTrack,
677 tracker_api::TrackerAPI,
678 visual_sort::{
679 options::python::PyVisualSortOptions,
680 python::{PyVisualSortObservationSet, PyWastedVisualSortTrack},
681 WastedVisualSortTrack,
682 },
683 },
684 };
685
686 use super::VisualSort;
687
688 #[pyclass]
689 #[pyo3(name = "VisualSort")]
690 pub struct PyVisualSort(pub(crate) VisualSort);
691
692 #[pymethods]
693 impl PyVisualSort {
694 #[new]
695 pub fn new(shards: i64, opts: &PyVisualSortOptions) -> Self {
696 assert!(shards > 0);
697 Self(VisualSort::new(shards.try_into().unwrap(), &opts.0))
698 }
699
700 #[pyo3(signature = (n))]
701 pub fn skip_epochs(&mut self, n: i64) {
702 assert!(n > 0);
703 self.0.skip_epochs(n.try_into().unwrap())
704 }
705
706 #[pyo3(signature = (scene_id, n))]
707 pub fn skip_epochs_for_scene(&mut self, scene_id: i64, n: i64) {
708 assert!(n > 0 && scene_id >= 0);
709 self.0
710 .skip_epochs_for_scene(scene_id.try_into().unwrap(), n.try_into().unwrap())
711 }
712
713 #[pyo3(signature = ())]
716 pub fn shard_stats(&self) -> Vec<i64> {
717 Python::with_gil(|py| {
718 py.allow_threads(|| {
719 self.0
720 .active_shard_stats()
721 .into_iter()
722 .map(|e| i64::try_from(e).unwrap())
723 .collect()
724 })
725 })
726 }
727
728 #[pyo3(signature = ())]
731 pub fn current_epoch(&self) -> i64 {
732 self.0.current_epoch_with_scene(0).try_into().unwrap()
733 }
734
735 #[pyo3(signature = (scene_id))]
741 pub fn current_epoch_with_scene(&self, scene_id: i64) -> isize {
742 assert!(scene_id >= 0);
743 self.0
744 .current_epoch_with_scene(scene_id.try_into().unwrap())
745 .try_into()
746 .unwrap()
747 }
748
749 #[pyo3(signature = (observation_set))]
755 pub fn predict(
756 &mut self,
757 observation_set: &PyVisualSortObservationSet,
758 ) -> Vec<PySortTrack> {
759 unsafe { std::mem::transmute(self.0.predict_with_scene(0, &observation_set.0.inner)) }
760 }
761
762 #[pyo3(signature = (scene_id, observation_set))]
769 pub fn predict_with_scene(
770 &mut self,
771 scene_id: i64,
772 observation_set: &PyVisualSortObservationSet,
773 ) -> Vec<PySortTrack> {
774 assert!(scene_id >= 0);
775 let observations = observation_set
776 .0
777 .inner
778 .iter()
779 .map(|e| {
780 VisualSortObservation::new(
781 e.feature.as_deref(),
782 e.feature_quality,
783 e.bounding_box.clone(),
784 e.custom_object_id,
785 )
786 })
787 .collect::<Vec<_>>();
788
789 Python::with_gil(|py| {
790 py.allow_threads(|| unsafe {
791 std::mem::transmute(
792 self.0
793 .predict_with_scene(scene_id.try_into().unwrap(), &observations),
794 )
795 })
796 })
797 }
798
799 #[pyo3(signature = ())]
802 pub fn wasted(&mut self) -> Vec<PyWastedVisualSortTrack> {
803 Python::with_gil(|py| {
804 py.allow_threads(|| {
805 self.0
806 .wasted()
807 .into_iter()
808 .map(WastedVisualSortTrack::from)
809 .map(PyWastedVisualSortTrack)
810 .collect()
811 })
812 })
813 }
814
815 #[pyo3(signature = ())]
818 pub fn clear_wasted(&mut self) {
819 Python::with_gil(|py| py.allow_threads(|| self.0.clear_wasted()));
820 }
821
822 #[pyo3(signature = ())]
825 pub fn idle_tracks(&mut self) -> Vec<PySortTrack> {
826 unsafe { std::mem::transmute(self.0.idle_tracks_with_scene(0)) }
827 }
828
829 #[pyo3(signature = (scene_id))]
832 pub fn idle_tracks_with_scene_py(&mut self, scene_id: i64) -> Vec<PySortTrack> {
833 Python::with_gil(|py| {
834 py.allow_threads(|| unsafe {
835 std::mem::transmute(self.0.idle_tracks_with_scene(scene_id.try_into().unwrap()))
836 })
837 })
838 }
839 }
840}