Skip to main content

edgefirst_client/coco/
verify.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! COCO dataset verification utilities.
5//!
6//! Provides functions for comparing COCO datasets and validating annotation
7//! accuracy using Hungarian matching for optimal annotation pairing.
8
9use super::{
10    decode_compressed_rle, decode_rle,
11    types::{CocoAnnotation, CocoDataset, CocoSegmentation},
12};
13use pathfinding::{kuhn_munkres::kuhn_munkres_min, matrix::Matrix};
14use std::{
15    collections::{HashMap, HashSet},
16    fmt,
17};
18
19/// Result of verifying a COCO import against Studio data.
20#[derive(Debug, Clone)]
21pub struct VerificationResult {
22    /// Total images in COCO dataset.
23    pub coco_image_count: usize,
24    /// Images found in Studio.
25    pub studio_image_count: usize,
26    /// Images missing from Studio.
27    pub missing_images: Vec<String>,
28    /// Extra images in Studio not in COCO.
29    pub extra_images: Vec<String>,
30    /// Total annotations in COCO dataset.
31    pub coco_annotation_count: usize,
32    /// Total annotations in Studio.
33    pub studio_annotation_count: usize,
34    /// Bounding box validation results.
35    pub bbox_validation: BboxValidationResult,
36    /// Segmentation mask validation results.
37    pub mask_validation: MaskValidationResult,
38    /// Category validation results.
39    pub category_validation: CategoryValidationResult,
40}
41
42impl VerificationResult {
43    /// Returns true if the verification passed all checks.
44    pub fn is_valid(&self) -> bool {
45        self.missing_images.is_empty()
46            && self.extra_images.is_empty()
47            && self.bbox_validation.is_valid()
48            && self.mask_validation.is_valid()
49    }
50
51    /// Returns a summary of the verification.
52    pub fn summary(&self) -> String {
53        let mut s = String::new();
54        s.push_str(&format!(
55            "Images: {}/{} (missing: {}, extra: {})\n",
56            self.studio_image_count,
57            self.coco_image_count,
58            self.missing_images.len(),
59            self.extra_images.len()
60        ));
61        s.push_str(&format!(
62            "Annotations: {}/{}\n",
63            self.studio_annotation_count, self.coco_annotation_count
64        ));
65        s.push_str(&format!(
66            "Bbox: {:.1}% matched, {:.4} avg IoU\n",
67            self.bbox_validation.match_rate() * 100.0,
68            self.bbox_validation.avg_iou()
69        ));
70        s.push_str(&format!(
71            "Masks: {:.1}% preserved, {:.4} avg bbox IoU\n",
72            self.mask_validation.preservation_rate() * 100.0,
73            self.mask_validation.avg_bbox_iou()
74        ));
75        s
76    }
77}
78
79impl fmt::Display for VerificationResult {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        writeln!(
82            f,
83            "╔══════════════════════════════════════════════════════════════╗"
84        )?;
85        writeln!(
86            f,
87            "║                  COCO IMPORT VERIFICATION                    ║"
88        )?;
89        writeln!(
90            f,
91            "╠══════════════════════════════════════════════════════════════╣"
92        )?;
93        writeln!(
94            f,
95            "║ Images:      {} in COCO, {} in Studio",
96            self.coco_image_count, self.studio_image_count
97        )?;
98        if !self.missing_images.is_empty() {
99            writeln!(f, "║   Missing:   {} images", self.missing_images.len())?;
100            for name in self.missing_images.iter().take(5) {
101                writeln!(f, "║              - {}", name)?;
102            }
103            if self.missing_images.len() > 5 {
104                writeln!(
105                    f,
106                    "║              ... and {} more",
107                    self.missing_images.len() - 5
108                )?;
109            }
110        }
111        if !self.extra_images.is_empty() {
112            writeln!(f, "║   Extra:     {} images", self.extra_images.len())?;
113            for name in self.extra_images.iter().take(5) {
114                writeln!(f, "║              - {}", name)?;
115            }
116            if self.extra_images.len() > 5 {
117                writeln!(
118                    f,
119                    "║              ... and {} more",
120                    self.extra_images.len() - 5
121                )?;
122            }
123        }
124        writeln!(
125            f,
126            "║ Annotations: {} in COCO, {} in Studio",
127            self.coco_annotation_count, self.studio_annotation_count
128        )?;
129        writeln!(
130            f,
131            "╠══════════════════════════════════════════════════════════════╣"
132        )?;
133        write!(f, "{}", self.bbox_validation)?;
134        writeln!(
135            f,
136            "╠══════════════════════════════════════════════════════════════╣"
137        )?;
138        write!(f, "{}", self.mask_validation)?;
139        writeln!(
140            f,
141            "╠══════════════════════════════════════════════════════════════╣"
142        )?;
143        write!(f, "{}", self.category_validation)?;
144        writeln!(
145            f,
146            "╠══════════════════════════════════════════════════════════════╣"
147        )?;
148        let status = if self.is_valid() {
149            "✓ PASSED"
150        } else {
151            "✗ FAILED"
152        };
153        writeln!(f, "║ Status: {}", status)?;
154        writeln!(
155            f,
156            "╚══════════════════════════════════════════════════════════════╝"
157        )?;
158        Ok(())
159    }
160}
161
162/// Bounding box validation results.
163#[derive(Debug, Clone, Default)]
164pub struct BboxValidationResult {
165    /// Total annotations that were matched using Hungarian algorithm.
166    pub total_matched: usize,
167    /// Total annotations that could not be matched (IoU too low).
168    pub total_unmatched: usize,
169    /// Coordinate errors by range: [<1px, <2px, <5px, <10px, >=10px]
170    pub errors_by_range: [usize; 5],
171    /// Maximum coordinate error in pixels.
172    pub max_error: f64,
173    /// Sum of IoU values for averaging.
174    pub sum_iou: f64,
175}
176
177impl BboxValidationResult {
178    /// Returns the percentage of coordinates within 1 pixel error.
179    pub fn within_1px_rate(&self) -> f64 {
180        let total_coords = self.total_matched * 4;
181        if total_coords == 0 {
182            1.0
183        } else {
184            self.errors_by_range[0] as f64 / total_coords as f64
185        }
186    }
187
188    /// Returns the percentage of coordinates within 2 pixels error.
189    pub fn within_2px_rate(&self) -> f64 {
190        let total_coords = self.total_matched * 4;
191        if total_coords == 0 {
192            1.0
193        } else {
194            (self.errors_by_range[0] + self.errors_by_range[1]) as f64 / total_coords as f64
195        }
196    }
197
198    /// Returns the average IoU across all matched annotations.
199    pub fn avg_iou(&self) -> f64 {
200        if self.total_matched == 0 {
201            1.0
202        } else {
203            self.sum_iou / self.total_matched as f64
204        }
205    }
206
207    /// Returns the match rate (matched / total).
208    pub fn match_rate(&self) -> f64 {
209        let total = self.total_matched + self.total_unmatched;
210        if total == 0 {
211            1.0
212        } else {
213            self.total_matched as f64 / total as f64
214        }
215    }
216
217    /// Returns true if bbox validation passes quality thresholds.
218    pub fn is_valid(&self) -> bool {
219        self.within_1px_rate() > 0.99 && self.match_rate() > 0.95 && self.avg_iou() > 0.95
220    }
221}
222
223impl fmt::Display for BboxValidationResult {
224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225        writeln!(f, "║ Bounding Box Validation:")?;
226        writeln!(
227            f,
228            "║   Matched:    {}/{} ({:.1}%)",
229            self.total_matched,
230            self.total_matched + self.total_unmatched,
231            self.match_rate() * 100.0
232        )?;
233        writeln!(f, "║   Avg IoU:    {:.4}", self.avg_iou())?;
234        writeln!(f, "║   Within 1px: {:.1}%", self.within_1px_rate() * 100.0)?;
235        writeln!(f, "║   Within 2px: {:.1}%", self.within_2px_rate() * 100.0)?;
236        writeln!(f, "║   Max error:  {:.2}px", self.max_error)?;
237        Ok(())
238    }
239}
240
241/// Segmentation mask validation results.
242#[derive(Debug, Clone, Default)]
243pub struct MaskValidationResult {
244    /// Annotations with segmentation in original.
245    pub original_with_seg: usize,
246    /// Annotations with segmentation in restored.
247    pub restored_with_seg: usize,
248    /// Matched pairs where both have segmentation.
249    pub matched_pairs_with_seg: usize,
250    /// Polygon pairs (for vertex comparison).
251    pub polygon_pairs: usize,
252    /// RLE pairs converted to polygon.
253    pub rle_pairs: usize,
254    /// Pairs where vertex count matches exactly.
255    pub vertex_count_exact_match: usize,
256    /// Pairs where vertex count is within 10%.
257    pub vertex_count_close_match: usize,
258    /// Pairs where part count matches.
259    pub part_count_match: usize,
260    /// Pairs with area within 1%.
261    pub area_within_1pct: usize,
262    /// Pairs with area within 5%.
263    pub area_within_5pct: usize,
264    /// Pairs with bbox IoU >= 0.9.
265    pub bbox_iou_high: usize,
266    /// Pairs with bbox IoU < 0.5.
267    pub bbox_iou_low: usize,
268    /// Sum of area ratios.
269    pub sum_area_ratio: f64,
270    /// Minimum area ratio.
271    pub min_area_ratio: f64,
272    /// Maximum area ratio.
273    pub max_area_ratio: f64,
274    /// Sum of bbox IoU values.
275    pub sum_bbox_iou: f64,
276    /// Count of zero-area segmentations.
277    pub zero_area_count: usize,
278}
279
280impl MaskValidationResult {
281    /// Create a new result with initialized min/max values.
282    pub fn new() -> Self {
283        Self {
284            min_area_ratio: f64::MAX,
285            max_area_ratio: 0.0,
286            ..Default::default()
287        }
288    }
289
290    /// Returns the segmentation preservation rate.
291    pub fn preservation_rate(&self) -> f64 {
292        if self.original_with_seg == 0 {
293            1.0
294        } else {
295            self.restored_with_seg as f64 / self.original_with_seg as f64
296        }
297    }
298
299    /// Returns the average area ratio.
300    pub fn avg_area_ratio(&self) -> f64 {
301        let valid_count = self
302            .matched_pairs_with_seg
303            .saturating_sub(self.zero_area_count);
304        if valid_count == 0 {
305            1.0
306        } else {
307            self.sum_area_ratio / valid_count as f64
308        }
309    }
310
311    /// Returns the average bbox IoU.
312    pub fn avg_bbox_iou(&self) -> f64 {
313        if self.matched_pairs_with_seg == 0 {
314            1.0
315        } else {
316            self.sum_bbox_iou / self.matched_pairs_with_seg as f64
317        }
318    }
319
320    /// Returns true if mask validation passes quality thresholds.
321    pub fn is_valid(&self) -> bool {
322        self.preservation_rate() > 0.95 && self.avg_bbox_iou() > 0.90
323    }
324
325    /// Aggregate a single segmentation comparison into the result.
326    pub fn aggregate_comparison(&mut self, cmp: &SegmentationPairComparison) {
327        self.matched_pairs_with_seg += 1;
328
329        if cmp.is_rle {
330            self.rle_pairs += 1;
331        } else {
332            self.polygon_pairs += 1;
333            if cmp.vertex_exact_match {
334                self.vertex_count_exact_match += 1;
335            }
336            if cmp.vertex_close_match {
337                self.vertex_count_close_match += 1;
338            }
339            if cmp.part_match {
340                self.part_count_match += 1;
341            }
342        }
343
344        if let Some(area_ratio) = cmp.area_ratio {
345            self.sum_area_ratio += area_ratio;
346            self.min_area_ratio = self.min_area_ratio.min(area_ratio);
347            self.max_area_ratio = self.max_area_ratio.max(area_ratio);
348            if (area_ratio - 1.0).abs() <= 0.01 {
349                self.area_within_1pct += 1;
350            }
351            if (area_ratio - 1.0).abs() <= 0.05 {
352                self.area_within_5pct += 1;
353            }
354        } else {
355            self.zero_area_count += 1;
356        }
357
358        self.sum_bbox_iou += cmp.bbox_iou;
359        if cmp.bbox_iou >= 0.9 {
360            self.bbox_iou_high += 1;
361        }
362        if cmp.bbox_iou < 0.5 {
363            self.bbox_iou_low += 1;
364        }
365    }
366}
367
368impl fmt::Display for MaskValidationResult {
369    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370        writeln!(f, "║ Segmentation Mask Validation:")?;
371        writeln!(
372            f,
373            "║   Preserved:   {}/{} ({:.1}%)",
374            self.restored_with_seg,
375            self.original_with_seg,
376            self.preservation_rate() * 100.0
377        )?;
378        writeln!(
379            f,
380            "║   Matched:     {} ({} polygon, {} RLE→polygon)",
381            self.matched_pairs_with_seg, self.polygon_pairs, self.rle_pairs
382        )?;
383        writeln!(f, "║   Avg bbox IoU: {:.4}", self.avg_bbox_iou())?;
384        writeln!(
385            f,
386            "║   High IoU (>=0.9): {}/{} ({:.1}%)",
387            self.bbox_iou_high,
388            self.matched_pairs_with_seg,
389            if self.matched_pairs_with_seg > 0 {
390                self.bbox_iou_high as f64 / self.matched_pairs_with_seg as f64 * 100.0
391            } else {
392                100.0
393            }
394        )?;
395        if self.polygon_pairs > 0 {
396            writeln!(
397                f,
398                "║   Vertex exact: {}/{} ({:.1}%)",
399                self.vertex_count_exact_match,
400                self.polygon_pairs,
401                self.vertex_count_exact_match as f64 / self.polygon_pairs as f64 * 100.0
402            )?;
403        }
404        Ok(())
405    }
406}
407
408/// Category validation results.
409#[derive(Debug, Clone, Default)]
410pub struct CategoryValidationResult {
411    /// Categories in COCO dataset.
412    pub coco_categories: HashSet<String>,
413    /// Categories in Studio.
414    pub studio_categories: HashSet<String>,
415    /// Categories missing from Studio.
416    pub missing_categories: Vec<String>,
417    /// Extra categories in Studio.
418    pub extra_categories: Vec<String>,
419}
420
421impl CategoryValidationResult {
422    /// Returns true if all categories are present.
423    pub fn is_valid(&self) -> bool {
424        self.missing_categories.is_empty()
425    }
426}
427
428impl fmt::Display for CategoryValidationResult {
429    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
430        writeln!(
431            f,
432            "║ Categories:    {} in COCO, {} in Studio",
433            self.coco_categories.len(),
434            self.studio_categories.len()
435        )?;
436        if !self.missing_categories.is_empty() {
437            writeln!(f, "║   Missing:     {:?}", self.missing_categories)?;
438        }
439        if !self.extra_categories.is_empty() {
440            writeln!(f, "║   Extra:       {:?}", self.extra_categories)?;
441        }
442        Ok(())
443    }
444}
445
446/// Result of comparing a single segmentation pair.
447#[derive(Debug, Clone, Default)]
448pub struct SegmentationPairComparison {
449    /// Whether the original was RLE (vs polygon).
450    pub is_rle: bool,
451    /// Vertex count exact match (polygon only).
452    pub vertex_exact_match: bool,
453    /// Vertex count within 10% (polygon only).
454    pub vertex_close_match: bool,
455    /// Part count match (polygon only).
456    pub part_match: bool,
457    /// Area ratio (restored / original), or None if zero area.
458    pub area_ratio: Option<f64>,
459    /// Bounding box IoU.
460    pub bbox_iou: f64,
461}
462
463/// Compare two segmentations and return comparison metrics.
464pub fn compare_segmentation_pair(
465    orig_seg: &CocoSegmentation,
466    rest_seg: &CocoSegmentation,
467) -> SegmentationPairComparison {
468    let is_rle = matches!(
469        orig_seg,
470        CocoSegmentation::Rle(_) | CocoSegmentation::CompressedRle(_)
471    );
472
473    let (vertex_exact_match, vertex_close_match, part_match) = if is_rle {
474        (false, false, false)
475    } else {
476        let orig_vertices = count_polygon_vertices(orig_seg);
477        let rest_vertices = count_polygon_vertices(rest_seg);
478        let orig_parts = count_polygon_parts(orig_seg);
479        let rest_parts = count_polygon_parts(rest_seg);
480
481        let vertex_diff = (orig_vertices as f64 - rest_vertices as f64).abs();
482        let vertex_threshold = (orig_vertices as f64 * 0.1).max(1.0);
483
484        (
485            orig_vertices == rest_vertices,
486            vertex_diff <= vertex_threshold,
487            orig_parts == rest_parts,
488        )
489    };
490
491    let orig_area = compute_segmentation_area(orig_seg);
492    let rest_area = compute_segmentation_area(rest_seg);
493    let area_ratio = if orig_area > 0.0 && rest_area > 0.0 {
494        Some(rest_area / orig_area)
495    } else {
496        None
497    };
498
499    let bbox_iou = segmentation_bbox_iou(orig_seg, rest_seg);
500
501    SegmentationPairComparison {
502        is_rle,
503        vertex_exact_match,
504        vertex_close_match,
505        part_match,
506        area_ratio,
507        bbox_iou,
508    }
509}
510
511/// Calculate Intersection over Union (IoU) for two COCO bboxes.
512/// COCO bbox format: [x, y, width, height] (top-left corner)
513pub fn bbox_iou(a: &[f64; 4], b: &[f64; 4]) -> f64 {
514    let a_x1 = a[0];
515    let a_y1 = a[1];
516    let a_x2 = a[0] + a[2];
517    let a_y2 = a[1] + a[3];
518
519    let b_x1 = b[0];
520    let b_y1 = b[1];
521    let b_x2 = b[0] + b[2];
522    let b_y2 = b[1] + b[3];
523
524    // Intersection
525    let inter_x1 = a_x1.max(b_x1);
526    let inter_y1 = a_y1.max(b_y1);
527    let inter_x2 = a_x2.min(b_x2);
528    let inter_y2 = a_y2.min(b_y2);
529
530    let inter_w = (inter_x2 - inter_x1).max(0.0);
531    let inter_h = (inter_y2 - inter_y1).max(0.0);
532    let inter_area = inter_w * inter_h;
533
534    // Union
535    let a_area = a[2] * a[3];
536    let b_area = b[2] * b[3];
537    let union_area = a_area + b_area - inter_area;
538
539    if union_area > 0.0 {
540        inter_area / union_area
541    } else {
542        0.0
543    }
544}
545
546/// Use Hungarian algorithm to find optimal matching between two sets of
547/// annotations based on bounding box IoU.
548/// Returns pairs of (original_idx, restored_idx) for matched annotations.
549pub fn hungarian_match<'a>(
550    orig_anns: &[&'a CocoAnnotation],
551    rest_anns: &[&'a CocoAnnotation],
552) -> Vec<(usize, usize)> {
553    if orig_anns.is_empty() || rest_anns.is_empty() {
554        return vec![];
555    }
556
557    let n = orig_anns.len();
558    let m = rest_anns.len();
559
560    // Make the matrix square by padding with high-cost dummy entries
561    let size = n.max(m);
562
563    // Build cost matrix: cost = (1 - IoU) * scale
564    // We use i64 for kuhn_munkres, scale by 10000 for precision
565    let scale = 10000i64;
566    let max_cost = scale; // Cost for non-matching (IoU = 0)
567
568    let mut weights = Vec::with_capacity(size * size);
569    for i in 0..size {
570        for j in 0..size {
571            let cost = match (orig_anns.get(i), rest_anns.get(j)) {
572                (Some(orig), Some(rest)) => {
573                    let iou = bbox_iou(&orig.bbox, &rest.bbox);
574                    ((1.0 - iou) * scale as f64) as i64
575                }
576                _ => max_cost, // Dummy entry for padding
577            };
578            weights.push(cost);
579        }
580    }
581
582    let matrix = Matrix::from_vec(size, size, weights).expect("Failed to create matrix");
583    let (_, assignments) = kuhn_munkres_min(&matrix);
584
585    // Filter to only real matches (not dummy) with reasonable IoU
586    let min_iou_threshold = 0.3; // Only accept matches with IoU > 0.3
587    assignments
588        .iter()
589        .enumerate()
590        .filter_map(|(i, &j)| {
591            if i < n && j < m {
592                let iou = bbox_iou(&orig_anns[i].bbox, &rest_anns[j].bbox);
593                if iou >= min_iou_threshold {
594                    Some((i, j))
595                } else {
596                    None
597                }
598            } else {
599                None
600            }
601        })
602        .collect()
603}
604
605/// Calculate polygon area using the Shoelace formula.
606/// Takes coordinates as flat array [x1, y1, x2, y2, ...]
607pub fn polygon_area(coords: &[f64]) -> f64 {
608    let n = coords.len() / 2;
609    if n < 3 {
610        return 0.0;
611    }
612
613    let mut area = 0.0;
614    for i in 0..n {
615        let j = (i + 1) % n;
616        let x_i = coords[i * 2];
617        let y_i = coords[i * 2 + 1];
618        let x_j = coords[j * 2];
619        let y_j = coords[j * 2 + 1];
620        area += x_i * y_j - x_j * y_i;
621    }
622    (area / 2.0).abs()
623}
624
625/// Calculate total area of a segmentation.
626pub fn compute_segmentation_area(seg: &CocoSegmentation) -> f64 {
627    match seg {
628        CocoSegmentation::Polygon(polys) => polys.iter().map(|p| polygon_area(p)).sum(),
629        CocoSegmentation::Rle(rle) => {
630            if let Ok((mask, _, _)) = decode_rle(rle) {
631                mask.iter().filter(|&&v| v == 1).count() as f64
632            } else {
633                0.0
634            }
635        }
636        CocoSegmentation::CompressedRle(compressed) => {
637            if let Ok((mask, _, _)) = decode_compressed_rle(compressed) {
638                mask.iter().filter(|&&v| v == 1).count() as f64
639            } else {
640                0.0
641            }
642        }
643    }
644}
645
646/// Calculate bounding box of a polygon (min_x, min_y, max_x, max_y)
647pub fn polygon_bounds(coords: &[f64]) -> Option<(f64, f64, f64, f64)> {
648    if coords.len() < 4 {
649        return None;
650    }
651    let mut min_x = f64::MAX;
652    let mut min_y = f64::MAX;
653    let mut max_x = f64::MIN;
654    let mut max_y = f64::MIN;
655
656    for chunk in coords.chunks(2) {
657        if chunk.len() == 2 {
658            min_x = min_x.min(chunk[0]);
659            max_x = max_x.max(chunk[0]);
660            min_y = min_y.min(chunk[1]);
661            max_y = max_y.max(chunk[1]);
662        }
663    }
664    Some((min_x, min_y, max_x, max_y))
665}
666
667/// Get bounding box for any segmentation type.
668pub fn segmentation_bounds(seg: &CocoSegmentation) -> Option<(f64, f64, f64, f64)> {
669    match seg {
670        CocoSegmentation::Polygon(polys) => {
671            polys
672                .iter()
673                .filter_map(|p| polygon_bounds(p))
674                .fold(None, |acc, b| match acc {
675                    None => Some(b),
676                    Some((min_x, min_y, max_x, max_y)) => Some((
677                        min_x.min(b.0),
678                        min_y.min(b.1),
679                        max_x.max(b.2),
680                        max_y.max(b.3),
681                    )),
682                })
683        }
684        CocoSegmentation::Rle(rle) => {
685            let (mask, height, width) = decode_rle(rle).ok()?;
686            rle_mask_bounds(&mask, height, width)
687        }
688        CocoSegmentation::CompressedRle(compressed) => {
689            let (mask, height, width) = decode_compressed_rle(compressed).ok()?;
690            rle_mask_bounds(&mask, height, width)
691        }
692    }
693}
694
695/// Find bounds of a binary mask.
696fn rle_mask_bounds(mask: &[u8], height: u32, width: u32) -> Option<(f64, f64, f64, f64)> {
697    let mut min_x = width;
698    let mut min_y = height;
699    let mut max_x = 0u32;
700    let mut max_y = 0u32;
701    let mut found_any = false;
702
703    for y in 0..height {
704        for x in 0..width {
705            let idx = (y as usize) * (width as usize) + (x as usize);
706            if mask.get(idx) == Some(&1) {
707                found_any = true;
708                min_x = min_x.min(x);
709                max_x = max_x.max(x);
710                min_y = min_y.min(y);
711                max_y = max_y.max(y);
712            }
713        }
714    }
715
716    if found_any {
717        Some((min_x as f64, min_y as f64, max_x as f64, max_y as f64))
718    } else {
719        None
720    }
721}
722
723/// Calculate IoU between two segmentation bounding boxes.
724pub fn segmentation_bbox_iou(seg1: &CocoSegmentation, seg2: &CocoSegmentation) -> f64 {
725    let bounds1 = segmentation_bounds(seg1);
726    let bounds2 = segmentation_bounds(seg2);
727
728    match (bounds1, bounds2) {
729        (Some((a_x1, a_y1, a_x2, a_y2)), Some((b_x1, b_y1, b_x2, b_y2))) => {
730            let inter_x1 = a_x1.max(b_x1);
731            let inter_y1 = a_y1.max(b_y1);
732            let inter_x2 = a_x2.min(b_x2);
733            let inter_y2 = a_y2.min(b_y2);
734
735            let inter_w = (inter_x2 - inter_x1).max(0.0);
736            let inter_h = (inter_y2 - inter_y1).max(0.0);
737            let inter_area = inter_w * inter_h;
738
739            let a_area = (a_x2 - a_x1) * (a_y2 - a_y1);
740            let b_area = (b_x2 - b_x1) * (b_y2 - b_y1);
741            let union_area = a_area + b_area - inter_area;
742
743            if union_area > 0.0 {
744                inter_area / union_area
745            } else {
746                0.0
747            }
748        }
749        _ => 0.0,
750    }
751}
752
753/// Count total polygon vertices in a segmentation.
754pub fn count_polygon_vertices(seg: &CocoSegmentation) -> usize {
755    match seg {
756        CocoSegmentation::Polygon(polys) => polys.iter().map(|p| p.len() / 2).sum(),
757        _ => 0,
758    }
759}
760
761/// Count number of polygon parts in a segmentation.
762pub fn count_polygon_parts(seg: &CocoSegmentation) -> usize {
763    match seg {
764        CocoSegmentation::Polygon(polys) => polys.len(),
765        _ => 0,
766    }
767}
768
769/// Build a map of annotations by sample name for efficient lookup.
770pub fn build_annotation_map_by_name(
771    dataset: &CocoDataset,
772) -> HashMap<String, Vec<&CocoAnnotation>> {
773    let image_names: HashMap<u64, String> = dataset
774        .images
775        .iter()
776        .map(|img| {
777            let name = std::path::Path::new(&img.file_name)
778                .file_stem()
779                .and_then(|s| s.to_str())
780                .unwrap_or(&img.file_name)
781                .to_string();
782            (img.id, name)
783        })
784        .collect();
785
786    let mut map: HashMap<String, Vec<_>> = HashMap::new();
787    for ann in &dataset.annotations {
788        if let Some(name) = image_names.get(&ann.image_id) {
789            map.entry(name.clone()).or_default().push(ann);
790        }
791    }
792    map
793}
794
795/// Validate bounding boxes between two datasets using Hungarian matching.
796pub fn validate_bboxes(original: &CocoDataset, restored: &CocoDataset) -> BboxValidationResult {
797    let mut result = BboxValidationResult::default();
798
799    let original_by_name = build_annotation_map_by_name(original);
800    let restored_by_name = build_annotation_map_by_name(restored);
801
802    for (name, orig_anns) in &original_by_name {
803        if let Some(rest_anns) = restored_by_name.get(name) {
804            let matches = hungarian_match(orig_anns, rest_anns);
805
806            for (orig_idx, rest_idx) in &matches {
807                let orig_ann = orig_anns[*orig_idx];
808                let rest_ann = rest_anns[*rest_idx];
809
810                // Track IoU
811                let iou = bbox_iou(&orig_ann.bbox, &rest_ann.bbox);
812                result.sum_iou += iou;
813
814                // Measure coordinate errors
815                for i in 0..4 {
816                    let error = (orig_ann.bbox[i] - rest_ann.bbox[i]).abs();
817                    result.max_error = result.max_error.max(error);
818
819                    if error < 1.0 {
820                        result.errors_by_range[0] += 1;
821                    } else if error < 2.0 {
822                        result.errors_by_range[1] += 1;
823                    } else if error < 5.0 {
824                        result.errors_by_range[2] += 1;
825                    } else if error < 10.0 {
826                        result.errors_by_range[3] += 1;
827                    } else {
828                        result.errors_by_range[4] += 1;
829                    }
830                }
831                result.total_matched += 1;
832            }
833
834            result.total_unmatched += orig_anns.len() - matches.len();
835        } else {
836            result.total_unmatched += orig_anns.len();
837        }
838    }
839
840    result
841}
842
843/// Validate segmentation masks between two datasets using Hungarian matching.
844pub fn validate_masks(original: &CocoDataset, restored: &CocoDataset) -> MaskValidationResult {
845    let mut result = MaskValidationResult::new();
846
847    // Count segmentations in original and restored
848    result.original_with_seg = original
849        .annotations
850        .iter()
851        .filter(|a| a.segmentation.is_some())
852        .count();
853    result.restored_with_seg = restored
854        .annotations
855        .iter()
856        .filter(|a| a.segmentation.is_some())
857        .count();
858
859    let original_by_name = build_annotation_map_by_name(original);
860    let restored_by_name = build_annotation_map_by_name(restored);
861
862    for (name, orig_anns) in &original_by_name {
863        if let Some(rest_anns) = restored_by_name.get(name) {
864            let matches = hungarian_match(orig_anns, rest_anns);
865
866            for (orig_idx, rest_idx) in &matches {
867                let orig_ann = orig_anns[*orig_idx];
868                let rest_ann = rest_anns[*rest_idx];
869
870                if let (Some(orig_seg), Some(rest_seg)) =
871                    (&orig_ann.segmentation, &rest_ann.segmentation)
872                {
873                    let comparison = compare_segmentation_pair(orig_seg, rest_seg);
874                    result.aggregate_comparison(&comparison);
875                }
876            }
877        }
878    }
879
880    result
881}
882
883/// Validate categories between two datasets.
884pub fn validate_categories(
885    original: &CocoDataset,
886    restored: &CocoDataset,
887) -> CategoryValidationResult {
888    let coco_cats: HashSet<String> = original.categories.iter().map(|c| c.name.clone()).collect();
889    let studio_cats: HashSet<String> = restored.categories.iter().map(|c| c.name.clone()).collect();
890
891    let missing: Vec<String> = coco_cats.difference(&studio_cats).cloned().collect();
892    let extra: Vec<String> = studio_cats.difference(&coco_cats).cloned().collect();
893
894    CategoryValidationResult {
895        coco_categories: coco_cats,
896        studio_categories: studio_cats,
897        missing_categories: missing,
898        extra_categories: extra,
899    }
900}
901
902#[cfg(test)]
903mod tests {
904    use super::*;
905    use crate::coco::{CocoCategory, CocoImage, CocoRle};
906
907    // =========================================================================
908    // bbox_iou tests
909    // =========================================================================
910
911    #[test]
912    fn test_bbox_iou_perfect_overlap() {
913        let a = [0.0, 0.0, 100.0, 100.0];
914        let b = [0.0, 0.0, 100.0, 100.0];
915        assert!((bbox_iou(&a, &b) - 1.0).abs() < 1e-6);
916    }
917
918    #[test]
919    fn test_bbox_iou_no_overlap() {
920        let a = [0.0, 0.0, 100.0, 100.0];
921        let b = [200.0, 200.0, 100.0, 100.0];
922        assert!(bbox_iou(&a, &b) < 1e-6);
923    }
924
925    #[test]
926    fn test_bbox_iou_partial_overlap() {
927        let a = [0.0, 0.0, 100.0, 100.0];
928        let b = [50.0, 50.0, 100.0, 100.0];
929        // Intersection: 50x50 = 2500, Union: 10000 + 10000 - 2500 = 17500
930        let expected = 2500.0 / 17500.0;
931        assert!((bbox_iou(&a, &b) - expected).abs() < 1e-6);
932    }
933
934    #[test]
935    fn test_bbox_iou_contained() {
936        // b is fully contained in a
937        let a = [0.0, 0.0, 100.0, 100.0];
938        let b = [25.0, 25.0, 50.0, 50.0];
939        // Intersection: 50x50 = 2500, Union: 10000
940        let expected = 2500.0 / 10000.0;
941        assert!((bbox_iou(&a, &b) - expected).abs() < 1e-6);
942    }
943
944    #[test]
945    fn test_bbox_iou_zero_area() {
946        let a = [0.0, 0.0, 0.0, 0.0];
947        let b = [0.0, 0.0, 100.0, 100.0];
948        assert!(bbox_iou(&a, &b) < 1e-6);
949    }
950
951    // =========================================================================
952    // polygon_area tests
953    // =========================================================================
954
955    #[test]
956    fn test_polygon_area_square() {
957        // 10x10 square
958        let coords = [0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0];
959        assert!((polygon_area(&coords) - 100.0).abs() < 1e-6);
960    }
961
962    #[test]
963    fn test_polygon_area_triangle() {
964        // Triangle with base 10 and height 10
965        let coords = [0.0, 0.0, 10.0, 0.0, 5.0, 10.0];
966        assert!((polygon_area(&coords) - 50.0).abs() < 1e-6);
967    }
968
969    #[test]
970    fn test_polygon_area_too_small() {
971        // Less than 3 points (6 coords)
972        let coords = [0.0, 0.0, 10.0, 10.0];
973        assert!(polygon_area(&coords) < 1e-6);
974    }
975
976    #[test]
977    fn test_polygon_area_complex() {
978        // L-shaped polygon (can compute as two rectangles)
979        // Points: (0,0), (20,0), (20,10), (10,10), (10,20), (0,20)
980        let coords = [
981            0.0, 0.0, 20.0, 0.0, 20.0, 10.0, 10.0, 10.0, 10.0, 20.0, 0.0, 20.0,
982        ];
983        // Area = 10*20 + 10*10 = 300
984        assert!((polygon_area(&coords) - 300.0).abs() < 1e-6);
985    }
986
987    // =========================================================================
988    // polygon_bounds tests
989    // =========================================================================
990
991    #[test]
992    fn test_polygon_bounds_square() {
993        let coords = [0.0, 0.0, 100.0, 0.0, 100.0, 100.0, 0.0, 100.0];
994        let bounds = polygon_bounds(&coords);
995        assert_eq!(bounds, Some((0.0, 0.0, 100.0, 100.0)));
996    }
997
998    #[test]
999    fn test_polygon_bounds_offset() {
1000        let coords = [50.0, 60.0, 150.0, 60.0, 150.0, 160.0, 50.0, 160.0];
1001        let bounds = polygon_bounds(&coords);
1002        assert_eq!(bounds, Some((50.0, 60.0, 150.0, 160.0)));
1003    }
1004
1005    #[test]
1006    fn test_polygon_bounds_too_small() {
1007        let coords = [0.0, 0.0];
1008        assert!(polygon_bounds(&coords).is_none());
1009    }
1010
1011    // =========================================================================
1012    // hungarian_match tests
1013    // =========================================================================
1014
1015    #[test]
1016    fn test_hungarian_match_empty_inputs() {
1017        let orig: Vec<&CocoAnnotation> = vec![];
1018        let rest: Vec<&CocoAnnotation> = vec![];
1019        let matches = hungarian_match(&orig, &rest);
1020        assert!(matches.is_empty());
1021    }
1022
1023    #[test]
1024    fn test_hungarian_match_perfect_match() {
1025        let ann1 = CocoAnnotation {
1026            id: 1,
1027            image_id: 1,
1028            category_id: 1,
1029            bbox: [0.0, 0.0, 100.0, 100.0],
1030            ..Default::default()
1031        };
1032        let ann2 = CocoAnnotation {
1033            id: 2,
1034            image_id: 1,
1035            category_id: 1,
1036            bbox: [0.0, 0.0, 100.0, 100.0], // Same bbox
1037            ..Default::default()
1038        };
1039
1040        let orig = vec![&ann1];
1041        let rest = vec![&ann2];
1042        let matches = hungarian_match(&orig, &rest);
1043
1044        assert_eq!(matches.len(), 1);
1045        assert_eq!(matches[0], (0, 0));
1046    }
1047
1048    #[test]
1049    fn test_hungarian_match_multiple() {
1050        let ann1 = CocoAnnotation {
1051            id: 1,
1052            image_id: 1,
1053            category_id: 1,
1054            bbox: [0.0, 0.0, 50.0, 50.0],
1055            ..Default::default()
1056        };
1057        let ann2 = CocoAnnotation {
1058            id: 2,
1059            image_id: 1,
1060            category_id: 1,
1061            bbox: [100.0, 100.0, 50.0, 50.0],
1062            ..Default::default()
1063        };
1064
1065        let ann3 = CocoAnnotation {
1066            id: 3,
1067            image_id: 1,
1068            category_id: 1,
1069            bbox: [100.0, 100.0, 50.0, 50.0], // Matches ann2
1070            ..Default::default()
1071        };
1072        let ann4 = CocoAnnotation {
1073            id: 4,
1074            image_id: 1,
1075            category_id: 1,
1076            bbox: [0.0, 0.0, 50.0, 50.0], // Matches ann1
1077            ..Default::default()
1078        };
1079
1080        let orig = vec![&ann1, &ann2];
1081        let rest = vec![&ann3, &ann4];
1082        let matches = hungarian_match(&orig, &rest);
1083
1084        assert_eq!(matches.len(), 2);
1085    }
1086
1087    #[test]
1088    fn test_hungarian_match_unequal_sizes() {
1089        let ann1 = CocoAnnotation {
1090            id: 1,
1091            image_id: 1,
1092            category_id: 1,
1093            bbox: [0.0, 0.0, 100.0, 100.0],
1094            ..Default::default()
1095        };
1096        let ann2 = CocoAnnotation {
1097            id: 2,
1098            image_id: 1,
1099            category_id: 1,
1100            bbox: [200.0, 200.0, 100.0, 100.0],
1101            ..Default::default()
1102        };
1103        let ann3 = CocoAnnotation {
1104            id: 3,
1105            image_id: 1,
1106            category_id: 1,
1107            bbox: [0.0, 0.0, 100.0, 100.0], // Matches ann1
1108            ..Default::default()
1109        };
1110
1111        let orig = vec![&ann1, &ann2];
1112        let rest = vec![&ann3]; // Fewer restored
1113        let matches = hungarian_match(&orig, &rest);
1114
1115        assert_eq!(matches.len(), 1);
1116        assert_eq!(matches[0], (0, 0)); // ann1 matched to ann3
1117    }
1118
1119    // =========================================================================
1120    // BboxValidationResult tests
1121    // =========================================================================
1122
1123    #[test]
1124    fn test_bbox_validation_result_rates() {
1125        let mut result = BboxValidationResult {
1126            total_matched: 100,
1127            total_unmatched: 10,
1128            sum_iou: 95.0,
1129            ..Default::default()
1130        };
1131        result.errors_by_range[0] = 350; // 350/400 = 87.5%
1132        result.errors_by_range[1] = 40;
1133
1134        assert!((result.match_rate() - 0.909).abs() < 0.01);
1135        assert!((result.avg_iou() - 0.95).abs() < 0.01);
1136    }
1137
1138    #[test]
1139    fn test_bbox_validation_result_empty() {
1140        let result = BboxValidationResult::default();
1141        assert!((result.match_rate() - 1.0).abs() < 1e-6);
1142        assert!((result.avg_iou() - 1.0).abs() < 1e-6);
1143        assert!((result.within_1px_rate() - 1.0).abs() < 1e-6);
1144    }
1145
1146    #[test]
1147    fn test_bbox_validation_result_is_valid() {
1148        let mut result = BboxValidationResult {
1149            total_matched: 100,
1150            sum_iou: 98.0,
1151            ..Default::default()
1152        };
1153        result.errors_by_range[0] = 400; // All within 1px
1154        assert!(result.is_valid());
1155    }
1156
1157    #[test]
1158    fn test_bbox_validation_result_not_valid() {
1159        let mut result = BboxValidationResult {
1160            total_matched: 100,
1161            total_unmatched: 50, // Low match rate
1162            sum_iou: 50.0,
1163            ..Default::default()
1164        };
1165        result.errors_by_range[0] = 200;
1166        assert!(!result.is_valid());
1167    }
1168
1169    // =========================================================================
1170    // MaskValidationResult tests
1171    // =========================================================================
1172
1173    #[test]
1174    fn test_mask_validation_result_new() {
1175        let result = MaskValidationResult::new();
1176        assert_eq!(result.min_area_ratio, f64::MAX);
1177        assert_eq!(result.max_area_ratio, 0.0);
1178    }
1179
1180    #[test]
1181    fn test_mask_validation_result_preservation_rate() {
1182        let mut result = MaskValidationResult::new();
1183        result.original_with_seg = 100;
1184        result.restored_with_seg = 95;
1185        assert!((result.preservation_rate() - 0.95).abs() < 1e-6);
1186    }
1187
1188    #[test]
1189    fn test_mask_validation_result_empty() {
1190        let result = MaskValidationResult::new();
1191        assert!((result.preservation_rate() - 1.0).abs() < 1e-6);
1192        assert!((result.avg_area_ratio() - 1.0).abs() < 1e-6);
1193        assert!((result.avg_bbox_iou() - 1.0).abs() < 1e-6);
1194    }
1195
1196    // =========================================================================
1197    // CategoryValidationResult tests
1198    // =========================================================================
1199
1200    #[test]
1201    fn test_category_validation_result_is_valid() {
1202        let result = CategoryValidationResult {
1203            coco_categories: ["person", "car"].iter().map(|s| s.to_string()).collect(),
1204            studio_categories: ["person", "car"].iter().map(|s| s.to_string()).collect(),
1205            missing_categories: vec![],
1206            extra_categories: vec![],
1207        };
1208        assert!(result.is_valid());
1209    }
1210
1211    #[test]
1212    fn test_category_validation_result_missing() {
1213        let result = CategoryValidationResult {
1214            coco_categories: ["person", "car"].iter().map(|s| s.to_string()).collect(),
1215            studio_categories: ["person"].iter().map(|s| s.to_string()).collect(),
1216            missing_categories: vec!["car".to_string()],
1217            extra_categories: vec![],
1218        };
1219        assert!(!result.is_valid());
1220    }
1221
1222    // =========================================================================
1223    // compute_segmentation_area tests
1224    // =========================================================================
1225
1226    #[test]
1227    fn test_compute_segmentation_area_polygon() {
1228        let seg =
1229            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 100.0, 0.0, 100.0, 100.0, 0.0, 100.0]]);
1230        let area = compute_segmentation_area(&seg);
1231        assert!((area - 10000.0).abs() < 1e-6);
1232    }
1233
1234    #[test]
1235    fn test_compute_segmentation_area_multiple_polygons() {
1236        // Two 10x10 squares
1237        let seg = CocoSegmentation::Polygon(vec![
1238            vec![0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0],
1239            vec![20.0, 20.0, 30.0, 20.0, 30.0, 30.0, 20.0, 30.0],
1240        ]);
1241        let area = compute_segmentation_area(&seg);
1242        assert!((area - 200.0).abs() < 1e-6);
1243    }
1244
1245    // =========================================================================
1246    // count_polygon_vertices/parts tests
1247    // =========================================================================
1248
1249    #[test]
1250    fn test_count_polygon_vertices() {
1251        let seg = CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0]]);
1252        assert_eq!(count_polygon_vertices(&seg), 4);
1253    }
1254
1255    #[test]
1256    fn test_count_polygon_vertices_multiple() {
1257        let seg = CocoSegmentation::Polygon(vec![
1258            vec![0.0, 0.0, 10.0, 0.0, 10.0, 10.0], // 3 vertices
1259            vec![20.0, 20.0, 30.0, 20.0, 30.0, 30.0, 20.0, 30.0], // 4 vertices
1260        ]);
1261        assert_eq!(count_polygon_vertices(&seg), 7);
1262    }
1263
1264    #[test]
1265    fn test_count_polygon_parts() {
1266        let seg = CocoSegmentation::Polygon(vec![
1267            vec![0.0, 0.0, 10.0, 0.0, 10.0, 10.0],
1268            vec![20.0, 20.0, 30.0, 20.0, 30.0, 30.0],
1269        ]);
1270        assert_eq!(count_polygon_parts(&seg), 2);
1271    }
1272
1273    #[test]
1274    fn test_count_polygon_vertices_rle() {
1275        let rle = CocoRle {
1276            counts: vec![100],
1277            size: [10, 10],
1278        };
1279        let seg = CocoSegmentation::Rle(rle);
1280        assert_eq!(count_polygon_vertices(&seg), 0);
1281    }
1282
1283    // =========================================================================
1284    // VerificationResult tests
1285    // =========================================================================
1286
1287    #[test]
1288    fn test_verification_result_is_valid() {
1289        let result = VerificationResult {
1290            coco_image_count: 100,
1291            studio_image_count: 100,
1292            missing_images: vec![],
1293            extra_images: vec![],
1294            coco_annotation_count: 500,
1295            studio_annotation_count: 500,
1296            bbox_validation: {
1297                let mut bv = BboxValidationResult {
1298                    total_matched: 500,
1299                    sum_iou: 495.0,
1300                    ..Default::default()
1301                };
1302                bv.errors_by_range[0] = 2000; // All within 1px
1303                bv
1304            },
1305            mask_validation: {
1306                let mut mv = MaskValidationResult::new();
1307                mv.original_with_seg = 500;
1308                mv.restored_with_seg = 500;
1309                mv.matched_pairs_with_seg = 500;
1310                mv.sum_bbox_iou = 475.0;
1311                mv
1312            },
1313            category_validation: CategoryValidationResult {
1314                coco_categories: ["person"].iter().map(|s| s.to_string()).collect(),
1315                studio_categories: ["person"].iter().map(|s| s.to_string()).collect(),
1316                missing_categories: vec![],
1317                extra_categories: vec![],
1318            },
1319        };
1320
1321        assert!(result.is_valid());
1322    }
1323
1324    #[test]
1325    fn test_verification_result_summary() {
1326        let result = VerificationResult {
1327            coco_image_count: 100,
1328            studio_image_count: 98,
1329            missing_images: vec!["img1.jpg".to_string(), "img2.jpg".to_string()],
1330            extra_images: vec![],
1331            coco_annotation_count: 500,
1332            studio_annotation_count: 490,
1333            bbox_validation: BboxValidationResult::default(),
1334            mask_validation: MaskValidationResult::new(),
1335            category_validation: CategoryValidationResult::default(),
1336        };
1337
1338        let summary = result.summary();
1339        assert!(summary.contains("Images:"));
1340        assert!(summary.contains("Annotations:"));
1341    }
1342
1343    // =========================================================================
1344    // build_annotation_map_by_name tests
1345    // =========================================================================
1346
1347    #[test]
1348    fn test_build_annotation_map_by_name() {
1349        let dataset = CocoDataset {
1350            images: vec![
1351                CocoImage {
1352                    id: 1,
1353                    file_name: "image1.jpg".to_string(),
1354                    ..Default::default()
1355                },
1356                CocoImage {
1357                    id: 2,
1358                    file_name: "image2.jpg".to_string(),
1359                    ..Default::default()
1360                },
1361            ],
1362            annotations: vec![
1363                CocoAnnotation {
1364                    id: 1,
1365                    image_id: 1,
1366                    ..Default::default()
1367                },
1368                CocoAnnotation {
1369                    id: 2,
1370                    image_id: 1,
1371                    ..Default::default()
1372                },
1373                CocoAnnotation {
1374                    id: 3,
1375                    image_id: 2,
1376                    ..Default::default()
1377                },
1378            ],
1379            ..Default::default()
1380        };
1381
1382        let map = build_annotation_map_by_name(&dataset);
1383
1384        assert_eq!(map.len(), 2);
1385        assert_eq!(map.get("image1").unwrap().len(), 2);
1386        assert_eq!(map.get("image2").unwrap().len(), 1);
1387    }
1388
1389    // =========================================================================
1390    // validate_categories tests
1391    // =========================================================================
1392
1393    #[test]
1394    fn test_validate_categories_match() {
1395        let original = CocoDataset {
1396            categories: vec![
1397                CocoCategory {
1398                    id: 1,
1399                    name: "cat".to_string(),
1400                    supercategory: None,
1401                    ..Default::default()
1402                },
1403                CocoCategory {
1404                    id: 2,
1405                    name: "dog".to_string(),
1406                    supercategory: None,
1407                    ..Default::default()
1408                },
1409            ],
1410            ..Default::default()
1411        };
1412
1413        let restored = CocoDataset {
1414            categories: vec![
1415                CocoCategory {
1416                    id: 1,
1417                    name: "cat".to_string(),
1418                    supercategory: None,
1419                    ..Default::default()
1420                },
1421                CocoCategory {
1422                    id: 2,
1423                    name: "dog".to_string(),
1424                    supercategory: None,
1425                    ..Default::default()
1426                },
1427            ],
1428            ..Default::default()
1429        };
1430
1431        let result = validate_categories(&original, &restored);
1432        assert!(result.is_valid());
1433        assert!(result.missing_categories.is_empty());
1434        assert!(result.extra_categories.is_empty());
1435    }
1436
1437    #[test]
1438    fn test_validate_categories_missing_and_extra() {
1439        let original = CocoDataset {
1440            categories: vec![
1441                CocoCategory {
1442                    id: 1,
1443                    name: "cat".to_string(),
1444                    supercategory: None,
1445                    ..Default::default()
1446                },
1447                CocoCategory {
1448                    id: 2,
1449                    name: "dog".to_string(),
1450                    supercategory: None,
1451                    ..Default::default()
1452                },
1453            ],
1454            ..Default::default()
1455        };
1456
1457        let restored = CocoDataset {
1458            categories: vec![
1459                CocoCategory {
1460                    id: 1,
1461                    name: "cat".to_string(),
1462                    supercategory: None,
1463                    ..Default::default()
1464                },
1465                CocoCategory {
1466                    id: 3,
1467                    name: "bird".to_string(),
1468                    supercategory: None,
1469                    ..Default::default()
1470                },
1471            ],
1472            ..Default::default()
1473        };
1474
1475        let result = validate_categories(&original, &restored);
1476        assert!(!result.is_valid());
1477        assert!(result.missing_categories.contains(&"dog".to_string()));
1478        assert!(result.extra_categories.contains(&"bird".to_string()));
1479    }
1480
1481    // =========================================================================
1482    // compare_segmentation_pair tests
1483    // =========================================================================
1484
1485    #[test]
1486    fn test_compare_segmentation_pair_identical_polygons() {
1487        let seg =
1488            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 100.0, 0.0, 100.0, 100.0, 0.0, 100.0]]);
1489        let result = compare_segmentation_pair(&seg, &seg);
1490
1491        assert!(!result.is_rle);
1492        assert!(result.vertex_exact_match);
1493        assert!(result.vertex_close_match);
1494        assert!(result.part_match);
1495        assert!(result.area_ratio.is_some());
1496        assert!((result.area_ratio.unwrap() - 1.0).abs() < 1e-6);
1497        assert!((result.bbox_iou - 1.0).abs() < 1e-6);
1498    }
1499
1500    #[test]
1501    fn test_compare_segmentation_pair_different_vertex_count() {
1502        let seg1 =
1503            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 100.0, 0.0, 100.0, 100.0, 0.0, 100.0]]);
1504        // Triangle with same bounding box
1505        let seg2 = CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 100.0, 0.0, 50.0, 100.0]]);
1506
1507        let result = compare_segmentation_pair(&seg1, &seg2);
1508
1509        assert!(!result.is_rle);
1510        assert!(!result.vertex_exact_match); // 4 vs 3 vertices
1511        // Note: vertex_close_match is true because threshold is max(10%, 1) = 1, and
1512        // diff = 1
1513        assert!(result.vertex_close_match);
1514        assert!(result.part_match); // Both have 1 part
1515    }
1516
1517    #[test]
1518    fn test_compare_segmentation_pair_rle() {
1519        let rle = CocoRle {
1520            counts: vec![100],
1521            size: [10, 10],
1522        };
1523        let seg = CocoSegmentation::Rle(rle);
1524        let poly =
1525            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0]]);
1526
1527        let result = compare_segmentation_pair(&seg, &poly);
1528
1529        assert!(result.is_rle);
1530        assert!(!result.vertex_exact_match);
1531        assert!(!result.vertex_close_match);
1532        assert!(!result.part_match);
1533    }
1534
1535    #[test]
1536    fn test_compare_segmentation_pair_scaled() {
1537        let seg1 =
1538            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 100.0, 0.0, 100.0, 100.0, 0.0, 100.0]]);
1539        let seg2 =
1540            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 50.0, 0.0, 50.0, 50.0, 0.0, 50.0]]);
1541
1542        let result = compare_segmentation_pair(&seg1, &seg2);
1543
1544        assert!(result.area_ratio.is_some());
1545        // 2500 / 10000 = 0.25
1546        assert!((result.area_ratio.unwrap() - 0.25).abs() < 0.01);
1547    }
1548
1549    // =========================================================================
1550    // aggregate_comparison tests
1551    // =========================================================================
1552
1553    #[test]
1554    fn test_aggregate_comparison_polygon() {
1555        let mut result = MaskValidationResult::new();
1556        let cmp = SegmentationPairComparison {
1557            is_rle: false,
1558            vertex_exact_match: true,
1559            vertex_close_match: true,
1560            part_match: true,
1561            area_ratio: Some(1.0),
1562            bbox_iou: 0.95,
1563        };
1564
1565        result.aggregate_comparison(&cmp);
1566
1567        assert_eq!(result.matched_pairs_with_seg, 1);
1568        assert_eq!(result.polygon_pairs, 1);
1569        assert_eq!(result.rle_pairs, 0);
1570        assert_eq!(result.vertex_count_exact_match, 1);
1571        assert_eq!(result.vertex_count_close_match, 1);
1572        assert_eq!(result.part_count_match, 1);
1573        assert_eq!(result.area_within_1pct, 1);
1574        assert_eq!(result.area_within_5pct, 1);
1575        assert_eq!(result.bbox_iou_high, 1);
1576    }
1577
1578    #[test]
1579    fn test_aggregate_comparison_rle() {
1580        let mut result = MaskValidationResult::new();
1581        let cmp = SegmentationPairComparison {
1582            is_rle: true,
1583            vertex_exact_match: false,
1584            vertex_close_match: false,
1585            part_match: false,
1586            area_ratio: Some(0.98),
1587            bbox_iou: 0.92,
1588        };
1589
1590        result.aggregate_comparison(&cmp);
1591
1592        assert_eq!(result.matched_pairs_with_seg, 1);
1593        assert_eq!(result.polygon_pairs, 0);
1594        assert_eq!(result.rle_pairs, 1);
1595        assert_eq!(result.vertex_count_exact_match, 0);
1596        assert_eq!(result.area_within_5pct, 1);
1597        assert_eq!(result.bbox_iou_high, 1);
1598    }
1599
1600    #[test]
1601    fn test_aggregate_comparison_zero_area() {
1602        let mut result = MaskValidationResult::new();
1603        let cmp = SegmentationPairComparison {
1604            is_rle: false,
1605            vertex_exact_match: true,
1606            vertex_close_match: true,
1607            part_match: true,
1608            area_ratio: None, // Zero area
1609            bbox_iou: 0.3,
1610        };
1611
1612        result.aggregate_comparison(&cmp);
1613
1614        assert_eq!(result.zero_area_count, 1);
1615        assert_eq!(result.area_within_1pct, 0);
1616        assert_eq!(result.bbox_iou_low, 1);
1617    }
1618
1619    #[test]
1620    fn test_aggregate_comparison_multiple() {
1621        let mut result = MaskValidationResult::new();
1622
1623        let cmp1 = SegmentationPairComparison {
1624            is_rle: false,
1625            vertex_exact_match: true,
1626            vertex_close_match: true,
1627            part_match: true,
1628            area_ratio: Some(1.0),
1629            bbox_iou: 0.95,
1630        };
1631        let cmp2 = SegmentationPairComparison {
1632            is_rle: true,
1633            vertex_exact_match: false,
1634            vertex_close_match: false,
1635            part_match: false,
1636            area_ratio: Some(0.9),
1637            bbox_iou: 0.85,
1638        };
1639
1640        result.aggregate_comparison(&cmp1);
1641        result.aggregate_comparison(&cmp2);
1642
1643        assert_eq!(result.matched_pairs_with_seg, 2);
1644        assert_eq!(result.polygon_pairs, 1);
1645        assert_eq!(result.rle_pairs, 1);
1646        assert!((result.sum_area_ratio - 1.9).abs() < 0.01);
1647        assert!((result.min_area_ratio - 0.9).abs() < 0.01);
1648        assert!((result.max_area_ratio - 1.0).abs() < 0.01);
1649    }
1650}