mod utils;
use clap::Parser;
use oar_ocr::predictors::TableStructureRecognitionPredictor;
use oar_ocr::utils::load_image;
use std::path::PathBuf;
use std::time::Instant;
use tracing::{error, info};
use utils::device_config::parse_device_config;
#[derive(Parser)]
#[command(name = "table_structure_recognition")]
#[command(about = "Table Structure Recognition Example - recognizes table structure as HTML")]
struct Args {
#[arg(short, long)]
model_path: PathBuf,
#[arg(required = true)]
images: Vec<PathBuf>,
#[arg(long)]
dict_path: PathBuf,
#[arg(long)]
model_name: Option<String>,
#[arg(long, default_value = "cpu")]
device: String,
#[arg(long, default_value = "0.5")]
score_thresh: f32,
#[arg(long, default_value = "500")]
max_length: usize,
#[arg(long)]
input_height: Option<u32>,
#[arg(long)]
input_width: Option<u32>,
#[arg(short = 'o', long = "output-dir")]
output_dir: Option<PathBuf>,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
utils::init_tracing();
info!("Table Structure Recognition Example");
if !args.model_path.exists() {
error!("Model file not found: {}", args.model_path.display());
return Err("Model file not found".into());
}
if !args.dict_path.exists() {
error!("Dictionary file not found: {}", args.dict_path.display());
return Err("Dictionary file not found".into());
}
let existing_images: Vec<PathBuf> = args
.images
.iter()
.filter(|path| {
let exists = path.exists();
if !exists {
error!("Image file not found: {}", path.display());
}
exists
})
.cloned()
.collect();
if existing_images.is_empty() {
error!("No valid image files found");
return Err("No valid image files found".into());
}
info!("Using device: {}", args.device);
let ort_config = parse_device_config(&args.device)?.unwrap_or_default();
if ort_config.execution_providers.is_some() {
info!("CUDA execution provider configured successfully");
}
info!("Recognition Configuration:");
info!(" Score threshold: {}", args.score_thresh);
info!(" Max structure length: {}", args.max_length);
if let Some(ref model_name) = args.model_name {
info!(" Model preset: {}", model_name);
} else {
info!(" Model preset: <auto-detect from path>");
}
match (args.input_height, args.input_width) {
(Some(height), Some(width)) => info!(" Input shape override: ({}, {})", height, width),
(None, None) => info!(" Input shape override: <adapter default>"),
_ => {
return Err("Both --input-height and --input-width must be provided together".into());
}
}
info!(" Dictionary: {}", args.dict_path.display());
info!("Building table structure recognition predictor...");
info!(" Model: {}", args.model_path.display());
let start_build = Instant::now();
let mut predictor_builder = TableStructureRecognitionPredictor::builder()
.score_threshold(args.score_thresh)
.dict_path(&args.dict_path)
.with_ort_config(ort_config);
if let Some(ref model_name) = args.model_name {
predictor_builder = predictor_builder.model_name(model_name);
}
if let (Some(height), Some(width)) = (args.input_height, args.input_width) {
predictor_builder = predictor_builder.input_shape(height, width);
}
let predictor = predictor_builder.build(&args.model_path)?;
info!(
"Predictor built in {:.2}ms",
start_build.elapsed().as_secs_f64() * 1000.0
);
info!("Processing {} images...", existing_images.len());
let mut images = Vec::new();
for image_path in &existing_images {
match load_image(image_path) {
Ok(rgb_img) => {
info!(
"Loaded image: {} ({}x{})",
image_path.display(),
rgb_img.width(),
rgb_img.height()
);
images.push(rgb_img);
}
Err(e) => {
error!("Failed to load image {}: {}", image_path.display(), e);
continue;
}
}
}
if images.is_empty() {
error!("No images could be loaded for processing");
return Err("No images could be loaded".into());
}
info!("Running table structure recognition...");
let start = Instant::now();
let output = predictor.predict(images)?;
let duration = start.elapsed();
info!(
"Recognition completed in {:.2}ms",
duration.as_secs_f64() * 1000.0
);
info!("\n=== Structure Recognition Results ===");
for (idx, (structure, bboxes)) in output
.structures
.iter()
.zip(output.bboxes.iter())
.enumerate()
{
info!(
"\nImage {}: {}",
idx,
existing_images
.get(idx)
.map(|p| p.display().to_string())
.unwrap_or_else(|| "N/A".to_string())
);
info!(" Structure tokens ({}): {:?}", structure.len(), structure);
info!(" Cell bboxes ({}): {:?}", bboxes.len(), bboxes);
}
if let Some(ref output_dir) = args.output_dir {
std::fs::create_dir_all(output_dir)?;
for (idx, structure) in output.structures.iter().enumerate() {
let structure_html = structure.join("");
let html_stem = existing_images
.get(idx)
.and_then(|path| path.file_stem())
.and_then(|name| name.to_str())
.unwrap_or("table_structure");
let html_path = output_dir.join(format!("{}_{}_structure.html", html_stem, idx));
if let Err(e) = std::fs::write(&html_path, structure_html) {
error!(
"Failed to write structure HTML {}: {}",
html_path.display(),
e
);
} else {
info!("Structure HTML saved to: {}", html_path.display());
}
}
}
Ok(())
}