1use std::collections::{HashMap, HashSet};
12
13use crate::types::Dataset;
14use serde::Serialize;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
18pub enum Layer {
19 Structural,
20 Quality,
21 Distribution,
22 Compatibility,
23}
24
25#[derive(Debug, Clone, Serialize)]
27pub struct Finding {
28 pub code: &'static str,
30 pub message: String,
32 pub affected_ids: Vec<u64>,
35 pub layer: Layer,
37}
38
39#[derive(Debug, Clone, Serialize)]
41pub struct DatasetSummary {
42 pub num_images: usize,
43 pub num_annotations: usize,
44 pub num_categories: usize,
45 pub images_without_annotations: usize,
46 pub category_counts: Vec<(String, usize)>,
48 pub imbalance_ratio: f64,
50}
51
52#[derive(Debug, Clone, Serialize)]
54pub struct HealthReport {
55 pub errors: Vec<Finding>,
56 pub warnings: Vec<Finding>,
57 pub summary: DatasetSummary,
58}
59
60pub fn healthcheck(dataset: &Dataset) -> HealthReport {
65 let mut errors = Vec::new();
66 let mut warnings = Vec::new();
67
68 check_structural(dataset, &mut errors);
69 check_quality(dataset, &mut warnings);
70 let summary = build_summary(dataset, &mut warnings);
71
72 HealthReport {
73 errors,
74 warnings,
75 summary,
76 }
77}
78
79pub fn healthcheck_compatibility(gt: &Dataset, dt: &Dataset) -> HealthReport {
84 let mut report = healthcheck(gt);
85 check_compatibility(gt, dt, &mut report.errors, &mut report.warnings);
86 report
87}
88
89fn find_duplicate_ids<T, F>(items: &[T], id_fn: F) -> Vec<u64>
90where
91 F: Fn(&T) -> u64,
92{
93 let mut seen = HashSet::new();
94 items
95 .iter()
96 .filter_map(|item| {
97 let id = id_fn(item);
98 if seen.insert(id) { None } else { Some(id) }
99 })
100 .collect()
101}
102
103fn push_if_nonempty(
104 findings: &mut Vec<Finding>,
105 ids: Vec<u64>,
106 code: &'static str,
107 message: String,
108 layer: Layer,
109) {
110 if !ids.is_empty() {
111 findings.push(Finding {
112 code,
113 message,
114 affected_ids: ids,
115 layer,
116 });
117 }
118}
119
120fn check_structural(dataset: &Dataset, errors: &mut Vec<Finding>) {
121 let dup_img = find_duplicate_ids(&dataset.images, |i| i.id);
122 push_if_nonempty(
123 errors,
124 dup_img,
125 "duplicate_image_id",
126 "Duplicate image ID(s) found. Each image must have a unique ID.".into(),
127 Layer::Structural,
128 );
129
130 let dup_ann = find_duplicate_ids(&dataset.annotations, |a| a.id);
131 push_if_nonempty(
132 errors,
133 dup_ann,
134 "duplicate_ann_id",
135 "Duplicate annotation ID(s) found. Each annotation must have a unique ID.".into(),
136 Layer::Structural,
137 );
138
139 let dup_cat = find_duplicate_ids(&dataset.categories, |c| c.id);
140 push_if_nonempty(
141 errors,
142 dup_cat,
143 "duplicate_category_id",
144 "Duplicate category ID(s) found. Each category must have a unique ID.".into(),
145 Layer::Structural,
146 );
147
148 let image_ids: HashSet<u64> = dataset.images.iter().map(|img| img.id).collect();
149 let orphan_img: Vec<u64> = dataset
150 .annotations
151 .iter()
152 .filter(|ann| !image_ids.contains(&ann.image_id))
153 .map(|ann| ann.id)
154 .collect();
155 push_if_nonempty(
156 errors,
157 orphan_img,
158 "orphan_image_id",
159 "Annotation(s) reference image IDs not present in images.".into(),
160 Layer::Structural,
161 );
162
163 let cat_ids: HashSet<u64> = dataset.categories.iter().map(|c| c.id).collect();
164 let orphan_cat: Vec<u64> = dataset
165 .annotations
166 .iter()
167 .filter(|ann| !cat_ids.contains(&ann.category_id))
168 .map(|ann| ann.id)
169 .collect();
170 push_if_nonempty(
171 errors,
172 orphan_cat,
173 "orphan_category_id",
174 "Annotation(s) reference category IDs not present in categories.".into(),
175 Layer::Structural,
176 );
177
178 let missing_geom: Vec<u64> = dataset
179 .annotations
180 .iter()
181 .filter(|ann| ann.bbox.is_none() && ann.segmentation.is_none() && ann.keypoints.is_none())
182 .map(|ann| ann.id)
183 .collect();
184 push_if_nonempty(
185 errors,
186 missing_geom,
187 "missing_geometry",
188 "Annotation(s) have no bbox, segmentation, or keypoints.".into(),
189 Layer::Structural,
190 );
191
192 let zero_dim: Vec<u64> = dataset
193 .images
194 .iter()
195 .filter(|img| img.height == 0 || img.width == 0)
196 .map(|img| img.id)
197 .collect();
198 push_if_nonempty(
199 errors,
200 zero_dim,
201 "zero_dimensions",
202 "Image(s) have zero height or width.".into(),
203 Layer::Structural,
204 );
205}
206
207fn bbox_iou_pair(a: &[f64; 4], b: &[f64; 4]) -> f64 {
209 let ax2 = a[0] + a[2];
210 let ay2 = a[1] + a[3];
211 let bx2 = b[0] + b[2];
212 let by2 = b[1] + b[3];
213
214 let inter_x = (ax2.min(bx2) - a[0].max(b[0])).max(0.0);
215 let inter_y = (ay2.min(by2) - a[1].max(b[1])).max(0.0);
216 let inter = inter_x * inter_y;
217
218 let area_a = a[2] * a[3];
219 let area_b = b[2] * b[3];
220 let union = area_a + area_b - inter;
221
222 if union <= 0.0 { 0.0 } else { inter / union }
223}
224
225fn check_quality(dataset: &Dataset, warnings: &mut Vec<Finding>) {
226 let img_dims: HashMap<u64, (u32, u32)> = dataset
228 .images
229 .iter()
230 .map(|img| (img.id, (img.width, img.height)))
231 .collect();
232
233 let mut degenerate_ids = Vec::new();
234 let mut zero_area_ids = Vec::new();
235 let mut oob_ids = Vec::new();
236 let mut extreme_ar_ids = Vec::new();
237
238 struct AnnBbox {
240 id: u64,
241 bbox: [f64; 4],
242 }
243 let mut groups: HashMap<(u64, u64), Vec<AnnBbox>> = HashMap::new();
244
245 for ann in &dataset.annotations {
246 if let Some(bbox) = &ann.bbox {
247 let w = bbox[2];
248 let h = bbox[3];
249
250 if w <= 0.0 || h <= 0.0 {
252 degenerate_ids.push(ann.id);
253 continue;
254 }
255
256 if let Some(area) = ann.area {
258 if area == 0.0 {
259 zero_area_ids.push(ann.id);
260 }
261 }
262
263 if let Some(&(img_w, img_h)) = img_dims.get(&ann.image_id) {
265 let x2 = bbox[0] + w;
266 let y2 = bbox[1] + h;
267 if x2 > img_w as f64 || y2 > img_h as f64 {
268 oob_ids.push(ann.id);
269 }
270 }
271
272 let ar = if w > h { w / h } else { h / w };
274 if ar > 20.0 {
275 extreme_ar_ids.push(ann.id);
276 }
277
278 groups
280 .entry((ann.image_id, ann.category_id))
281 .or_default()
282 .push(AnnBbox {
283 id: ann.id,
284 bbox: *bbox,
285 });
286 }
287 }
288
289 let n = degenerate_ids.len();
290 push_if_nonempty(
291 warnings,
292 degenerate_ids,
293 "degenerate_bbox",
294 format!("{n} annotation(s) have degenerate bboxes (width or height <= 0)."),
295 Layer::Quality,
296 );
297
298 let n = zero_area_ids.len();
299 push_if_nonempty(
300 warnings,
301 zero_area_ids,
302 "zero_area",
303 format!("{n} annotation(s) have zero area."),
304 Layer::Quality,
305 );
306
307 let n = oob_ids.len();
308 push_if_nonempty(
309 warnings,
310 oob_ids,
311 "bbox_out_of_bounds",
312 format!("{n} annotation(s) have bboxes extending outside the image boundary."),
313 Layer::Quality,
314 );
315
316 let n = extreme_ar_ids.len();
317 push_if_nonempty(
318 warnings,
319 extreme_ar_ids,
320 "extreme_aspect_ratio",
321 format!("{n} annotation(s) have extreme aspect ratios (>20:1)."),
322 Layer::Quality,
323 );
324
325 let mut near_dup_ids = HashSet::new();
327 let mut skipped_img_ids = Vec::new();
328 for ((img_id, _), anns) in &groups {
329 if anns.len() > 100 {
331 skipped_img_ids.push(*img_id);
332 continue;
333 }
334 for i in 0..anns.len() {
335 for j in (i + 1)..anns.len() {
336 if bbox_iou_pair(&anns[i].bbox, &anns[j].bbox) > 0.95 {
337 near_dup_ids.insert(anns[i].id);
338 near_dup_ids.insert(anns[j].id);
339 }
340 }
341 }
342 }
343
344 if !near_dup_ids.is_empty() {
345 let mut ids: Vec<u64> = near_dup_ids.into_iter().collect();
346 ids.sort_unstable();
347 warnings.push(Finding {
348 code: "near_duplicate",
349 message: format!(
350 "{} annotation(s) appear to be near-duplicates (same class, same image, IoU > 0.95).",
351 ids.len()
352 ),
353 affected_ids: ids,
354 layer: Layer::Quality,
355 });
356 }
357
358 skipped_img_ids.sort_unstable();
359 skipped_img_ids.dedup();
360 let n = skipped_img_ids.len();
361 push_if_nonempty(
362 warnings,
363 skipped_img_ids,
364 "near_duplicate_check_skipped",
365 format!("{n} image(s) have >100 same-class annotations; near-duplicate check was skipped."),
366 Layer::Quality,
367 );
368}
369
370fn build_summary(dataset: &Dataset, warnings: &mut Vec<Finding>) -> DatasetSummary {
371 let annotated_img_ids: HashSet<u64> =
373 dataset.annotations.iter().map(|ann| ann.image_id).collect();
374 let images_without_annotations = dataset
375 .images
376 .iter()
377 .filter(|img| !annotated_img_ids.contains(&img.id))
378 .count();
379
380 let cat_name_map: HashMap<u64, &str> = dataset
382 .categories
383 .iter()
384 .map(|c| (c.id, c.name.as_str()))
385 .collect();
386
387 let mut cat_counts: HashMap<u64, usize> = HashMap::new();
388 for cat in &dataset.categories {
389 cat_counts.insert(cat.id, 0);
390 }
391 for ann in &dataset.annotations {
392 if let Some(count) = cat_counts.get_mut(&ann.category_id) {
393 *count += 1;
394 }
395 }
396
397 let mut category_counts: Vec<(String, usize)> = cat_counts
398 .iter()
399 .filter_map(|(&cat_id, &count)| {
400 cat_name_map
401 .get(&cat_id)
402 .map(|name| ((*name).to_string(), count))
403 })
404 .collect();
405 category_counts.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
406
407 let nonzero_counts: Vec<usize> = category_counts
409 .iter()
410 .map(|(_, c)| *c)
411 .filter(|&c| c > 0)
412 .collect();
413 let imbalance_ratio = if nonzero_counts.len() < 2 {
414 1.0
415 } else {
416 let max = *nonzero_counts.iter().max().expect("len >= 2") as f64;
417 let min = *nonzero_counts.iter().min().expect("len >= 2") as f64;
418 max / min
419 };
420
421 let zero_cats: Vec<u64> = cat_counts
423 .iter()
424 .filter(|&(_, &count)| count == 0)
425 .map(|(&id, _)| id)
426 .collect();
427 if !zero_cats.is_empty() {
428 warnings.push(Finding {
429 code: "zero_instance_category",
430 message: format!(
431 "{} category/categories have zero annotation instances.",
432 zero_cats.len()
433 ),
434 affected_ids: zero_cats,
435 layer: Layer::Distribution,
436 });
437 }
438
439 let low_cats: Vec<u64> = cat_counts
441 .iter()
442 .filter(|&(_, &count)| count > 0 && count < 10)
443 .map(|(&id, _)| id)
444 .collect();
445 if !low_cats.is_empty() {
446 warnings.push(Finding {
447 code: "low_instance_category",
448 message: format!(
449 "{} category/categories have fewer than 10 annotation instances.",
450 low_cats.len()
451 ),
452 affected_ids: low_cats,
453 layer: Layer::Distribution,
454 });
455 }
456
457 DatasetSummary {
458 num_images: dataset.images.len(),
459 num_annotations: dataset.annotations.len(),
460 num_categories: dataset.categories.len(),
461 images_without_annotations,
462 category_counts,
463 imbalance_ratio,
464 }
465}
466
467fn check_compatibility(
468 gt: &Dataset,
469 dt: &Dataset,
470 errors: &mut Vec<Finding>,
471 warnings: &mut Vec<Finding>,
472) {
473 let gt_image_ids: HashSet<u64> = gt.images.iter().map(|img| img.id).collect();
474 let gt_cat_ids: HashSet<u64> = gt.categories.iter().map(|c| c.id).collect();
475
476 let orphan_img: Vec<u64> = dt
478 .annotations
479 .iter()
480 .filter(|ann| !gt_image_ids.contains(&ann.image_id))
481 .map(|ann| ann.id)
482 .collect();
483 let n = orphan_img.len();
484 push_if_nonempty(
485 errors,
486 orphan_img,
487 "dt_orphan_image_id",
488 format!("{n} detection(s) reference image IDs not present in ground truth."),
489 Layer::Compatibility,
490 );
491
492 let orphan_cat: Vec<u64> = dt
494 .annotations
495 .iter()
496 .filter(|ann| !gt_cat_ids.contains(&ann.category_id))
497 .map(|ann| ann.id)
498 .collect();
499 let n = orphan_cat.len();
500 push_if_nonempty(
501 errors,
502 orphan_cat,
503 "dt_orphan_category_id",
504 format!("{n} detection(s) reference category IDs not present in ground truth."),
505 Layer::Compatibility,
506 );
507
508 let missing_score: Vec<u64> = dt
510 .annotations
511 .iter()
512 .filter(|ann| ann.score.is_none())
513 .map(|ann| ann.id)
514 .collect();
515 let n = missing_score.len();
516 push_if_nonempty(
517 warnings,
518 missing_score,
519 "dt_missing_score",
520 format!("{n} detection(s) are missing a confidence score."),
521 Layer::Compatibility,
522 );
523
524 let bad_score: Vec<u64> = dt
526 .annotations
527 .iter()
528 .filter(|ann| {
529 if let Some(score) = ann.score {
530 !(0.0..=1.0).contains(&score)
531 } else {
532 false
533 }
534 })
535 .map(|ann| ann.id)
536 .collect();
537 let n = bad_score.len();
538 push_if_nonempty(
539 warnings,
540 bad_score,
541 "dt_score_out_of_range",
542 format!("{n} detection(s) have scores outside the [0, 1] range."),
543 Layer::Compatibility,
544 );
545}