edgefirst_client/coco/
verify.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! COCO dataset verification utilities.
5//!
6//! Provides functions for comparing COCO datasets and validating annotation
7//! accuracy using Hungarian matching for optimal annotation pairing.
8
9use super::{
10    decode_compressed_rle, decode_rle,
11    types::{CocoAnnotation, CocoDataset, CocoSegmentation},
12};
13use pathfinding::{kuhn_munkres::kuhn_munkres_min, matrix::Matrix};
14use std::{
15    collections::{HashMap, HashSet},
16    fmt,
17};
18
19/// Result of verifying a COCO import against Studio data.
20#[derive(Debug, Clone)]
21pub struct VerificationResult {
22    /// Total images in COCO dataset.
23    pub coco_image_count: usize,
24    /// Images found in Studio.
25    pub studio_image_count: usize,
26    /// Images missing from Studio.
27    pub missing_images: Vec<String>,
28    /// Extra images in Studio not in COCO.
29    pub extra_images: Vec<String>,
30    /// Total annotations in COCO dataset.
31    pub coco_annotation_count: usize,
32    /// Total annotations in Studio.
33    pub studio_annotation_count: usize,
34    /// Bounding box validation results.
35    pub bbox_validation: BboxValidationResult,
36    /// Segmentation mask validation results.
37    pub mask_validation: MaskValidationResult,
38    /// Category validation results.
39    pub category_validation: CategoryValidationResult,
40}
41
42impl VerificationResult {
43    /// Returns true if the verification passed all checks.
44    pub fn is_valid(&self) -> bool {
45        self.missing_images.is_empty()
46            && self.extra_images.is_empty()
47            && self.bbox_validation.is_valid()
48            && self.mask_validation.is_valid()
49    }
50
51    /// Returns a summary of the verification.
52    pub fn summary(&self) -> String {
53        let mut s = String::new();
54        s.push_str(&format!(
55            "Images: {}/{} (missing: {}, extra: {})\n",
56            self.studio_image_count,
57            self.coco_image_count,
58            self.missing_images.len(),
59            self.extra_images.len()
60        ));
61        s.push_str(&format!(
62            "Annotations: {}/{}\n",
63            self.studio_annotation_count, self.coco_annotation_count
64        ));
65        s.push_str(&format!(
66            "Bbox: {:.1}% matched, {:.4} avg IoU\n",
67            self.bbox_validation.match_rate() * 100.0,
68            self.bbox_validation.avg_iou()
69        ));
70        s.push_str(&format!(
71            "Masks: {:.1}% preserved, {:.4} avg bbox IoU\n",
72            self.mask_validation.preservation_rate() * 100.0,
73            self.mask_validation.avg_bbox_iou()
74        ));
75        s
76    }
77}
78
79impl fmt::Display for VerificationResult {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        writeln!(
82            f,
83            "╔══════════════════════════════════════════════════════════════╗"
84        )?;
85        writeln!(
86            f,
87            "║                  COCO IMPORT VERIFICATION                    ║"
88        )?;
89        writeln!(
90            f,
91            "╠══════════════════════════════════════════════════════════════╣"
92        )?;
93        writeln!(
94            f,
95            "║ Images:      {} in COCO, {} in Studio",
96            self.coco_image_count, self.studio_image_count
97        )?;
98        if !self.missing_images.is_empty() {
99            writeln!(f, "║   Missing:   {} images", self.missing_images.len())?;
100            for name in self.missing_images.iter().take(5) {
101                writeln!(f, "║              - {}", name)?;
102            }
103            if self.missing_images.len() > 5 {
104                writeln!(
105                    f,
106                    "║              ... and {} more",
107                    self.missing_images.len() - 5
108                )?;
109            }
110        }
111        if !self.extra_images.is_empty() {
112            writeln!(f, "║   Extra:     {} images", self.extra_images.len())?;
113            for name in self.extra_images.iter().take(5) {
114                writeln!(f, "║              - {}", name)?;
115            }
116            if self.extra_images.len() > 5 {
117                writeln!(
118                    f,
119                    "║              ... and {} more",
120                    self.extra_images.len() - 5
121                )?;
122            }
123        }
124        writeln!(
125            f,
126            "║ Annotations: {} in COCO, {} in Studio",
127            self.coco_annotation_count, self.studio_annotation_count
128        )?;
129        writeln!(
130            f,
131            "╠══════════════════════════════════════════════════════════════╣"
132        )?;
133        write!(f, "{}", self.bbox_validation)?;
134        writeln!(
135            f,
136            "╠══════════════════════════════════════════════════════════════╣"
137        )?;
138        write!(f, "{}", self.mask_validation)?;
139        writeln!(
140            f,
141            "╠══════════════════════════════════════════════════════════════╣"
142        )?;
143        write!(f, "{}", self.category_validation)?;
144        writeln!(
145            f,
146            "╠══════════════════════════════════════════════════════════════╣"
147        )?;
148        let status = if self.is_valid() {
149            "✓ PASSED"
150        } else {
151            "✗ FAILED"
152        };
153        writeln!(f, "║ Status: {}", status)?;
154        writeln!(
155            f,
156            "╚══════════════════════════════════════════════════════════════╝"
157        )?;
158        Ok(())
159    }
160}
161
162/// Bounding box validation results.
163#[derive(Debug, Clone, Default)]
164pub struct BboxValidationResult {
165    /// Total annotations that were matched using Hungarian algorithm.
166    pub total_matched: usize,
167    /// Total annotations that could not be matched (IoU too low).
168    pub total_unmatched: usize,
169    /// Coordinate errors by range: [<1px, <2px, <5px, <10px, >=10px]
170    pub errors_by_range: [usize; 5],
171    /// Maximum coordinate error in pixels.
172    pub max_error: f64,
173    /// Sum of IoU values for averaging.
174    pub sum_iou: f64,
175}
176
177impl BboxValidationResult {
178    /// Returns the percentage of coordinates within 1 pixel error.
179    pub fn within_1px_rate(&self) -> f64 {
180        let total_coords = self.total_matched * 4;
181        if total_coords == 0 {
182            1.0
183        } else {
184            self.errors_by_range[0] as f64 / total_coords as f64
185        }
186    }
187
188    /// Returns the percentage of coordinates within 2 pixels error.
189    pub fn within_2px_rate(&self) -> f64 {
190        let total_coords = self.total_matched * 4;
191        if total_coords == 0 {
192            1.0
193        } else {
194            (self.errors_by_range[0] + self.errors_by_range[1]) as f64 / total_coords as f64
195        }
196    }
197
198    /// Returns the average IoU across all matched annotations.
199    pub fn avg_iou(&self) -> f64 {
200        if self.total_matched == 0 {
201            1.0
202        } else {
203            self.sum_iou / self.total_matched as f64
204        }
205    }
206
207    /// Returns the match rate (matched / total).
208    pub fn match_rate(&self) -> f64 {
209        let total = self.total_matched + self.total_unmatched;
210        if total == 0 {
211            1.0
212        } else {
213            self.total_matched as f64 / total as f64
214        }
215    }
216
217    /// Returns true if bbox validation passes quality thresholds.
218    pub fn is_valid(&self) -> bool {
219        self.within_1px_rate() > 0.99 && self.match_rate() > 0.95 && self.avg_iou() > 0.95
220    }
221}
222
223impl fmt::Display for BboxValidationResult {
224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225        writeln!(f, "║ Bounding Box Validation:")?;
226        writeln!(
227            f,
228            "║   Matched:    {}/{} ({:.1}%)",
229            self.total_matched,
230            self.total_matched + self.total_unmatched,
231            self.match_rate() * 100.0
232        )?;
233        writeln!(f, "║   Avg IoU:    {:.4}", self.avg_iou())?;
234        writeln!(f, "║   Within 1px: {:.1}%", self.within_1px_rate() * 100.0)?;
235        writeln!(f, "║   Within 2px: {:.1}%", self.within_2px_rate() * 100.0)?;
236        writeln!(f, "║   Max error:  {:.2}px", self.max_error)?;
237        Ok(())
238    }
239}
240
241/// Segmentation mask validation results.
242#[derive(Debug, Clone, Default)]
243pub struct MaskValidationResult {
244    /// Annotations with segmentation in original.
245    pub original_with_seg: usize,
246    /// Annotations with segmentation in restored.
247    pub restored_with_seg: usize,
248    /// Matched pairs where both have segmentation.
249    pub matched_pairs_with_seg: usize,
250    /// Polygon pairs (for vertex comparison).
251    pub polygon_pairs: usize,
252    /// RLE pairs converted to polygon.
253    pub rle_pairs: usize,
254    /// Pairs where vertex count matches exactly.
255    pub vertex_count_exact_match: usize,
256    /// Pairs where vertex count is within 10%.
257    pub vertex_count_close_match: usize,
258    /// Pairs where part count matches.
259    pub part_count_match: usize,
260    /// Pairs with area within 1%.
261    pub area_within_1pct: usize,
262    /// Pairs with area within 5%.
263    pub area_within_5pct: usize,
264    /// Pairs with bbox IoU >= 0.9.
265    pub bbox_iou_high: usize,
266    /// Pairs with bbox IoU < 0.5.
267    pub bbox_iou_low: usize,
268    /// Sum of area ratios.
269    pub sum_area_ratio: f64,
270    /// Minimum area ratio.
271    pub min_area_ratio: f64,
272    /// Maximum area ratio.
273    pub max_area_ratio: f64,
274    /// Sum of bbox IoU values.
275    pub sum_bbox_iou: f64,
276    /// Count of zero-area segmentations.
277    pub zero_area_count: usize,
278}
279
280impl MaskValidationResult {
281    /// Create a new result with initialized min/max values.
282    pub fn new() -> Self {
283        Self {
284            min_area_ratio: f64::MAX,
285            max_area_ratio: 0.0,
286            ..Default::default()
287        }
288    }
289
290    /// Returns the segmentation preservation rate.
291    pub fn preservation_rate(&self) -> f64 {
292        if self.original_with_seg == 0 {
293            1.0
294        } else {
295            self.restored_with_seg as f64 / self.original_with_seg as f64
296        }
297    }
298
299    /// Returns the average area ratio.
300    pub fn avg_area_ratio(&self) -> f64 {
301        let valid_count = self
302            .matched_pairs_with_seg
303            .saturating_sub(self.zero_area_count);
304        if valid_count == 0 {
305            1.0
306        } else {
307            self.sum_area_ratio / valid_count as f64
308        }
309    }
310
311    /// Returns the average bbox IoU.
312    pub fn avg_bbox_iou(&self) -> f64 {
313        if self.matched_pairs_with_seg == 0 {
314            1.0
315        } else {
316            self.sum_bbox_iou / self.matched_pairs_with_seg as f64
317        }
318    }
319
320    /// Returns true if mask validation passes quality thresholds.
321    pub fn is_valid(&self) -> bool {
322        self.preservation_rate() > 0.95 && self.avg_bbox_iou() > 0.90
323    }
324
325    /// Aggregate a single segmentation comparison into the result.
326    pub fn aggregate_comparison(&mut self, cmp: &SegmentationPairComparison) {
327        self.matched_pairs_with_seg += 1;
328
329        if cmp.is_rle {
330            self.rle_pairs += 1;
331        } else {
332            self.polygon_pairs += 1;
333            if cmp.vertex_exact_match {
334                self.vertex_count_exact_match += 1;
335            }
336            if cmp.vertex_close_match {
337                self.vertex_count_close_match += 1;
338            }
339            if cmp.part_match {
340                self.part_count_match += 1;
341            }
342        }
343
344        if let Some(area_ratio) = cmp.area_ratio {
345            self.sum_area_ratio += area_ratio;
346            self.min_area_ratio = self.min_area_ratio.min(area_ratio);
347            self.max_area_ratio = self.max_area_ratio.max(area_ratio);
348            if (area_ratio - 1.0).abs() <= 0.01 {
349                self.area_within_1pct += 1;
350            }
351            if (area_ratio - 1.0).abs() <= 0.05 {
352                self.area_within_5pct += 1;
353            }
354        } else {
355            self.zero_area_count += 1;
356        }
357
358        self.sum_bbox_iou += cmp.bbox_iou;
359        if cmp.bbox_iou >= 0.9 {
360            self.bbox_iou_high += 1;
361        }
362        if cmp.bbox_iou < 0.5 {
363            self.bbox_iou_low += 1;
364        }
365    }
366}
367
368impl fmt::Display for MaskValidationResult {
369    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370        writeln!(f, "║ Segmentation Mask Validation:")?;
371        writeln!(
372            f,
373            "║   Preserved:   {}/{} ({:.1}%)",
374            self.restored_with_seg,
375            self.original_with_seg,
376            self.preservation_rate() * 100.0
377        )?;
378        writeln!(
379            f,
380            "║   Matched:     {} ({} polygon, {} RLE→polygon)",
381            self.matched_pairs_with_seg, self.polygon_pairs, self.rle_pairs
382        )?;
383        writeln!(f, "║   Avg bbox IoU: {:.4}", self.avg_bbox_iou())?;
384        writeln!(
385            f,
386            "║   High IoU (>=0.9): {}/{} ({:.1}%)",
387            self.bbox_iou_high,
388            self.matched_pairs_with_seg,
389            if self.matched_pairs_with_seg > 0 {
390                self.bbox_iou_high as f64 / self.matched_pairs_with_seg as f64 * 100.0
391            } else {
392                100.0
393            }
394        )?;
395        if self.polygon_pairs > 0 {
396            writeln!(
397                f,
398                "║   Vertex exact: {}/{} ({:.1}%)",
399                self.vertex_count_exact_match,
400                self.polygon_pairs,
401                self.vertex_count_exact_match as f64 / self.polygon_pairs as f64 * 100.0
402            )?;
403        }
404        Ok(())
405    }
406}
407
408/// Category validation results.
409#[derive(Debug, Clone, Default)]
410pub struct CategoryValidationResult {
411    /// Categories in COCO dataset.
412    pub coco_categories: HashSet<String>,
413    /// Categories in Studio.
414    pub studio_categories: HashSet<String>,
415    /// Categories missing from Studio.
416    pub missing_categories: Vec<String>,
417    /// Extra categories in Studio.
418    pub extra_categories: Vec<String>,
419}
420
421impl CategoryValidationResult {
422    /// Returns true if all categories are present.
423    pub fn is_valid(&self) -> bool {
424        self.missing_categories.is_empty()
425    }
426}
427
428impl fmt::Display for CategoryValidationResult {
429    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
430        writeln!(
431            f,
432            "║ Categories:    {} in COCO, {} in Studio",
433            self.coco_categories.len(),
434            self.studio_categories.len()
435        )?;
436        if !self.missing_categories.is_empty() {
437            writeln!(f, "║   Missing:     {:?}", self.missing_categories)?;
438        }
439        if !self.extra_categories.is_empty() {
440            writeln!(f, "║   Extra:       {:?}", self.extra_categories)?;
441        }
442        Ok(())
443    }
444}
445
446/// Result of comparing a single segmentation pair.
447#[derive(Debug, Clone, Default)]
448pub struct SegmentationPairComparison {
449    /// Whether the original was RLE (vs polygon).
450    pub is_rle: bool,
451    /// Vertex count exact match (polygon only).
452    pub vertex_exact_match: bool,
453    /// Vertex count within 10% (polygon only).
454    pub vertex_close_match: bool,
455    /// Part count match (polygon only).
456    pub part_match: bool,
457    /// Area ratio (restored / original), or None if zero area.
458    pub area_ratio: Option<f64>,
459    /// Bounding box IoU.
460    pub bbox_iou: f64,
461}
462
463/// Compare two segmentations and return comparison metrics.
464pub fn compare_segmentation_pair(
465    orig_seg: &CocoSegmentation,
466    rest_seg: &CocoSegmentation,
467) -> SegmentationPairComparison {
468    let is_rle = matches!(
469        orig_seg,
470        CocoSegmentation::Rle(_) | CocoSegmentation::CompressedRle(_)
471    );
472
473    let (vertex_exact_match, vertex_close_match, part_match) = if is_rle {
474        (false, false, false)
475    } else {
476        let orig_vertices = count_polygon_vertices(orig_seg);
477        let rest_vertices = count_polygon_vertices(rest_seg);
478        let orig_parts = count_polygon_parts(orig_seg);
479        let rest_parts = count_polygon_parts(rest_seg);
480
481        let vertex_diff = (orig_vertices as f64 - rest_vertices as f64).abs();
482        let vertex_threshold = (orig_vertices as f64 * 0.1).max(1.0);
483
484        (
485            orig_vertices == rest_vertices,
486            vertex_diff <= vertex_threshold,
487            orig_parts == rest_parts,
488        )
489    };
490
491    let orig_area = compute_segmentation_area(orig_seg);
492    let rest_area = compute_segmentation_area(rest_seg);
493    let area_ratio = if orig_area > 0.0 && rest_area > 0.0 {
494        Some(rest_area / orig_area)
495    } else {
496        None
497    };
498
499    let bbox_iou = segmentation_bbox_iou(orig_seg, rest_seg);
500
501    SegmentationPairComparison {
502        is_rle,
503        vertex_exact_match,
504        vertex_close_match,
505        part_match,
506        area_ratio,
507        bbox_iou,
508    }
509}
510
511/// Calculate Intersection over Union (IoU) for two COCO bboxes.
512/// COCO bbox format: [x, y, width, height] (top-left corner)
513pub fn bbox_iou(a: &[f64; 4], b: &[f64; 4]) -> f64 {
514    let a_x1 = a[0];
515    let a_y1 = a[1];
516    let a_x2 = a[0] + a[2];
517    let a_y2 = a[1] + a[3];
518
519    let b_x1 = b[0];
520    let b_y1 = b[1];
521    let b_x2 = b[0] + b[2];
522    let b_y2 = b[1] + b[3];
523
524    // Intersection
525    let inter_x1 = a_x1.max(b_x1);
526    let inter_y1 = a_y1.max(b_y1);
527    let inter_x2 = a_x2.min(b_x2);
528    let inter_y2 = a_y2.min(b_y2);
529
530    let inter_w = (inter_x2 - inter_x1).max(0.0);
531    let inter_h = (inter_y2 - inter_y1).max(0.0);
532    let inter_area = inter_w * inter_h;
533
534    // Union
535    let a_area = a[2] * a[3];
536    let b_area = b[2] * b[3];
537    let union_area = a_area + b_area - inter_area;
538
539    if union_area > 0.0 {
540        inter_area / union_area
541    } else {
542        0.0
543    }
544}
545
546/// Use Hungarian algorithm to find optimal matching between two sets of
547/// annotations based on bounding box IoU.
548/// Returns pairs of (original_idx, restored_idx) for matched annotations.
549pub fn hungarian_match<'a>(
550    orig_anns: &[&'a CocoAnnotation],
551    rest_anns: &[&'a CocoAnnotation],
552) -> Vec<(usize, usize)> {
553    if orig_anns.is_empty() || rest_anns.is_empty() {
554        return vec![];
555    }
556
557    let n = orig_anns.len();
558    let m = rest_anns.len();
559
560    // Make the matrix square by padding with high-cost dummy entries
561    let size = n.max(m);
562
563    // Build cost matrix: cost = (1 - IoU) * scale
564    // We use i64 for kuhn_munkres, scale by 10000 for precision
565    let scale = 10000i64;
566    let max_cost = scale; // Cost for non-matching (IoU = 0)
567
568    let mut weights = Vec::with_capacity(size * size);
569    for i in 0..size {
570        for j in 0..size {
571            let cost = match (orig_anns.get(i), rest_anns.get(j)) {
572                (Some(orig), Some(rest)) => {
573                    let iou = bbox_iou(&orig.bbox, &rest.bbox);
574                    ((1.0 - iou) * scale as f64) as i64
575                }
576                _ => max_cost, // Dummy entry for padding
577            };
578            weights.push(cost);
579        }
580    }
581
582    let matrix = Matrix::from_vec(size, size, weights).expect("Failed to create matrix");
583    let (_, assignments) = kuhn_munkres_min(&matrix);
584
585    // Filter to only real matches (not dummy) with reasonable IoU
586    let min_iou_threshold = 0.3; // Only accept matches with IoU > 0.3
587    assignments
588        .iter()
589        .enumerate()
590        .filter_map(|(i, &j)| {
591            if i < n && j < m {
592                let iou = bbox_iou(&orig_anns[i].bbox, &rest_anns[j].bbox);
593                if iou >= min_iou_threshold {
594                    Some((i, j))
595                } else {
596                    None
597                }
598            } else {
599                None
600            }
601        })
602        .collect()
603}
604
605/// Calculate polygon area using the Shoelace formula.
606/// Takes coordinates as flat array [x1, y1, x2, y2, ...]
607pub fn polygon_area(coords: &[f64]) -> f64 {
608    let n = coords.len() / 2;
609    if n < 3 {
610        return 0.0;
611    }
612
613    let mut area = 0.0;
614    for i in 0..n {
615        let j = (i + 1) % n;
616        let x_i = coords[i * 2];
617        let y_i = coords[i * 2 + 1];
618        let x_j = coords[j * 2];
619        let y_j = coords[j * 2 + 1];
620        area += x_i * y_j - x_j * y_i;
621    }
622    (area / 2.0).abs()
623}
624
625/// Calculate total area of a segmentation.
626pub fn compute_segmentation_area(seg: &CocoSegmentation) -> f64 {
627    match seg {
628        CocoSegmentation::Polygon(polys) => polys.iter().map(|p| polygon_area(p)).sum(),
629        CocoSegmentation::Rle(rle) => {
630            if let Ok((mask, _, _)) = decode_rle(rle) {
631                mask.iter().filter(|&&v| v == 1).count() as f64
632            } else {
633                0.0
634            }
635        }
636        CocoSegmentation::CompressedRle(compressed) => {
637            if let Ok((mask, _, _)) = decode_compressed_rle(compressed) {
638                mask.iter().filter(|&&v| v == 1).count() as f64
639            } else {
640                0.0
641            }
642        }
643    }
644}
645
646/// Calculate bounding box of a polygon (min_x, min_y, max_x, max_y)
647pub fn polygon_bounds(coords: &[f64]) -> Option<(f64, f64, f64, f64)> {
648    if coords.len() < 4 {
649        return None;
650    }
651    let mut min_x = f64::MAX;
652    let mut min_y = f64::MAX;
653    let mut max_x = f64::MIN;
654    let mut max_y = f64::MIN;
655
656    for chunk in coords.chunks(2) {
657        if chunk.len() == 2 {
658            min_x = min_x.min(chunk[0]);
659            max_x = max_x.max(chunk[0]);
660            min_y = min_y.min(chunk[1]);
661            max_y = max_y.max(chunk[1]);
662        }
663    }
664    Some((min_x, min_y, max_x, max_y))
665}
666
667/// Get bounding box for any segmentation type.
668pub fn segmentation_bounds(seg: &CocoSegmentation) -> Option<(f64, f64, f64, f64)> {
669    match seg {
670        CocoSegmentation::Polygon(polys) => {
671            polys
672                .iter()
673                .filter_map(|p| polygon_bounds(p))
674                .fold(None, |acc, b| match acc {
675                    None => Some(b),
676                    Some((min_x, min_y, max_x, max_y)) => Some((
677                        min_x.min(b.0),
678                        min_y.min(b.1),
679                        max_x.max(b.2),
680                        max_y.max(b.3),
681                    )),
682                })
683        }
684        CocoSegmentation::Rle(rle) => {
685            let (mask, height, width) = decode_rle(rle).ok()?;
686            rle_mask_bounds(&mask, height, width)
687        }
688        CocoSegmentation::CompressedRle(compressed) => {
689            let (mask, height, width) = decode_compressed_rle(compressed).ok()?;
690            rle_mask_bounds(&mask, height, width)
691        }
692    }
693}
694
695/// Find bounds of a binary mask.
696fn rle_mask_bounds(mask: &[u8], height: u32, width: u32) -> Option<(f64, f64, f64, f64)> {
697    let mut min_x = width;
698    let mut min_y = height;
699    let mut max_x = 0u32;
700    let mut max_y = 0u32;
701    let mut found_any = false;
702
703    for y in 0..height {
704        for x in 0..width {
705            let idx = (y as usize) * (width as usize) + (x as usize);
706            if mask.get(idx) == Some(&1) {
707                found_any = true;
708                min_x = min_x.min(x);
709                max_x = max_x.max(x);
710                min_y = min_y.min(y);
711                max_y = max_y.max(y);
712            }
713        }
714    }
715
716    if found_any {
717        Some((min_x as f64, min_y as f64, max_x as f64, max_y as f64))
718    } else {
719        None
720    }
721}
722
723/// Calculate IoU between two segmentation bounding boxes.
724pub fn segmentation_bbox_iou(seg1: &CocoSegmentation, seg2: &CocoSegmentation) -> f64 {
725    let bounds1 = segmentation_bounds(seg1);
726    let bounds2 = segmentation_bounds(seg2);
727
728    match (bounds1, bounds2) {
729        (Some((a_x1, a_y1, a_x2, a_y2)), Some((b_x1, b_y1, b_x2, b_y2))) => {
730            let inter_x1 = a_x1.max(b_x1);
731            let inter_y1 = a_y1.max(b_y1);
732            let inter_x2 = a_x2.min(b_x2);
733            let inter_y2 = a_y2.min(b_y2);
734
735            let inter_w = (inter_x2 - inter_x1).max(0.0);
736            let inter_h = (inter_y2 - inter_y1).max(0.0);
737            let inter_area = inter_w * inter_h;
738
739            let a_area = (a_x2 - a_x1) * (a_y2 - a_y1);
740            let b_area = (b_x2 - b_x1) * (b_y2 - b_y1);
741            let union_area = a_area + b_area - inter_area;
742
743            if union_area > 0.0 {
744                inter_area / union_area
745            } else {
746                0.0
747            }
748        }
749        _ => 0.0,
750    }
751}
752
753/// Count total polygon vertices in a segmentation.
754pub fn count_polygon_vertices(seg: &CocoSegmentation) -> usize {
755    match seg {
756        CocoSegmentation::Polygon(polys) => polys.iter().map(|p| p.len() / 2).sum(),
757        _ => 0,
758    }
759}
760
761/// Count number of polygon parts in a segmentation.
762pub fn count_polygon_parts(seg: &CocoSegmentation) -> usize {
763    match seg {
764        CocoSegmentation::Polygon(polys) => polys.len(),
765        _ => 0,
766    }
767}
768
769/// Build a map of annotations by sample name for efficient lookup.
770pub fn build_annotation_map_by_name(
771    dataset: &CocoDataset,
772) -> HashMap<String, Vec<&CocoAnnotation>> {
773    let image_names: HashMap<u64, String> = dataset
774        .images
775        .iter()
776        .map(|img| {
777            let name = std::path::Path::new(&img.file_name)
778                .file_stem()
779                .and_then(|s| s.to_str())
780                .unwrap_or(&img.file_name)
781                .to_string();
782            (img.id, name)
783        })
784        .collect();
785
786    let mut map: HashMap<String, Vec<_>> = HashMap::new();
787    for ann in &dataset.annotations {
788        if let Some(name) = image_names.get(&ann.image_id) {
789            map.entry(name.clone()).or_default().push(ann);
790        }
791    }
792    map
793}
794
795/// Validate bounding boxes between two datasets using Hungarian matching.
796pub fn validate_bboxes(original: &CocoDataset, restored: &CocoDataset) -> BboxValidationResult {
797    let mut result = BboxValidationResult::default();
798
799    let original_by_name = build_annotation_map_by_name(original);
800    let restored_by_name = build_annotation_map_by_name(restored);
801
802    for (name, orig_anns) in &original_by_name {
803        if let Some(rest_anns) = restored_by_name.get(name) {
804            let matches = hungarian_match(orig_anns, rest_anns);
805
806            for (orig_idx, rest_idx) in &matches {
807                let orig_ann = orig_anns[*orig_idx];
808                let rest_ann = rest_anns[*rest_idx];
809
810                // Track IoU
811                let iou = bbox_iou(&orig_ann.bbox, &rest_ann.bbox);
812                result.sum_iou += iou;
813
814                // Measure coordinate errors
815                for i in 0..4 {
816                    let error = (orig_ann.bbox[i] - rest_ann.bbox[i]).abs();
817                    result.max_error = result.max_error.max(error);
818
819                    if error < 1.0 {
820                        result.errors_by_range[0] += 1;
821                    } else if error < 2.0 {
822                        result.errors_by_range[1] += 1;
823                    } else if error < 5.0 {
824                        result.errors_by_range[2] += 1;
825                    } else if error < 10.0 {
826                        result.errors_by_range[3] += 1;
827                    } else {
828                        result.errors_by_range[4] += 1;
829                    }
830                }
831                result.total_matched += 1;
832            }
833
834            result.total_unmatched += orig_anns.len() - matches.len();
835        } else {
836            result.total_unmatched += orig_anns.len();
837        }
838    }
839
840    result
841}
842
843/// Validate segmentation masks between two datasets using Hungarian matching.
844pub fn validate_masks(original: &CocoDataset, restored: &CocoDataset) -> MaskValidationResult {
845    let mut result = MaskValidationResult::new();
846
847    // Count segmentations in original and restored
848    result.original_with_seg = original
849        .annotations
850        .iter()
851        .filter(|a| a.segmentation.is_some())
852        .count();
853    result.restored_with_seg = restored
854        .annotations
855        .iter()
856        .filter(|a| a.segmentation.is_some())
857        .count();
858
859    let original_by_name = build_annotation_map_by_name(original);
860    let restored_by_name = build_annotation_map_by_name(restored);
861
862    for (name, orig_anns) in &original_by_name {
863        if let Some(rest_anns) = restored_by_name.get(name) {
864            let matches = hungarian_match(orig_anns, rest_anns);
865
866            for (orig_idx, rest_idx) in &matches {
867                let orig_ann = orig_anns[*orig_idx];
868                let rest_ann = rest_anns[*rest_idx];
869
870                if let (Some(orig_seg), Some(rest_seg)) =
871                    (&orig_ann.segmentation, &rest_ann.segmentation)
872                {
873                    let comparison = compare_segmentation_pair(orig_seg, rest_seg);
874                    result.aggregate_comparison(&comparison);
875                }
876            }
877        }
878    }
879
880    result
881}
882
883/// Validate categories between two datasets.
884pub fn validate_categories(
885    original: &CocoDataset,
886    restored: &CocoDataset,
887) -> CategoryValidationResult {
888    let coco_cats: HashSet<String> = original.categories.iter().map(|c| c.name.clone()).collect();
889    let studio_cats: HashSet<String> = restored.categories.iter().map(|c| c.name.clone()).collect();
890
891    let missing: Vec<String> = coco_cats.difference(&studio_cats).cloned().collect();
892    let extra: Vec<String> = studio_cats.difference(&coco_cats).cloned().collect();
893
894    CategoryValidationResult {
895        coco_categories: coco_cats,
896        studio_categories: studio_cats,
897        missing_categories: missing,
898        extra_categories: extra,
899    }
900}
901
902#[cfg(test)]
903mod tests {
904    use super::*;
905    use crate::coco::{CocoCategory, CocoImage, CocoRle};
906
907    // =========================================================================
908    // bbox_iou tests
909    // =========================================================================
910
911    #[test]
912    fn test_bbox_iou_perfect_overlap() {
913        let a = [0.0, 0.0, 100.0, 100.0];
914        let b = [0.0, 0.0, 100.0, 100.0];
915        assert!((bbox_iou(&a, &b) - 1.0).abs() < 1e-6);
916    }
917
918    #[test]
919    fn test_bbox_iou_no_overlap() {
920        let a = [0.0, 0.0, 100.0, 100.0];
921        let b = [200.0, 200.0, 100.0, 100.0];
922        assert!(bbox_iou(&a, &b) < 1e-6);
923    }
924
925    #[test]
926    fn test_bbox_iou_partial_overlap() {
927        let a = [0.0, 0.0, 100.0, 100.0];
928        let b = [50.0, 50.0, 100.0, 100.0];
929        // Intersection: 50x50 = 2500, Union: 10000 + 10000 - 2500 = 17500
930        let expected = 2500.0 / 17500.0;
931        assert!((bbox_iou(&a, &b) - expected).abs() < 1e-6);
932    }
933
934    #[test]
935    fn test_bbox_iou_contained() {
936        // b is fully contained in a
937        let a = [0.0, 0.0, 100.0, 100.0];
938        let b = [25.0, 25.0, 50.0, 50.0];
939        // Intersection: 50x50 = 2500, Union: 10000
940        let expected = 2500.0 / 10000.0;
941        assert!((bbox_iou(&a, &b) - expected).abs() < 1e-6);
942    }
943
944    #[test]
945    fn test_bbox_iou_zero_area() {
946        let a = [0.0, 0.0, 0.0, 0.0];
947        let b = [0.0, 0.0, 100.0, 100.0];
948        assert!(bbox_iou(&a, &b) < 1e-6);
949    }
950
951    // =========================================================================
952    // polygon_area tests
953    // =========================================================================
954
955    #[test]
956    fn test_polygon_area_square() {
957        // 10x10 square
958        let coords = [0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0];
959        assert!((polygon_area(&coords) - 100.0).abs() < 1e-6);
960    }
961
962    #[test]
963    fn test_polygon_area_triangle() {
964        // Triangle with base 10 and height 10
965        let coords = [0.0, 0.0, 10.0, 0.0, 5.0, 10.0];
966        assert!((polygon_area(&coords) - 50.0).abs() < 1e-6);
967    }
968
969    #[test]
970    fn test_polygon_area_too_small() {
971        // Less than 3 points (6 coords)
972        let coords = [0.0, 0.0, 10.0, 10.0];
973        assert!(polygon_area(&coords) < 1e-6);
974    }
975
976    #[test]
977    fn test_polygon_area_complex() {
978        // L-shaped polygon (can compute as two rectangles)
979        // Points: (0,0), (20,0), (20,10), (10,10), (10,20), (0,20)
980        let coords = [
981            0.0, 0.0, 20.0, 0.0, 20.0, 10.0, 10.0, 10.0, 10.0, 20.0, 0.0, 20.0,
982        ];
983        // Area = 10*20 + 10*10 = 300
984        assert!((polygon_area(&coords) - 300.0).abs() < 1e-6);
985    }
986
987    // =========================================================================
988    // polygon_bounds tests
989    // =========================================================================
990
991    #[test]
992    fn test_polygon_bounds_square() {
993        let coords = [0.0, 0.0, 100.0, 0.0, 100.0, 100.0, 0.0, 100.0];
994        let bounds = polygon_bounds(&coords);
995        assert_eq!(bounds, Some((0.0, 0.0, 100.0, 100.0)));
996    }
997
998    #[test]
999    fn test_polygon_bounds_offset() {
1000        let coords = [50.0, 60.0, 150.0, 60.0, 150.0, 160.0, 50.0, 160.0];
1001        let bounds = polygon_bounds(&coords);
1002        assert_eq!(bounds, Some((50.0, 60.0, 150.0, 160.0)));
1003    }
1004
1005    #[test]
1006    fn test_polygon_bounds_too_small() {
1007        let coords = [0.0, 0.0];
1008        assert!(polygon_bounds(&coords).is_none());
1009    }
1010
1011    // =========================================================================
1012    // hungarian_match tests
1013    // =========================================================================
1014
1015    #[test]
1016    fn test_hungarian_match_empty_inputs() {
1017        let orig: Vec<&CocoAnnotation> = vec![];
1018        let rest: Vec<&CocoAnnotation> = vec![];
1019        let matches = hungarian_match(&orig, &rest);
1020        assert!(matches.is_empty());
1021    }
1022
1023    #[test]
1024    fn test_hungarian_match_perfect_match() {
1025        let ann1 = CocoAnnotation {
1026            id: 1,
1027            image_id: 1,
1028            category_id: 1,
1029            bbox: [0.0, 0.0, 100.0, 100.0],
1030            ..Default::default()
1031        };
1032        let ann2 = CocoAnnotation {
1033            id: 2,
1034            image_id: 1,
1035            category_id: 1,
1036            bbox: [0.0, 0.0, 100.0, 100.0], // Same bbox
1037            ..Default::default()
1038        };
1039
1040        let orig = vec![&ann1];
1041        let rest = vec![&ann2];
1042        let matches = hungarian_match(&orig, &rest);
1043
1044        assert_eq!(matches.len(), 1);
1045        assert_eq!(matches[0], (0, 0));
1046    }
1047
1048    #[test]
1049    fn test_hungarian_match_multiple() {
1050        let ann1 = CocoAnnotation {
1051            id: 1,
1052            image_id: 1,
1053            category_id: 1,
1054            bbox: [0.0, 0.0, 50.0, 50.0],
1055            ..Default::default()
1056        };
1057        let ann2 = CocoAnnotation {
1058            id: 2,
1059            image_id: 1,
1060            category_id: 1,
1061            bbox: [100.0, 100.0, 50.0, 50.0],
1062            ..Default::default()
1063        };
1064
1065        let ann3 = CocoAnnotation {
1066            id: 3,
1067            image_id: 1,
1068            category_id: 1,
1069            bbox: [100.0, 100.0, 50.0, 50.0], // Matches ann2
1070            ..Default::default()
1071        };
1072        let ann4 = CocoAnnotation {
1073            id: 4,
1074            image_id: 1,
1075            category_id: 1,
1076            bbox: [0.0, 0.0, 50.0, 50.0], // Matches ann1
1077            ..Default::default()
1078        };
1079
1080        let orig = vec![&ann1, &ann2];
1081        let rest = vec![&ann3, &ann4];
1082        let matches = hungarian_match(&orig, &rest);
1083
1084        assert_eq!(matches.len(), 2);
1085    }
1086
1087    #[test]
1088    fn test_hungarian_match_unequal_sizes() {
1089        let ann1 = CocoAnnotation {
1090            id: 1,
1091            image_id: 1,
1092            category_id: 1,
1093            bbox: [0.0, 0.0, 100.0, 100.0],
1094            ..Default::default()
1095        };
1096        let ann2 = CocoAnnotation {
1097            id: 2,
1098            image_id: 1,
1099            category_id: 1,
1100            bbox: [200.0, 200.0, 100.0, 100.0],
1101            ..Default::default()
1102        };
1103        let ann3 = CocoAnnotation {
1104            id: 3,
1105            image_id: 1,
1106            category_id: 1,
1107            bbox: [0.0, 0.0, 100.0, 100.0], // Matches ann1
1108            ..Default::default()
1109        };
1110
1111        let orig = vec![&ann1, &ann2];
1112        let rest = vec![&ann3]; // Fewer restored
1113        let matches = hungarian_match(&orig, &rest);
1114
1115        assert_eq!(matches.len(), 1);
1116        assert_eq!(matches[0], (0, 0)); // ann1 matched to ann3
1117    }
1118
1119    // =========================================================================
1120    // BboxValidationResult tests
1121    // =========================================================================
1122
1123    #[test]
1124    fn test_bbox_validation_result_rates() {
1125        let mut result = BboxValidationResult::default();
1126        result.total_matched = 100;
1127        result.total_unmatched = 10;
1128        result.errors_by_range[0] = 350; // 350/400 = 87.5%
1129        result.errors_by_range[1] = 40;
1130        result.sum_iou = 95.0;
1131
1132        assert!((result.match_rate() - 0.909).abs() < 0.01);
1133        assert!((result.avg_iou() - 0.95).abs() < 0.01);
1134    }
1135
1136    #[test]
1137    fn test_bbox_validation_result_empty() {
1138        let result = BboxValidationResult::default();
1139        assert!((result.match_rate() - 1.0).abs() < 1e-6);
1140        assert!((result.avg_iou() - 1.0).abs() < 1e-6);
1141        assert!((result.within_1px_rate() - 1.0).abs() < 1e-6);
1142    }
1143
1144    #[test]
1145    fn test_bbox_validation_result_is_valid() {
1146        let mut result = BboxValidationResult::default();
1147        result.total_matched = 100;
1148        result.errors_by_range[0] = 400; // All within 1px
1149        result.sum_iou = 98.0;
1150        assert!(result.is_valid());
1151    }
1152
1153    #[test]
1154    fn test_bbox_validation_result_not_valid() {
1155        let mut result = BboxValidationResult::default();
1156        result.total_matched = 100;
1157        result.total_unmatched = 50; // Low match rate
1158        result.errors_by_range[0] = 200;
1159        result.sum_iou = 50.0;
1160        assert!(!result.is_valid());
1161    }
1162
1163    // =========================================================================
1164    // MaskValidationResult tests
1165    // =========================================================================
1166
1167    #[test]
1168    fn test_mask_validation_result_new() {
1169        let result = MaskValidationResult::new();
1170        assert_eq!(result.min_area_ratio, f64::MAX);
1171        assert_eq!(result.max_area_ratio, 0.0);
1172    }
1173
1174    #[test]
1175    fn test_mask_validation_result_preservation_rate() {
1176        let mut result = MaskValidationResult::new();
1177        result.original_with_seg = 100;
1178        result.restored_with_seg = 95;
1179        assert!((result.preservation_rate() - 0.95).abs() < 1e-6);
1180    }
1181
1182    #[test]
1183    fn test_mask_validation_result_empty() {
1184        let result = MaskValidationResult::new();
1185        assert!((result.preservation_rate() - 1.0).abs() < 1e-6);
1186        assert!((result.avg_area_ratio() - 1.0).abs() < 1e-6);
1187        assert!((result.avg_bbox_iou() - 1.0).abs() < 1e-6);
1188    }
1189
1190    // =========================================================================
1191    // CategoryValidationResult tests
1192    // =========================================================================
1193
1194    #[test]
1195    fn test_category_validation_result_is_valid() {
1196        let result = CategoryValidationResult {
1197            coco_categories: ["person", "car"].iter().map(|s| s.to_string()).collect(),
1198            studio_categories: ["person", "car"].iter().map(|s| s.to_string()).collect(),
1199            missing_categories: vec![],
1200            extra_categories: vec![],
1201        };
1202        assert!(result.is_valid());
1203    }
1204
1205    #[test]
1206    fn test_category_validation_result_missing() {
1207        let result = CategoryValidationResult {
1208            coco_categories: ["person", "car"].iter().map(|s| s.to_string()).collect(),
1209            studio_categories: ["person"].iter().map(|s| s.to_string()).collect(),
1210            missing_categories: vec!["car".to_string()],
1211            extra_categories: vec![],
1212        };
1213        assert!(!result.is_valid());
1214    }
1215
1216    // =========================================================================
1217    // compute_segmentation_area tests
1218    // =========================================================================
1219
1220    #[test]
1221    fn test_compute_segmentation_area_polygon() {
1222        let seg =
1223            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 100.0, 0.0, 100.0, 100.0, 0.0, 100.0]]);
1224        let area = compute_segmentation_area(&seg);
1225        assert!((area - 10000.0).abs() < 1e-6);
1226    }
1227
1228    #[test]
1229    fn test_compute_segmentation_area_multiple_polygons() {
1230        // Two 10x10 squares
1231        let seg = CocoSegmentation::Polygon(vec![
1232            vec![0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0],
1233            vec![20.0, 20.0, 30.0, 20.0, 30.0, 30.0, 20.0, 30.0],
1234        ]);
1235        let area = compute_segmentation_area(&seg);
1236        assert!((area - 200.0).abs() < 1e-6);
1237    }
1238
1239    // =========================================================================
1240    // count_polygon_vertices/parts tests
1241    // =========================================================================
1242
1243    #[test]
1244    fn test_count_polygon_vertices() {
1245        let seg = CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0]]);
1246        assert_eq!(count_polygon_vertices(&seg), 4);
1247    }
1248
1249    #[test]
1250    fn test_count_polygon_vertices_multiple() {
1251        let seg = CocoSegmentation::Polygon(vec![
1252            vec![0.0, 0.0, 10.0, 0.0, 10.0, 10.0], // 3 vertices
1253            vec![20.0, 20.0, 30.0, 20.0, 30.0, 30.0, 20.0, 30.0], // 4 vertices
1254        ]);
1255        assert_eq!(count_polygon_vertices(&seg), 7);
1256    }
1257
1258    #[test]
1259    fn test_count_polygon_parts() {
1260        let seg = CocoSegmentation::Polygon(vec![
1261            vec![0.0, 0.0, 10.0, 0.0, 10.0, 10.0],
1262            vec![20.0, 20.0, 30.0, 20.0, 30.0, 30.0],
1263        ]);
1264        assert_eq!(count_polygon_parts(&seg), 2);
1265    }
1266
1267    #[test]
1268    fn test_count_polygon_vertices_rle() {
1269        let rle = CocoRle {
1270            counts: vec![100],
1271            size: [10, 10],
1272        };
1273        let seg = CocoSegmentation::Rle(rle);
1274        assert_eq!(count_polygon_vertices(&seg), 0);
1275    }
1276
1277    // =========================================================================
1278    // VerificationResult tests
1279    // =========================================================================
1280
1281    #[test]
1282    fn test_verification_result_is_valid() {
1283        let result = VerificationResult {
1284            coco_image_count: 100,
1285            studio_image_count: 100,
1286            missing_images: vec![],
1287            extra_images: vec![],
1288            coco_annotation_count: 500,
1289            studio_annotation_count: 500,
1290            bbox_validation: {
1291                let mut bv = BboxValidationResult::default();
1292                bv.total_matched = 500;
1293                bv.errors_by_range[0] = 2000; // All within 1px
1294                bv.sum_iou = 495.0;
1295                bv
1296            },
1297            mask_validation: {
1298                let mut mv = MaskValidationResult::new();
1299                mv.original_with_seg = 500;
1300                mv.restored_with_seg = 500;
1301                mv.matched_pairs_with_seg = 500;
1302                mv.sum_bbox_iou = 475.0;
1303                mv
1304            },
1305            category_validation: CategoryValidationResult {
1306                coco_categories: ["person"].iter().map(|s| s.to_string()).collect(),
1307                studio_categories: ["person"].iter().map(|s| s.to_string()).collect(),
1308                missing_categories: vec![],
1309                extra_categories: vec![],
1310            },
1311        };
1312
1313        assert!(result.is_valid());
1314    }
1315
1316    #[test]
1317    fn test_verification_result_summary() {
1318        let result = VerificationResult {
1319            coco_image_count: 100,
1320            studio_image_count: 98,
1321            missing_images: vec!["img1.jpg".to_string(), "img2.jpg".to_string()],
1322            extra_images: vec![],
1323            coco_annotation_count: 500,
1324            studio_annotation_count: 490,
1325            bbox_validation: BboxValidationResult::default(),
1326            mask_validation: MaskValidationResult::new(),
1327            category_validation: CategoryValidationResult::default(),
1328        };
1329
1330        let summary = result.summary();
1331        assert!(summary.contains("Images:"));
1332        assert!(summary.contains("Annotations:"));
1333    }
1334
1335    // =========================================================================
1336    // build_annotation_map_by_name tests
1337    // =========================================================================
1338
1339    #[test]
1340    fn test_build_annotation_map_by_name() {
1341        let dataset = CocoDataset {
1342            images: vec![
1343                CocoImage {
1344                    id: 1,
1345                    file_name: "image1.jpg".to_string(),
1346                    ..Default::default()
1347                },
1348                CocoImage {
1349                    id: 2,
1350                    file_name: "image2.jpg".to_string(),
1351                    ..Default::default()
1352                },
1353            ],
1354            annotations: vec![
1355                CocoAnnotation {
1356                    id: 1,
1357                    image_id: 1,
1358                    ..Default::default()
1359                },
1360                CocoAnnotation {
1361                    id: 2,
1362                    image_id: 1,
1363                    ..Default::default()
1364                },
1365                CocoAnnotation {
1366                    id: 3,
1367                    image_id: 2,
1368                    ..Default::default()
1369                },
1370            ],
1371            ..Default::default()
1372        };
1373
1374        let map = build_annotation_map_by_name(&dataset);
1375
1376        assert_eq!(map.len(), 2);
1377        assert_eq!(map.get("image1").unwrap().len(), 2);
1378        assert_eq!(map.get("image2").unwrap().len(), 1);
1379    }
1380
1381    // =========================================================================
1382    // validate_categories tests
1383    // =========================================================================
1384
1385    #[test]
1386    fn test_validate_categories_match() {
1387        let original = CocoDataset {
1388            categories: vec![
1389                CocoCategory {
1390                    id: 1,
1391                    name: "cat".to_string(),
1392                    supercategory: None,
1393                },
1394                CocoCategory {
1395                    id: 2,
1396                    name: "dog".to_string(),
1397                    supercategory: None,
1398                },
1399            ],
1400            ..Default::default()
1401        };
1402
1403        let restored = CocoDataset {
1404            categories: vec![
1405                CocoCategory {
1406                    id: 1,
1407                    name: "cat".to_string(),
1408                    supercategory: None,
1409                },
1410                CocoCategory {
1411                    id: 2,
1412                    name: "dog".to_string(),
1413                    supercategory: None,
1414                },
1415            ],
1416            ..Default::default()
1417        };
1418
1419        let result = validate_categories(&original, &restored);
1420        assert!(result.is_valid());
1421        assert!(result.missing_categories.is_empty());
1422        assert!(result.extra_categories.is_empty());
1423    }
1424
1425    #[test]
1426    fn test_validate_categories_missing_and_extra() {
1427        let original = CocoDataset {
1428            categories: vec![
1429                CocoCategory {
1430                    id: 1,
1431                    name: "cat".to_string(),
1432                    supercategory: None,
1433                },
1434                CocoCategory {
1435                    id: 2,
1436                    name: "dog".to_string(),
1437                    supercategory: None,
1438                },
1439            ],
1440            ..Default::default()
1441        };
1442
1443        let restored = CocoDataset {
1444            categories: vec![
1445                CocoCategory {
1446                    id: 1,
1447                    name: "cat".to_string(),
1448                    supercategory: None,
1449                },
1450                CocoCategory {
1451                    id: 3,
1452                    name: "bird".to_string(),
1453                    supercategory: None,
1454                },
1455            ],
1456            ..Default::default()
1457        };
1458
1459        let result = validate_categories(&original, &restored);
1460        assert!(!result.is_valid());
1461        assert!(result.missing_categories.contains(&"dog".to_string()));
1462        assert!(result.extra_categories.contains(&"bird".to_string()));
1463    }
1464
1465    // =========================================================================
1466    // compare_segmentation_pair tests
1467    // =========================================================================
1468
1469    #[test]
1470    fn test_compare_segmentation_pair_identical_polygons() {
1471        let seg =
1472            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 100.0, 0.0, 100.0, 100.0, 0.0, 100.0]]);
1473        let result = compare_segmentation_pair(&seg, &seg);
1474
1475        assert!(!result.is_rle);
1476        assert!(result.vertex_exact_match);
1477        assert!(result.vertex_close_match);
1478        assert!(result.part_match);
1479        assert!(result.area_ratio.is_some());
1480        assert!((result.area_ratio.unwrap() - 1.0).abs() < 1e-6);
1481        assert!((result.bbox_iou - 1.0).abs() < 1e-6);
1482    }
1483
1484    #[test]
1485    fn test_compare_segmentation_pair_different_vertex_count() {
1486        let seg1 =
1487            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 100.0, 0.0, 100.0, 100.0, 0.0, 100.0]]);
1488        // Triangle with same bounding box
1489        let seg2 = CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 100.0, 0.0, 50.0, 100.0]]);
1490
1491        let result = compare_segmentation_pair(&seg1, &seg2);
1492
1493        assert!(!result.is_rle);
1494        assert!(!result.vertex_exact_match); // 4 vs 3 vertices
1495        // Note: vertex_close_match is true because threshold is max(10%, 1) = 1, and
1496        // diff = 1
1497        assert!(result.vertex_close_match);
1498        assert!(result.part_match); // Both have 1 part
1499    }
1500
1501    #[test]
1502    fn test_compare_segmentation_pair_rle() {
1503        let rle = CocoRle {
1504            counts: vec![100],
1505            size: [10, 10],
1506        };
1507        let seg = CocoSegmentation::Rle(rle);
1508        let poly =
1509            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0]]);
1510
1511        let result = compare_segmentation_pair(&seg, &poly);
1512
1513        assert!(result.is_rle);
1514        assert!(!result.vertex_exact_match);
1515        assert!(!result.vertex_close_match);
1516        assert!(!result.part_match);
1517    }
1518
1519    #[test]
1520    fn test_compare_segmentation_pair_scaled() {
1521        let seg1 =
1522            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 100.0, 0.0, 100.0, 100.0, 0.0, 100.0]]);
1523        let seg2 =
1524            CocoSegmentation::Polygon(vec![vec![0.0, 0.0, 50.0, 0.0, 50.0, 50.0, 0.0, 50.0]]);
1525
1526        let result = compare_segmentation_pair(&seg1, &seg2);
1527
1528        assert!(result.area_ratio.is_some());
1529        // 2500 / 10000 = 0.25
1530        assert!((result.area_ratio.unwrap() - 0.25).abs() < 0.01);
1531    }
1532
1533    // =========================================================================
1534    // aggregate_comparison tests
1535    // =========================================================================
1536
1537    #[test]
1538    fn test_aggregate_comparison_polygon() {
1539        let mut result = MaskValidationResult::new();
1540        let cmp = SegmentationPairComparison {
1541            is_rle: false,
1542            vertex_exact_match: true,
1543            vertex_close_match: true,
1544            part_match: true,
1545            area_ratio: Some(1.0),
1546            bbox_iou: 0.95,
1547        };
1548
1549        result.aggregate_comparison(&cmp);
1550
1551        assert_eq!(result.matched_pairs_with_seg, 1);
1552        assert_eq!(result.polygon_pairs, 1);
1553        assert_eq!(result.rle_pairs, 0);
1554        assert_eq!(result.vertex_count_exact_match, 1);
1555        assert_eq!(result.vertex_count_close_match, 1);
1556        assert_eq!(result.part_count_match, 1);
1557        assert_eq!(result.area_within_1pct, 1);
1558        assert_eq!(result.area_within_5pct, 1);
1559        assert_eq!(result.bbox_iou_high, 1);
1560    }
1561
1562    #[test]
1563    fn test_aggregate_comparison_rle() {
1564        let mut result = MaskValidationResult::new();
1565        let cmp = SegmentationPairComparison {
1566            is_rle: true,
1567            vertex_exact_match: false,
1568            vertex_close_match: false,
1569            part_match: false,
1570            area_ratio: Some(0.98),
1571            bbox_iou: 0.92,
1572        };
1573
1574        result.aggregate_comparison(&cmp);
1575
1576        assert_eq!(result.matched_pairs_with_seg, 1);
1577        assert_eq!(result.polygon_pairs, 0);
1578        assert_eq!(result.rle_pairs, 1);
1579        assert_eq!(result.vertex_count_exact_match, 0);
1580        assert_eq!(result.area_within_5pct, 1);
1581        assert_eq!(result.bbox_iou_high, 1);
1582    }
1583
1584    #[test]
1585    fn test_aggregate_comparison_zero_area() {
1586        let mut result = MaskValidationResult::new();
1587        let cmp = SegmentationPairComparison {
1588            is_rle: false,
1589            vertex_exact_match: true,
1590            vertex_close_match: true,
1591            part_match: true,
1592            area_ratio: None, // Zero area
1593            bbox_iou: 0.3,
1594        };
1595
1596        result.aggregate_comparison(&cmp);
1597
1598        assert_eq!(result.zero_area_count, 1);
1599        assert_eq!(result.area_within_1pct, 0);
1600        assert_eq!(result.bbox_iou_low, 1);
1601    }
1602
1603    #[test]
1604    fn test_aggregate_comparison_multiple() {
1605        let mut result = MaskValidationResult::new();
1606
1607        let cmp1 = SegmentationPairComparison {
1608            is_rle: false,
1609            vertex_exact_match: true,
1610            vertex_close_match: true,
1611            part_match: true,
1612            area_ratio: Some(1.0),
1613            bbox_iou: 0.95,
1614        };
1615        let cmp2 = SegmentationPairComparison {
1616            is_rle: true,
1617            vertex_exact_match: false,
1618            vertex_close_match: false,
1619            part_match: false,
1620            area_ratio: Some(0.9),
1621            bbox_iou: 0.85,
1622        };
1623
1624        result.aggregate_comparison(&cmp1);
1625        result.aggregate_comparison(&cmp2);
1626
1627        assert_eq!(result.matched_pairs_with_seg, 2);
1628        assert_eq!(result.polygon_pairs, 1);
1629        assert_eq!(result.rle_pairs, 1);
1630        assert!((result.sum_area_ratio - 1.9).abs() < 0.01);
1631        assert!((result.min_area_ratio - 0.9).abs() < 0.01);
1632        assert!((result.max_area_ratio - 1.0).abs() < 0.01);
1633    }
1634}