mod utils;
use clap::Parser;
use image::Rgb;
use oar_ocr::predictors::{TableCellDetectionPredictor, TableCellModelVariant};
use oar_ocr::utils::load_image;
use std::path::PathBuf;
use std::time::Instant;
use tracing::{error, info, warn};
use utils::device_config::parse_device_config;
use utils::visualization::{Detection, DetectionVisConfig, save_rgb_image, visualize_detections};
#[derive(Parser)]
#[command(name = "table_cell_detection")]
#[command(about = "Detect table cells using RT-DETR models")]
struct Args {
#[arg(short, long)]
model_path: PathBuf,
#[arg(required = true)]
images: Vec<PathBuf>,
#[arg(short, long)]
output_dir: Option<PathBuf>,
#[arg(long)]
vis: bool,
#[arg(long)]
model_type: Option<String>,
#[arg(long, default_value_t = 0.3)]
score_threshold: f32,
#[arg(long, default_value_t = 300)]
max_cells: usize,
#[arg(long, default_value = "cpu")]
device: String,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
utils::init_tracing();
let args = Args::parse();
info!("Loading table cell detection model: {:?}", args.model_path);
let variant = if let Some(ref model_type) = args.model_type {
parse_model_variant(model_type).map_err(|supported| {
format!(
"Unknown model type '{}'. Supported types: {}",
model_type,
supported.join(", ")
)
})?
} else {
TableCellModelVariant::detect_from_path(&args.model_path).ok_or_else(|| {
format!(
"Could not infer model type from filename '{}'. Specify --model-type explicitly.",
args.model_path.display()
)
})?
};
info!("Detected model type: {}", variant.as_str());
info!("Using device: {}", args.device);
let ort_config = parse_device_config(&args.device)?;
if ort_config.is_some() {
info!("CUDA execution provider configured successfully");
}
let mut predictor_builder = TableCellDetectionPredictor::builder()
.score_threshold(args.score_threshold)
.model_variant(variant);
if let Some(ort_cfg) = ort_config {
predictor_builder = predictor_builder.with_ort_config(ort_cfg);
}
let predictor = predictor_builder.build(&args.model_path)?;
let mut images = Vec::new();
let existing_images: Vec<PathBuf> = args
.images
.iter()
.filter(|path| {
let exists = path.exists();
if !exists {
error!("Image file not found: {}", path.display());
}
exists
})
.cloned()
.collect();
if existing_images.is_empty() {
error!("No valid image files found");
return Err("No valid image files found".into());
}
for image_path in &existing_images {
match load_image(image_path) {
Ok(img) => {
info!(
"Loaded image: {} ({}x{})",
image_path.display(),
img.width(),
img.height()
);
images.push(img);
}
Err(e) => {
error!("Failed to load image {:?}: {}", image_path, e);
continue;
}
}
}
if images.is_empty() {
error!("No images could be loaded for processing");
return Err("No images could be loaded".into());
}
info!("Running table cell detection...");
let start = Instant::now();
let output = predictor.predict(images.clone())?;
let elapsed = start.elapsed();
info!("Detection completed in {:.2?}", elapsed);
for (idx, (image_path, cells)) in existing_images.iter().zip(output.cells.iter()).enumerate() {
info!("\n=== Results for image {} ===", idx + 1);
info!("Image: {}", image_path.display());
info!("Detected {} table cells", cells.len());
if cells.is_empty() {
warn!("No cells detected in this image");
} else {
for (cell_idx, cell) in cells.iter().enumerate() {
if let Some((min_x, min_y, max_x, max_y)) = bbox_bounds(&cell.bbox) {
info!(
" [{}] {}: ({:.0},{:.0})-({:.0},{:.0}), score: {:.3}",
cell_idx, cell.label, min_x, min_y, max_x, max_y, cell.score
);
} else {
info!(
" [{}] {}: <empty bbox>, score: {:.3}",
cell_idx, cell.label, cell.score
);
}
}
}
}
if args.vis {
let output_dir = args
.output_dir
.as_ref()
.ok_or("--output-dir is required when --vis is enabled")?;
std::fs::create_dir_all(output_dir)?;
info!("\nSaving visualizations to: {}", output_dir.display());
let vis_config = DetectionVisConfig::default()
.with_box_color(Rgb([0, 200, 255]))
.with_label_color(Rgb([0, 200, 255]));
for (image_path, rgb_img, cells) in existing_images
.iter()
.zip(images.iter())
.zip(output.cells.iter())
.map(|((path, img), cells)| (path, img, cells))
{
if !cells.is_empty() {
let vis_detections: Vec<Detection> = cells
.iter()
.map(|c| Detection::new(&c.bbox, c.score).with_label(&c.label))
.collect();
let output_filename = image_path
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("result.png");
let output_path = output_dir.join(output_filename);
let visualized = visualize_detections(rgb_img, &vis_detections, &vis_config);
save_rgb_image(&visualized, &output_path)
.map_err(|e| format!("Failed to save visualization: {}", e))?;
info!(" Saved: {}", output_path.display());
} else {
warn!(
" Skipping visualization for {} (no cells detected)",
image_path.display()
);
}
}
}
Ok(())
}
fn parse_model_variant(model_type: &str) -> Result<TableCellModelVariant, Vec<&'static str>> {
TableCellModelVariant::from_model_type(model_type).ok_or_else(|| {
vec![
TableCellModelVariant::RTDetrLWired.as_str(),
TableCellModelVariant::RTDetrLWireless.as_str(),
]
})
}
fn bbox_bounds(bbox: &oar_ocr::processors::BoundingBox) -> Option<(f32, f32, f32, f32)> {
if bbox.points.is_empty() {
return None;
}
let min_x = bbox
.points
.iter()
.map(|p| p.x)
.fold(f32::INFINITY, f32::min);
let min_y = bbox
.points
.iter()
.map(|p| p.y)
.fold(f32::INFINITY, f32::min);
let max_x = bbox
.points
.iter()
.map(|p| p.x)
.fold(f32::NEG_INFINITY, f32::max);
let max_y = bbox
.points
.iter()
.map(|p| p.y)
.fold(f32::NEG_INFINITY, f32::max);
Some((min_x, min_y, max_x, max_y))
}