cecile_supercool_tracker/trackers/sort/
metric.rs1use crate::track::{
2 MetricOutput, MetricQuery, Observation, ObservationAttributes, ObservationMetric,
3 ObservationMetricOk,
4};
5use crate::trackers::kalman_prediction::TrackAttributesKalmanPrediction;
6use crate::trackers::sort::PositionalMetricType;
7use crate::trackers::sort::{SortAttributes, DEFAULT_SORT_IOU_THRESHOLD};
8use crate::utils::bbox::Universal2DBox;
9use crate::utils::kalman::kalman_2d_box::Universal2DBoxKalmanFilter;
10
11pub const DEFAULT_MINIMAL_SORT_CONFIDENCE: f32 = 0.05;
12
13#[derive(Clone)]
14pub struct SortMetric {
15 method: PositionalMetricType,
16 min_confidence: f32,
17}
18
19impl Default for SortMetric {
20 fn default() -> Self {
21 Self::new(
22 PositionalMetricType::IoU(DEFAULT_SORT_IOU_THRESHOLD),
23 DEFAULT_MINIMAL_SORT_CONFIDENCE,
24 )
25 }
26}
27
28impl SortMetric {
29 pub fn new(method: PositionalMetricType, min_confidence: f32) -> Self {
30 Self {
31 method,
32 min_confidence,
33 }
34 }
35}
36
37impl ObservationMetric<SortAttributes, Universal2DBox> for SortMetric {
38 fn metric(&self, mq: &MetricQuery<SortAttributes, Universal2DBox>) -> MetricOutput<f32> {
39 let (candidate_bbox, track_bbox) = (
40 mq.candidate_observation.attr().as_ref().unwrap(),
41 mq.track_observation.attr().as_ref().unwrap(),
42 );
43 let conf = if candidate_bbox.confidence < self.min_confidence {
44 self.min_confidence
45 } else {
46 candidate_bbox.confidence
47 };
48
49 if Universal2DBox::too_far(candidate_bbox, track_bbox) {
50 None
51 } else {
52 Some(match self.method {
53 PositionalMetricType::Mahalanobis => {
54 let state = mq.track_attrs.get_state().unwrap();
55 let f = Universal2DBoxKalmanFilter::new(
56 mq.track_attrs.get_position_weight(),
57 mq.track_attrs.get_velocity_weight(),
58 );
59 let dist = f.distance(state, candidate_bbox);
60 (
61 Some(Universal2DBoxKalmanFilter::calculate_cost(dist, true) / conf),
62 None,
63 )
64 }
65 PositionalMetricType::IoU(threshold) => {
66 let box_m_opt = Universal2DBox::calculate_metric_object(
67 &Some(candidate_bbox),
68 &Some(track_bbox),
69 );
70 (
71 box_m_opt.map(|e| e * conf).filter(|e| *e >= threshold),
72 None,
73 )
74 }
75 })
76 }
77 }
78
79 fn optimize(
80 &mut self,
81 _feature_class: u64,
82 _merge_history: &[u64],
83 attrs: &mut SortAttributes,
84 features: &mut Vec<Observation<Universal2DBox>>,
85 _prev_length: usize,
86 _is_merge: bool,
87 ) -> anyhow::Result<()> {
88 let mut observation = features.pop().unwrap();
89 let observation_bbox = observation.attr().as_ref().unwrap();
90 features.clear();
91
92 let mut predicted_bbox = attrs.make_prediction(observation_bbox);
93 attrs.update_history(observation_bbox, &predicted_bbox);
94
95 *observation.attr_mut() = Some(match self.method {
96 PositionalMetricType::Mahalanobis => predicted_bbox,
97 PositionalMetricType::IoU(_) => {
98 predicted_bbox.gen_vertices();
99 predicted_bbox
100 }
101 });
102
103 features.push(observation);
104 Ok(())
105 }
106
107 fn postprocess_distances(
108 &self,
109 unfiltered: Vec<ObservationMetricOk<Universal2DBox>>,
110 ) -> Vec<ObservationMetricOk<Universal2DBox>> {
111 unfiltered
112 .into_iter()
113 .filter(|res| res.attribute_metric.is_some())
114 .collect()
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use crate::prelude::{BoundingBox, PositionalMetricType};
121 use crate::track::{MetricQuery, Observation, ObservationMetric};
122 use crate::trackers::sort::metric::{SortMetric, DEFAULT_MINIMAL_SORT_CONFIDENCE};
123 use crate::trackers::sort::{
124 SortAttributes, SortAttributesOptions, DEFAULT_SORT_IOU_THRESHOLD,
125 };
126 use crate::trackers::spatio_temporal_constraints::SpatioTemporalConstraints;
127 use crate::EPS;
128 use std::sync::Arc;
129
130 #[test]
131 fn confidence_preserved_during_optimization() {
132 let mut attrs = SortAttributes::new(Arc::new(SortAttributesOptions::new(
133 None,
134 0,
135 5,
136 SpatioTemporalConstraints::default(),
137 1.0 / 20.0,
138 1.0 / 160.0,
139 )));
140
141 let mut metric = SortMetric::new(
142 PositionalMetricType::IoU(DEFAULT_SORT_IOU_THRESHOLD),
143 DEFAULT_MINIMAL_SORT_CONFIDENCE,
144 );
145
146 let mut obs = vec![Observation::new(
147 Some(BoundingBox::new_with_confidence(0.0, 0.0, 8.0, 10.0, 0.8).as_xyaah()),
148 None,
149 )];
150
151 metric
152 .optimize(0, &[], &mut attrs, &mut obs, 0, true)
153 .unwrap();
154
155 assert_eq!(
156 obs[0].0.as_ref().unwrap().confidence,
157 0.8,
158 "Confidence must be preserved during optimization"
159 );
160 }
161
162 #[test]
163 fn confidence_used_in_distance_calculation() {
164 let attr_opts = Arc::new(SortAttributesOptions::new(
165 None,
166 0,
167 5,
168 SpatioTemporalConstraints::default(),
169 1.0 / 20.0,
170 1.0 / 160.0,
171 ));
172
173 let candidate_attrs = SortAttributes::new(attr_opts.clone());
174 let track_attrs = SortAttributes::new(attr_opts.clone());
175
176 let metric = SortMetric::new(
177 PositionalMetricType::IoU(DEFAULT_SORT_IOU_THRESHOLD),
178 DEFAULT_MINIMAL_SORT_CONFIDENCE,
179 );
180
181 let candidate_obs = Observation::new(
182 Some(BoundingBox::new_with_confidence(0.0, 0.0, 8.0, 10.0, 0.8).as_xyaah()),
183 None,
184 );
185
186 let track_obs = Observation::new(
187 Some(BoundingBox::new_with_confidence(0.0, 0.0, 8.0, 10.0, 1.0).as_xyaah()),
188 None,
189 );
190
191 let mq = MetricQuery {
192 feature_class: 0,
193 candidate_attrs: &candidate_attrs,
194 candidate_observation: &candidate_obs,
195 track_attrs: &track_attrs,
196 track_observation: &track_obs,
197 };
198
199 let res = metric.metric(&mq);
200 assert!(
201 (res.unwrap().0.unwrap() - 0.8).abs() < EPS,
202 "Confidence value in candidate box must be used."
203 );
204
205 let mq = MetricQuery {
206 feature_class: 0,
207 candidate_attrs: &track_attrs,
208 candidate_observation: &track_obs,
209 track_attrs: &candidate_attrs,
210 track_observation: &candidate_obs,
211 };
212
213 let res = metric.metric(&mq);
214 assert!(
215 (res.unwrap().0.unwrap() - 1.0).abs() < EPS,
216 "Confidence in track box must NOT be used."
217 );
218 }
219}