cecile_supercool_tracker/trackers/sort/
metric.rs

1use crate::track::{
2    MetricOutput, MetricQuery, Observation, ObservationAttributes, ObservationMetric,
3    ObservationMetricOk,
4};
5use crate::trackers::kalman_prediction::TrackAttributesKalmanPrediction;
6use crate::trackers::sort::PositionalMetricType;
7use crate::trackers::sort::{SortAttributes, DEFAULT_SORT_IOU_THRESHOLD};
8use crate::utils::bbox::Universal2DBox;
9use crate::utils::kalman::kalman_2d_box::Universal2DBoxKalmanFilter;
10
11pub const DEFAULT_MINIMAL_SORT_CONFIDENCE: f32 = 0.05;
12
13#[derive(Clone)]
14pub struct SortMetric {
15    method: PositionalMetricType,
16    min_confidence: f32,
17}
18
19impl Default for SortMetric {
20    fn default() -> Self {
21        Self::new(
22            PositionalMetricType::IoU(DEFAULT_SORT_IOU_THRESHOLD),
23            DEFAULT_MINIMAL_SORT_CONFIDENCE,
24        )
25    }
26}
27
28impl SortMetric {
29    pub fn new(method: PositionalMetricType, min_confidence: f32) -> Self {
30        Self {
31            method,
32            min_confidence,
33        }
34    }
35}
36
37impl ObservationMetric<SortAttributes, Universal2DBox> for SortMetric {
38    fn metric(&self, mq: &MetricQuery<SortAttributes, Universal2DBox>) -> MetricOutput<f32> {
39        let (candidate_bbox, track_bbox) = (
40            mq.candidate_observation.attr().as_ref().unwrap(),
41            mq.track_observation.attr().as_ref().unwrap(),
42        );
43        let conf = if candidate_bbox.confidence < self.min_confidence {
44            self.min_confidence
45        } else {
46            candidate_bbox.confidence
47        };
48
49        if Universal2DBox::too_far(candidate_bbox, track_bbox) {
50            None
51        } else {
52            Some(match self.method {
53                PositionalMetricType::Mahalanobis => {
54                    let state = mq.track_attrs.get_state().unwrap();
55                    let f = Universal2DBoxKalmanFilter::new(
56                        mq.track_attrs.get_position_weight(),
57                        mq.track_attrs.get_velocity_weight(),
58                    );
59                    let dist = f.distance(state, candidate_bbox);
60                    (
61                        Some(Universal2DBoxKalmanFilter::calculate_cost(dist, true) / conf),
62                        None,
63                    )
64                }
65                PositionalMetricType::IoU(threshold) => {
66                    let box_m_opt = Universal2DBox::calculate_metric_object(
67                        &Some(candidate_bbox),
68                        &Some(track_bbox),
69                    );
70                    (
71                        box_m_opt.map(|e| e * conf).filter(|e| *e >= threshold),
72                        None,
73                    )
74                }
75            })
76        }
77    }
78
79    fn optimize(
80        &mut self,
81        _feature_class: u64,
82        _merge_history: &[u64],
83        attrs: &mut SortAttributes,
84        features: &mut Vec<Observation<Universal2DBox>>,
85        _prev_length: usize,
86        _is_merge: bool,
87    ) -> anyhow::Result<()> {
88        let mut observation = features.pop().unwrap();
89        let observation_bbox = observation.attr().as_ref().unwrap();
90        features.clear();
91
92        let mut predicted_bbox = attrs.make_prediction(observation_bbox);
93        attrs.update_history(observation_bbox, &predicted_bbox);
94
95        *observation.attr_mut() = Some(match self.method {
96            PositionalMetricType::Mahalanobis => predicted_bbox,
97            PositionalMetricType::IoU(_) => {
98                predicted_bbox.gen_vertices();
99                predicted_bbox
100            }
101        });
102
103        features.push(observation);
104        Ok(())
105    }
106
107    fn postprocess_distances(
108        &self,
109        unfiltered: Vec<ObservationMetricOk<Universal2DBox>>,
110    ) -> Vec<ObservationMetricOk<Universal2DBox>> {
111        unfiltered
112            .into_iter()
113            .filter(|res| res.attribute_metric.is_some())
114            .collect()
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use crate::prelude::{BoundingBox, PositionalMetricType};
121    use crate::track::{MetricQuery, Observation, ObservationMetric};
122    use crate::trackers::sort::metric::{SortMetric, DEFAULT_MINIMAL_SORT_CONFIDENCE};
123    use crate::trackers::sort::{
124        SortAttributes, SortAttributesOptions, DEFAULT_SORT_IOU_THRESHOLD,
125    };
126    use crate::trackers::spatio_temporal_constraints::SpatioTemporalConstraints;
127    use crate::EPS;
128    use std::sync::Arc;
129
130    #[test]
131    fn confidence_preserved_during_optimization() {
132        let mut attrs = SortAttributes::new(Arc::new(SortAttributesOptions::new(
133            None,
134            0,
135            5,
136            SpatioTemporalConstraints::default(),
137            1.0 / 20.0,
138            1.0 / 160.0,
139        )));
140
141        let mut metric = SortMetric::new(
142            PositionalMetricType::IoU(DEFAULT_SORT_IOU_THRESHOLD),
143            DEFAULT_MINIMAL_SORT_CONFIDENCE,
144        );
145
146        let mut obs = vec![Observation::new(
147            Some(BoundingBox::new_with_confidence(0.0, 0.0, 8.0, 10.0, 0.8).as_xyaah()),
148            None,
149        )];
150
151        metric
152            .optimize(0, &[], &mut attrs, &mut obs, 0, true)
153            .unwrap();
154
155        assert_eq!(
156            obs[0].0.as_ref().unwrap().confidence,
157            0.8,
158            "Confidence must be preserved during optimization"
159        );
160    }
161
162    #[test]
163    fn confidence_used_in_distance_calculation() {
164        let attr_opts = Arc::new(SortAttributesOptions::new(
165            None,
166            0,
167            5,
168            SpatioTemporalConstraints::default(),
169            1.0 / 20.0,
170            1.0 / 160.0,
171        ));
172
173        let candidate_attrs = SortAttributes::new(attr_opts.clone());
174        let track_attrs = SortAttributes::new(attr_opts.clone());
175
176        let metric = SortMetric::new(
177            PositionalMetricType::IoU(DEFAULT_SORT_IOU_THRESHOLD),
178            DEFAULT_MINIMAL_SORT_CONFIDENCE,
179        );
180
181        let candidate_obs = Observation::new(
182            Some(BoundingBox::new_with_confidence(0.0, 0.0, 8.0, 10.0, 0.8).as_xyaah()),
183            None,
184        );
185
186        let track_obs = Observation::new(
187            Some(BoundingBox::new_with_confidence(0.0, 0.0, 8.0, 10.0, 1.0).as_xyaah()),
188            None,
189        );
190
191        let mq = MetricQuery {
192            feature_class: 0,
193            candidate_attrs: &candidate_attrs,
194            candidate_observation: &candidate_obs,
195            track_attrs: &track_attrs,
196            track_observation: &track_obs,
197        };
198
199        let res = metric.metric(&mq);
200        assert!(
201            (res.unwrap().0.unwrap() - 0.8).abs() < EPS,
202            "Confidence value in candidate box must be used."
203        );
204
205        let mq = MetricQuery {
206            feature_class: 0,
207            candidate_attrs: &track_attrs,
208            candidate_observation: &track_obs,
209            track_attrs: &candidate_attrs,
210            track_observation: &candidate_obs,
211        };
212
213        let res = metric.metric(&mq);
214        assert!(
215            (res.unwrap().0.unwrap() - 1.0).abs() < EPS,
216            "Confidence in track box must NOT be used."
217        );
218    }
219}