cecile_supercool_tracker/trackers/
visual_sort.rs

1use std::borrow::Cow;
2
3use crate::{
4    track::{utils::FromVec, Track},
5    utils::bbox::Universal2DBox,
6};
7
8use self::{
9    metric::VisualMetric, observation_attributes::VisualObservationAttributes,
10    track_attributes::VisualAttributes,
11};
12
13/// Track metric implementation
14pub mod metric;
15
16/// Cascade voting engine for visual_sort tracker. Combines TopN voting first for features and
17/// Hungarian voting for the rest of unmatched (objects, tracks)
18pub mod voting;
19
20/// Track attributes for visual_sort tracker
21pub mod track_attributes;
22
23/// Observation attributes for visual_sort tracker
24pub mod observation_attributes;
25
26/// Implementation of Visual tracker with simple API
27pub mod simple_api;
28
29/// Batched API that accepts the batch with multiple scenes at once
30pub mod batch_api;
31/// Options object to configure the tracker
32pub mod options;
33
34#[derive(Debug, Clone)]
35pub struct VisualSortObservation<'a> {
36    feature: Option<Cow<'a, [f32]>>,
37    feature_quality: Option<f32>,
38    bounding_box: Universal2DBox,
39    custom_object_id: Option<i64>,
40}
41
42impl<'a> VisualSortObservation<'a> {
43    pub fn new(
44        feature: Option<&'a [f32]>,
45        feature_quality: Option<f32>,
46        bounding_box: Universal2DBox,
47        custom_object_id: Option<i64>,
48    ) -> Self {
49        Self {
50            feature: feature.map(Cow::Borrowed),
51            feature_quality,
52            bounding_box,
53            custom_object_id,
54        }
55    }
56}
57
58#[derive(Debug, Clone)]
59pub struct VisualSortObservationSet<'a> {
60    pub inner: Vec<VisualSortObservation<'a>>,
61}
62
63impl<'a> VisualSortObservationSet<'a> {
64    pub fn new() -> Self {
65        Self {
66            inner: Vec::default(),
67        }
68    }
69
70    pub fn add(&mut self, observation: VisualSortObservation<'a>) {
71        self.inner.push(observation);
72    }
73}
74
75impl<'a> Default for VisualSortObservationSet<'a> {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81/// Online track structure that contains tracking information for the last tracker epoch
82///
83#[derive(Debug, Clone)]
84pub struct WastedVisualSortTrack {
85    /// id of the track
86    ///
87    pub id: u64,
88
89    /// when the track was lastly updated
90    ///
91    pub epoch: usize,
92
93    /// the bbox predicted by KF
94    ///
95    pub predicted_bbox: Universal2DBox,
96
97    /// the bbox passed by detector
98    ///
99    pub observed_bbox: Universal2DBox,
100
101    /// user-defined scene id that splits tracking space on isolated realms
102    ///
103    pub scene_id: u64,
104
105    /// current track length
106    ///
107    pub length: usize,
108
109    /// history of predicted boxes
110    ///
111    pub predicted_boxes: Vec<Universal2DBox>,
112
113    /// history of observed boxes
114    ///
115    pub observed_boxes: Vec<Universal2DBox>,
116
117    /// history of features
118    ///
119    pub observed_features: Vec<Option<Vec<f32>>>,
120}
121
122impl From<Track<VisualAttributes, VisualMetric, VisualObservationAttributes>>
123    for WastedVisualSortTrack
124{
125    fn from(track: Track<VisualAttributes, VisualMetric, VisualObservationAttributes>) -> Self {
126        let attrs = track.get_attributes();
127        WastedVisualSortTrack {
128            id: track.get_track_id(),
129            epoch: attrs.last_updated_epoch,
130            scene_id: attrs.scene_id,
131            length: attrs.track_length,
132            observed_bbox: attrs.observed_boxes.back().unwrap().clone(),
133            predicted_bbox: attrs.predicted_boxes.back().unwrap().clone(),
134            predicted_boxes: attrs.predicted_boxes.clone().into_iter().collect(),
135            observed_boxes: attrs.observed_boxes.clone().into_iter().collect(),
136            observed_features: attrs
137                .observed_features
138                .clone()
139                .iter()
140                .map(|f_opt| f_opt.as_ref().map(Vec::from_vec))
141                .collect(),
142        }
143    }
144}
145
146#[cfg(feature = "python")]
147pub mod python {
148    use super::{VisualSortObservation, VisualSortObservationSet, WastedVisualSortTrack};
149    use crate::utils::bbox::python::PyUniversal2DBox;
150    use pyo3::prelude::*;
151    use std::borrow::Cow;
152
153    #[pyclass]
154    #[pyo3(name = "WastedVisualSortTrack")]
155    pub struct PyWastedVisualSortTrack(pub(crate) WastedVisualSortTrack);
156
157    #[pymethods]
158    impl PyWastedVisualSortTrack {
159        #[classattr]
160        const __hash__: Option<Py<PyAny>> = None;
161
162        fn __repr__(&self) -> String {
163            format!("{:?}", self.0)
164        }
165
166        fn __str__(&self) -> String {
167            format!("{:#?}", self.0)
168        }
169
170        #[getter]
171        fn id(&self) -> u64 {
172            self.0.id
173        }
174
175        #[getter]
176        fn epoch(&self) -> usize {
177            self.0.epoch
178        }
179
180        #[getter]
181        fn predicted_bbox(&self) -> PyUniversal2DBox {
182            PyUniversal2DBox(self.0.predicted_bbox.clone())
183        }
184
185        #[getter]
186        fn observed_bbox(&self) -> PyUniversal2DBox {
187            PyUniversal2DBox(self.0.observed_bbox.clone())
188        }
189
190        #[getter]
191        fn scene_id(&self) -> u64 {
192            self.0.scene_id
193        }
194
195        #[getter]
196        fn length(&self) -> usize {
197            self.0.length
198        }
199
200        #[getter]
201        fn predicted_boxes(&self) -> Vec<PyUniversal2DBox> {
202            unsafe { std::mem::transmute(self.0.predicted_boxes.clone()) }
203        }
204
205        #[getter]
206        fn observed_boxes(&self) -> Vec<PyUniversal2DBox> {
207            unsafe { std::mem::transmute(self.0.observed_boxes.clone()) }
208        }
209
210        #[getter]
211        fn observed_features(&self) -> Vec<Option<Vec<f32>>> {
212            self.0.observed_features.clone()
213        }
214    }
215
216    #[pyclass]
217    #[derive(Debug, Clone)]
218    #[pyo3(name = "VisualSortObservation")]
219    pub struct PyVisualSortObservation(pub(crate) VisualSortObservation<'static>);
220
221    #[pymethods]
222    impl PyVisualSortObservation {
223        #[new]
224        #[pyo3(signature = (feature, feature_quality, bounding_box, custom_object_id))]
225        pub fn new(
226            feature: Option<Vec<f32>>,
227            feature_quality: Option<f32>,
228            bounding_box: PyUniversal2DBox,
229            custom_object_id: Option<i64>,
230        ) -> Self {
231            Self(VisualSortObservation {
232                feature: feature.map(Cow::Owned),
233                feature_quality,
234                bounding_box: bounding_box.0,
235                custom_object_id,
236            })
237        }
238
239        #[classattr]
240        const __hash__: Option<Py<PyAny>> = None;
241
242        fn __repr__(&self) -> String {
243            format!("{self:?}")
244        }
245
246        fn __str__(&self) -> String {
247            format!("{self:#?}")
248        }
249    }
250
251    #[pyclass]
252    #[derive(Debug)]
253    #[pyo3(name = "VisualSortObservationSet")]
254    pub struct PyVisualSortObservationSet(pub(crate) VisualSortObservationSet<'static>);
255
256    #[pymethods]
257    impl PyVisualSortObservationSet {
258        #[new]
259        fn new() -> Self {
260            Self(VisualSortObservationSet::new())
261        }
262
263        #[pyo3(text_signature = "($self, observation)")]
264        fn add(&mut self, observation: PyVisualSortObservation) {
265            self.0.add(observation.0);
266        }
267
268        #[classattr]
269        const __hash__: Option<Py<PyAny>> = None;
270
271        fn __repr__(&self) -> String {
272            format!("{self:?}")
273        }
274
275        fn __str__(&self) -> String {
276            format!("{self:#?}")
277        }
278    }
279}