oar_ocr_core/processors/
layout_utils.rs

1//! Layout parsing utilities.
2//!
3//! This module provides utilities for layout analysis, including sorting layout boxes
4//! and associating OCR results with layout regions. The implementation follows
5//! established approaches.
6
7use crate::processors::BoundingBox;
8use std::collections::HashSet;
9
10/// Result of associating OCR boxes with layout regions.
11#[derive(Debug, Clone)]
12pub struct LayoutOCRAssociation {
13    /// Indices of OCR boxes that are within the layout regions
14    pub matched_indices: Vec<usize>,
15    /// Indices of OCR boxes that are outside all layout regions
16    pub unmatched_indices: Vec<usize>,
17}
18
19/// Get indices of OCR boxes that overlap with layout regions.
20///
21/// This function checks which OCR boxes have significant overlap with any of the
22/// layout regions. An overlap is considered significant if the intersection has
23/// both width and height greater than the threshold (default: 3 pixels).
24///
25/// This follows standard overlap detection implementation.
26///
27/// # Arguments
28///
29/// * `ocr_boxes` - Slice of OCR bounding boxes
30/// * `layout_regions` - Slice of layout region bounding boxes
31/// * `threshold` - Minimum intersection dimension (default: 3.0 pixels)
32///
33/// # Returns
34///
35/// Vector of indices of OCR boxes that overlap with any layout region
36pub fn get_overlap_boxes_idx(
37    ocr_boxes: &[BoundingBox],
38    layout_regions: &[BoundingBox],
39    threshold: f32,
40) -> Vec<usize> {
41    let mut matched_indices = Vec::new();
42
43    if ocr_boxes.is_empty() || layout_regions.is_empty() {
44        return matched_indices;
45    }
46
47    // For each layout region, find overlapping OCR boxes
48    for layout_region in layout_regions {
49        for (idx, ocr_box) in ocr_boxes.iter().enumerate() {
50            if ocr_box.overlaps_with(layout_region, threshold) {
51                matched_indices.push(idx);
52            }
53        }
54    }
55
56    matched_indices
57}
58
59/// Associate OCR results with layout regions.
60///
61/// This function filters OCR boxes based on whether they are within or outside
62/// the specified layout regions.
63///
64/// This follows standard region association implementation.
65///
66/// # Arguments
67///
68/// * `ocr_boxes` - Slice of OCR bounding boxes
69/// * `layout_regions` - Slice of layout region bounding boxes
70/// * `flag_within` - If true, return boxes within regions; if false, return boxes outside regions
71/// * `threshold` - Minimum intersection dimension for overlap detection
72///
73/// # Returns
74///
75/// `LayoutOCRAssociation` containing matched and unmatched indices
76pub fn associate_ocr_with_layout(
77    ocr_boxes: &[BoundingBox],
78    layout_regions: &[BoundingBox],
79    flag_within: bool,
80    threshold: f32,
81) -> LayoutOCRAssociation {
82    let overlap_indices = get_overlap_boxes_idx(ocr_boxes, layout_regions, threshold);
83    let overlap_set: HashSet<usize> = overlap_indices.into_iter().collect();
84
85    let mut matched_indices = Vec::new();
86    let mut unmatched_indices = Vec::new();
87
88    for (idx, _) in ocr_boxes.iter().enumerate() {
89        let is_overlapping = overlap_set.contains(&idx);
90
91        if flag_within {
92            // Return boxes within regions
93            if is_overlapping {
94                matched_indices.push(idx);
95            } else {
96                unmatched_indices.push(idx);
97            }
98        } else {
99            // Return boxes outside regions
100            if !is_overlapping {
101                matched_indices.push(idx);
102            } else {
103                unmatched_indices.push(idx);
104            }
105        }
106    }
107
108    LayoutOCRAssociation {
109        matched_indices,
110        unmatched_indices,
111    }
112}
113
114/// Layout box with bounding box and label for sorting/processing purposes.
115///
116/// This is a lightweight structure used for layout processing utilities like
117/// sorting and OCR association. For final structured output, use
118/// `LayoutElement` from `domain::structure`.
119#[derive(Debug, Clone)]
120pub struct LayoutBox {
121    /// Bounding box of the layout element
122    pub bbox: BoundingBox,
123    /// Label/type of the layout element (e.g., "text", "title", "table", "figure")
124    pub label: String,
125    /// Optional content text
126    pub content: Option<String>,
127}
128
129impl LayoutBox {
130    /// Create a new layout box.
131    pub fn new(bbox: BoundingBox, label: String) -> Self {
132        Self {
133            bbox,
134            label,
135            content: None,
136        }
137    }
138
139    /// Create a layout box with content.
140    pub fn with_content(bbox: BoundingBox, label: String, content: String) -> Self {
141        Self {
142            bbox,
143            label,
144            content: Some(content),
145        }
146    }
147}
148
149/// Sort layout boxes in reading order with column detection.
150///
151/// This function sorts layout boxes from top to bottom, left to right, with special
152/// handling for two-column layouts. Boxes are first sorted by y-coordinate, then
153/// separated into left and right columns based on their x-coordinate.
154///
155/// The algorithm:
156/// 1. Sort boxes by (y, x) coordinates
157/// 2. Identify left column boxes (x1 < w/4 and x2 < 3w/5)
158/// 3. Identify right column boxes (x1 > 2w/5)
159/// 4. Other boxes are considered full-width
160/// 5. Within each column, sort by y-coordinate
161///
162/// This follows standard layout sorting implementation.
163///
164/// # Arguments
165///
166/// * `elements` - Slice of layout boxes to sort
167/// * `image_width` - Width of the image for column detection
168///
169/// # Returns
170///
171/// A vector of sorted layout boxes
172pub fn sort_layout_boxes(elements: &[LayoutBox], image_width: f32) -> Vec<LayoutBox> {
173    let num_boxes = elements.len();
174
175    if num_boxes <= 1 {
176        return elements.to_vec();
177    }
178
179    // Sort by y-coordinate first, then x-coordinate
180    let mut sorted: Vec<LayoutBox> = elements.to_vec();
181    sorted.sort_by(|a, b| {
182        let a_y = a.bbox.y_min();
183        let a_x = a.bbox.x_min();
184        let b_y = b.bbox.y_min();
185        let b_x = b.bbox.x_min();
186
187        match a_y.partial_cmp(&b_y) {
188            Some(std::cmp::Ordering::Equal) => {
189                a_x.partial_cmp(&b_x).unwrap_or(std::cmp::Ordering::Equal)
190            }
191            other => other.unwrap_or(std::cmp::Ordering::Equal),
192        }
193    });
194
195    let mut result = Vec::new();
196    let mut left_column = Vec::new();
197    let mut right_column = Vec::new();
198
199    let w = image_width;
200    let mut i = 0;
201
202    while i < num_boxes {
203        let elem = &sorted[i];
204        let x1 = elem.bbox.x_min();
205        let x2 = elem.bbox.x_max();
206
207        // Check if box is in left column
208        if x1 < w / 4.0 && x2 < 3.0 * w / 5.0 {
209            left_column.push(elem.clone());
210        }
211        // Check if box is in right column
212        else if x1 > 2.0 * w / 5.0 {
213            right_column.push(elem.clone());
214        }
215        // Full-width box - flush columns and add this box
216        else {
217            // Add accumulated column boxes
218            result.append(&mut left_column);
219            result.append(&mut right_column);
220            result.push(elem.clone());
221        }
222
223        i += 1;
224    }
225
226    // Sort left and right columns by y-coordinate
227    left_column.sort_by(|a, b| {
228        a.bbox
229            .y_min()
230            .partial_cmp(&b.bbox.y_min())
231            .unwrap_or(std::cmp::Ordering::Equal)
232    });
233    right_column.sort_by(|a, b| {
234        a.bbox
235            .y_min()
236            .partial_cmp(&b.bbox.y_min())
237            .unwrap_or(std::cmp::Ordering::Equal)
238    });
239
240    // Add remaining column boxes
241    result.append(&mut left_column);
242    result.append(&mut right_column);
243
244    result
245}
246
247/// Reconciles structure recognition cells with detected cells.
248///
249/// This function aligns the number of output cells with the structure cells (N).
250/// - If multiple detected cells map to one structure cell, they are merged (compressed).
251/// - If a structure cell has no matching detected cell, the original structure box is kept (filled).
252///
253/// # Arguments
254/// * `structure_cells` - Cells derived from structure recognition (provides logical N)
255/// * `detected_cells` - Cells from detection model (provides precise geometry)
256///
257/// # Returns
258/// * `Vec<BoundingBox>` - Reconciled bounding boxes of length N.
259pub fn reconcile_table_cells(
260    structure_cells: &[BoundingBox],
261    detected_cells: &[BoundingBox],
262) -> Vec<BoundingBox> {
263    let n = structure_cells.len();
264    if n == 0 {
265        return Vec::new();
266    }
267    if detected_cells.is_empty() {
268        return structure_cells.to_vec();
269    }
270
271    // If detection produces significantly more cells than the table structure,
272    // reduce them using KMeans-style clustering on box centers.
273    let mut det_boxes: Vec<BoundingBox> = detected_cells.to_vec();
274    if det_boxes.len() > n {
275        det_boxes = combine_rectangles_kmeans(&det_boxes, n);
276    }
277
278    // Assignments: structure_idx -> list of detected_indices
279    let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); n];
280
281    // Assign each detected cell to the best matching structure cell
282    for (det_idx, det_box) in det_boxes.iter().enumerate() {
283        let mut best_iou = 0.001f32; // Minimal threshold
284        let mut best_struct_idx: Option<usize> = None;
285
286        for (struct_idx, struct_box) in structure_cells.iter().enumerate() {
287            // Use IoU for assignment
288            // Note: We could also use intersection over detection area to handle
289            // cases where detection is much smaller than structure cell
290            let iou = calculate_iou(det_box, struct_box);
291            if iou > best_iou {
292                best_iou = iou;
293                best_struct_idx = Some(struct_idx);
294            }
295        }
296
297        if let Some(idx) = best_struct_idx {
298            assignments[idx].push(det_idx);
299        }
300    }
301
302    // Build result
303    let mut reconciled = Vec::with_capacity(n);
304    for i in 0..n {
305        let assigned = &assignments[i];
306        if assigned.is_empty() {
307            // Fill: No matching detection, keep original structure box
308            reconciled.push(structure_cells[i].clone());
309        } else if assigned.len() == 1 {
310            // Exact match: Use detected box
311            reconciled.push(det_boxes[assigned[0]].clone());
312        } else {
313            // Compress: Multiple detections map to one structure cell
314            // Merge them by taking the bounding box of all detections
315            let mut merged = det_boxes[assigned[0]].clone();
316            for &idx in &assigned[1..] {
317                merged = merged.union(&det_boxes[idx]);
318            }
319            reconciled.push(merged);
320        }
321    }
322
323    reconciled
324}
325
326/// Reprocesses detected table cell boxes using OCR boxes to better match the
327/// structure model's expected cell count.
328///
329/// This mirrors cell detection results reprocessing in
330/// `table_recognition/pipeline_v2.py`:
331/// - If detected cells > target_n, keep top-N by score.
332/// - Find OCR boxes not sufficiently covered by any cell (IoA >= 0.6).
333/// - If missing OCR boxes exist, supplement/merge boxes with KMeans-style clustering.
334/// - If final count is too small, fall back to clustering OCR boxes.
335///
336/// All boxes must be in the same coordinate system (typically table-crop coords).
337pub fn reprocess_table_cells_with_ocr(
338    detected_cells: &[BoundingBox],
339    detected_scores: &[f32],
340    ocr_boxes: &[BoundingBox],
341    target_n: usize,
342) -> Vec<BoundingBox> {
343    if target_n == 0 {
344        return Vec::new();
345    }
346
347    // If no detected cells, fall back to OCR clustering.
348    if detected_cells.is_empty() {
349        return combine_rectangles_kmeans(ocr_boxes, target_n);
350    }
351
352    // Defensive: scores length mismatch -> assume uniform.
353    let scores: Vec<f32> = if detected_scores.len() == detected_cells.len() {
354        detected_scores.to_vec()
355    } else {
356        vec![1.0; detected_cells.len()]
357    };
358
359    let mut cells: Vec<BoundingBox> = detected_cells.to_vec();
360
361    let mut more_cells_flag = false;
362    if cells.len() == target_n {
363        return cells;
364    } else if cells.len() > target_n {
365        more_cells_flag = true;
366        // Keep top target_n by score (descending).
367        let mut idxs: Vec<usize> = (0..cells.len()).collect();
368        idxs.sort_by(|&a, &b| {
369            scores[b]
370                .partial_cmp(&scores[a])
371                .unwrap_or(std::cmp::Ordering::Equal)
372        });
373        idxs.truncate(target_n);
374        cells = idxs.iter().map(|&i| cells[i].clone()).collect();
375    }
376
377    // Compute IoA (intersection / ocr_area) between OCR and cell boxes.
378    fn ioa_ocr_in_cell(ocr: &BoundingBox, cell: &BoundingBox) -> f32 {
379        let inter = ocr.intersection_area(cell);
380        if inter <= 0.0 {
381            return 0.0;
382        }
383        let area = (ocr.x_max() - ocr.x_min()) * (ocr.y_max() - ocr.y_min());
384        if area <= 0.0 { 0.0 } else { inter / area }
385    }
386
387    let iou_threshold = 0.6f32;
388    let mut ocr_miss_boxes: Vec<BoundingBox> = Vec::new();
389
390    for ocr_box in ocr_boxes {
391        let mut has_large_ioa = false;
392        let mut merge_ioa_sum = 0.0f32;
393        for cell_box in &cells {
394            let ioa = ioa_ocr_in_cell(ocr_box, cell_box);
395            if ioa > 0.0 {
396                merge_ioa_sum += ioa;
397            }
398            if ioa >= iou_threshold || merge_ioa_sum >= iou_threshold {
399                has_large_ioa = true;
400                break;
401            }
402        }
403        if !has_large_ioa {
404            ocr_miss_boxes.push(ocr_box.clone());
405        }
406    }
407
408    let mut final_results: Vec<BoundingBox>;
409
410    if ocr_miss_boxes.is_empty() {
411        final_results = cells;
412    } else if more_cells_flag {
413        // More cells than expected: merge cells + missing OCR boxes to target_n.
414        let mut merged = cells.clone();
415        merged.extend(ocr_miss_boxes);
416        final_results = combine_rectangles_kmeans(&merged, target_n);
417    } else {
418        // Fewer cells than expected: supplement with clustered missing OCR boxes.
419        let need_n = target_n.saturating_sub(cells.len());
420        let supp = combine_rectangles_kmeans(&ocr_miss_boxes, need_n);
421        final_results = cells;
422        final_results.extend(supp);
423    }
424
425    // If still too few, fall back to clustering OCR boxes.
426    if final_results.len() as f32 <= 0.6 * target_n as f32 {
427        final_results = combine_rectangles_kmeans(ocr_boxes, target_n);
428    }
429
430    final_results
431}
432
433/// Combines rectangles into at most `target_n` rectangles using KMeans-style clustering
434/// on box centers.
435pub fn combine_rectangles_kmeans(rectangles: &[BoundingBox], target_n: usize) -> Vec<BoundingBox> {
436    let num_rects = rectangles.len();
437    if num_rects == 0 || target_n == 0 {
438        return Vec::new();
439    }
440    if target_n >= num_rects {
441        return rectangles.to_vec();
442    }
443
444    // Represent each rectangle by its center point (x, y)
445    let points: Vec<(f32, f32)> = rectangles
446        .iter()
447        .map(|r| {
448            let cx = (r.x_min() + r.x_max()) * 0.5;
449            let cy = (r.y_min() + r.y_max()) * 0.5;
450            (cx, cy)
451        })
452        .collect();
453
454    // Initialize cluster centers using the first target_n points
455    let mut centers: Vec<(f32, f32)> = points.iter().take(target_n).cloned().collect();
456    let mut labels: Vec<usize> = vec![0; num_rects];
457
458    let max_iters = 10;
459    for _ in 0..max_iters {
460        let mut changed = false;
461
462        // Assignment step: assign each point to nearest center
463        for (i, &(px, py)) in points.iter().enumerate() {
464            let mut best_idx = 0usize;
465            let mut best_dist = f32::MAX;
466            for (c_idx, &(cx, cy)) in centers.iter().enumerate() {
467                let dx = px - cx;
468                let dy = py - cy;
469                let dist = dx * dx + dy * dy;
470                if dist < best_dist {
471                    best_dist = dist;
472                    best_idx = c_idx;
473                }
474            }
475            if labels[i] != best_idx {
476                labels[i] = best_idx;
477                changed = true;
478            }
479        }
480
481        // Recompute centers
482        let mut sums: Vec<(f32, f32, usize)> = vec![(0.0, 0.0, 0); target_n];
483        for (i, &(px, py)) in points.iter().enumerate() {
484            let l = labels[i];
485            sums[l].0 += px;
486            sums[l].1 += py;
487            sums[l].2 += 1;
488        }
489        for (c_idx, center) in centers.iter_mut().enumerate() {
490            let (sx, sy, count) = sums[c_idx];
491            if count > 0 {
492                center.0 = sx / count as f32;
493                center.1 = sy / count as f32;
494            }
495        }
496
497        if !changed {
498            break;
499        }
500    }
501
502    // Build combined rectangles per cluster
503    let mut combined: Vec<BoundingBox> = Vec::new();
504    for cluster_idx in 0..target_n {
505        let mut first = true;
506        let mut min_x = 0.0f32;
507        let mut min_y = 0.0f32;
508        let mut max_x = 0.0f32;
509        let mut max_y = 0.0f32;
510
511        for (i, rect) in rectangles.iter().enumerate() {
512            if labels[i] == cluster_idx {
513                if first {
514                    min_x = rect.x_min();
515                    min_y = rect.y_min();
516                    max_x = rect.x_max();
517                    max_y = rect.y_max();
518                    first = false;
519                } else {
520                    min_x = min_x.min(rect.x_min());
521                    min_y = min_y.min(rect.y_min());
522                    max_x = max_x.max(rect.x_max());
523                    max_y = max_y.max(rect.y_max());
524                }
525            }
526        }
527
528        if !first {
529            combined.push(BoundingBox::from_coords(min_x, min_y, max_x, max_y));
530        }
531    }
532
533    if combined.is_empty() {
534        rectangles.to_vec()
535    } else {
536        combined
537    }
538}
539
540/// Calculates Intersection over Union (IoU) between two bounding boxes.
541fn calculate_iou(a: &BoundingBox, b: &BoundingBox) -> f32 {
542    let inter_x1 = a.x_min().max(b.x_min());
543    let inter_y1 = a.y_min().max(b.y_min());
544    let inter_x2 = a.x_max().min(b.x_max());
545    let inter_y2 = a.y_max().min(b.y_max());
546
547    let inter_area = (inter_x2 - inter_x1).max(0.0) * (inter_y2 - inter_y1).max(0.0);
548
549    let area_a = (a.x_max() - a.x_min()) * (a.y_max() - a.y_min());
550    let area_b = (b.x_max() - b.x_min()) * (b.y_max() - b.y_min());
551
552    let union_area = area_a + area_b - inter_area;
553
554    if union_area <= 0.0 {
555        0.0
556    } else {
557        inter_area / union_area
558    }
559}
560
561/// Calculates Intersection over Area (IoA) - intersection / smaller box area.
562fn calculate_ioa_smaller(a: &BoundingBox, b: &BoundingBox) -> f32 {
563    let inter_x1 = a.x_min().max(b.x_min());
564    let inter_y1 = a.y_min().max(b.y_min());
565    let inter_x2 = a.x_max().min(b.x_max());
566    let inter_y2 = a.y_max().min(b.y_max());
567
568    let inter_area = (inter_x2 - inter_x1).max(0.0) * (inter_y2 - inter_y1).max(0.0);
569
570    let area_a = (a.x_max() - a.x_min()) * (a.y_max() - a.y_min());
571    let area_b = (b.x_max() - b.x_min()) * (b.y_max() - b.y_min());
572
573    let smaller_area = area_a.min(area_b);
574
575    if smaller_area <= 0.0 {
576        0.0
577    } else {
578        inter_area / smaller_area
579    }
580}
581
582/// Result of overlap removal.
583#[derive(Debug, Clone)]
584pub struct OverlapRemovalResult<T> {
585    /// Elements that were kept after overlap removal
586    pub kept: Vec<T>,
587    /// Indices of elements that were removed
588    pub removed_indices: Vec<usize>,
589}
590
591/// Removes overlapping layout blocks based on overlap ratio threshold.
592///
593/// This follows standard overlap removal implementation in
594/// `layout_parsing/utils.py`. When two blocks overlap significantly:
595/// - If one is an image and one is not, the image is removed (text takes priority)
596/// - Otherwise, the smaller block is removed
597///
598/// # Arguments
599///
600/// * `elements` - Slice of layout elements to process
601/// * `threshold` - Overlap ratio threshold (default: 0.65)
602///   If intersection/smaller_area > threshold, blocks are considered overlapping
603///
604/// # Returns
605///
606/// `OverlapRemovalResult` containing kept elements and indices of removed elements
607///
608/// # Example
609///
610/// ```ignore
611/// use oar_ocr_core::processors::layout_utils::{remove_overlap_blocks, LayoutBox};
612/// use oar_ocr_core::processors::BoundingBox;
613///
614/// let elements = vec![
615///     LayoutBox::new(BoundingBox::from_coords(0.0, 0.0, 100.0, 100.0), "text".to_string()),
616///     LayoutBox::new(BoundingBox::from_coords(10.0, 10.0, 90.0, 90.0), "text".to_string()),
617/// ];
618///
619/// let result = remove_overlap_blocks(&elements, 0.65);
620/// assert_eq!(result.kept.len(), 1); // Smaller overlapping box was removed
621/// ```
622pub fn remove_overlap_blocks(
623    elements: &[LayoutBox],
624    threshold: f32,
625) -> OverlapRemovalResult<LayoutBox> {
626    let n = elements.len();
627    if n <= 1 {
628        return OverlapRemovalResult {
629            kept: elements.to_vec(),
630            removed_indices: Vec::new(),
631        };
632    }
633
634    let mut dropped_indices: HashSet<usize> = HashSet::new();
635
636    // Compare all pairs of elements
637    for i in 0..n {
638        if dropped_indices.contains(&i) {
639            continue;
640        }
641
642        for j in (i + 1)..n {
643            if dropped_indices.contains(&j) {
644                continue;
645            }
646
647            let elem_i = &elements[i];
648            let elem_j = &elements[j];
649
650            // Calculate overlap ratio (intersection / smaller area)
651            let overlap_ratio = calculate_ioa_smaller(&elem_i.bbox, &elem_j.bbox);
652
653            if overlap_ratio > threshold {
654                // Determine which element to remove
655                let is_i_image = elem_i.label == "image";
656                let is_j_image = elem_j.label == "image";
657
658                let drop_index = if is_i_image != is_j_image {
659                    // One is image, one is not: remove the image (text takes priority)
660                    if is_i_image { i } else { j }
661                } else {
662                    // Same type: remove the smaller one
663                    let area_i = (elem_i.bbox.x_max() - elem_i.bbox.x_min())
664                        * (elem_i.bbox.y_max() - elem_i.bbox.y_min());
665                    let area_j = (elem_j.bbox.x_max() - elem_j.bbox.x_min())
666                        * (elem_j.bbox.y_max() - elem_j.bbox.y_min());
667
668                    if area_i < area_j { i } else { j }
669                };
670
671                dropped_indices.insert(drop_index);
672                tracing::debug!(
673                    "Removing overlapping element {} (label={}, overlap={:.2})",
674                    drop_index,
675                    elements[drop_index].label,
676                    overlap_ratio
677                );
678            }
679        }
680    }
681
682    // Build result
683    let mut kept = Vec::new();
684    let mut removed_indices: Vec<usize> = dropped_indices.into_iter().collect();
685    removed_indices.sort();
686
687    for (idx, elem) in elements.iter().enumerate() {
688        if !removed_indices.contains(&idx) {
689            kept.push(elem.clone());
690        }
691    }
692
693    tracing::info!(
694        "Overlap removal: {} elements -> {} kept, {} removed",
695        n,
696        kept.len(),
697        removed_indices.len()
698    );
699
700    OverlapRemovalResult {
701        kept,
702        removed_indices,
703    }
704}
705
706/// Removes overlapping layout blocks, returning only indices to remove.
707///
708/// A lighter-weight version that works with any bbox type implementing the required traits.
709/// This is useful when you want to apply overlap removal to `LayoutElement` from `domain::structure`.
710///
711/// # Arguments
712///
713/// * `bboxes` - Slice of bounding boxes
714/// * `labels` - Slice of labels corresponding to each bbox
715/// * `threshold` - Overlap ratio threshold
716///
717/// # Returns
718///
719/// Set of indices that should be removed
720pub fn get_overlap_removal_indices(
721    bboxes: &[BoundingBox],
722    labels: &[&str],
723    threshold: f32,
724) -> HashSet<usize> {
725    let n = bboxes.len();
726    if n <= 1 || n != labels.len() {
727        return HashSet::new();
728    }
729
730    let mut dropped_indices: HashSet<usize> = HashSet::new();
731
732    for i in 0..n {
733        if dropped_indices.contains(&i) {
734            continue;
735        }
736
737        for j in (i + 1)..n {
738            if dropped_indices.contains(&j) {
739                continue;
740            }
741
742            let overlap_ratio = calculate_ioa_smaller(&bboxes[i], &bboxes[j]);
743
744            if overlap_ratio > threshold {
745                let is_i_image = labels[i] == "image";
746                let is_j_image = labels[j] == "image";
747
748                let drop_index = if is_i_image != is_j_image {
749                    if is_i_image { i } else { j }
750                } else {
751                    let area_i = (bboxes[i].x_max() - bboxes[i].x_min())
752                        * (bboxes[i].y_max() - bboxes[i].y_min());
753                    let area_j = (bboxes[j].x_max() - bboxes[j].x_min())
754                        * (bboxes[j].y_max() - bboxes[j].y_min());
755
756                    if area_i < area_j { i } else { j }
757                };
758
759                dropped_indices.insert(drop_index);
760            }
761        }
762    }
763
764    dropped_indices
765}
766
767#[cfg(test)]
768mod tests {
769    use super::*;
770
771    #[test]
772    fn test_get_overlap_boxes_idx() {
773        // Create OCR boxes
774        let ocr_boxes = vec![
775            BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0), // inside region
776            BoundingBox::from_coords(60.0, 60.0, 100.0, 80.0), // inside region
777            BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0), // outside region
778        ];
779
780        // Create layout region
781        let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 150.0, 150.0)];
782
783        let matched = get_overlap_boxes_idx(&ocr_boxes, &layout_regions, 3.0);
784
785        // First two boxes should match
786        assert_eq!(matched.len(), 2);
787        assert!(matched.contains(&0));
788        assert!(matched.contains(&1));
789        assert!(!matched.contains(&2));
790    }
791
792    #[test]
793    fn test_associate_ocr_with_layout_within() {
794        let ocr_boxes = vec![
795            BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0),
796            BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0),
797        ];
798
799        let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 100.0, 100.0)];
800
801        let association = associate_ocr_with_layout(&ocr_boxes, &layout_regions, true, 3.0);
802
803        assert_eq!(association.matched_indices.len(), 1);
804        assert_eq!(association.matched_indices[0], 0);
805        assert_eq!(association.unmatched_indices.len(), 1);
806        assert_eq!(association.unmatched_indices[0], 1);
807    }
808
809    #[test]
810    fn test_associate_ocr_with_layout_outside() {
811        let ocr_boxes = vec![
812            BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0),
813            BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0),
814        ];
815
816        let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 100.0, 100.0)];
817
818        let association = associate_ocr_with_layout(&ocr_boxes, &layout_regions, false, 3.0);
819
820        // flag_within=false returns boxes outside regions
821        assert_eq!(association.matched_indices.len(), 1);
822        assert_eq!(association.matched_indices[0], 1);
823    }
824
825    #[test]
826    fn test_sort_layout_boxes_single_column() {
827        let elements = vec![
828            LayoutBox::new(
829                BoundingBox::from_coords(10.0, 50.0, 200.0, 70.0),
830                "text".to_string(),
831            ), // bottom
832            LayoutBox::new(
833                BoundingBox::from_coords(10.0, 10.0, 200.0, 30.0),
834                "title".to_string(),
835            ), // top
836        ];
837
838        let sorted = sort_layout_boxes(&elements, 300.0);
839
840        assert_eq!(sorted[0].label, "title"); // top first
841        assert_eq!(sorted[1].label, "text"); // bottom second
842    }
843
844    #[test]
845    fn test_sort_layout_boxes_two_columns() {
846        let image_width = 400.0;
847
848        let elements = vec![
849            // Left column boxes (x < w/4 and x2 < 3w/5)
850            LayoutBox::new(
851                BoundingBox::from_coords(10.0, 100.0, 90.0, 120.0),
852                "left_bottom".to_string(),
853            ),
854            LayoutBox::new(
855                BoundingBox::from_coords(10.0, 50.0, 90.0, 70.0),
856                "left_top".to_string(),
857            ),
858            // Right column boxes (x > 2w/5)
859            LayoutBox::new(
860                BoundingBox::from_coords(250.0, 100.0, 390.0, 120.0),
861                "right_bottom".to_string(),
862            ),
863            LayoutBox::new(
864                BoundingBox::from_coords(250.0, 50.0, 390.0, 70.0),
865                "right_top".to_string(),
866            ),
867            // Full-width box (neither left nor right)
868            LayoutBox::new(
869                BoundingBox::from_coords(10.0, 10.0, 390.0, 30.0),
870                "title".to_string(),
871            ),
872        ];
873
874        let sorted = sort_layout_boxes(&elements, image_width);
875
876        // Expected order:
877        // 1. title (full-width, top)
878        // 2. left_top (left column, higher)
879        // 3. right_top (right column, higher)
880        // 4. left_bottom (left column, lower)
881        // 5. right_bottom (right column, lower)
882
883        assert_eq!(sorted[0].label, "title");
884        // Left column should come before right column
885        let left_top_idx = sorted.iter().position(|e| e.label == "left_top").unwrap();
886        let left_bottom_idx = sorted
887            .iter()
888            .position(|e| e.label == "left_bottom")
889            .unwrap();
890        let right_top_idx = sorted.iter().position(|e| e.label == "right_top").unwrap();
891        let right_bottom_idx = sorted
892            .iter()
893            .position(|e| e.label == "right_bottom")
894            .unwrap();
895
896        // Within left column, top should come before bottom
897        assert!(left_top_idx < left_bottom_idx);
898        // Within right column, top should come before bottom
899        assert!(right_top_idx < right_bottom_idx);
900    }
901
902    #[test]
903    fn test_sort_layout_boxes_empty() {
904        let elements: Vec<LayoutBox> = Vec::new();
905        let sorted = sort_layout_boxes(&elements, 300.0);
906        assert!(sorted.is_empty());
907    }
908
909    #[test]
910    fn test_sort_layout_boxes_single_element() {
911        let elements = vec![LayoutBox::new(
912            BoundingBox::from_coords(10.0, 10.0, 100.0, 30.0),
913            "text".to_string(),
914        )];
915
916        let sorted = sort_layout_boxes(&elements, 300.0);
917        assert_eq!(sorted.len(), 1);
918        assert_eq!(sorted[0].label, "text");
919    }
920}