Skip to main content

oar_ocr_core/processors/
layout_utils.rs

1//! Layout parsing utilities.
2//!
3//! This module provides utilities for layout analysis, including sorting layout boxes
4//! and associating OCR results with layout regions. The implementation follows
5//! established approaches.
6
7use crate::processors::BoundingBox;
8use std::collections::HashSet;
9
10/// Result of associating OCR boxes with layout regions.
11#[derive(Debug, Clone)]
12pub struct LayoutOCRAssociation {
13    /// Indices of OCR boxes that are within the layout regions
14    pub matched_indices: Vec<usize>,
15    /// Indices of OCR boxes that are outside all layout regions
16    pub unmatched_indices: Vec<usize>,
17}
18
19/// Get indices of OCR boxes that overlap with layout regions.
20///
21/// This function checks which OCR boxes have significant overlap with any of the
22/// layout regions. An overlap is considered significant if the intersection has
23/// both width and height greater than the threshold (default: 3 pixels).
24///
25/// This follows standard overlap detection implementation.
26///
27/// # Arguments
28///
29/// * `ocr_boxes` - Slice of OCR bounding boxes
30/// * `layout_regions` - Slice of layout region bounding boxes
31/// * `threshold` - Minimum intersection dimension (default: 3.0 pixels)
32///
33/// # Returns
34///
35/// Vector of indices of OCR boxes that overlap with any layout region
36pub fn get_overlap_boxes_idx(
37    ocr_boxes: &[BoundingBox],
38    layout_regions: &[BoundingBox],
39    threshold: f32,
40) -> Vec<usize> {
41    let mut matched_indices = Vec::new();
42
43    if ocr_boxes.is_empty() || layout_regions.is_empty() {
44        return matched_indices;
45    }
46
47    // For each layout region, find overlapping OCR boxes
48    for layout_region in layout_regions {
49        for (idx, ocr_box) in ocr_boxes.iter().enumerate() {
50            if ocr_box.overlaps_with(layout_region, threshold) {
51                matched_indices.push(idx);
52            }
53        }
54    }
55
56    matched_indices
57}
58
59/// Associate OCR results with layout regions.
60///
61/// This function filters OCR boxes based on whether they are within or outside
62/// the specified layout regions.
63///
64/// This follows standard region association implementation.
65///
66/// # Arguments
67///
68/// * `ocr_boxes` - Slice of OCR bounding boxes
69/// * `layout_regions` - Slice of layout region bounding boxes
70/// * `flag_within` - If true, return boxes within regions; if false, return boxes outside regions
71/// * `threshold` - Minimum intersection dimension for overlap detection
72///
73/// # Returns
74///
75/// `LayoutOCRAssociation` containing matched and unmatched indices
76pub fn associate_ocr_with_layout(
77    ocr_boxes: &[BoundingBox],
78    layout_regions: &[BoundingBox],
79    flag_within: bool,
80    threshold: f32,
81) -> LayoutOCRAssociation {
82    let overlap_indices = get_overlap_boxes_idx(ocr_boxes, layout_regions, threshold);
83    let overlap_set: HashSet<usize> = overlap_indices.into_iter().collect();
84
85    let mut matched_indices = Vec::new();
86    let mut unmatched_indices = Vec::new();
87
88    for (idx, _) in ocr_boxes.iter().enumerate() {
89        let is_overlapping = overlap_set.contains(&idx);
90
91        if flag_within {
92            // Return boxes within regions
93            if is_overlapping {
94                matched_indices.push(idx);
95            } else {
96                unmatched_indices.push(idx);
97            }
98        } else {
99            // Return boxes outside regions
100            if !is_overlapping {
101                matched_indices.push(idx);
102            } else {
103                unmatched_indices.push(idx);
104            }
105        }
106    }
107
108    LayoutOCRAssociation {
109        matched_indices,
110        unmatched_indices,
111    }
112}
113
114/// Layout box with bounding box and label for sorting/processing purposes.
115///
116/// This is a lightweight structure used for layout processing utilities like
117/// sorting and OCR association. For final structured output, use
118/// `LayoutElement` from `domain::structure`.
119#[derive(Debug, Clone)]
120pub struct LayoutBox {
121    /// Bounding box of the layout element
122    pub bbox: BoundingBox,
123    /// Label/type of the layout element (e.g., "text", "title", "table", "figure")
124    pub label: String,
125    /// Optional content text
126    pub content: Option<String>,
127}
128
129impl LayoutBox {
130    /// Create a new layout box.
131    pub fn new(bbox: BoundingBox, label: String) -> Self {
132        Self {
133            bbox,
134            label,
135            content: None,
136        }
137    }
138
139    /// Create a layout box with content.
140    pub fn with_content(bbox: BoundingBox, label: String, content: String) -> Self {
141        Self {
142            bbox,
143            label,
144            content: Some(content),
145        }
146    }
147}
148
149/// Sort layout boxes in reading order with column detection.
150///
151/// This function sorts layout boxes from top to bottom, left to right, with special
152/// handling for two-column layouts. Boxes are first sorted by y-coordinate, then
153/// separated into left and right columns based on their x-coordinate.
154///
155/// The algorithm:
156/// 1. Sort boxes by (y, x) coordinates
157/// 2. Identify left column boxes (x1 < w/4 and x2 < 3w/5)
158/// 3. Identify right column boxes (x1 > 2w/5)
159/// 4. Other boxes are considered full-width
160/// 5. Within each column, sort by y-coordinate
161///
162/// This follows standard layout sorting implementation.
163///
164/// # Arguments
165///
166/// * `elements` - Slice of layout boxes to sort
167/// * `image_width` - Width of the image for column detection
168///
169/// # Returns
170///
171/// A vector of sorted layout boxes
172pub fn sort_layout_boxes(elements: &[LayoutBox], image_width: f32) -> Vec<LayoutBox> {
173    let num_boxes = elements.len();
174
175    if num_boxes <= 1 {
176        return elements.to_vec();
177    }
178
179    // Sort by y-coordinate first, then x-coordinate
180    let mut sorted: Vec<LayoutBox> = elements.to_vec();
181    sorted.sort_by(|a, b| {
182        let a_y = a.bbox.y_min();
183        let a_x = a.bbox.x_min();
184        let b_y = b.bbox.y_min();
185        let b_x = b.bbox.x_min();
186
187        match a_y.partial_cmp(&b_y) {
188            Some(std::cmp::Ordering::Equal) => {
189                a_x.partial_cmp(&b_x).unwrap_or(std::cmp::Ordering::Equal)
190            }
191            other => other.unwrap_or(std::cmp::Ordering::Equal),
192        }
193    });
194
195    let mut result = Vec::new();
196    let mut left_column = Vec::new();
197    let mut right_column = Vec::new();
198
199    let w = image_width;
200    let mut i = 0;
201
202    while i < num_boxes {
203        let elem = &sorted[i];
204        let x1 = elem.bbox.x_min();
205        let x2 = elem.bbox.x_max();
206
207        // Check if box is in left column
208        if x1 < w / 4.0 && x2 < 3.0 * w / 5.0 {
209            left_column.push(elem.clone());
210        }
211        // Check if box is in right column
212        else if x1 > 2.0 * w / 5.0 {
213            right_column.push(elem.clone());
214        }
215        // Full-width box - flush columns and add this box
216        else {
217            // Add accumulated column boxes
218            result.append(&mut left_column);
219            result.append(&mut right_column);
220            result.push(elem.clone());
221        }
222
223        i += 1;
224    }
225
226    // Sort left and right columns by y-coordinate
227    left_column.sort_by(|a, b| {
228        a.bbox
229            .y_min()
230            .partial_cmp(&b.bbox.y_min())
231            .unwrap_or(std::cmp::Ordering::Equal)
232    });
233    right_column.sort_by(|a, b| {
234        a.bbox
235            .y_min()
236            .partial_cmp(&b.bbox.y_min())
237            .unwrap_or(std::cmp::Ordering::Equal)
238    });
239
240    // Add remaining column boxes
241    result.append(&mut left_column);
242    result.append(&mut right_column);
243
244    result
245}
246
247/// Reconciles structure recognition cells with detected cells.
248///
249/// This function aligns the number of output cells with the structure cells (N).
250/// - If multiple detected cells map to one structure cell, they are merged (compressed).
251/// - If a structure cell has no matching detected cell, the original structure box is kept (filled).
252///
253/// # Arguments
254/// * `structure_cells` - Cells derived from structure recognition (provides logical N)
255/// * `detected_cells` - Cells from detection model (provides precise geometry)
256///
257/// # Returns
258/// * `Vec<BoundingBox>` - Reconciled bounding boxes of length N.
259pub fn reconcile_table_cells(
260    structure_cells: &[BoundingBox],
261    detected_cells: &[BoundingBox],
262) -> Vec<BoundingBox> {
263    let n = structure_cells.len();
264    if n == 0 {
265        return Vec::new();
266    }
267    if detected_cells.is_empty() {
268        return structure_cells.to_vec();
269    }
270
271    // If detection produces significantly more cells than the table structure,
272    // reduce them using KMeans-style clustering on box centers.
273    let mut det_boxes: Vec<BoundingBox> = detected_cells.to_vec();
274    if det_boxes.len() > n {
275        det_boxes = combine_rectangles_kmeans(&det_boxes, n);
276    }
277
278    // Assignments: structure_idx -> list of detected_indices
279    let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); n];
280
281    // Assign each detected cell to the best matching structure cell
282    for (det_idx, det_box) in det_boxes.iter().enumerate() {
283        let mut best_ioa = 0.001f32; // Minimal threshold
284        let mut best_struct_idx: Option<usize> = None;
285
286        let det_area = (det_box.x_max() - det_box.x_min()) * (det_box.y_max() - det_box.y_min());
287
288        for (struct_idx, struct_box) in structure_cells.iter().enumerate() {
289            // Use Intersection over Area (IoA) of detection for assignment.
290            // This properly handles cases where the structure cell has rowspan/colspan
291            // and is significantly larger than the detected text bounding box.
292            let inter_x1 = det_box.x_min().max(struct_box.x_min());
293            let inter_y1 = det_box.y_min().max(struct_box.y_min());
294            let inter_x2 = det_box.x_max().min(struct_box.x_max());
295            let inter_y2 = det_box.y_max().min(struct_box.y_max());
296
297            let inter_area = (inter_x2 - inter_x1).max(0.0) * (inter_y2 - inter_y1).max(0.0);
298
299            let ioa = if det_area > 0.0 {
300                inter_area / det_area
301            } else {
302                0.0
303            };
304
305            if ioa > best_ioa {
306                best_ioa = ioa;
307                best_struct_idx = Some(struct_idx);
308            }
309        }
310
311        if let Some(idx) = best_struct_idx {
312            assignments[idx].push(det_idx);
313        }
314    }
315
316    // Build result
317    let mut reconciled = Vec::with_capacity(n);
318    for i in 0..n {
319        let assigned = &assignments[i];
320        if assigned.is_empty() {
321            // Fill: No matching detection, keep original structure box
322            reconciled.push(structure_cells[i].clone());
323        } else if assigned.len() == 1 {
324            // Exact match: Use detected box
325            reconciled.push(det_boxes[assigned[0]].clone());
326        } else {
327            // Compress: Multiple detections map to one structure cell
328            // Merge them by taking the bounding box of all detections
329            let mut merged = det_boxes[assigned[0]].clone();
330            for &idx in &assigned[1..] {
331                merged = merged.union(&det_boxes[idx]);
332            }
333            reconciled.push(merged);
334        }
335    }
336
337    reconciled
338}
339
340/// Reprocesses detected table cell boxes using OCR boxes to better match the
341/// structure model's expected cell count.
342///
343/// This mirrors cell detection results reprocessing in
344/// `table_recognition/pipeline_v2.py`:
345/// - If detected cells > target_n, keep top-N by score.
346/// - Find OCR boxes not sufficiently covered by any cell (IoA >= 0.6).
347/// - If missing OCR boxes exist, supplement/merge boxes with KMeans-style clustering.
348/// - If final count is too small, fall back to clustering OCR boxes.
349///
350/// All boxes must be in the same coordinate system (typically table-crop coords).
351pub fn reprocess_table_cells_with_ocr(
352    detected_cells: &[BoundingBox],
353    detected_scores: &[f32],
354    ocr_boxes: &[BoundingBox],
355    target_n: usize,
356) -> Vec<BoundingBox> {
357    if target_n == 0 {
358        return Vec::new();
359    }
360
361    // If no detected cells, fall back to OCR clustering.
362    if detected_cells.is_empty() {
363        return combine_rectangles_kmeans(ocr_boxes, target_n);
364    }
365
366    // Defensive: scores length mismatch -> assume uniform.
367    let scores: Vec<f32> = if detected_scores.len() == detected_cells.len() {
368        detected_scores.to_vec()
369    } else {
370        vec![1.0; detected_cells.len()]
371    };
372
373    let mut cells: Vec<BoundingBox> = detected_cells.to_vec();
374
375    let mut more_cells_flag = false;
376    if cells.len() == target_n {
377        return cells;
378    } else if cells.len() > target_n {
379        more_cells_flag = true;
380        // Keep top target_n by score (descending).
381        let mut idxs: Vec<usize> = (0..cells.len()).collect();
382        idxs.sort_by(|&a, &b| {
383            scores[b]
384                .partial_cmp(&scores[a])
385                .unwrap_or(std::cmp::Ordering::Equal)
386        });
387        idxs.truncate(target_n);
388        cells = idxs.iter().map(|&i| cells[i].clone()).collect();
389    }
390
391    // Compute IoA (intersection / ocr_area) between OCR and cell boxes.
392    fn ioa_ocr_in_cell(ocr: &BoundingBox, cell: &BoundingBox) -> f32 {
393        let inter = ocr.intersection_area(cell);
394        if inter <= 0.0 {
395            return 0.0;
396        }
397        let area = (ocr.x_max() - ocr.x_min()) * (ocr.y_max() - ocr.y_min());
398        if area <= 0.0 { 0.0 } else { inter / area }
399    }
400
401    let iou_threshold = 0.6f32;
402    let mut ocr_miss_boxes: Vec<BoundingBox> = Vec::new();
403
404    for ocr_box in ocr_boxes {
405        let mut has_large_ioa = false;
406        let mut merge_ioa_sum = 0.0f32;
407        for cell_box in &cells {
408            let ioa = ioa_ocr_in_cell(ocr_box, cell_box);
409            if ioa > 0.0 {
410                merge_ioa_sum += ioa;
411            }
412            if ioa >= iou_threshold || merge_ioa_sum >= iou_threshold {
413                has_large_ioa = true;
414                break;
415            }
416        }
417        if !has_large_ioa {
418            ocr_miss_boxes.push(ocr_box.clone());
419        }
420    }
421
422    let mut final_results: Vec<BoundingBox>;
423
424    if ocr_miss_boxes.is_empty() {
425        final_results = cells;
426    } else if more_cells_flag {
427        // More cells than expected: merge cells + missing OCR boxes to target_n.
428        let mut merged = cells.clone();
429        merged.extend(ocr_miss_boxes);
430        final_results = combine_rectangles_kmeans(&merged, target_n);
431    } else {
432        // Fewer cells than expected: supplement with clustered missing OCR boxes.
433        let need_n = target_n.saturating_sub(cells.len());
434        let supp = combine_rectangles_kmeans(&ocr_miss_boxes, need_n);
435        final_results = cells;
436        final_results.extend(supp);
437    }
438
439    // If still too few, fall back to clustering OCR boxes.
440    if final_results.len() as f32 <= 0.6 * target_n as f32 {
441        final_results = combine_rectangles_kmeans(ocr_boxes, target_n);
442    }
443
444    final_results
445}
446
447/// Combines rectangles into at most `target_n` rectangles using KMeans-style clustering
448/// on box centers.
449///
450/// Uses K-Means++ initialization for better cluster center selection.
451pub fn combine_rectangles_kmeans(rectangles: &[BoundingBox], target_n: usize) -> Vec<BoundingBox> {
452    let num_rects = rectangles.len();
453    if num_rects == 0 || target_n == 0 {
454        return Vec::new();
455    }
456    if target_n >= num_rects {
457        return rectangles.to_vec();
458    }
459
460    // Represent each rectangle by its center point (x, y)
461    let points: Vec<(f32, f32)> = rectangles
462        .iter()
463        .map(|r| {
464            let cx = (r.x_min() + r.x_max()) * 0.5;
465            let cy = (r.y_min() + r.y_max()) * 0.5;
466            (cx, cy)
467        })
468        .collect();
469
470    // Initialize cluster centers using K-Means++ algorithm
471    let centers = kmeans_maxdist_init(&points, target_n);
472    let mut centers = centers;
473    let mut labels: Vec<usize> = vec![0; num_rects];
474
475    let max_iters = 10;
476    for _ in 0..max_iters {
477        let mut changed = false;
478
479        // Assignment step: assign each point to nearest center
480        for (i, &(px, py)) in points.iter().enumerate() {
481            let mut best_idx = 0usize;
482            let mut best_dist = f32::MAX;
483            for (c_idx, &(cx, cy)) in centers.iter().enumerate() {
484                let dx = px - cx;
485                let dy = py - cy;
486                let dist = dx * dx + dy * dy;
487                if dist < best_dist {
488                    best_dist = dist;
489                    best_idx = c_idx;
490                }
491            }
492            if labels[i] != best_idx {
493                labels[i] = best_idx;
494                changed = true;
495            }
496        }
497
498        // Recompute centers
499        let mut sums: Vec<(f32, f32, usize)> = vec![(0.0, 0.0, 0); target_n];
500        for (i, &(px, py)) in points.iter().enumerate() {
501            let l = labels[i];
502            sums[l].0 += px;
503            sums[l].1 += py;
504            sums[l].2 += 1;
505        }
506        for (c_idx, center) in centers.iter_mut().enumerate() {
507            let (sx, sy, count) = sums[c_idx];
508            if count > 0 {
509                center.0 = sx / count as f32;
510                center.1 = sy / count as f32;
511            }
512        }
513
514        if !changed {
515            break;
516        }
517    }
518
519    // Build combined rectangles per cluster
520    let mut combined: Vec<BoundingBox> = Vec::new();
521    for cluster_idx in 0..target_n {
522        let mut first = true;
523        let mut min_x = 0.0f32;
524        let mut min_y = 0.0f32;
525        let mut max_x = 0.0f32;
526        let mut max_y = 0.0f32;
527
528        for (i, rect) in rectangles.iter().enumerate() {
529            if labels[i] == cluster_idx {
530                if first {
531                    min_x = rect.x_min();
532                    min_y = rect.y_min();
533                    max_x = rect.x_max();
534                    max_y = rect.y_max();
535                    first = false;
536                } else {
537                    min_x = min_x.min(rect.x_min());
538                    min_y = min_y.min(rect.y_min());
539                    max_x = max_x.max(rect.x_max());
540                    max_y = max_y.max(rect.y_max());
541                }
542            }
543        }
544
545        if !first {
546            combined.push(BoundingBox::from_coords(min_x, min_y, max_x, max_y));
547        }
548    }
549
550    if combined.is_empty() {
551        rectangles.to_vec()
552    } else {
553        combined
554    }
555}
556
557/// Deterministic K-Means initialization using max-distance selection.
558///
559/// This is a simplified variant of K-Means++ that deterministically selects
560/// the point with maximum distance from existing centers, rather than using
561/// probabilistic selection. This avoids random number generation while still
562/// providing good initial cluster spread.
563///
564/// # Arguments
565///
566/// * `points` - Slice of (x, y) points to cluster.
567/// * `k` - Number of clusters.
568///
569/// # Returns
570///
571/// Vector of k initial cluster centers.
572fn kmeans_maxdist_init(points: &[(f32, f32)], k: usize) -> Vec<(f32, f32)> {
573    if points.is_empty() || k == 0 {
574        return Vec::new();
575    }
576
577    if k >= points.len() {
578        return points.to_vec();
579    }
580
581    let mut centers: Vec<(f32, f32)> = Vec::with_capacity(k);
582
583    // Use a simple deterministic selection based on point index for reproducibility
584    // Select the first center as the point with median x-coordinate
585    let mut sorted_by_x: Vec<usize> = (0..points.len()).collect();
586    sorted_by_x.sort_by(|&a, &b| {
587        points[a]
588            .0
589            .partial_cmp(&points[b].0)
590            .unwrap_or(std::cmp::Ordering::Equal)
591    });
592    let first_idx = sorted_by_x[sorted_by_x.len() / 2];
593    centers.push(points[first_idx]);
594
595    // Select remaining centers using K-Means++ selection
596    for _ in 1..k {
597        // Compute squared distances to nearest center for each point
598        let mut distances: Vec<f32> = Vec::with_capacity(points.len());
599        let mut total_dist = 0.0f32;
600
601        for &(px, py) in points {
602            let min_dist_sq = centers
603                .iter()
604                .map(|&(cx, cy)| {
605                    let dx = px - cx;
606                    let dy = py - cy;
607                    dx * dx + dy * dy
608                })
609                .fold(f32::MAX, f32::min);
610
611            distances.push(min_dist_sq);
612            total_dist += min_dist_sq;
613        }
614
615        if total_dist <= 0.0 {
616            // All points are at existing centers, pick any remaining point
617            if let Some(&point) = points.iter().find(|p| !centers.contains(p)) {
618                centers.push(point);
619            } else {
620                break;
621            }
622            continue;
623        }
624
625        // Select next center deterministically: pick the point with maximum distance
626        // This is simpler than probabilistic K-Means++ but still provides good spread
627        let mut max_dist = 0.0f32;
628        let mut max_idx = 0;
629
630        for (i, &dist) in distances.iter().enumerate() {
631            if dist > max_dist {
632                max_dist = dist;
633                max_idx = i;
634            }
635        }
636
637        centers.push(points[max_idx]);
638    }
639
640    centers
641}
642
643/// Calculates Intersection over Area (IoA) - intersection / smaller box area.
644fn calculate_ioa_smaller(a: &BoundingBox, b: &BoundingBox) -> f32 {
645    let inter_x1 = a.x_min().max(b.x_min());
646    let inter_y1 = a.y_min().max(b.y_min());
647    let inter_x2 = a.x_max().min(b.x_max());
648    let inter_y2 = a.y_max().min(b.y_max());
649
650    let inter_area = (inter_x2 - inter_x1).max(0.0) * (inter_y2 - inter_y1).max(0.0);
651
652    let area_a = (a.x_max() - a.x_min()) * (a.y_max() - a.y_min());
653    let area_b = (b.x_max() - b.x_min()) * (b.y_max() - b.y_min());
654
655    let smaller_area = area_a.min(area_b);
656
657    if smaller_area <= 0.0 {
658        0.0
659    } else {
660        inter_area / smaller_area
661    }
662}
663
664/// Result of overlap removal.
665#[derive(Debug, Clone)]
666pub struct OverlapRemovalResult<T> {
667    /// Elements that were kept after overlap removal
668    pub kept: Vec<T>,
669    /// Indices of elements that were removed
670    pub removed_indices: Vec<usize>,
671}
672
673/// Removes overlapping layout blocks based on overlap ratio threshold.
674///
675/// This follows standard overlap removal implementation in
676/// `layout_parsing/utils.py`. When two blocks overlap significantly:
677/// - If one is an image and one is not, the image is removed (text takes priority)
678/// - Otherwise, the smaller block is removed
679///
680/// # Arguments
681///
682/// * `elements` - Slice of layout elements to process
683/// * `threshold` - Overlap ratio threshold (default: 0.65)
684///   If intersection/smaller_area > threshold, blocks are considered overlapping
685///
686/// # Returns
687///
688/// `OverlapRemovalResult` containing kept elements and indices of removed elements
689///
690/// # Example
691///
692/// ```ignore
693/// use oar_ocr_core::processors::layout_utils::{remove_overlap_blocks, LayoutBox};
694/// use oar_ocr_core::processors::BoundingBox;
695///
696/// let elements = vec![
697///     LayoutBox::new(BoundingBox::from_coords(0.0, 0.0, 100.0, 100.0), "text".to_string()),
698///     LayoutBox::new(BoundingBox::from_coords(10.0, 10.0, 90.0, 90.0), "text".to_string()),
699/// ];
700///
701/// let result = remove_overlap_blocks(&elements, 0.65);
702/// assert_eq!(result.kept.len(), 1); // Smaller overlapping box was removed
703/// ```
704pub fn remove_overlap_blocks(
705    elements: &[LayoutBox],
706    threshold: f32,
707) -> OverlapRemovalResult<LayoutBox> {
708    let n = elements.len();
709    if n <= 1 {
710        return OverlapRemovalResult {
711            kept: elements.to_vec(),
712            removed_indices: Vec::new(),
713        };
714    }
715
716    let mut dropped_indices: HashSet<usize> = HashSet::new();
717
718    // Compare all pairs of elements
719    for i in 0..n {
720        if dropped_indices.contains(&i) {
721            continue;
722        }
723
724        for j in (i + 1)..n {
725            if dropped_indices.contains(&j) {
726                continue;
727            }
728
729            let elem_i = &elements[i];
730            let elem_j = &elements[j];
731
732            // Calculate overlap ratio (intersection / smaller area)
733            let overlap_ratio = calculate_ioa_smaller(&elem_i.bbox, &elem_j.bbox);
734
735            if overlap_ratio > threshold {
736                // Determine which element to remove
737                let is_i_image = elem_i.label == "image";
738                let is_j_image = elem_j.label == "image";
739
740                let drop_index = if is_i_image != is_j_image {
741                    // One is image, one is not: remove the image (text takes priority)
742                    if is_i_image { i } else { j }
743                } else {
744                    // Same type: remove the smaller one
745                    let area_i = (elem_i.bbox.x_max() - elem_i.bbox.x_min())
746                        * (elem_i.bbox.y_max() - elem_i.bbox.y_min());
747                    let area_j = (elem_j.bbox.x_max() - elem_j.bbox.x_min())
748                        * (elem_j.bbox.y_max() - elem_j.bbox.y_min());
749
750                    if area_i < area_j { i } else { j }
751                };
752
753                dropped_indices.insert(drop_index);
754                tracing::debug!(
755                    "Removing overlapping element {} (label={}, overlap={:.2})",
756                    drop_index,
757                    elements[drop_index].label,
758                    overlap_ratio
759                );
760            }
761        }
762    }
763
764    // Build result
765    let mut kept = Vec::new();
766    let mut removed_indices: Vec<usize> = dropped_indices.into_iter().collect();
767    removed_indices.sort();
768
769    for (idx, elem) in elements.iter().enumerate() {
770        if !removed_indices.contains(&idx) {
771            kept.push(elem.clone());
772        }
773    }
774
775    tracing::info!(
776        "Overlap removal: {} elements -> {} kept, {} removed",
777        n,
778        kept.len(),
779        removed_indices.len()
780    );
781
782    OverlapRemovalResult {
783        kept,
784        removed_indices,
785    }
786}
787
788/// Removes overlapping layout blocks, returning only indices to remove.
789///
790/// A lighter-weight version that works with any bbox type implementing the required traits.
791/// This is useful when you want to apply overlap removal to `LayoutElement` from `domain::structure`.
792///
793/// # Arguments
794///
795/// * `bboxes` - Slice of bounding boxes
796/// * `labels` - Slice of labels corresponding to each bbox
797/// * `threshold` - Overlap ratio threshold
798///
799/// # Returns
800///
801/// Set of indices that should be removed
802pub fn get_overlap_removal_indices(
803    bboxes: &[BoundingBox],
804    labels: &[&str],
805    threshold: f32,
806) -> HashSet<usize> {
807    let n = bboxes.len();
808    if n <= 1 || n != labels.len() {
809        return HashSet::new();
810    }
811
812    let mut dropped_indices: HashSet<usize> = HashSet::new();
813
814    for i in 0..n {
815        if dropped_indices.contains(&i) {
816            continue;
817        }
818
819        for j in (i + 1)..n {
820            if dropped_indices.contains(&j) {
821                continue;
822            }
823
824            let overlap_ratio = calculate_ioa_smaller(&bboxes[i], &bboxes[j]);
825
826            if overlap_ratio > threshold {
827                let is_i_image = labels[i] == "image";
828                let is_j_image = labels[j] == "image";
829
830                let drop_index = if is_i_image != is_j_image {
831                    if is_i_image { i } else { j }
832                } else {
833                    let area_i = (bboxes[i].x_max() - bboxes[i].x_min())
834                        * (bboxes[i].y_max() - bboxes[i].y_min());
835                    let area_j = (bboxes[j].x_max() - bboxes[j].x_min())
836                        * (bboxes[j].y_max() - bboxes[j].y_min());
837
838                    if area_i < area_j { i } else { j }
839                };
840
841                dropped_indices.insert(drop_index);
842            }
843        }
844    }
845
846    dropped_indices
847}
848
849#[cfg(test)]
850mod tests {
851    use super::*;
852
853    #[test]
854    fn test_get_overlap_boxes_idx() {
855        // Create OCR boxes
856        let ocr_boxes = vec![
857            BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0), // inside region
858            BoundingBox::from_coords(60.0, 60.0, 100.0, 80.0), // inside region
859            BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0), // outside region
860        ];
861
862        // Create layout region
863        let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 150.0, 150.0)];
864
865        let matched = get_overlap_boxes_idx(&ocr_boxes, &layout_regions, 3.0);
866
867        // First two boxes should match
868        assert_eq!(matched.len(), 2);
869        assert!(matched.contains(&0));
870        assert!(matched.contains(&1));
871        assert!(!matched.contains(&2));
872    }
873
874    #[test]
875    fn test_associate_ocr_with_layout_within() {
876        let ocr_boxes = vec![
877            BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0),
878            BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0),
879        ];
880
881        let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 100.0, 100.0)];
882
883        let association = associate_ocr_with_layout(&ocr_boxes, &layout_regions, true, 3.0);
884
885        assert_eq!(association.matched_indices.len(), 1);
886        assert_eq!(association.matched_indices[0], 0);
887        assert_eq!(association.unmatched_indices.len(), 1);
888        assert_eq!(association.unmatched_indices[0], 1);
889    }
890
891    #[test]
892    fn test_associate_ocr_with_layout_outside() {
893        let ocr_boxes = vec![
894            BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0),
895            BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0),
896        ];
897
898        let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 100.0, 100.0)];
899
900        let association = associate_ocr_with_layout(&ocr_boxes, &layout_regions, false, 3.0);
901
902        // flag_within=false returns boxes outside regions
903        assert_eq!(association.matched_indices.len(), 1);
904        assert_eq!(association.matched_indices[0], 1);
905    }
906
907    #[test]
908    fn test_sort_layout_boxes_single_column() {
909        let elements = vec![
910            LayoutBox::new(
911                BoundingBox::from_coords(10.0, 50.0, 200.0, 70.0),
912                "text".to_string(),
913            ), // bottom
914            LayoutBox::new(
915                BoundingBox::from_coords(10.0, 10.0, 200.0, 30.0),
916                "title".to_string(),
917            ), // top
918        ];
919
920        let sorted = sort_layout_boxes(&elements, 300.0);
921
922        assert_eq!(sorted[0].label, "title"); // top first
923        assert_eq!(sorted[1].label, "text"); // bottom second
924    }
925
926    #[test]
927    fn test_sort_layout_boxes_two_columns() {
928        let image_width = 400.0;
929
930        let elements = vec![
931            // Left column boxes (x < w/4 and x2 < 3w/5)
932            LayoutBox::new(
933                BoundingBox::from_coords(10.0, 100.0, 90.0, 120.0),
934                "left_bottom".to_string(),
935            ),
936            LayoutBox::new(
937                BoundingBox::from_coords(10.0, 50.0, 90.0, 70.0),
938                "left_top".to_string(),
939            ),
940            // Right column boxes (x > 2w/5)
941            LayoutBox::new(
942                BoundingBox::from_coords(250.0, 100.0, 390.0, 120.0),
943                "right_bottom".to_string(),
944            ),
945            LayoutBox::new(
946                BoundingBox::from_coords(250.0, 50.0, 390.0, 70.0),
947                "right_top".to_string(),
948            ),
949            // Full-width box (neither left nor right)
950            LayoutBox::new(
951                BoundingBox::from_coords(10.0, 10.0, 390.0, 30.0),
952                "title".to_string(),
953            ),
954        ];
955
956        let sorted = sort_layout_boxes(&elements, image_width);
957
958        // Expected order:
959        // 1. title (full-width, top)
960        // 2. left_top (left column, higher)
961        // 3. right_top (right column, higher)
962        // 4. left_bottom (left column, lower)
963        // 5. right_bottom (right column, lower)
964
965        assert_eq!(sorted[0].label, "title");
966        // Left column should come before right column
967        let Some(left_top_idx) = sorted.iter().position(|e| e.label == "left_top") else {
968            panic!("missing expected left_top element");
969        };
970        let Some(left_bottom_idx) = sorted.iter().position(|e| e.label == "left_bottom") else {
971            panic!("missing expected left_bottom element");
972        };
973        let Some(right_top_idx) = sorted.iter().position(|e| e.label == "right_top") else {
974            panic!("missing expected right_top element");
975        };
976        let Some(right_bottom_idx) = sorted.iter().position(|e| e.label == "right_bottom") else {
977            panic!("missing expected right_bottom element");
978        };
979
980        // Within left column, top should come before bottom
981        assert!(left_top_idx < left_bottom_idx);
982        // Within right column, top should come before bottom
983        assert!(right_top_idx < right_bottom_idx);
984    }
985
986    #[test]
987    fn test_sort_layout_boxes_empty() {
988        let elements: Vec<LayoutBox> = Vec::new();
989        let sorted = sort_layout_boxes(&elements, 300.0);
990        assert!(sorted.is_empty());
991    }
992
993    #[test]
994    fn test_sort_layout_boxes_single_element() {
995        let elements = vec![LayoutBox::new(
996            BoundingBox::from_coords(10.0, 10.0, 100.0, 30.0),
997            "text".to_string(),
998        )];
999
1000        let sorted = sort_layout_boxes(&elements, 300.0);
1001        assert_eq!(sorted.len(), 1);
1002        assert_eq!(sorted[0].label, "text");
1003    }
1004}