Skip to main content

oar_ocr_core/processors/
layout_utils.rs

1//! Layout parsing utilities.
2//!
3//! This module provides utilities for layout analysis, including sorting layout boxes
4//! and associating OCR results with layout regions. The implementation follows
5//! established approaches.
6
7use crate::processors::BoundingBox;
8use std::collections::HashSet;
9
10/// Result of associating OCR boxes with layout regions.
11#[derive(Debug, Clone)]
12pub struct LayoutOCRAssociation {
13    /// Indices of OCR boxes that are within the layout regions
14    pub matched_indices: Vec<usize>,
15    /// Indices of OCR boxes that are outside all layout regions
16    pub unmatched_indices: Vec<usize>,
17}
18
19/// Get indices of OCR boxes that overlap with layout regions.
20///
21/// This function checks which OCR boxes have significant overlap with any of the
22/// layout regions. An overlap is considered significant if the intersection has
23/// both width and height greater than the threshold (default: 3 pixels).
24///
25/// This follows standard overlap detection implementation.
26///
27/// # Arguments
28///
29/// * `ocr_boxes` - Slice of OCR bounding boxes
30/// * `layout_regions` - Slice of layout region bounding boxes
31/// * `threshold` - Minimum intersection dimension (default: 3.0 pixels)
32///
33/// # Returns
34///
35/// Vector of indices of OCR boxes that overlap with any layout region
36pub fn get_overlap_boxes_idx(
37    ocr_boxes: &[BoundingBox],
38    layout_regions: &[BoundingBox],
39    threshold: f32,
40) -> Vec<usize> {
41    let mut matched_indices = Vec::new();
42
43    if ocr_boxes.is_empty() || layout_regions.is_empty() {
44        return matched_indices;
45    }
46
47    // For each layout region, find overlapping OCR boxes
48    for layout_region in layout_regions {
49        for (idx, ocr_box) in ocr_boxes.iter().enumerate() {
50            if ocr_box.overlaps_with(layout_region, threshold) {
51                matched_indices.push(idx);
52            }
53        }
54    }
55
56    matched_indices
57}
58
59/// Associate OCR results with layout regions.
60///
61/// This function filters OCR boxes based on whether they are within or outside
62/// the specified layout regions.
63///
64/// This follows standard region association implementation.
65///
66/// # Arguments
67///
68/// * `ocr_boxes` - Slice of OCR bounding boxes
69/// * `layout_regions` - Slice of layout region bounding boxes
70/// * `flag_within` - If true, return boxes within regions; if false, return boxes outside regions
71/// * `threshold` - Minimum intersection dimension for overlap detection
72///
73/// # Returns
74///
75/// `LayoutOCRAssociation` containing matched and unmatched indices
76pub fn associate_ocr_with_layout(
77    ocr_boxes: &[BoundingBox],
78    layout_regions: &[BoundingBox],
79    flag_within: bool,
80    threshold: f32,
81) -> LayoutOCRAssociation {
82    let overlap_indices = get_overlap_boxes_idx(ocr_boxes, layout_regions, threshold);
83    let overlap_set: HashSet<usize> = overlap_indices.into_iter().collect();
84
85    let mut matched_indices = Vec::new();
86    let mut unmatched_indices = Vec::new();
87
88    for (idx, _) in ocr_boxes.iter().enumerate() {
89        let is_overlapping = overlap_set.contains(&idx);
90
91        if flag_within {
92            // Return boxes within regions
93            if is_overlapping {
94                matched_indices.push(idx);
95            } else {
96                unmatched_indices.push(idx);
97            }
98        } else {
99            // Return boxes outside regions
100            if !is_overlapping {
101                matched_indices.push(idx);
102            } else {
103                unmatched_indices.push(idx);
104            }
105        }
106    }
107
108    LayoutOCRAssociation {
109        matched_indices,
110        unmatched_indices,
111    }
112}
113
114/// Layout box with bounding box and label for sorting/processing purposes.
115///
116/// This is a lightweight structure used for layout processing utilities like
117/// sorting and OCR association. For final structured output, use
118/// `LayoutElement` from `domain::structure`.
119#[derive(Debug, Clone)]
120pub struct LayoutBox {
121    /// Bounding box of the layout element
122    pub bbox: BoundingBox,
123    /// Label/type of the layout element (e.g., "text", "title", "table", "figure")
124    pub label: String,
125    /// Optional content text
126    pub content: Option<String>,
127}
128
129impl LayoutBox {
130    /// Create a new layout box.
131    pub fn new(bbox: BoundingBox, label: String) -> Self {
132        Self {
133            bbox,
134            label,
135            content: None,
136        }
137    }
138
139    /// Create a layout box with content.
140    pub fn with_content(bbox: BoundingBox, label: String, content: String) -> Self {
141        Self {
142            bbox,
143            label,
144            content: Some(content),
145        }
146    }
147}
148
149/// Sort layout boxes in reading order with column detection.
150///
151/// This function sorts layout boxes from top to bottom, left to right, with special
152/// handling for two-column layouts. Boxes are first sorted by y-coordinate, then
153/// separated into left and right columns based on their x-coordinate.
154///
155/// The algorithm:
156/// 1. Sort boxes by (y, x) coordinates
157/// 2. Identify left column boxes (x1 < w/4 and x2 < 3w/5)
158/// 3. Identify right column boxes (x1 > 2w/5)
159/// 4. Other boxes are considered full-width
160/// 5. Within each column, sort by y-coordinate
161///
162/// This follows standard layout sorting implementation.
163///
164/// # Arguments
165///
166/// * `elements` - Slice of layout boxes to sort
167/// * `image_width` - Width of the image for column detection
168///
169/// # Returns
170///
171/// A vector of sorted layout boxes
172pub fn sort_layout_boxes(elements: &[LayoutBox], image_width: f32) -> Vec<LayoutBox> {
173    let num_boxes = elements.len();
174
175    if num_boxes <= 1 {
176        return elements.to_vec();
177    }
178
179    // Sort by y-coordinate first, then x-coordinate
180    let mut sorted: Vec<LayoutBox> = elements.to_vec();
181    sorted.sort_by(|a, b| {
182        let a_y = a.bbox.y_min();
183        let a_x = a.bbox.x_min();
184        let b_y = b.bbox.y_min();
185        let b_x = b.bbox.x_min();
186
187        match a_y.partial_cmp(&b_y) {
188            Some(std::cmp::Ordering::Equal) => {
189                a_x.partial_cmp(&b_x).unwrap_or(std::cmp::Ordering::Equal)
190            }
191            other => other.unwrap_or(std::cmp::Ordering::Equal),
192        }
193    });
194
195    let mut result = Vec::new();
196    let mut left_column = Vec::new();
197    let mut right_column = Vec::new();
198
199    let w = image_width;
200    let mut i = 0;
201
202    while i < num_boxes {
203        let elem = &sorted[i];
204        let x1 = elem.bbox.x_min();
205        let x2 = elem.bbox.x_max();
206
207        // Check if box is in left column
208        if x1 < w / 4.0 && x2 < 3.0 * w / 5.0 {
209            left_column.push(elem.clone());
210        }
211        // Check if box is in right column
212        else if x1 > 2.0 * w / 5.0 {
213            right_column.push(elem.clone());
214        }
215        // Full-width box - flush columns and add this box
216        else {
217            // Add accumulated column boxes
218            result.append(&mut left_column);
219            result.append(&mut right_column);
220            result.push(elem.clone());
221        }
222
223        i += 1;
224    }
225
226    // Sort left and right columns by y-coordinate
227    left_column.sort_by(|a, b| {
228        a.bbox
229            .y_min()
230            .partial_cmp(&b.bbox.y_min())
231            .unwrap_or(std::cmp::Ordering::Equal)
232    });
233    right_column.sort_by(|a, b| {
234        a.bbox
235            .y_min()
236            .partial_cmp(&b.bbox.y_min())
237            .unwrap_or(std::cmp::Ordering::Equal)
238    });
239
240    // Add remaining column boxes
241    result.append(&mut left_column);
242    result.append(&mut right_column);
243
244    result
245}
246
247/// Reconciles structure recognition cells with detected cells.
248///
249/// This function aligns the number of output cells with the structure cells (N).
250/// - If multiple detected cells map to one structure cell, they are merged (compressed).
251/// - If a structure cell has no matching detected cell, the original structure box is kept (filled).
252///
253/// # Arguments
254/// * `structure_cells` - Cells derived from structure recognition (provides logical N)
255/// * `detected_cells` - Cells from detection model (provides precise geometry)
256///
257/// # Returns
258/// * `Vec<BoundingBox>` - Reconciled bounding boxes of length N.
259pub fn reconcile_table_cells(
260    structure_cells: &[BoundingBox],
261    detected_cells: &[BoundingBox],
262) -> Vec<BoundingBox> {
263    let n = structure_cells.len();
264    if n == 0 {
265        return Vec::new();
266    }
267    if detected_cells.is_empty() {
268        return structure_cells.to_vec();
269    }
270
271    // If detection produces significantly more cells than the table structure,
272    // reduce them using KMeans-style clustering on box centers.
273    let mut det_boxes: Vec<BoundingBox> = detected_cells.to_vec();
274    if det_boxes.len() > n {
275        det_boxes = combine_rectangles_kmeans(&det_boxes, n);
276    }
277
278    // Assignments: structure_idx -> list of detected_indices
279    let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); n];
280
281    // Assign each detected cell to the best matching structure cell
282    for (det_idx, det_box) in det_boxes.iter().enumerate() {
283        let mut best_iou = 0.001f32; // Minimal threshold
284        let mut best_struct_idx: Option<usize> = None;
285
286        for (struct_idx, struct_box) in structure_cells.iter().enumerate() {
287            // Use IoU for assignment
288            // Note: We could also use intersection over detection area to handle
289            // cases where detection is much smaller than structure cell
290            let iou = calculate_iou(det_box, struct_box);
291            if iou > best_iou {
292                best_iou = iou;
293                best_struct_idx = Some(struct_idx);
294            }
295        }
296
297        if let Some(idx) = best_struct_idx {
298            assignments[idx].push(det_idx);
299        }
300    }
301
302    // Build result
303    let mut reconciled = Vec::with_capacity(n);
304    for i in 0..n {
305        let assigned = &assignments[i];
306        if assigned.is_empty() {
307            // Fill: No matching detection, keep original structure box
308            reconciled.push(structure_cells[i].clone());
309        } else if assigned.len() == 1 {
310            // Exact match: Use detected box
311            reconciled.push(det_boxes[assigned[0]].clone());
312        } else {
313            // Compress: Multiple detections map to one structure cell
314            // Merge them by taking the bounding box of all detections
315            let mut merged = det_boxes[assigned[0]].clone();
316            for &idx in &assigned[1..] {
317                merged = merged.union(&det_boxes[idx]);
318            }
319            reconciled.push(merged);
320        }
321    }
322
323    reconciled
324}
325
326/// Reprocesses detected table cell boxes using OCR boxes to better match the
327/// structure model's expected cell count.
328///
329/// This mirrors cell detection results reprocessing in
330/// `table_recognition/pipeline_v2.py`:
331/// - If detected cells > target_n, keep top-N by score.
332/// - Find OCR boxes not sufficiently covered by any cell (IoA >= 0.6).
333/// - If missing OCR boxes exist, supplement/merge boxes with KMeans-style clustering.
334/// - If final count is too small, fall back to clustering OCR boxes.
335///
336/// All boxes must be in the same coordinate system (typically table-crop coords).
337pub fn reprocess_table_cells_with_ocr(
338    detected_cells: &[BoundingBox],
339    detected_scores: &[f32],
340    ocr_boxes: &[BoundingBox],
341    target_n: usize,
342) -> Vec<BoundingBox> {
343    if target_n == 0 {
344        return Vec::new();
345    }
346
347    // If no detected cells, fall back to OCR clustering.
348    if detected_cells.is_empty() {
349        return combine_rectangles_kmeans(ocr_boxes, target_n);
350    }
351
352    // Defensive: scores length mismatch -> assume uniform.
353    let scores: Vec<f32> = if detected_scores.len() == detected_cells.len() {
354        detected_scores.to_vec()
355    } else {
356        vec![1.0; detected_cells.len()]
357    };
358
359    let mut cells: Vec<BoundingBox> = detected_cells.to_vec();
360
361    let mut more_cells_flag = false;
362    if cells.len() == target_n {
363        return cells;
364    } else if cells.len() > target_n {
365        more_cells_flag = true;
366        // Keep top target_n by score (descending).
367        let mut idxs: Vec<usize> = (0..cells.len()).collect();
368        idxs.sort_by(|&a, &b| {
369            scores[b]
370                .partial_cmp(&scores[a])
371                .unwrap_or(std::cmp::Ordering::Equal)
372        });
373        idxs.truncate(target_n);
374        cells = idxs.iter().map(|&i| cells[i].clone()).collect();
375    }
376
377    // Compute IoA (intersection / ocr_area) between OCR and cell boxes.
378    fn ioa_ocr_in_cell(ocr: &BoundingBox, cell: &BoundingBox) -> f32 {
379        let inter = ocr.intersection_area(cell);
380        if inter <= 0.0 {
381            return 0.0;
382        }
383        let area = (ocr.x_max() - ocr.x_min()) * (ocr.y_max() - ocr.y_min());
384        if area <= 0.0 { 0.0 } else { inter / area }
385    }
386
387    let iou_threshold = 0.6f32;
388    let mut ocr_miss_boxes: Vec<BoundingBox> = Vec::new();
389
390    for ocr_box in ocr_boxes {
391        let mut has_large_ioa = false;
392        let mut merge_ioa_sum = 0.0f32;
393        for cell_box in &cells {
394            let ioa = ioa_ocr_in_cell(ocr_box, cell_box);
395            if ioa > 0.0 {
396                merge_ioa_sum += ioa;
397            }
398            if ioa >= iou_threshold || merge_ioa_sum >= iou_threshold {
399                has_large_ioa = true;
400                break;
401            }
402        }
403        if !has_large_ioa {
404            ocr_miss_boxes.push(ocr_box.clone());
405        }
406    }
407
408    let mut final_results: Vec<BoundingBox>;
409
410    if ocr_miss_boxes.is_empty() {
411        final_results = cells;
412    } else if more_cells_flag {
413        // More cells than expected: merge cells + missing OCR boxes to target_n.
414        let mut merged = cells.clone();
415        merged.extend(ocr_miss_boxes);
416        final_results = combine_rectangles_kmeans(&merged, target_n);
417    } else {
418        // Fewer cells than expected: supplement with clustered missing OCR boxes.
419        let need_n = target_n.saturating_sub(cells.len());
420        let supp = combine_rectangles_kmeans(&ocr_miss_boxes, need_n);
421        final_results = cells;
422        final_results.extend(supp);
423    }
424
425    // If still too few, fall back to clustering OCR boxes.
426    if final_results.len() as f32 <= 0.6 * target_n as f32 {
427        final_results = combine_rectangles_kmeans(ocr_boxes, target_n);
428    }
429
430    final_results
431}
432
433/// Combines rectangles into at most `target_n` rectangles using KMeans-style clustering
434/// on box centers.
435///
436/// Uses K-Means++ initialization for better cluster center selection.
437pub fn combine_rectangles_kmeans(rectangles: &[BoundingBox], target_n: usize) -> Vec<BoundingBox> {
438    let num_rects = rectangles.len();
439    if num_rects == 0 || target_n == 0 {
440        return Vec::new();
441    }
442    if target_n >= num_rects {
443        return rectangles.to_vec();
444    }
445
446    // Represent each rectangle by its center point (x, y)
447    let points: Vec<(f32, f32)> = rectangles
448        .iter()
449        .map(|r| {
450            let cx = (r.x_min() + r.x_max()) * 0.5;
451            let cy = (r.y_min() + r.y_max()) * 0.5;
452            (cx, cy)
453        })
454        .collect();
455
456    // Initialize cluster centers using K-Means++ algorithm
457    let centers = kmeans_maxdist_init(&points, target_n);
458    let mut centers = centers;
459    let mut labels: Vec<usize> = vec![0; num_rects];
460
461    let max_iters = 10;
462    for _ in 0..max_iters {
463        let mut changed = false;
464
465        // Assignment step: assign each point to nearest center
466        for (i, &(px, py)) in points.iter().enumerate() {
467            let mut best_idx = 0usize;
468            let mut best_dist = f32::MAX;
469            for (c_idx, &(cx, cy)) in centers.iter().enumerate() {
470                let dx = px - cx;
471                let dy = py - cy;
472                let dist = dx * dx + dy * dy;
473                if dist < best_dist {
474                    best_dist = dist;
475                    best_idx = c_idx;
476                }
477            }
478            if labels[i] != best_idx {
479                labels[i] = best_idx;
480                changed = true;
481            }
482        }
483
484        // Recompute centers
485        let mut sums: Vec<(f32, f32, usize)> = vec![(0.0, 0.0, 0); target_n];
486        for (i, &(px, py)) in points.iter().enumerate() {
487            let l = labels[i];
488            sums[l].0 += px;
489            sums[l].1 += py;
490            sums[l].2 += 1;
491        }
492        for (c_idx, center) in centers.iter_mut().enumerate() {
493            let (sx, sy, count) = sums[c_idx];
494            if count > 0 {
495                center.0 = sx / count as f32;
496                center.1 = sy / count as f32;
497            }
498        }
499
500        if !changed {
501            break;
502        }
503    }
504
505    // Build combined rectangles per cluster
506    let mut combined: Vec<BoundingBox> = Vec::new();
507    for cluster_idx in 0..target_n {
508        let mut first = true;
509        let mut min_x = 0.0f32;
510        let mut min_y = 0.0f32;
511        let mut max_x = 0.0f32;
512        let mut max_y = 0.0f32;
513
514        for (i, rect) in rectangles.iter().enumerate() {
515            if labels[i] == cluster_idx {
516                if first {
517                    min_x = rect.x_min();
518                    min_y = rect.y_min();
519                    max_x = rect.x_max();
520                    max_y = rect.y_max();
521                    first = false;
522                } else {
523                    min_x = min_x.min(rect.x_min());
524                    min_y = min_y.min(rect.y_min());
525                    max_x = max_x.max(rect.x_max());
526                    max_y = max_y.max(rect.y_max());
527                }
528            }
529        }
530
531        if !first {
532            combined.push(BoundingBox::from_coords(min_x, min_y, max_x, max_y));
533        }
534    }
535
536    if combined.is_empty() {
537        rectangles.to_vec()
538    } else {
539        combined
540    }
541}
542
543/// Deterministic K-Means initialization using max-distance selection.
544///
545/// This is a simplified variant of K-Means++ that deterministically selects
546/// the point with maximum distance from existing centers, rather than using
547/// probabilistic selection. This avoids random number generation while still
548/// providing good initial cluster spread.
549///
550/// # Arguments
551///
552/// * `points` - Slice of (x, y) points to cluster.
553/// * `k` - Number of clusters.
554///
555/// # Returns
556///
557/// Vector of k initial cluster centers.
558fn kmeans_maxdist_init(points: &[(f32, f32)], k: usize) -> Vec<(f32, f32)> {
559    if points.is_empty() || k == 0 {
560        return Vec::new();
561    }
562
563    if k >= points.len() {
564        return points.to_vec();
565    }
566
567    let mut centers: Vec<(f32, f32)> = Vec::with_capacity(k);
568
569    // Use a simple deterministic selection based on point index for reproducibility
570    // Select the first center as the point with median x-coordinate
571    let mut sorted_by_x: Vec<usize> = (0..points.len()).collect();
572    sorted_by_x.sort_by(|&a, &b| {
573        points[a]
574            .0
575            .partial_cmp(&points[b].0)
576            .unwrap_or(std::cmp::Ordering::Equal)
577    });
578    let first_idx = sorted_by_x[sorted_by_x.len() / 2];
579    centers.push(points[first_idx]);
580
581    // Select remaining centers using K-Means++ selection
582    for _ in 1..k {
583        // Compute squared distances to nearest center for each point
584        let mut distances: Vec<f32> = Vec::with_capacity(points.len());
585        let mut total_dist = 0.0f32;
586
587        for &(px, py) in points {
588            let min_dist_sq = centers
589                .iter()
590                .map(|&(cx, cy)| {
591                    let dx = px - cx;
592                    let dy = py - cy;
593                    dx * dx + dy * dy
594                })
595                .fold(f32::MAX, f32::min);
596
597            distances.push(min_dist_sq);
598            total_dist += min_dist_sq;
599        }
600
601        if total_dist <= 0.0 {
602            // All points are at existing centers, pick any remaining point
603            if let Some(&point) = points.iter().find(|p| !centers.contains(p)) {
604                centers.push(point);
605            } else {
606                break;
607            }
608            continue;
609        }
610
611        // Select next center deterministically: pick the point with maximum distance
612        // This is simpler than probabilistic K-Means++ but still provides good spread
613        let mut max_dist = 0.0f32;
614        let mut max_idx = 0;
615
616        for (i, &dist) in distances.iter().enumerate() {
617            if dist > max_dist {
618                max_dist = dist;
619                max_idx = i;
620            }
621        }
622
623        centers.push(points[max_idx]);
624    }
625
626    centers
627}
628
629/// Calculates Intersection over Union (IoU) between two bounding boxes.
630fn calculate_iou(a: &BoundingBox, b: &BoundingBox) -> f32 {
631    let inter_x1 = a.x_min().max(b.x_min());
632    let inter_y1 = a.y_min().max(b.y_min());
633    let inter_x2 = a.x_max().min(b.x_max());
634    let inter_y2 = a.y_max().min(b.y_max());
635
636    let inter_area = (inter_x2 - inter_x1).max(0.0) * (inter_y2 - inter_y1).max(0.0);
637
638    let area_a = (a.x_max() - a.x_min()) * (a.y_max() - a.y_min());
639    let area_b = (b.x_max() - b.x_min()) * (b.y_max() - b.y_min());
640
641    let union_area = area_a + area_b - inter_area;
642
643    if union_area <= 0.0 {
644        0.0
645    } else {
646        inter_area / union_area
647    }
648}
649
650/// Calculates Intersection over Area (IoA) - intersection / smaller box area.
651fn calculate_ioa_smaller(a: &BoundingBox, b: &BoundingBox) -> f32 {
652    let inter_x1 = a.x_min().max(b.x_min());
653    let inter_y1 = a.y_min().max(b.y_min());
654    let inter_x2 = a.x_max().min(b.x_max());
655    let inter_y2 = a.y_max().min(b.y_max());
656
657    let inter_area = (inter_x2 - inter_x1).max(0.0) * (inter_y2 - inter_y1).max(0.0);
658
659    let area_a = (a.x_max() - a.x_min()) * (a.y_max() - a.y_min());
660    let area_b = (b.x_max() - b.x_min()) * (b.y_max() - b.y_min());
661
662    let smaller_area = area_a.min(area_b);
663
664    if smaller_area <= 0.0 {
665        0.0
666    } else {
667        inter_area / smaller_area
668    }
669}
670
671/// Result of overlap removal.
672#[derive(Debug, Clone)]
673pub struct OverlapRemovalResult<T> {
674    /// Elements that were kept after overlap removal
675    pub kept: Vec<T>,
676    /// Indices of elements that were removed
677    pub removed_indices: Vec<usize>,
678}
679
680/// Removes overlapping layout blocks based on overlap ratio threshold.
681///
682/// This follows standard overlap removal implementation in
683/// `layout_parsing/utils.py`. When two blocks overlap significantly:
684/// - If one is an image and one is not, the image is removed (text takes priority)
685/// - Otherwise, the smaller block is removed
686///
687/// # Arguments
688///
689/// * `elements` - Slice of layout elements to process
690/// * `threshold` - Overlap ratio threshold (default: 0.65)
691///   If intersection/smaller_area > threshold, blocks are considered overlapping
692///
693/// # Returns
694///
695/// `OverlapRemovalResult` containing kept elements and indices of removed elements
696///
697/// # Example
698///
699/// ```ignore
700/// use oar_ocr_core::processors::layout_utils::{remove_overlap_blocks, LayoutBox};
701/// use oar_ocr_core::processors::BoundingBox;
702///
703/// let elements = vec![
704///     LayoutBox::new(BoundingBox::from_coords(0.0, 0.0, 100.0, 100.0), "text".to_string()),
705///     LayoutBox::new(BoundingBox::from_coords(10.0, 10.0, 90.0, 90.0), "text".to_string()),
706/// ];
707///
708/// let result = remove_overlap_blocks(&elements, 0.65);
709/// assert_eq!(result.kept.len(), 1); // Smaller overlapping box was removed
710/// ```
711pub fn remove_overlap_blocks(
712    elements: &[LayoutBox],
713    threshold: f32,
714) -> OverlapRemovalResult<LayoutBox> {
715    let n = elements.len();
716    if n <= 1 {
717        return OverlapRemovalResult {
718            kept: elements.to_vec(),
719            removed_indices: Vec::new(),
720        };
721    }
722
723    let mut dropped_indices: HashSet<usize> = HashSet::new();
724
725    // Compare all pairs of elements
726    for i in 0..n {
727        if dropped_indices.contains(&i) {
728            continue;
729        }
730
731        for j in (i + 1)..n {
732            if dropped_indices.contains(&j) {
733                continue;
734            }
735
736            let elem_i = &elements[i];
737            let elem_j = &elements[j];
738
739            // Calculate overlap ratio (intersection / smaller area)
740            let overlap_ratio = calculate_ioa_smaller(&elem_i.bbox, &elem_j.bbox);
741
742            if overlap_ratio > threshold {
743                // Determine which element to remove
744                let is_i_image = elem_i.label == "image";
745                let is_j_image = elem_j.label == "image";
746
747                let drop_index = if is_i_image != is_j_image {
748                    // One is image, one is not: remove the image (text takes priority)
749                    if is_i_image { i } else { j }
750                } else {
751                    // Same type: remove the smaller one
752                    let area_i = (elem_i.bbox.x_max() - elem_i.bbox.x_min())
753                        * (elem_i.bbox.y_max() - elem_i.bbox.y_min());
754                    let area_j = (elem_j.bbox.x_max() - elem_j.bbox.x_min())
755                        * (elem_j.bbox.y_max() - elem_j.bbox.y_min());
756
757                    if area_i < area_j { i } else { j }
758                };
759
760                dropped_indices.insert(drop_index);
761                tracing::debug!(
762                    "Removing overlapping element {} (label={}, overlap={:.2})",
763                    drop_index,
764                    elements[drop_index].label,
765                    overlap_ratio
766                );
767            }
768        }
769    }
770
771    // Build result
772    let mut kept = Vec::new();
773    let mut removed_indices: Vec<usize> = dropped_indices.into_iter().collect();
774    removed_indices.sort();
775
776    for (idx, elem) in elements.iter().enumerate() {
777        if !removed_indices.contains(&idx) {
778            kept.push(elem.clone());
779        }
780    }
781
782    tracing::info!(
783        "Overlap removal: {} elements -> {} kept, {} removed",
784        n,
785        kept.len(),
786        removed_indices.len()
787    );
788
789    OverlapRemovalResult {
790        kept,
791        removed_indices,
792    }
793}
794
795/// Removes overlapping layout blocks, returning only indices to remove.
796///
797/// A lighter-weight version that works with any bbox type implementing the required traits.
798/// This is useful when you want to apply overlap removal to `LayoutElement` from `domain::structure`.
799///
800/// # Arguments
801///
802/// * `bboxes` - Slice of bounding boxes
803/// * `labels` - Slice of labels corresponding to each bbox
804/// * `threshold` - Overlap ratio threshold
805///
806/// # Returns
807///
808/// Set of indices that should be removed
809pub fn get_overlap_removal_indices(
810    bboxes: &[BoundingBox],
811    labels: &[&str],
812    threshold: f32,
813) -> HashSet<usize> {
814    let n = bboxes.len();
815    if n <= 1 || n != labels.len() {
816        return HashSet::new();
817    }
818
819    let mut dropped_indices: HashSet<usize> = HashSet::new();
820
821    for i in 0..n {
822        if dropped_indices.contains(&i) {
823            continue;
824        }
825
826        for j in (i + 1)..n {
827            if dropped_indices.contains(&j) {
828                continue;
829            }
830
831            let overlap_ratio = calculate_ioa_smaller(&bboxes[i], &bboxes[j]);
832
833            if overlap_ratio > threshold {
834                let is_i_image = labels[i] == "image";
835                let is_j_image = labels[j] == "image";
836
837                let drop_index = if is_i_image != is_j_image {
838                    if is_i_image { i } else { j }
839                } else {
840                    let area_i = (bboxes[i].x_max() - bboxes[i].x_min())
841                        * (bboxes[i].y_max() - bboxes[i].y_min());
842                    let area_j = (bboxes[j].x_max() - bboxes[j].x_min())
843                        * (bboxes[j].y_max() - bboxes[j].y_min());
844
845                    if area_i < area_j { i } else { j }
846                };
847
848                dropped_indices.insert(drop_index);
849            }
850        }
851    }
852
853    dropped_indices
854}
855
856#[cfg(test)]
857mod tests {
858    use super::*;
859
860    #[test]
861    fn test_get_overlap_boxes_idx() {
862        // Create OCR boxes
863        let ocr_boxes = vec![
864            BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0), // inside region
865            BoundingBox::from_coords(60.0, 60.0, 100.0, 80.0), // inside region
866            BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0), // outside region
867        ];
868
869        // Create layout region
870        let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 150.0, 150.0)];
871
872        let matched = get_overlap_boxes_idx(&ocr_boxes, &layout_regions, 3.0);
873
874        // First two boxes should match
875        assert_eq!(matched.len(), 2);
876        assert!(matched.contains(&0));
877        assert!(matched.contains(&1));
878        assert!(!matched.contains(&2));
879    }
880
881    #[test]
882    fn test_associate_ocr_with_layout_within() {
883        let ocr_boxes = vec![
884            BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0),
885            BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0),
886        ];
887
888        let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 100.0, 100.0)];
889
890        let association = associate_ocr_with_layout(&ocr_boxes, &layout_regions, true, 3.0);
891
892        assert_eq!(association.matched_indices.len(), 1);
893        assert_eq!(association.matched_indices[0], 0);
894        assert_eq!(association.unmatched_indices.len(), 1);
895        assert_eq!(association.unmatched_indices[0], 1);
896    }
897
898    #[test]
899    fn test_associate_ocr_with_layout_outside() {
900        let ocr_boxes = vec![
901            BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0),
902            BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0),
903        ];
904
905        let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 100.0, 100.0)];
906
907        let association = associate_ocr_with_layout(&ocr_boxes, &layout_regions, false, 3.0);
908
909        // flag_within=false returns boxes outside regions
910        assert_eq!(association.matched_indices.len(), 1);
911        assert_eq!(association.matched_indices[0], 1);
912    }
913
914    #[test]
915    fn test_sort_layout_boxes_single_column() {
916        let elements = vec![
917            LayoutBox::new(
918                BoundingBox::from_coords(10.0, 50.0, 200.0, 70.0),
919                "text".to_string(),
920            ), // bottom
921            LayoutBox::new(
922                BoundingBox::from_coords(10.0, 10.0, 200.0, 30.0),
923                "title".to_string(),
924            ), // top
925        ];
926
927        let sorted = sort_layout_boxes(&elements, 300.0);
928
929        assert_eq!(sorted[0].label, "title"); // top first
930        assert_eq!(sorted[1].label, "text"); // bottom second
931    }
932
933    #[test]
934    fn test_sort_layout_boxes_two_columns() {
935        let image_width = 400.0;
936
937        let elements = vec![
938            // Left column boxes (x < w/4 and x2 < 3w/5)
939            LayoutBox::new(
940                BoundingBox::from_coords(10.0, 100.0, 90.0, 120.0),
941                "left_bottom".to_string(),
942            ),
943            LayoutBox::new(
944                BoundingBox::from_coords(10.0, 50.0, 90.0, 70.0),
945                "left_top".to_string(),
946            ),
947            // Right column boxes (x > 2w/5)
948            LayoutBox::new(
949                BoundingBox::from_coords(250.0, 100.0, 390.0, 120.0),
950                "right_bottom".to_string(),
951            ),
952            LayoutBox::new(
953                BoundingBox::from_coords(250.0, 50.0, 390.0, 70.0),
954                "right_top".to_string(),
955            ),
956            // Full-width box (neither left nor right)
957            LayoutBox::new(
958                BoundingBox::from_coords(10.0, 10.0, 390.0, 30.0),
959                "title".to_string(),
960            ),
961        ];
962
963        let sorted = sort_layout_boxes(&elements, image_width);
964
965        // Expected order:
966        // 1. title (full-width, top)
967        // 2. left_top (left column, higher)
968        // 3. right_top (right column, higher)
969        // 4. left_bottom (left column, lower)
970        // 5. right_bottom (right column, lower)
971
972        assert_eq!(sorted[0].label, "title");
973        // Left column should come before right column
974        let Some(left_top_idx) = sorted.iter().position(|e| e.label == "left_top") else {
975            panic!("missing expected left_top element");
976        };
977        let Some(left_bottom_idx) = sorted.iter().position(|e| e.label == "left_bottom") else {
978            panic!("missing expected left_bottom element");
979        };
980        let Some(right_top_idx) = sorted.iter().position(|e| e.label == "right_top") else {
981            panic!("missing expected right_top element");
982        };
983        let Some(right_bottom_idx) = sorted.iter().position(|e| e.label == "right_bottom") else {
984            panic!("missing expected right_bottom element");
985        };
986
987        // Within left column, top should come before bottom
988        assert!(left_top_idx < left_bottom_idx);
989        // Within right column, top should come before bottom
990        assert!(right_top_idx < right_bottom_idx);
991    }
992
993    #[test]
994    fn test_sort_layout_boxes_empty() {
995        let elements: Vec<LayoutBox> = Vec::new();
996        let sorted = sort_layout_boxes(&elements, 300.0);
997        assert!(sorted.is_empty());
998    }
999
1000    #[test]
1001    fn test_sort_layout_boxes_single_element() {
1002        let elements = vec![LayoutBox::new(
1003            BoundingBox::from_coords(10.0, 10.0, 100.0, 30.0),
1004            "text".to_string(),
1005        )];
1006
1007        let sorted = sort_layout_boxes(&elements, 300.0);
1008        assert_eq!(sorted.len(), 1);
1009        assert_eq!(sorted[0].label, "text");
1010    }
1011}