mod utils;
use clap::Parser;
use image::RgbImage;
use oar_ocr::domain::structure::TableType;
use oar_ocr::domain::tasks::{
FormulaRecognitionConfig, LayoutDetectionConfig, TextDetectionConfig, TextRecognitionConfig,
};
use oar_ocr::oarocr::OARStructureBuilder;
use oar_ocr::processors::LimitType;
use std::path::PathBuf;
use std::sync::Arc;
use tracing::{error, info, warn};
use utils::device_config::parse_device_config;
use utils::pdf::{PdfDocument, is_pdf_file};
#[derive(Parser)]
#[command(name = "structure")]
#[command(about = "Run document structure analysis with optional table/formula/OCR components")]
struct Args {
#[arg(long = "layout-model")]
layout_model: PathBuf,
#[arg(long = "layout-model-name", default_value = "PP-DocLayout_plus-L")]
layout_model_name: String,
#[arg(required = true)]
images: Vec<PathBuf>,
#[arg(long)]
orientation_model: Option<PathBuf>,
#[arg(long)]
rectification_model: Option<PathBuf>,
#[arg(long = "region-model")]
region_model: Option<PathBuf>,
#[arg(long = "region-model-name", default_value = "PP-DocBlockLayout")]
region_model_name: String,
#[arg(long = "table-cls-model")]
table_cls_model: Option<PathBuf>,
#[arg(long = "table-orientation-model")]
table_orientation_model: Option<PathBuf>,
#[arg(long = "wired-structure-model")]
wired_structure_model: Option<PathBuf>,
#[arg(long = "wired-structure-model-name", default_value = "SLANeXt_wired")]
wired_structure_model_name: String,
#[arg(long = "wireless-structure-model")]
wireless_structure_model: Option<PathBuf>,
#[arg(long = "wireless-structure-model-name", default_value = "SLANet_plus")]
wireless_structure_model_name: String,
#[arg(long = "wired-cell-model")]
wired_cell_model: Option<PathBuf>,
#[arg(
long = "wired-cell-model-name",
default_value = "RT-DETR-L_wired_table_cell_det"
)]
wired_cell_model_name: String,
#[arg(long = "wireless-cell-model")]
wireless_cell_model: Option<PathBuf>,
#[arg(
long = "wireless-cell-model-name",
default_value = "RT-DETR-L_wireless_table_cell_det"
)]
wireless_cell_model_name: String,
#[arg(long = "table-structure-dict")]
table_structure_dict: Option<PathBuf>,
#[arg(long, default_value_t = false, action = clap::ArgAction::Set)]
use_e2e_wired_table_rec: bool,
#[arg(long, default_value_t = true, action = clap::ArgAction::Set)]
use_e2e_wireless_table_rec: bool,
#[arg(long, default_value_t = false, action = clap::ArgAction::Set)]
use_wired_table_cells_trans_to_html: bool,
#[arg(long, default_value_t = false, action = clap::ArgAction::Set)]
use_wireless_table_cells_trans_to_html: bool,
#[arg(long = "formula-model")]
formula_model: Option<PathBuf>,
#[arg(long = "formula-tokenizer")]
formula_tokenizer: Option<PathBuf>,
#[arg(long = "formula-type")]
formula_type: Option<String>,
#[arg(long = "seal-model")]
seal_model: Option<PathBuf>,
#[arg(long = "text-det-model")]
text_det_model: Option<PathBuf>,
#[arg(long = "text-det-model-name", default_value = "PP-OCRv5_server_det")]
text_det_model_name: String,
#[arg(long = "text-rec-model")]
text_rec_model: Option<PathBuf>,
#[arg(long = "text-rec-model-name", default_value = "PP-OCRv5_server_rec")]
text_rec_model_name: String,
#[arg(long = "text-dict-path")]
text_dict_path: Option<PathBuf>,
#[arg(long = "textline-orientation-model")]
textline_orientation_model: Option<PathBuf>,
#[arg(long, default_value = "cuda")]
device: String,
#[arg(long, default_value = "0.5")]
layout_score_thresh: f32,
#[arg(long, default_value = "0.5")]
layout_nms_thresh: f32,
#[arg(long, default_value = "true")]
layout_nms: bool,
#[arg(long, default_value_t = 0.0)]
formula_score_thresh: f32,
#[arg(long, default_value_t = 1536)]
formula_max_length: usize,
#[arg(long, default_value = "0.3")]
det_score_thresh: f32,
#[arg(long, default_value = "0.6")]
det_box_thresh: f32,
#[arg(long, default_value = "1.5")]
det_unclip_ratio: f32,
#[arg(long, default_value = "1000")]
det_max_candidates: usize,
#[arg(long, default_value = "0.0")]
rec_score_thresh: f32,
#[arg(long, default_value_t = 320)]
text_rec_max_length: usize,
#[arg(long, default_value = "0.2")]
seal_det_score_thresh: f32,
#[arg(long, default_value = "0.6")]
seal_det_box_thresh: f32,
#[arg(long, default_value = "0.5")]
seal_det_unclip_ratio: f32,
#[arg(long, default_value = "0.4")]
table_det_box_thresh: f32,
#[arg(short, long, default_value = "output/structure_analysis")]
output_dir: PathBuf,
#[arg(long = "to-json", default_value_t = false)]
to_json: bool,
#[arg(long = "to-markdown", default_value_t = false)]
to_markdown: bool,
#[arg(long = "to-html", default_value_t = false)]
to_html: bool,
#[arg(long)]
vis: bool,
}
enum InputSource {
ImageFile(PathBuf),
PdfPage {
pdf_path: PathBuf,
page_number: usize,
image: Arc<RgbImage>,
},
}
impl InputSource {
fn path(&self) -> String {
match self {
Self::ImageFile(p) => p.to_string_lossy().to_string(),
Self::PdfPage {
pdf_path,
page_number,
..
} => {
format!("{}#{}", pdf_path.to_string_lossy(), page_number)
}
}
}
fn into_image(self) -> Result<RgbImage, Box<dyn std::error::Error>> {
match self {
Self::ImageFile(p) => oar_ocr::utils::load_image(&p).map_err(|e| e.into()),
Self::PdfPage { image, .. } => {
Ok(Arc::try_unwrap(image).unwrap_or_else(|arc| (*arc).clone()))
}
}
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
utils::init_tracing();
let args = Args::parse();
info!("Running document structure analysis example");
if !args.layout_model.exists() {
error!("Layout model not found: {}", args.layout_model.display());
return Err("Layout model not found".into());
}
validate_optional_path("orientation model", args.orientation_model.as_ref())?;
validate_optional_path("rectification model", args.rectification_model.as_ref())?;
validate_optional_path("region model", args.region_model.as_ref())?;
validate_optional_path("table classification model", args.table_cls_model.as_ref())?;
validate_optional_path("wired structure model", args.wired_structure_model.as_ref())?;
validate_optional_path(
"wireless structure model",
args.wireless_structure_model.as_ref(),
)?;
validate_optional_path("wired cell model", args.wired_cell_model.as_ref())?;
validate_optional_path("wireless cell model", args.wireless_cell_model.as_ref())?;
validate_optional_path("table structure dict", args.table_structure_dict.as_ref())?;
validate_optional_path("formula model", args.formula_model.as_ref())?;
validate_optional_path("formula tokenizer", args.formula_tokenizer.as_ref())?;
validate_optional_path("seal model", args.seal_model.as_ref())?;
validate_optional_path("text detection model", args.text_det_model.as_ref())?;
validate_optional_path("text recognition model", args.text_rec_model.as_ref())?;
validate_optional_path("text dict", args.text_dict_path.as_ref())?;
validate_optional_path(
"text line orientation model",
args.textline_orientation_model.as_ref(),
)?;
let mut input_sources: Vec<InputSource> = Vec::new();
for input_path in &args.images {
if !input_path.exists() {
error!("Input not found: {}", input_path.display());
continue;
}
if is_pdf_file(input_path) {
info!("Processing PDF file: {}", input_path.display());
let pdf_doc = match PdfDocument::open(input_path) {
Ok(doc) => doc,
Err(e) => {
error!("Failed to open PDF {}: {}", input_path.display(), e);
continue;
}
};
let page_count = pdf_doc.page_count();
info!("PDF has {} page(s)", page_count);
for page_num in 1..=page_count {
match pdf_doc.render_page(page_num, None) {
Ok(rendered) => {
info!(
" Page {} rendered: {}x{}",
page_num, rendered.width, rendered.height
);
input_sources.push(InputSource::PdfPage {
pdf_path: input_path.clone(),
page_number: page_num,
image: Arc::new(rendered.image),
});
}
Err(e) => {
error!(" Failed to render page {}: {}", page_num, e);
}
}
}
} else {
input_sources.push(InputSource::ImageFile(input_path.clone()));
}
}
if input_sources.is_empty() {
return Err("No valid inputs provided".into());
}
let has_table_structure =
args.wired_structure_model.is_some() || args.wireless_structure_model.is_some();
if has_table_structure && args.table_structure_dict.is_none() {
return Err("Table structure recognition requires --table-structure-dict".into());
}
if args.formula_model.is_some()
&& (args.formula_tokenizer.is_none() || args.formula_type.is_none())
{
return Err("Formula recognition requires --formula-tokenizer and --formula-type".into());
}
let has_partial_ocr = args.text_det_model.is_some()
|| args.text_rec_model.is_some()
|| args.text_dict_path.is_some();
if has_partial_ocr
&& (args.text_det_model.is_none()
|| args.text_rec_model.is_none()
|| args.text_dict_path.is_none())
{
warn!(
"OCR integration ignored because detection/recognition/dictionary are not all provided"
);
}
let mut layout_config = LayoutDetectionConfig::with_pp_structurev3_defaults();
layout_config.score_threshold = args.layout_score_thresh;
layout_config.layout_nms = args.layout_nms;
let formula_config = FormulaRecognitionConfig {
score_threshold: args.formula_score_thresh,
max_length: args.formula_max_length,
};
let text_det_config = TextDetectionConfig {
score_threshold: args.det_score_thresh,
box_threshold: args.det_box_thresh,
unclip_ratio: args.det_unclip_ratio,
max_candidates: args.det_max_candidates,
limit_side_len: Some(736),
limit_type: Some(LimitType::Min),
max_side_len: None,
};
let text_rec_config = TextRecognitionConfig {
score_threshold: args.rec_score_thresh,
max_text_length: args.text_rec_max_length,
};
let mut builder =
OARStructureBuilder::new(&args.layout_model).layout_detection_config(layout_config);
builder = builder.layout_model_name(&args.layout_model_name);
if let Some(config) = parse_device_config(&args.device)? {
builder = builder.ort_session(config);
}
if let Some(path) = args.orientation_model {
builder = builder.with_document_orientation(path);
}
if let Some(path) = args.rectification_model {
builder = builder.with_document_rectification(path);
}
if let Some(path) = args.region_model {
builder = builder
.with_region_detection(path)
.region_model_name(&args.region_model_name);
}
if let Some(path) = args.table_cls_model {
builder = builder.with_table_classification(path);
}
if let Some(path) = args.table_orientation_model {
builder = builder.with_table_orientation(path);
}
if let Some(path) = args.wired_structure_model {
builder = builder
.with_wired_table_structure(path)
.wired_table_structure_model_name(&args.wired_structure_model_name);
}
if let Some(path) = args.wireless_structure_model {
builder = builder
.with_wireless_table_structure(path)
.wireless_table_structure_model_name(&args.wireless_structure_model_name);
}
if let Some(path) = args.wired_cell_model {
builder = builder
.with_wired_table_cell_detection(path)
.wired_table_cell_model_name(&args.wired_cell_model_name);
}
if let Some(path) = args.wireless_cell_model {
builder = builder
.with_wireless_table_cell_detection(path)
.wireless_table_cell_model_name(&args.wireless_cell_model_name);
}
if let Some(path) = args.table_structure_dict {
builder = builder.table_structure_dict_path(path);
}
builder = builder.use_e2e_wired_table_rec(args.use_e2e_wired_table_rec);
builder = builder.use_e2e_wireless_table_rec(args.use_e2e_wireless_table_rec);
builder = builder.use_wired_table_cells_trans_to_html(args.use_wired_table_cells_trans_to_html);
builder =
builder.use_wireless_table_cells_trans_to_html(args.use_wireless_table_cells_trans_to_html);
if let Some(path) = args.formula_model {
let Some(tokenizer) = args.formula_tokenizer else {
return Err("Formula recognition requires --formula-tokenizer".into());
};
let Some(model_type) = args.formula_type else {
return Err("Formula recognition requires --formula-type".into());
};
builder = builder
.with_formula_recognition(path, tokenizer, model_type)
.formula_recognition_config(formula_config);
}
if let Some(path) = args.seal_model {
builder = builder.with_seal_text_detection(path);
}
if let Some(path) = args.textline_orientation_model {
builder = builder.with_text_line_orientation(path);
}
if let (Some(text_det_model), Some(text_rec_model), Some(text_dict_path)) = (
&args.text_det_model,
&args.text_rec_model,
&args.text_dict_path,
) {
builder = builder
.with_ocr(
text_det_model.clone(),
text_rec_model.clone(),
text_dict_path.clone(),
)
.text_detection_model_name(&args.text_det_model_name)
.text_recognition_model_name(&args.text_rec_model_name)
.text_detection_config(text_det_config)
.text_recognition_config(text_rec_config);
}
let analyzer = builder.build()?;
let mut all_results: Vec<oar_ocr::domain::structure::StructureResult> = Vec::new();
let mut images: Vec<image::RgbImage> = Vec::new();
let mut source_meta: Vec<(String, String)> = Vec::new();
for source in std::mem::take(&mut input_sources) {
let source_path = source.path();
let source_stem = match &source {
InputSource::ImageFile(p) => p
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("result")
.to_string(),
InputSource::PdfPage {
pdf_path,
page_number,
..
} => {
format!(
"{}_page_{:03}",
pdf_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("pdf"),
page_number
)
}
};
match source.into_image() {
Ok(img) => {
images.push(img);
source_meta.push((source_path, source_stem));
}
Err(err) => {
error!("Failed to load image {}: {}", source_path, err);
}
}
}
info!(
"Batch processing {} image(s) with cross-page formula batching",
images.len()
);
let batch_results = analyzer.predict_images(images);
for (idx, (page_result, (source_path, source_stem))) in
batch_results.into_iter().zip(source_meta).enumerate()
{
let mut result = match page_result {
Ok(res) => res,
Err(err) => {
error!("Failed to analyze {}: {}", source_path, err);
continue;
}
};
info!("\nProcessed input {}: {}", idx + 1, source_path);
result.input_path = std::sync::Arc::from(source_path.clone());
all_results.push(result.clone());
if let Err(err) = result.save_results(&args.output_dir, args.to_json, args.to_html) {
error!("Failed to save results for {}: {}", source_path, err);
}
if args.vis {
let vis_path = args.output_dir.join(format!("{}.png", source_stem));
if let Err(err) =
utils::visualization::visualize_structure_results(&result, &vis_path, None)
{
error!("Failed to save visualization: {}", err);
} else {
info!(" Visualization saved to: {}", vis_path.display());
}
}
if let Some(angle) = result.orientation_angle {
info!(" Orientation corrected by {:.0} degrees", angle);
}
info!(" Layout elements: {}", result.layout_elements.len());
for (elem_idx, elem) in result.layout_elements.iter().enumerate() {
let label = elem
.label
.as_deref()
.unwrap_or_else(|| elem.element_type.as_str());
info!(
" [{}] {} ({:.1}%) at [{:.1},{:.1}] - [{:.1},{:.1}]",
elem_idx + 1,
label,
elem.confidence * 100.0,
elem.bbox.x_min(),
elem.bbox.y_min(),
elem.bbox.x_max(),
elem.bbox.y_max()
);
}
if let Some(regions) = &result.region_blocks {
info!(" Region blocks: {}", regions.len());
for (region_idx, region) in regions.iter().enumerate() {
let order = region
.order_index
.map(|v| v.to_string())
.unwrap_or_else(|| "n/a".to_string());
info!(
" [{}] order={} elements={} ({:.1}%) at [{:.1},{:.1}] - [{:.1},{:.1}]",
region_idx + 1,
order,
region.element_indices.len(),
region.confidence * 100.0,
region.bbox.x_min(),
region.bbox.y_min(),
region.bbox.x_max(),
region.bbox.y_max()
);
}
} else {
info!(" Region blocks: not enabled");
}
info!(" Tables: {}", result.tables.len());
for (table_idx, table) in result.tables.iter().enumerate() {
let table_type = match table.table_type {
TableType::Wired => "wired",
TableType::Wireless => "wireless",
TableType::Unknown => "unknown",
};
let cls_conf = table
.classification_confidence
.map(|c| format!("{:.1}%", c * 100.0))
.unwrap_or_else(|| "n/a".to_string());
let html_info = table
.html_structure
.as_ref()
.map(|html| format!("html len {}", html.len()))
.unwrap_or_else(|| "no structure".to_string());
info!(
" [{}] type={} cls={} cells={} {}",
table_idx + 1,
table_type,
cls_conf,
table.cells.len(),
html_info
);
}
info!(" Formulas: {}", result.formulas.len());
for (formula_idx, formula) in result.formulas.iter().enumerate() {
info!(
" [{}] {} ({:.1}%)",
formula_idx + 1,
formula.latex,
formula.confidence * 100.0
);
}
if let Some(text_regions) = &result.text_regions {
info!(" OCR regions: {}", text_regions.len());
for (region_idx, region) in text_regions.iter().enumerate() {
let text = region
.text
.as_ref()
.map(|t| t.to_string())
.unwrap_or_else(|| "<no text>".to_string());
let score = region.confidence.unwrap_or(0.0) * 100.0;
info!(
" [{}] \"{}\" ({:.1}%) at [{:.1},{:.1}] - [{:.1},{:.1}]",
region_idx + 1,
text,
score,
region.bounding_box.x_min(),
region.bounding_box.y_min(),
region.bounding_box.x_max(),
region.bounding_box.y_max()
);
}
} else {
info!(" OCR regions: not enabled");
}
}
if !all_results.is_empty() {
let is_multi_page_pdf =
all_results.len() > 1 && all_results.iter().any(|r| r.input_path.contains('#'));
if is_multi_page_pdf {
info!("\nConcatenating {} pages", all_results.len());
}
let base_name = {
let path_str: &str = &all_results[0].input_path;
let path = if let Some(hash_idx) = path_str.rfind('#') {
std::path::Path::new(&path_str[..hash_idx])
} else {
std::path::Path::new(path_str)
};
path.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("output")
.to_string()
};
if args.to_markdown {
match utils::markdown::export_concatenated_markdown_with_images(
&all_results,
&args.output_dir,
) {
Ok(concat_md) => {
let md_path = args.output_dir.join(format!("{}.md", base_name));
if let Err(err) = std::fs::write(&md_path, concat_md) {
error!("Failed to save markdown: {}", err);
} else {
info!("Markdown saved to: {}", md_path.display());
}
}
Err(err) => {
error!("Failed to generate markdown with images: {}", err);
}
}
}
if args.to_json {
let json_path = args.output_dir.join(format!("{}.json", base_name));
let json_file = match std::fs::File::create(&json_path) {
Ok(f) => f,
Err(e) => {
return Err(format!("Failed to create JSON file: {}", e).into());
}
};
if let Err(e) = serde_json::to_writer_pretty(json_file, &all_results) {
error!("Failed to save JSON: {}", e);
} else {
info!("JSON saved to: {}", json_path.display());
}
}
if is_multi_page_pdf {
info!("=== Multi-page PDF processing complete ===");
}
}
Ok(())
}
fn validate_optional_path(
label: &str,
path: Option<&PathBuf>,
) -> Result<(), Box<dyn std::error::Error>> {
if let Some(p) = path
&& !p.exists()
{
return Err(format!("{label} not found: {}", p.display()).into());
}
Ok(())
}