1use yscv_detect::{BoundingBox, Detection, iou};
2
3use crate::EvalError;
4use crate::util::{harmonic_mean, safe_ratio, validate_iou_threshold, validate_score_threshold};
5
6#[derive(Debug, Clone, Copy, PartialEq)]
7pub struct LabeledBox {
8 pub bbox: BoundingBox,
9 pub class_id: usize,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct DetectionFrame<'a> {
14 pub ground_truth: &'a [LabeledBox],
15 pub predictions: &'a [Detection],
16}
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub struct DetectionEvalConfig {
20 pub iou_threshold: f32,
21 pub score_threshold: f32,
22}
23
24impl Default for DetectionEvalConfig {
25 fn default() -> Self {
26 Self {
27 iou_threshold: 0.5,
28 score_threshold: 0.0,
29 }
30 }
31}
32
33impl DetectionEvalConfig {
34 pub fn validate(&self) -> Result<(), EvalError> {
35 validate_iou_threshold(self.iou_threshold)?;
36 validate_score_threshold(self.score_threshold)?;
37 Ok(())
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub struct DetectionMetrics {
43 pub true_positives: u64,
44 pub false_positives: u64,
45 pub false_negatives: u64,
46 pub precision: f32,
47 pub recall: f32,
48 pub f1: f32,
49 pub average_precision: f32,
50}
51
52#[derive(Debug, Clone, PartialEq)]
53pub struct DetectionDatasetFrame {
54 pub ground_truth: Vec<LabeledBox>,
55 pub predictions: Vec<Detection>,
56}
57
58impl DetectionDatasetFrame {
59 pub fn as_view(&self) -> DetectionFrame<'_> {
60 DetectionFrame {
61 ground_truth: &self.ground_truth,
62 predictions: &self.predictions,
63 }
64 }
65}
66
67pub fn detection_frames_as_view(frames: &[DetectionDatasetFrame]) -> Vec<DetectionFrame<'_>> {
68 frames.iter().map(DetectionDatasetFrame::as_view).collect()
69}
70
71pub fn evaluate_detections_from_dataset(
72 frames: &[DetectionDatasetFrame],
73 config: DetectionEvalConfig,
74) -> Result<DetectionMetrics, EvalError> {
75 let borrowed = detection_frames_as_view(frames);
76 evaluate_detections(&borrowed, config)
77}
78
79pub fn evaluate_detections(
80 frames: &[DetectionFrame<'_>],
81 config: DetectionEvalConfig,
82) -> Result<DetectionMetrics, EvalError> {
83 config.validate()?;
84
85 let mut true_positives = 0u64;
86 let mut false_positives = 0u64;
87 let mut false_negatives = 0u64;
88
89 for frame in frames {
90 let mut predictions: Vec<Detection> = frame
91 .predictions
92 .iter()
93 .copied()
94 .filter(|prediction| prediction.score >= config.score_threshold)
95 .collect();
96 predictions.sort_by(|a, b| b.score.total_cmp(&a.score));
97
98 let mut gt_taken = vec![false; frame.ground_truth.len()];
99 for prediction in predictions {
100 if let Some(best_gt_idx) = best_gt_match(
101 prediction,
102 frame.ground_truth,
103 >_taken,
104 config.iou_threshold,
105 ) {
106 gt_taken[best_gt_idx] = true;
107 true_positives += 1;
108 } else {
109 false_positives += 1;
110 }
111 }
112
113 false_negatives += gt_taken.iter().filter(|matched| !**matched).count() as u64;
114 }
115
116 let precision = safe_ratio(true_positives, true_positives + false_positives);
117 let recall = safe_ratio(true_positives, true_positives + false_negatives);
118 let f1 = harmonic_mean(precision, recall);
119 let average_precision = average_precision(frames, config);
120
121 Ok(DetectionMetrics {
122 true_positives,
123 false_positives,
124 false_negatives,
125 precision,
126 recall,
127 f1,
128 average_precision,
129 })
130}
131
132fn best_gt_match(
133 prediction: Detection,
134 ground_truth: &[LabeledBox],
135 gt_taken: &[bool],
136 iou_threshold: f32,
137) -> Option<usize> {
138 let mut best_iou = iou_threshold;
139 let mut best_idx = None;
140
141 for (idx, gt) in ground_truth.iter().enumerate() {
142 if gt_taken[idx] || gt.class_id != prediction.class_id {
143 continue;
144 }
145 let overlap = iou(gt.bbox, prediction.bbox);
146 if overlap >= best_iou {
147 best_iou = overlap;
148 best_idx = Some(idx);
149 }
150 }
151 best_idx
152}
153
154fn average_precision(frames: &[DetectionFrame<'_>], config: DetectionEvalConfig) -> f32 {
155 let total_ground_truth = frames
156 .iter()
157 .map(|frame| frame.ground_truth.len() as u64)
158 .sum::<u64>();
159 if total_ground_truth == 0 {
160 return 0.0;
161 }
162
163 let mut ranked_predictions = Vec::new();
164 for (frame_idx, frame) in frames.iter().enumerate() {
165 for prediction in frame.predictions {
166 if prediction.score >= config.score_threshold {
167 ranked_predictions.push((frame_idx, *prediction));
168 }
169 }
170 }
171 ranked_predictions.sort_by(|a, b| b.1.score.total_cmp(&a.1.score));
172
173 if ranked_predictions.is_empty() {
174 return 0.0;
175 }
176
177 let mut gt_taken: Vec<Vec<bool>> = frames
178 .iter()
179 .map(|frame| vec![false; frame.ground_truth.len()])
180 .collect();
181 let mut precisions = Vec::with_capacity(ranked_predictions.len());
182 let mut recalls = Vec::with_capacity(ranked_predictions.len());
183
184 let mut true_positives = 0u64;
185 let mut false_positives = 0u64;
186
187 for (frame_idx, prediction) in ranked_predictions {
188 if let Some(best_gt_idx) = best_gt_match(
189 prediction,
190 frames[frame_idx].ground_truth,
191 >_taken[frame_idx],
192 config.iou_threshold,
193 ) {
194 gt_taken[frame_idx][best_gt_idx] = true;
195 true_positives += 1;
196 } else {
197 false_positives += 1;
198 }
199
200 precisions.push(safe_ratio(true_positives, true_positives + false_positives));
201 recalls.push(safe_ratio(true_positives, total_ground_truth));
202 }
203
204 let mut monotonic_precisions = Vec::with_capacity(precisions.len() + 2);
205 let mut padded_recalls = Vec::with_capacity(recalls.len() + 2);
206
207 padded_recalls.push(0.0);
208 padded_recalls.extend(recalls.iter().copied());
209 padded_recalls.push(1.0);
210
211 monotonic_precisions.push(0.0);
212 monotonic_precisions.extend(precisions.iter().copied());
213 monotonic_precisions.push(0.0);
214
215 for idx in (0..monotonic_precisions.len() - 1).rev() {
216 monotonic_precisions[idx] = monotonic_precisions[idx].max(monotonic_precisions[idx + 1]);
217 }
218
219 let mut ap = 0.0f32;
220 for idx in 0..padded_recalls.len() - 1 {
221 let recall_delta = padded_recalls[idx + 1] - padded_recalls[idx];
222 if recall_delta > 0.0 {
223 ap += recall_delta * monotonic_precisions[idx + 1];
224 }
225 }
226 ap.clamp(0.0, 1.0)
227}
228
229const COCO_IOU_THRESHOLDS: [f32; 10] = [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95];
234
235const SMALL_AREA_MAX: f32 = 32.0 * 32.0;
236const MEDIUM_AREA_MAX: f32 = 96.0 * 96.0;
237
238#[derive(Debug, Clone, Copy, PartialEq)]
239pub struct CocoMetrics {
240 pub ap: f32,
242 pub ap50: f32,
244 pub ap75: f32,
246 pub ap_small: f32,
248 pub ap_medium: f32,
250 pub ap_large: f32,
252 pub ar: f32,
254}
255
256fn box_area(b: &BoundingBox) -> f32 {
257 (b.x2 - b.x1) * (b.y2 - b.y1)
258}
259
260fn filter_gt_by<F>(frames: &[DetectionFrame<'_>], pred: F) -> Vec<DetectionDatasetFrame>
263where
264 F: Fn(&LabeledBox) -> bool,
265{
266 frames
267 .iter()
268 .map(|frame| {
269 let ground_truth: Vec<LabeledBox> = frame
270 .ground_truth
271 .iter()
272 .filter(|lb| pred(lb))
273 .copied()
274 .collect();
275 DetectionDatasetFrame {
276 ground_truth,
277 predictions: frame.predictions.to_vec(),
278 }
279 })
280 .collect()
281}
282
283pub fn evaluate_detections_coco(
285 frames: &[DetectionFrame<'_>],
286 score_threshold: f32,
287) -> Result<CocoMetrics, EvalError> {
288 validate_score_threshold(score_threshold)?;
289
290 let mut aps = [0.0f32; 10];
292 let mut recalls = [0.0f32; 10];
293
294 for (i, &iou_thresh) in COCO_IOU_THRESHOLDS.iter().enumerate() {
295 let config = DetectionEvalConfig {
296 iou_threshold: iou_thresh,
297 score_threshold,
298 };
299 let m = evaluate_detections(frames, config)?;
300 aps[i] = m.average_precision;
301 recalls[i] = m.recall;
302 }
303
304 let ap = aps.iter().sum::<f32>() / aps.len() as f32;
305 let ap50 = aps[0]; let ap75 = aps[5]; let ar = recalls.iter().sum::<f32>() / recalls.len() as f32;
308
309 let ap_small = size_ap(frames, score_threshold, |a| a < SMALL_AREA_MAX)?;
311 let ap_medium = size_ap(frames, score_threshold, |a| {
312 (SMALL_AREA_MAX..MEDIUM_AREA_MAX).contains(&a)
313 })?;
314 let ap_large = size_ap(frames, score_threshold, |a| a >= MEDIUM_AREA_MAX)?;
315
316 Ok(CocoMetrics {
317 ap,
318 ap50,
319 ap75,
320 ap_small,
321 ap_medium,
322 ap_large,
323 ar,
324 })
325}
326
327fn size_ap<F>(
328 frames: &[DetectionFrame<'_>],
329 score_threshold: f32,
330 area_filter: F,
331) -> Result<f32, EvalError>
332where
333 F: Fn(f32) -> bool,
334{
335 let owned = filter_gt_by(frames, |lb| area_filter(box_area(&lb.bbox)));
336 let views = detection_frames_as_view(&owned);
337
338 let mut sum = 0.0f32;
339 for &iou_thresh in &COCO_IOU_THRESHOLDS {
340 let config = DetectionEvalConfig {
341 iou_threshold: iou_thresh,
342 score_threshold,
343 };
344 let m = evaluate_detections(&views, config)?;
345 sum += m.average_precision;
346 }
347 Ok(sum / COCO_IOU_THRESHOLDS.len() as f32)
348}