use crate::domain::structure::TableCell;
use crate::domain::text_region::TextRegion;
use crate::processors::BoundingBox;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct SplitConfig {
pub min_overlap_ratio: f32,
pub min_cells_to_split: usize,
pub split_horizontal: bool,
pub split_vertical: bool,
}
impl Default for SplitConfig {
fn default() -> Self {
Self {
min_overlap_ratio: 0.05, min_cells_to_split: 2,
split_horizontal: true,
split_vertical: true,
}
}
}
#[derive(Debug, Clone)]
pub struct CrossCellDetection {
pub ocr_index: usize,
pub affected_cell_indices: Vec<usize>,
pub x_boundaries: Vec<f32>,
pub y_boundaries: Vec<f32>,
pub is_horizontal_split: bool,
}
#[derive(Debug, Clone)]
pub struct SplitSegment {
pub bbox: BoundingBox,
pub text: String,
pub cell_index: usize,
}
#[derive(Debug, Clone)]
pub struct SplitOcrResult {
pub original_bbox: BoundingBox,
pub original_text: String,
pub confidence: Option<f32>,
pub segments: Vec<SplitSegment>,
}
pub fn detect_cross_cell_ocr_boxes(
text_regions: &[TextRegion],
cells: &[TableCell],
config: &SplitConfig,
) -> Vec<CrossCellDetection> {
let mut detections = Vec::new();
if cells.is_empty() || text_regions.is_empty() {
return detections;
}
for (ocr_idx, region) in text_regions.iter().enumerate() {
if region.text.is_none() {
continue;
}
let ocr_bbox = ®ion.bounding_box;
let ocr_area = calculate_bbox_area(ocr_bbox);
if ocr_area <= 0.0 {
continue;
}
let mut overlapping_cells: Vec<(usize, f32)> = Vec::new();
for (cell_idx, cell) in cells.iter().enumerate() {
let inter_area = ocr_bbox.intersection_area(&cell.bbox);
let ioa = inter_area / ocr_area;
if ioa > config.min_overlap_ratio {
overlapping_cells.push((cell_idx, ioa));
}
}
if overlapping_cells.len() >= config.min_cells_to_split {
overlapping_cells.sort_by_key(|(idx, _)| *idx);
let affected_cell_indices: Vec<usize> =
overlapping_cells.iter().map(|(idx, _)| *idx).collect();
let (x_boundaries, y_boundaries, is_horizontal) =
compute_split_boundaries(ocr_bbox, &affected_cell_indices, cells, config);
if !x_boundaries.is_empty() || !y_boundaries.is_empty() {
detections.push(CrossCellDetection {
ocr_index: ocr_idx,
affected_cell_indices,
x_boundaries,
y_boundaries,
is_horizontal_split: is_horizontal,
});
}
}
}
detections
}
fn compute_split_boundaries(
ocr_bbox: &BoundingBox,
cell_indices: &[usize],
cells: &[TableCell],
config: &SplitConfig,
) -> (Vec<f32>, Vec<f32>, bool) {
if cell_indices.is_empty() {
return (Vec::new(), Vec::new(), true);
}
let mut x_edges: Vec<f32> = Vec::new();
let mut y_edges: Vec<f32> = Vec::new();
let ocr_x_min = ocr_bbox.x_min();
let ocr_x_max = ocr_bbox.x_max();
let ocr_y_min = ocr_bbox.y_min();
let ocr_y_max = ocr_bbox.y_max();
for &cell_idx in cell_indices {
let cell = &cells[cell_idx];
if config.split_horizontal {
let cell_x_min = cell.bbox.x_min();
let cell_x_max = cell.bbox.x_max();
if cell_x_min > ocr_x_min && cell_x_min < ocr_x_max {
x_edges.push(cell_x_min);
}
if cell_x_max > ocr_x_min && cell_x_max < ocr_x_max {
x_edges.push(cell_x_max);
}
}
if config.split_vertical {
let cell_y_min = cell.bbox.y_min();
let cell_y_max = cell.bbox.y_max();
if cell_y_min > ocr_y_min && cell_y_min < ocr_y_max {
y_edges.push(cell_y_min);
}
if cell_y_max > ocr_y_min && cell_y_max < ocr_y_max {
y_edges.push(cell_y_max);
}
}
}
x_edges.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
x_edges.dedup_by(|a, b| (*a - *b).abs() < 1.0);
y_edges.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
y_edges.dedup_by(|a, b| (*a - *b).abs() < 1.0);
let ocr_width = ocr_x_max - ocr_x_min;
let ocr_height = ocr_y_max - ocr_y_min;
let is_horizontal = if !x_edges.is_empty() && !y_edges.is_empty() {
ocr_width >= ocr_height
} else {
!x_edges.is_empty()
};
if is_horizontal {
(x_edges, Vec::new(), true)
} else {
(Vec::new(), y_edges, false)
}
}
pub fn split_ocr_box_at_cell_boundaries(
region: &TextRegion,
detection: &CrossCellDetection,
cells: &[TableCell],
) -> SplitOcrResult {
let original_bbox = region.bounding_box.clone();
let original_text = region
.text
.as_ref()
.map(|s| s.to_string())
.unwrap_or_default();
let confidence = region.confidence;
if original_text.is_empty() || detection.affected_cell_indices.is_empty() {
return SplitOcrResult {
original_bbox,
original_text,
confidence,
segments: Vec::new(),
};
}
let segments = if detection.is_horizontal_split && !detection.x_boundaries.is_empty() {
split_horizontally(
&original_bbox,
&original_text,
&detection.x_boundaries,
&detection.affected_cell_indices,
cells,
)
} else if !detection.y_boundaries.is_empty() {
split_vertically(
&original_bbox,
&original_text,
&detection.y_boundaries,
&detection.affected_cell_indices,
cells,
)
} else {
vec![SplitSegment {
bbox: original_bbox.clone(),
text: original_text.clone(),
cell_index: detection.affected_cell_indices[0],
}]
};
SplitOcrResult {
original_bbox,
original_text,
confidence,
segments,
}
}
fn split_horizontally(
ocr_bbox: &BoundingBox,
text: &str,
x_boundaries: &[f32],
cell_indices: &[usize],
cells: &[TableCell],
) -> Vec<SplitSegment> {
let mut segments = Vec::new();
let ocr_x_min = ocr_bbox.x_min();
let ocr_x_max = ocr_bbox.x_max();
let ocr_y_min = ocr_bbox.y_min();
let ocr_y_max = ocr_bbox.y_max();
let ocr_width = ocr_x_max - ocr_x_min;
if ocr_width <= 0.0 {
return segments;
}
let mut x_ranges: Vec<(f32, f32)> = Vec::new();
let mut prev_x = ocr_x_min;
for &boundary_x in x_boundaries {
if boundary_x > prev_x && boundary_x < ocr_x_max {
x_ranges.push((prev_x, boundary_x));
prev_x = boundary_x;
}
}
if prev_x < ocr_x_max {
x_ranges.push((prev_x, ocr_x_max));
}
if x_ranges.is_empty() {
return segments;
}
let total_width: f32 = x_ranges.iter().map(|(x1, x2)| x2 - x1).sum();
let ratios: Vec<f32> = x_ranges
.iter()
.map(|(x1, x2)| (x2 - x1) / total_width)
.collect();
let text_parts = split_text_by_ratio(text, &ratios);
for ((x1, x2), text_part) in x_ranges.iter().zip(text_parts.iter()) {
let segment_bbox = BoundingBox::from_coords(*x1, ocr_y_min, *x2, ocr_y_max);
let cell_index = find_best_matching_cell(&segment_bbox, cell_indices, cells);
segments.push(SplitSegment {
bbox: segment_bbox,
text: text_part.clone(),
cell_index,
});
}
segments
}
fn split_vertically(
ocr_bbox: &BoundingBox,
text: &str,
y_boundaries: &[f32],
cell_indices: &[usize],
cells: &[TableCell],
) -> Vec<SplitSegment> {
let mut segments = Vec::new();
let ocr_x_min = ocr_bbox.x_min();
let ocr_x_max = ocr_bbox.x_max();
let ocr_y_min = ocr_bbox.y_min();
let ocr_y_max = ocr_bbox.y_max();
let ocr_height = ocr_y_max - ocr_y_min;
if ocr_height <= 0.0 {
return segments;
}
let mut y_ranges: Vec<(f32, f32)> = Vec::new();
let mut prev_y = ocr_y_min;
for &boundary_y in y_boundaries {
if boundary_y > prev_y && boundary_y < ocr_y_max {
y_ranges.push((prev_y, boundary_y));
prev_y = boundary_y;
}
}
if prev_y < ocr_y_max {
y_ranges.push((prev_y, ocr_y_max));
}
if y_ranges.is_empty() {
return segments;
}
let lines: Vec<&str> = text.lines().collect();
if lines.len() >= y_ranges.len() {
let lines_per_segment = lines.len() / y_ranges.len();
let mut line_idx = 0;
for (i, (y1, y2)) in y_ranges.iter().enumerate() {
let segment_bbox = BoundingBox::from_coords(ocr_x_min, *y1, ocr_x_max, *y2);
let num_lines = if i == y_ranges.len() - 1 {
lines.len() - line_idx } else {
lines_per_segment
};
let segment_text: String = lines[line_idx..line_idx + num_lines].join("\n");
line_idx += num_lines;
let cell_index = find_best_matching_cell(&segment_bbox, cell_indices, cells);
segments.push(SplitSegment {
bbox: segment_bbox,
text: segment_text,
cell_index,
});
}
} else {
let total_height: f32 = y_ranges.iter().map(|(y1, y2)| y2 - y1).sum();
let ratios: Vec<f32> = y_ranges
.iter()
.map(|(y1, y2)| (y2 - y1) / total_height)
.collect();
let text_parts = split_text_by_ratio(text, &ratios);
for ((y1, y2), text_part) in y_ranges.iter().zip(text_parts.iter()) {
let segment_bbox = BoundingBox::from_coords(ocr_x_min, *y1, ocr_x_max, *y2);
let cell_index = find_best_matching_cell(&segment_bbox, cell_indices, cells);
segments.push(SplitSegment {
bbox: segment_bbox,
text: text_part.clone(),
cell_index,
});
}
}
segments
}
fn find_best_matching_cell(
segment_bbox: &BoundingBox,
candidate_indices: &[usize],
cells: &[TableCell],
) -> usize {
let mut best_cell_idx = candidate_indices.first().copied().unwrap_or(0);
let mut best_iou = 0.0f32;
for &cell_idx in candidate_indices {
if cell_idx >= cells.len() {
continue;
}
let iou = segment_bbox.iou(&cells[cell_idx].bbox);
if iou > best_iou {
best_iou = iou;
best_cell_idx = cell_idx;
}
}
best_cell_idx
}
pub fn split_text_by_ratio(text: &str, ratios: &[f32]) -> Vec<String> {
if ratios.is_empty() {
return vec![text.to_string()];
}
if ratios.len() == 1 {
return vec![text.to_string()];
}
let chars: Vec<char> = text.chars().collect();
let total_chars = chars.len();
if total_chars == 0 {
return ratios.iter().map(|_| String::new()).collect();
}
let total_ratio: f32 = ratios.iter().sum();
let normalized_ratios: Vec<f32> = if total_ratio > 0.0 {
ratios.iter().map(|r| r / total_ratio).collect()
} else {
let equal = 1.0 / ratios.len() as f32;
vec![equal; ratios.len()]
};
let mut result = Vec::with_capacity(ratios.len());
let mut start_idx = 0;
for (i, ratio) in normalized_ratios.iter().enumerate() {
let chars_for_segment = if i == normalized_ratios.len() - 1 {
total_chars - start_idx
} else {
(total_chars as f32 * ratio).round() as usize
};
let end_idx = (start_idx + chars_for_segment).min(total_chars);
let adjusted_end_idx = if end_idx < total_chars && end_idx > start_idx {
find_word_boundary(&chars, start_idx, end_idx)
} else {
end_idx
};
let segment: String = chars[start_idx..adjusted_end_idx].iter().collect();
result.push(segment.trim().to_string());
start_idx = adjusted_end_idx;
}
if start_idx < total_chars && !result.is_empty() {
let remaining: String = chars[start_idx..].iter().collect();
if let Some(last) = result.last_mut()
&& !remaining.trim().is_empty()
{
last.push_str(remaining.trim());
}
}
result
}
fn find_word_boundary(chars: &[char], start: usize, target_end: usize) -> usize {
let window = 5.min(target_end - start);
for offset in 0..window {
let check_idx = target_end.saturating_sub(offset);
if check_idx > start
&& check_idx < chars.len()
&& (chars[check_idx].is_whitespace()
|| chars[check_idx] == ','
|| chars[check_idx] == '.')
{
return check_idx + 1;
}
}
target_end
}
fn calculate_bbox_area(bbox: &BoundingBox) -> f32 {
let width = bbox.x_max() - bbox.x_min();
let height = bbox.y_max() - bbox.y_min();
(width * height).max(0.0)
}
pub fn create_expanded_ocr_for_table(
text_regions: &[TextRegion],
cells: &[TableCell],
config: Option<&SplitConfig>,
) -> (Vec<TextRegion>, std::collections::HashSet<usize>) {
let default_config = SplitConfig::default();
let config = config.unwrap_or(&default_config);
let detections = detect_cross_cell_ocr_boxes(text_regions, cells, config);
let mut expanded_regions = Vec::new();
let mut processed_indices = std::collections::HashSet::new();
for detection in &detections {
processed_indices.insert(detection.ocr_index);
let region = &text_regions[detection.ocr_index];
let split_result = split_ocr_box_at_cell_boundaries(region, detection, cells);
for segment in split_result.segments {
if !segment.text.is_empty() {
let new_region = TextRegion::with_recognition(
segment.bbox,
Some(Arc::from(segment.text.as_str())),
split_result.confidence,
);
expanded_regions.push(new_region);
}
}
}
(expanded_regions, processed_indices)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_region(x1: f32, y1: f32, x2: f32, y2: f32, text: &str) -> TextRegion {
TextRegion::with_recognition(
BoundingBox::from_coords(x1, y1, x2, y2),
Some(Arc::from(text)),
Some(0.9),
)
}
fn make_cell(x1: f32, y1: f32, x2: f32, y2: f32) -> TableCell {
TableCell::new(BoundingBox::from_coords(x1, y1, x2, y2), 0.9)
}
#[test]
fn test_detect_no_cross_cell_ocr() {
let regions = vec![make_region(10.0, 10.0, 90.0, 40.0, "Hello World")];
let cells = vec![
make_cell(0.0, 0.0, 100.0, 50.0),
make_cell(100.0, 0.0, 200.0, 50.0),
];
let config = SplitConfig::default();
let detections = detect_cross_cell_ocr_boxes(®ions, &cells, &config);
assert!(
detections.is_empty(),
"Should not detect cross-cell for box fully inside one cell"
);
}
#[test]
fn test_detect_cross_cell_horizontal() {
let regions = vec![make_region(50.0, 10.0, 150.0, 40.0, "Header Text")];
let cells = vec![
make_cell(0.0, 0.0, 100.0, 50.0),
make_cell(100.0, 0.0, 200.0, 50.0),
];
let config = SplitConfig::default();
let detections = detect_cross_cell_ocr_boxes(®ions, &cells, &config);
assert_eq!(detections.len(), 1, "Should detect one cross-cell OCR box");
assert_eq!(detections[0].affected_cell_indices.len(), 2);
assert!(detections[0].is_horizontal_split);
}
#[test]
fn test_split_text_by_ratio_equal() {
let text = "ABCDEFGHIJ";
let ratios = vec![0.5, 0.5];
let parts = split_text_by_ratio(text, &ratios);
assert_eq!(parts.len(), 2);
let total_len: usize = parts.iter().map(|s| s.len()).sum();
assert_eq!(total_len, text.len());
}
#[test]
fn test_split_text_by_ratio_unequal() {
let text = "Hello World";
let ratios = vec![0.3, 0.7];
let parts = split_text_by_ratio(text, &ratios);
assert_eq!(parts.len(), 2);
assert!(!parts[0].is_empty() || !parts[1].is_empty());
}
#[test]
fn test_split_text_empty() {
let text = "";
let ratios = vec![0.5, 0.5];
let parts = split_text_by_ratio(text, &ratios);
assert_eq!(parts.len(), 2);
assert!(parts[0].is_empty());
assert!(parts[1].is_empty());
}
#[test]
fn test_split_ocr_box_horizontal() {
let region = make_region(50.0, 10.0, 150.0, 40.0, "Col1 Col2");
let cells = vec![
make_cell(0.0, 0.0, 100.0, 50.0),
make_cell(100.0, 0.0, 200.0, 50.0),
];
let detection = CrossCellDetection {
ocr_index: 0,
affected_cell_indices: vec![0, 1],
x_boundaries: vec![100.0],
y_boundaries: Vec::new(),
is_horizontal_split: true,
};
let result = split_ocr_box_at_cell_boundaries(®ion, &detection, &cells);
assert_eq!(result.segments.len(), 2, "Should produce 2 segments");
let seg1_x_max = result.segments[0].bbox.x_max();
let seg2_x_min = result.segments[1].bbox.x_min();
assert!(
seg1_x_max <= seg2_x_min + 1.0,
"Segments should not overlap"
);
}
#[test]
fn test_create_expanded_ocr_for_table() {
let regions = vec![
make_region(10.0, 10.0, 90.0, 40.0, "Cell1 Only"), make_region(50.0, 10.0, 150.0, 40.0, "Across Cells"), ];
let cells = vec![
make_cell(0.0, 0.0, 100.0, 50.0),
make_cell(100.0, 0.0, 200.0, 50.0),
];
let config = SplitConfig::default();
let (expanded, processed) = create_expanded_ocr_for_table(®ions, &cells, Some(&config));
assert!(processed.contains(&1));
assert!(!processed.contains(&0));
assert!(!expanded.is_empty());
}
}