use crate::processors::BoundingBox;
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct LayoutOCRAssociation {
pub matched_indices: Vec<usize>,
pub unmatched_indices: Vec<usize>,
}
pub fn get_overlap_boxes_idx(
ocr_boxes: &[BoundingBox],
layout_regions: &[BoundingBox],
threshold: f32,
) -> Vec<usize> {
let mut matched_indices = Vec::new();
if ocr_boxes.is_empty() || layout_regions.is_empty() {
return matched_indices;
}
for layout_region in layout_regions {
for (idx, ocr_box) in ocr_boxes.iter().enumerate() {
if ocr_box.overlaps_with(layout_region, threshold) {
matched_indices.push(idx);
}
}
}
matched_indices
}
pub fn associate_ocr_with_layout(
ocr_boxes: &[BoundingBox],
layout_regions: &[BoundingBox],
flag_within: bool,
threshold: f32,
) -> LayoutOCRAssociation {
let overlap_indices = get_overlap_boxes_idx(ocr_boxes, layout_regions, threshold);
let overlap_set: HashSet<usize> = overlap_indices.into_iter().collect();
let mut matched_indices = Vec::new();
let mut unmatched_indices = Vec::new();
for (idx, _) in ocr_boxes.iter().enumerate() {
let is_overlapping = overlap_set.contains(&idx);
if flag_within {
if is_overlapping {
matched_indices.push(idx);
} else {
unmatched_indices.push(idx);
}
} else {
if !is_overlapping {
matched_indices.push(idx);
} else {
unmatched_indices.push(idx);
}
}
}
LayoutOCRAssociation {
matched_indices,
unmatched_indices,
}
}
#[derive(Debug, Clone)]
pub struct LayoutBox {
pub bbox: BoundingBox,
pub label: String,
pub content: Option<String>,
}
impl LayoutBox {
pub fn new(bbox: BoundingBox, label: String) -> Self {
Self {
bbox,
label,
content: None,
}
}
pub fn with_content(bbox: BoundingBox, label: String, content: String) -> Self {
Self {
bbox,
label,
content: Some(content),
}
}
}
pub fn sort_layout_boxes(elements: &[LayoutBox], image_width: f32) -> Vec<LayoutBox> {
let num_boxes = elements.len();
if num_boxes <= 1 {
return elements.to_vec();
}
let mut sorted: Vec<LayoutBox> = elements.to_vec();
sorted.sort_by(|a, b| {
let a_y = a.bbox.y_min();
let a_x = a.bbox.x_min();
let b_y = b.bbox.y_min();
let b_x = b.bbox.x_min();
match a_y.partial_cmp(&b_y) {
Some(std::cmp::Ordering::Equal) => {
a_x.partial_cmp(&b_x).unwrap_or(std::cmp::Ordering::Equal)
}
other => other.unwrap_or(std::cmp::Ordering::Equal),
}
});
let mut result = Vec::new();
let mut left_column = Vec::new();
let mut right_column = Vec::new();
let w = image_width;
let mut i = 0;
while i < num_boxes {
let elem = &sorted[i];
let x1 = elem.bbox.x_min();
let x2 = elem.bbox.x_max();
if x1 < w / 4.0 && x2 < 3.0 * w / 5.0 {
left_column.push(elem.clone());
}
else if x1 > 2.0 * w / 5.0 {
right_column.push(elem.clone());
}
else {
result.append(&mut left_column);
result.append(&mut right_column);
result.push(elem.clone());
}
i += 1;
}
left_column.sort_by(|a, b| {
a.bbox
.y_min()
.partial_cmp(&b.bbox.y_min())
.unwrap_or(std::cmp::Ordering::Equal)
});
right_column.sort_by(|a, b| {
a.bbox
.y_min()
.partial_cmp(&b.bbox.y_min())
.unwrap_or(std::cmp::Ordering::Equal)
});
result.append(&mut left_column);
result.append(&mut right_column);
result
}
pub fn reconcile_table_cells(
structure_cells: &[BoundingBox],
detected_cells: &[BoundingBox],
) -> Vec<BoundingBox> {
let n = structure_cells.len();
if n == 0 {
return Vec::new();
}
if detected_cells.is_empty() {
return structure_cells.to_vec();
}
let mut det_boxes: Vec<BoundingBox> = detected_cells.to_vec();
if det_boxes.len() > n {
det_boxes = combine_rectangles_kmeans(&det_boxes, n);
}
let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); n];
for (det_idx, det_box) in det_boxes.iter().enumerate() {
let mut best_ioa = 0.001f32; let mut best_struct_idx: Option<usize> = None;
let det_area = (det_box.x_max() - det_box.x_min()) * (det_box.y_max() - det_box.y_min());
for (struct_idx, struct_box) in structure_cells.iter().enumerate() {
let inter_x1 = det_box.x_min().max(struct_box.x_min());
let inter_y1 = det_box.y_min().max(struct_box.y_min());
let inter_x2 = det_box.x_max().min(struct_box.x_max());
let inter_y2 = det_box.y_max().min(struct_box.y_max());
let inter_area = (inter_x2 - inter_x1).max(0.0) * (inter_y2 - inter_y1).max(0.0);
let ioa = if det_area > 0.0 {
inter_area / det_area
} else {
0.0
};
if ioa > best_ioa {
best_ioa = ioa;
best_struct_idx = Some(struct_idx);
}
}
if let Some(idx) = best_struct_idx {
assignments[idx].push(det_idx);
}
}
let mut reconciled = Vec::with_capacity(n);
for i in 0..n {
let assigned = &assignments[i];
if assigned.is_empty() {
reconciled.push(structure_cells[i].clone());
} else if assigned.len() == 1 {
reconciled.push(det_boxes[assigned[0]].clone());
} else {
let mut merged = det_boxes[assigned[0]].clone();
for &idx in &assigned[1..] {
merged = merged.union(&det_boxes[idx]);
}
reconciled.push(merged);
}
}
reconciled
}
pub fn reprocess_table_cells_with_ocr(
detected_cells: &[BoundingBox],
detected_scores: &[f32],
ocr_boxes: &[BoundingBox],
target_n: usize,
) -> Vec<BoundingBox> {
if target_n == 0 {
return Vec::new();
}
if detected_cells.is_empty() {
return combine_rectangles_kmeans(ocr_boxes, target_n);
}
let scores: Vec<f32> = if detected_scores.len() == detected_cells.len() {
detected_scores.to_vec()
} else {
vec![1.0; detected_cells.len()]
};
let mut cells: Vec<BoundingBox> = detected_cells.to_vec();
let mut more_cells_flag = false;
if cells.len() == target_n {
return cells;
} else if cells.len() > target_n {
more_cells_flag = true;
let mut idxs: Vec<usize> = (0..cells.len()).collect();
idxs.sort_by(|&a, &b| {
scores[b]
.partial_cmp(&scores[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
idxs.truncate(target_n);
cells = idxs.iter().map(|&i| cells[i].clone()).collect();
}
fn ioa_ocr_in_cell(ocr: &BoundingBox, cell: &BoundingBox) -> f32 {
let inter = ocr.intersection_area(cell);
if inter <= 0.0 {
return 0.0;
}
let area = (ocr.x_max() - ocr.x_min()) * (ocr.y_max() - ocr.y_min());
if area <= 0.0 { 0.0 } else { inter / area }
}
let iou_threshold = 0.6f32;
let mut ocr_miss_boxes: Vec<BoundingBox> = Vec::new();
for ocr_box in ocr_boxes {
let mut has_large_ioa = false;
let mut merge_ioa_sum = 0.0f32;
for cell_box in &cells {
let ioa = ioa_ocr_in_cell(ocr_box, cell_box);
if ioa > 0.0 {
merge_ioa_sum += ioa;
}
if ioa >= iou_threshold || merge_ioa_sum >= iou_threshold {
has_large_ioa = true;
break;
}
}
if !has_large_ioa {
ocr_miss_boxes.push(ocr_box.clone());
}
}
let mut final_results: Vec<BoundingBox>;
if ocr_miss_boxes.is_empty() {
final_results = cells;
} else if more_cells_flag {
let mut merged = cells.clone();
merged.extend(ocr_miss_boxes);
final_results = combine_rectangles_kmeans(&merged, target_n);
} else {
let need_n = target_n.saturating_sub(cells.len());
let supp = combine_rectangles_kmeans(&ocr_miss_boxes, need_n);
final_results = cells;
final_results.extend(supp);
}
if final_results.len() as f32 <= 0.6 * target_n as f32 {
final_results = combine_rectangles_kmeans(ocr_boxes, target_n);
}
final_results
}
pub fn combine_rectangles_kmeans(rectangles: &[BoundingBox], target_n: usize) -> Vec<BoundingBox> {
let num_rects = rectangles.len();
if num_rects == 0 || target_n == 0 {
return Vec::new();
}
if target_n >= num_rects {
return rectangles.to_vec();
}
let points: Vec<(f32, f32)> = rectangles
.iter()
.map(|r| {
let cx = (r.x_min() + r.x_max()) * 0.5;
let cy = (r.y_min() + r.y_max()) * 0.5;
(cx, cy)
})
.collect();
let centers = kmeans_maxdist_init(&points, target_n);
let mut centers = centers;
let mut labels: Vec<usize> = vec![0; num_rects];
let max_iters = 10;
for _ in 0..max_iters {
let mut changed = false;
for (i, &(px, py)) in points.iter().enumerate() {
let mut best_idx = 0usize;
let mut best_dist = f32::MAX;
for (c_idx, &(cx, cy)) in centers.iter().enumerate() {
let dx = px - cx;
let dy = py - cy;
let dist = dx * dx + dy * dy;
if dist < best_dist {
best_dist = dist;
best_idx = c_idx;
}
}
if labels[i] != best_idx {
labels[i] = best_idx;
changed = true;
}
}
let mut sums: Vec<(f32, f32, usize)> = vec![(0.0, 0.0, 0); target_n];
for (i, &(px, py)) in points.iter().enumerate() {
let l = labels[i];
sums[l].0 += px;
sums[l].1 += py;
sums[l].2 += 1;
}
for (c_idx, center) in centers.iter_mut().enumerate() {
let (sx, sy, count) = sums[c_idx];
if count > 0 {
center.0 = sx / count as f32;
center.1 = sy / count as f32;
}
}
if !changed {
break;
}
}
let mut combined: Vec<BoundingBox> = Vec::new();
for cluster_idx in 0..target_n {
let mut first = true;
let mut min_x = 0.0f32;
let mut min_y = 0.0f32;
let mut max_x = 0.0f32;
let mut max_y = 0.0f32;
for (i, rect) in rectangles.iter().enumerate() {
if labels[i] == cluster_idx {
if first {
min_x = rect.x_min();
min_y = rect.y_min();
max_x = rect.x_max();
max_y = rect.y_max();
first = false;
} else {
min_x = min_x.min(rect.x_min());
min_y = min_y.min(rect.y_min());
max_x = max_x.max(rect.x_max());
max_y = max_y.max(rect.y_max());
}
}
}
if !first {
combined.push(BoundingBox::from_coords(min_x, min_y, max_x, max_y));
}
}
if combined.is_empty() {
rectangles.to_vec()
} else {
combined
}
}
fn kmeans_maxdist_init(points: &[(f32, f32)], k: usize) -> Vec<(f32, f32)> {
if points.is_empty() || k == 0 {
return Vec::new();
}
if k >= points.len() {
return points.to_vec();
}
let mut centers: Vec<(f32, f32)> = Vec::with_capacity(k);
let mut sorted_by_x: Vec<usize> = (0..points.len()).collect();
sorted_by_x.sort_by(|&a, &b| {
points[a]
.0
.partial_cmp(&points[b].0)
.unwrap_or(std::cmp::Ordering::Equal)
});
let first_idx = sorted_by_x[sorted_by_x.len() / 2];
centers.push(points[first_idx]);
for _ in 1..k {
let mut distances: Vec<f32> = Vec::with_capacity(points.len());
let mut total_dist = 0.0f32;
for &(px, py) in points {
let min_dist_sq = centers
.iter()
.map(|&(cx, cy)| {
let dx = px - cx;
let dy = py - cy;
dx * dx + dy * dy
})
.fold(f32::MAX, f32::min);
distances.push(min_dist_sq);
total_dist += min_dist_sq;
}
if total_dist <= 0.0 {
if let Some(&point) = points.iter().find(|p| !centers.contains(p)) {
centers.push(point);
} else {
break;
}
continue;
}
let mut max_dist = 0.0f32;
let mut max_idx = 0;
for (i, &dist) in distances.iter().enumerate() {
if dist > max_dist {
max_dist = dist;
max_idx = i;
}
}
centers.push(points[max_idx]);
}
centers
}
fn calculate_ioa_smaller(a: &BoundingBox, b: &BoundingBox) -> f32 {
let inter_x1 = a.x_min().max(b.x_min());
let inter_y1 = a.y_min().max(b.y_min());
let inter_x2 = a.x_max().min(b.x_max());
let inter_y2 = a.y_max().min(b.y_max());
let inter_area = (inter_x2 - inter_x1).max(0.0) * (inter_y2 - inter_y1).max(0.0);
let area_a = (a.x_max() - a.x_min()) * (a.y_max() - a.y_min());
let area_b = (b.x_max() - b.x_min()) * (b.y_max() - b.y_min());
let smaller_area = area_a.min(area_b);
if smaller_area <= 0.0 {
0.0
} else {
inter_area / smaller_area
}
}
#[derive(Debug, Clone)]
pub struct OverlapRemovalResult<T> {
pub kept: Vec<T>,
pub removed_indices: Vec<usize>,
}
pub fn remove_overlap_blocks(
elements: &[LayoutBox],
threshold: f32,
) -> OverlapRemovalResult<LayoutBox> {
let n = elements.len();
if n <= 1 {
return OverlapRemovalResult {
kept: elements.to_vec(),
removed_indices: Vec::new(),
};
}
let mut dropped_indices: HashSet<usize> = HashSet::new();
for i in 0..n {
if dropped_indices.contains(&i) {
continue;
}
for j in (i + 1)..n {
if dropped_indices.contains(&j) {
continue;
}
let elem_i = &elements[i];
let elem_j = &elements[j];
let overlap_ratio = calculate_ioa_smaller(&elem_i.bbox, &elem_j.bbox);
if overlap_ratio > threshold {
let is_i_image = elem_i.label == "image";
let is_j_image = elem_j.label == "image";
let drop_index = if is_i_image != is_j_image {
if is_i_image { i } else { j }
} else {
let area_i = (elem_i.bbox.x_max() - elem_i.bbox.x_min())
* (elem_i.bbox.y_max() - elem_i.bbox.y_min());
let area_j = (elem_j.bbox.x_max() - elem_j.bbox.x_min())
* (elem_j.bbox.y_max() - elem_j.bbox.y_min());
if area_i < area_j { i } else { j }
};
dropped_indices.insert(drop_index);
tracing::debug!(
"Removing overlapping element {} (label={}, overlap={:.2})",
drop_index,
elements[drop_index].label,
overlap_ratio
);
}
}
}
let mut kept = Vec::new();
let mut removed_indices: Vec<usize> = dropped_indices.into_iter().collect();
removed_indices.sort();
for (idx, elem) in elements.iter().enumerate() {
if !removed_indices.contains(&idx) {
kept.push(elem.clone());
}
}
tracing::info!(
"Overlap removal: {} elements -> {} kept, {} removed",
n,
kept.len(),
removed_indices.len()
);
OverlapRemovalResult {
kept,
removed_indices,
}
}
pub fn get_overlap_removal_indices(
bboxes: &[BoundingBox],
labels: &[&str],
threshold: f32,
) -> HashSet<usize> {
let n = bboxes.len();
if n <= 1 || n != labels.len() {
return HashSet::new();
}
let mut dropped_indices: HashSet<usize> = HashSet::new();
for i in 0..n {
if dropped_indices.contains(&i) {
continue;
}
for j in (i + 1)..n {
if dropped_indices.contains(&j) {
continue;
}
let overlap_ratio = calculate_ioa_smaller(&bboxes[i], &bboxes[j]);
if overlap_ratio > threshold {
let is_i_image = labels[i] == "image";
let is_j_image = labels[j] == "image";
let drop_index = if is_i_image != is_j_image {
if is_i_image { i } else { j }
} else {
let area_i = (bboxes[i].x_max() - bboxes[i].x_min())
* (bboxes[i].y_max() - bboxes[i].y_min());
let area_j = (bboxes[j].x_max() - bboxes[j].x_min())
* (bboxes[j].y_max() - bboxes[j].y_min());
if area_i < area_j { i } else { j }
};
dropped_indices.insert(drop_index);
}
}
}
dropped_indices
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_overlap_boxes_idx() {
let ocr_boxes = vec![
BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0), BoundingBox::from_coords(60.0, 60.0, 100.0, 80.0), BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0), ];
let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 150.0, 150.0)];
let matched = get_overlap_boxes_idx(&ocr_boxes, &layout_regions, 3.0);
assert_eq!(matched.len(), 2);
assert!(matched.contains(&0));
assert!(matched.contains(&1));
assert!(!matched.contains(&2));
}
#[test]
fn test_associate_ocr_with_layout_within() {
let ocr_boxes = vec![
BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0),
BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0),
];
let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 100.0, 100.0)];
let association = associate_ocr_with_layout(&ocr_boxes, &layout_regions, true, 3.0);
assert_eq!(association.matched_indices.len(), 1);
assert_eq!(association.matched_indices[0], 0);
assert_eq!(association.unmatched_indices.len(), 1);
assert_eq!(association.unmatched_indices[0], 1);
}
#[test]
fn test_associate_ocr_with_layout_outside() {
let ocr_boxes = vec![
BoundingBox::from_coords(10.0, 10.0, 50.0, 30.0),
BoundingBox::from_coords(200.0, 200.0, 250.0, 220.0),
];
let layout_regions = vec![BoundingBox::from_coords(0.0, 0.0, 100.0, 100.0)];
let association = associate_ocr_with_layout(&ocr_boxes, &layout_regions, false, 3.0);
assert_eq!(association.matched_indices.len(), 1);
assert_eq!(association.matched_indices[0], 1);
}
#[test]
fn test_sort_layout_boxes_single_column() {
let elements = vec![
LayoutBox::new(
BoundingBox::from_coords(10.0, 50.0, 200.0, 70.0),
"text".to_string(),
), LayoutBox::new(
BoundingBox::from_coords(10.0, 10.0, 200.0, 30.0),
"title".to_string(),
), ];
let sorted = sort_layout_boxes(&elements, 300.0);
assert_eq!(sorted[0].label, "title"); assert_eq!(sorted[1].label, "text"); }
#[test]
fn test_sort_layout_boxes_two_columns() {
let image_width = 400.0;
let elements = vec![
LayoutBox::new(
BoundingBox::from_coords(10.0, 100.0, 90.0, 120.0),
"left_bottom".to_string(),
),
LayoutBox::new(
BoundingBox::from_coords(10.0, 50.0, 90.0, 70.0),
"left_top".to_string(),
),
LayoutBox::new(
BoundingBox::from_coords(250.0, 100.0, 390.0, 120.0),
"right_bottom".to_string(),
),
LayoutBox::new(
BoundingBox::from_coords(250.0, 50.0, 390.0, 70.0),
"right_top".to_string(),
),
LayoutBox::new(
BoundingBox::from_coords(10.0, 10.0, 390.0, 30.0),
"title".to_string(),
),
];
let sorted = sort_layout_boxes(&elements, image_width);
assert_eq!(sorted[0].label, "title");
let Some(left_top_idx) = sorted.iter().position(|e| e.label == "left_top") else {
panic!("missing expected left_top element");
};
let Some(left_bottom_idx) = sorted.iter().position(|e| e.label == "left_bottom") else {
panic!("missing expected left_bottom element");
};
let Some(right_top_idx) = sorted.iter().position(|e| e.label == "right_top") else {
panic!("missing expected right_top element");
};
let Some(right_bottom_idx) = sorted.iter().position(|e| e.label == "right_bottom") else {
panic!("missing expected right_bottom element");
};
assert!(left_top_idx < left_bottom_idx);
assert!(right_top_idx < right_bottom_idx);
}
#[test]
fn test_sort_layout_boxes_empty() {
let elements: Vec<LayoutBox> = Vec::new();
let sorted = sort_layout_boxes(&elements, 300.0);
assert!(sorted.is_empty());
}
#[test]
fn test_sort_layout_boxes_single_element() {
let elements = vec![LayoutBox::new(
BoundingBox::from_coords(10.0, 10.0, 100.0, 30.0),
"text".to_string(),
)];
let sorted = sort_layout_boxes(&elements, 300.0);
assert_eq!(sorted.len(), 1);
assert_eq!(sorted[0].label, "text");
}
}