#![allow(clippy::unnecessary_wraps)]
use serde_json::Value;
use crate::document::types::DocItemLabel;
use crate::document::types_extended::{BoundingBox, Cluster, CoordOrigin, TextCell};
use crate::error::Result;
pub fn detect_layout_from_cells(json_str: &str) -> Result<Vec<Cluster>> {
#[cfg(feature = "docling-ffi")]
{
eprintln!(" đ Attempting ML-based layout detection...");
match detect_layout_with_ml(json_str) {
Ok(clusters) if !clusters.is_empty() => {
eprintln!(
" â
Using ML model (LayoutLMv3 ONNX) - {} regions",
clusters.len()
);
return Ok(clusters);
}
Ok(_) => {
eprintln!(" â ī¸ ML model returned empty, using rule-based");
}
Err(e) => {
eprintln!(" â ī¸ ML model failed: {e}, using rule-based");
}
}
}
#[cfg(not(feature = "docling-ffi"))]
{
eprintln!(" âšī¸ ML features not enabled, using rule-based");
}
detect_layout_with_rules(json_str)
}
#[cfg(feature = "docling-ffi")]
fn detect_layout_with_ml(json_str: &str) -> Result<Vec<Cluster>> {
use std::path::Path;
use crate::ml::layout_model::LayoutModel;
let model_path = Path::new("models/layout_model.onnx");
if !model_path.exists() {
eprintln!(" â ī¸ Model file not found: {}", model_path.display());
return Ok(Vec::new());
}
eprintln!(" đ§ Loading ONNX model...");
let _model = LayoutModel::new(model_path)?;
let json: Value = serde_json::from_str(json_str)?;
let Some(pages) = json["pages"].as_array() else {
return Ok(Vec::new());
};
let mut all_clusters = Vec::new();
let mut cluster_id = 0;
for (page_idx, page) in pages.iter().enumerate() {
let page_width = page["original"]["dimension"]["width"]
.as_f64()
.unwrap_or(612.0) as u32;
let page_height = page["original"]["dimension"]["height"]
.as_f64()
.unwrap_or(792.0) as u32;
eprintln!(
" đ Processing page {} ({}x{})...",
page_idx + 1,
page_width,
page_height
);
let cells = extract_text_cells_for_ml(page)?;
if cells.is_empty() {
eprintln!(" â ī¸ No cells found on page {}", page_idx + 1);
continue;
}
eprintln!(
" đ Page {}: {} cells extracted",
page_idx + 1,
cells.len()
);
if !cells.is_empty() {
eprintln!(
" đ Sample cell: '{}'",
cells[0].text.chars().take(30).collect::<String>()
);
}
let page_clusters =
cluster_cells_geometrically(&cells, &mut cluster_id, page_width, page_height)?;
all_clusters.extend(page_clusters);
eprintln!(
" â
Found {} regions on page {}",
all_clusters.len() - cluster_id,
page_idx + 1
);
}
Ok(all_clusters)
}
fn extract_text_cells_for_ml(page: &Value) -> Result<Vec<TextCell>> {
let mut cells = Vec::new();
if let Some(cells_obj) = page["original"]["cells"].as_object() {
if let Some(cell_data) = cells_obj["data"].as_array() {
for (idx, cell) in cell_data.iter().enumerate() {
if let Some(cell_array) = cell.as_array() {
if let (Some(x0), Some(y0), Some(x1), Some(y1), Some(text)) = (
cell_array.first().and_then(|v| v.as_f64()),
cell_array.get(1).and_then(|v| v.as_f64()),
cell_array.get(2).and_then(|v| v.as_f64()),
cell_array.get(3).and_then(|v| v.as_f64()),
cell_array.get(12).and_then(|v| v.as_str()),
) {
let font_size = cell_array
.get(15)
.and_then(|v| v.as_f64())
.map(|f| f as f32);
let font_name = cell_array
.get(18)
.and_then(|v| v.as_str())
.map(|s| s.to_string());
cells.push(TextCell {
index: idx,
text: text.to_string(),
bbox: BoundingBox {
l: x0,
t: y0,
r: x1,
b: y1,
origin: CoordOrigin::TopLeft,
},
font_name,
font_size,
confidence: 1.0,
from_ocr: false,
});
}
}
}
}
}
Ok(cells)
}
fn cluster_cells_geometrically(
cells: &[TextCell],
cluster_id: &mut usize,
page_width: u32,
_page_height: u32,
) -> Result<Vec<Cluster>> {
let mut clusters = Vec::new();
if cells.is_empty() {
return Ok(clusters);
}
let avg_font_size =
cells.iter().filter_map(|c| c.font_size).sum::<f32>() / cells.len().max(1) as f32;
let mut rows: Vec<Vec<&TextCell>> = Vec::new();
let row_threshold = 10.0;
let mut sorted_cells: Vec<&TextCell> = cells.iter().collect();
sorted_cells.sort_by(|a, b| a.bbox.t.partial_cmp(&b.bbox.t).unwrap());
for cell in sorted_cells {
let mut added = false;
for row in &mut rows {
if let Some(first) = row.first() {
if (cell.bbox.t - first.bbox.t).abs() < row_threshold {
row.push(cell);
added = true;
break;
}
}
}
if !added {
rows.push(vec![cell]);
}
}
let mut current_region_cells = Vec::new();
let mut current_label = DocItemLabel::Paragraph;
for (row_idx, row) in rows.iter().enumerate() {
let row_text: String = row
.iter()
.map(|c| c.text.as_str())
.collect::<Vec<_>>()
.join(" ");
let row_font_size =
row.iter().filter_map(|c| c.font_size).sum::<f32>() / row.len().max(1) as f32;
let label = classify_row(
&row_text,
row_font_size,
avg_font_size,
row_idx,
page_width,
row,
);
let should_split = if current_region_cells.is_empty() {
false
} else if label != current_label {
true
} else if let Some(last_row) = rows.get(row_idx.saturating_sub(1)) {
if let (Some(last_cell), Some(curr_cell)) = (last_row.first(), row.first()) {
(curr_cell.bbox.t - last_cell.bbox.b) > 20.0 } else {
false
}
} else {
false
};
if should_split {
if !current_region_cells.is_empty() {
clusters.push(create_cluster_from_cells(
current_region_cells.clone(),
*cluster_id,
current_label,
));
*cluster_id += 1;
current_region_cells.clear();
}
}
current_label = label;
current_region_cells.extend(row.iter().copied());
}
if !current_region_cells.is_empty() {
clusters.push(create_cluster_from_cells(
current_region_cells,
*cluster_id,
current_label,
));
*cluster_id += 1;
}
Ok(clusters)
}
fn classify_row(
text: &str,
font_size: f32,
avg_font_size: f32,
row_idx: usize,
_page_width: u32,
cells: &[&TextCell],
) -> DocItemLabel {
let text_lower = text.to_lowercase();
if row_idx == 0 && font_size > avg_font_size * 1.5 && text.len() < 100 {
return DocItemLabel::Title;
}
if font_size > avg_font_size * 1.2
&& text.len() < 150
&& (text.ends_with(':') || !text.ends_with('.'))
{
return DocItemLabel::SectionHeader;
}
if text.trim_start().starts_with('-')
|| text.trim_start().starts_with('âĸ')
|| text.trim_start().starts_with('*')
|| (text.len() > 2
&& text.chars().next().unwrap().is_numeric()
&& (text.chars().nth(1) == Some('.') || text.chars().nth(1) == Some(')')))
{
return DocItemLabel::ListItem;
}
if cells.len() > 4 {
return DocItemLabel::Table;
}
if text_lower.contains("figure ")
|| text_lower.contains("table ")
|| text_lower.contains("fig. ")
|| text_lower.starts_with("fig ")
{
return DocItemLabel::Caption;
}
DocItemLabel::Paragraph
}
fn create_cluster_from_cells(cells: Vec<&TextCell>, id: usize, label: DocItemLabel) -> Cluster {
let min_x = cells.iter().map(|c| c.bbox.l).fold(f64::INFINITY, f64::min);
let min_y = cells.iter().map(|c| c.bbox.t).fold(f64::INFINITY, f64::min);
let max_x = cells
.iter()
.map(|c| c.bbox.r)
.fold(f64::NEG_INFINITY, f64::max);
let max_y = cells
.iter()
.map(|c| c.bbox.b)
.fold(f64::NEG_INFINITY, f64::max);
let bbox = BoundingBox {
l: min_x,
t: min_y,
r: max_x,
b: max_y,
origin: CoordOrigin::TopLeft,
};
let owned_cells: Vec<TextCell> = cells.iter().map(|c| (*c).clone()).collect();
Cluster {
id,
label,
bbox,
cells: owned_cells,
confidence: 0.85, }
}
fn detect_layout_with_rules(json_str: &str) -> Result<Vec<Cluster>> {
let json: Value = serde_json::from_str(json_str)?;
let mut clusters = Vec::new();
let mut cluster_id = 0;
if let Some(pages) = json["pages"].as_array() {
for (page_idx, page) in pages.iter().enumerate() {
let page_clusters = detect_page_layout(page, page_idx, &mut cluster_id)?;
clusters.extend(page_clusters);
}
}
Ok(clusters)
}
fn detect_page_layout(
page: &Value,
_page_idx: usize,
cluster_id: &mut usize,
) -> Result<Vec<Cluster>> {
let mut clusters = Vec::new();
let cells = extract_text_cells(page)?;
if cells.is_empty() {
return Ok(clusters);
}
let (_page_width, page_height) = get_page_dimensions(page);
let table_clusters = detect_tables(&cells, *cluster_id);
clusters.extend(table_clusters.iter().cloned());
*cluster_id += table_clusters.len();
let title_clusters = detect_titles(&cells, page_height, *cluster_id);
clusters.extend(title_clusters.iter().cloned());
*cluster_id += title_clusters.len();
let header_clusters = detect_headers(&cells, *cluster_id);
clusters.extend(header_clusters.iter().cloned());
*cluster_id += header_clusters.len();
let list_clusters = detect_lists(&cells, *cluster_id);
clusters.extend(list_clusters.iter().cloned());
*cluster_id += list_clusters.len();
let used_cells: Vec<usize> = clusters
.iter()
.flat_map(|c| c.cells.iter().map(|cell| cell.index))
.collect();
let remaining_cells: Vec<TextCell> = cells
.into_iter()
.filter(|cell| !used_cells.contains(&cell.index))
.collect();
if !remaining_cells.is_empty() {
clusters.push(Cluster {
id: *cluster_id,
label: DocItemLabel::Paragraph,
bbox: compute_bounding_box(&remaining_cells),
cells: remaining_cells,
confidence: 0.9,
});
*cluster_id += 1;
}
Ok(clusters)
}
fn extract_text_cells(page: &Value) -> Result<Vec<TextCell>> {
let mut cells = Vec::new();
if let Some(cells_obj) = page["original"]["cells"].as_object() {
if let Some(cell_data) = cells_obj["data"].as_array() {
for (idx, cell) in cell_data.iter().enumerate() {
if let Some(cell_array) = cell.as_array() {
if let (Some(x0), Some(y0), Some(x1), Some(y1), Some(text)) = (
cell_array.first().and_then(|v| v.as_f64()),
cell_array.get(1).and_then(|v| v.as_f64()),
cell_array.get(2).and_then(|v| v.as_f64()),
cell_array.get(3).and_then(|v| v.as_f64()),
cell_array.get(12).and_then(|v| v.as_str()),
) {
let font_size = cell_array
.get(15)
.and_then(|v| v.as_f64())
.map(|f| f as f32);
let font_name = cell_array
.get(18)
.and_then(|v| v.as_str())
.map(|s| s.to_string());
cells.push(TextCell {
index: idx,
text: text.to_string(),
bbox: BoundingBox {
l: x0,
t: y0,
r: x1,
b: y1,
origin: CoordOrigin::TopLeft,
},
font_name,
font_size,
confidence: 1.0,
from_ocr: false,
});
}
}
}
}
}
Ok(cells)
}
fn get_page_dimensions(page: &Value) -> (f64, f64) {
let width = page["original"]["dimension"]["width"]
.as_f64()
.unwrap_or(612.0);
let height = page["original"]["dimension"]["height"]
.as_f64()
.unwrap_or(792.0);
(width, height)
}
fn detect_tables(_cells: &[TextCell], _start_id: usize) -> Vec<Cluster> {
Vec::new() }
fn detect_titles(cells: &[TextCell], page_height: f64, start_id: usize) -> Vec<Cluster> {
let mut titles = Vec::new();
let top_threshold = page_height * 0.8; let avg_font_size =
cells.iter().filter_map(|c| c.font_size).sum::<f32>() / cells.len().max(1) as f32;
let title_cells: Vec<TextCell> = cells
.iter()
.filter(|cell| {
cell.bbox.t > top_threshold && cell.font_size.unwrap_or(0.0) > avg_font_size * 1.3
})
.cloned()
.collect();
if !title_cells.is_empty() {
titles.push(Cluster {
id: start_id,
label: DocItemLabel::Title,
bbox: compute_bounding_box(&title_cells),
cells: title_cells,
confidence: 0.85,
});
}
titles
}
fn detect_headers(_cells: &[TextCell], _start_id: usize) -> Vec<Cluster> {
Vec::new() }
fn detect_lists(_cells: &[TextCell], _start_id: usize) -> Vec<Cluster> {
Vec::new() }
fn compute_bounding_box(cells: &[TextCell]) -> BoundingBox {
if cells.is_empty() {
return BoundingBox {
l: 0.0,
t: 0.0,
r: 0.0,
b: 0.0,
origin: CoordOrigin::TopLeft,
};
}
let min_x = cells.iter().map(|c| c.bbox.l).fold(f64::INFINITY, f64::min);
let min_y = cells.iter().map(|c| c.bbox.t).fold(f64::INFINITY, f64::min);
let max_x = cells
.iter()
.map(|c| c.bbox.r)
.fold(f64::NEG_INFINITY, f64::max);
let max_y = cells
.iter()
.map(|c| c.bbox.b)
.fold(f64::NEG_INFINITY, f64::max);
BoundingBox {
l: min_x,
t: min_y,
r: max_x,
b: max_y,
origin: CoordOrigin::TopLeft,
}
}