mod utils;
use clap::Parser;
use oar_ocr::domain::tasks::layout_detection::LayoutDetectionConfig;
use oar_ocr::predictors::LayoutDetectionPredictor;
use oar_ocr::utils::load_image;
use serde_json::json;
use std::path::PathBuf;
use std::time::Instant;
use tracing::{error, info, warn};
use utils::device_config::parse_device_config;
use utils::visualization::{LayoutItem, save_image, visualize_layout};
use std::fs;
#[derive(Parser)]
#[command(name = "layout_detection")]
#[command(about = "Layout Detection Example - detects document structure elements")]
struct Args {
#[arg(short, long)]
model_path: PathBuf,
#[arg(required = true)]
images: Vec<PathBuf>,
#[arg(short, long)]
output_dir: Option<PathBuf>,
#[arg(long)]
vis: bool,
#[arg(long)]
model_name: Option<String>,
#[arg(long)]
score_threshold: Option<f32>,
#[arg(long, default_value_t = 100)]
max_elements: usize,
#[arg(long)]
dump_json: bool,
#[arg(long, default_value = "cpu")]
device: String,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
utils::init_tracing();
let args = Args::parse();
info!("Loading layout detection model: {:?}", args.model_path);
let model_type = if let Some(ref mn) = args.model_name {
mn.clone()
} else {
let filename = args
.model_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("")
.to_lowercase();
if filename.contains("pp-docblocklayout") || filename.contains("pp_docblocklayout") {
"pp_docblocklayout".to_string()
} else if filename.contains("pp-doclayout_plus-l")
|| filename.contains("pp_doclayout_plus_l")
{
"pp_doclayout_plus_l".to_string()
} else if filename.contains("pp-doclayout-l") || filename.contains("pp_doclayout_l") {
"pp_doclayout_l".to_string()
} else if filename.contains("pp-doclayout-m") || filename.contains("pp_doclayout_m") {
"pp_doclayout_m".to_string()
} else if filename.contains("pp-doclayout-s") || filename.contains("pp_doclayout_s") {
"pp_doclayout_s".to_string()
} else if filename.contains("picodet-s_layout_3cls")
|| filename.contains("picodet_s_layout_3cls")
{
"picodet_s_layout_3cls".to_string()
} else if filename.contains("picodet-l_layout_3cls")
|| filename.contains("picodet_l_layout_3cls")
{
"picodet_l_layout_3cls".to_string()
} else if filename.contains("picodet-s_layout_17cls")
|| filename.contains("picodet_s_layout_17cls")
{
"picodet_s_layout_17cls".to_string()
} else if filename.contains("picodet-l_layout_17cls")
|| filename.contains("picodet_l_layout_17cls")
{
"picodet_l_layout_17cls".to_string()
} else if filename.contains("picodet_layout_1x_table") {
"picodet_layout_1x_table".to_string()
} else if filename.contains("rt-detr-h_layout_3cls")
|| filename.contains("rtdetr_h_layout_3cls")
{
"rtdetr_h_layout_3cls".to_string()
} else if filename.contains("rt-detr-h_layout_17cls")
|| filename.contains("rtdetr_h_layout_17cls")
{
"rtdetr_h_layout_17cls".to_string()
} else if filename.contains("picodet_layout_1x") {
"picodet_layout_1x".to_string()
} else {
warn!(
"Could not auto-detect model type from filename '{}', using default 'picodet_layout_1x'",
filename
);
"picodet_layout_1x".to_string()
}
};
info!("Detected model type: {}", model_type);
info!("Using device: {}", args.device);
let ort_config = parse_device_config(&args.device)?;
if ort_config.is_some() {
info!("CUDA execution provider configured successfully");
}
let mut base_config = match model_type.as_str() {
"pp_doclayoutv2" => LayoutDetectionConfig::with_pp_doclayoutv2_defaults(),
"pp_doclayoutv3" => LayoutDetectionConfig::with_pp_doclayoutv3_defaults(),
_ => LayoutDetectionConfig::default(),
};
base_config.max_elements = args.max_elements;
if let Some(threshold) = args.score_threshold {
base_config.score_threshold = threshold;
}
let mut predictor_builder = LayoutDetectionPredictor::builder()
.with_config(base_config)
.model_name(model_type);
if let Some(ort_cfg) = ort_config {
predictor_builder = predictor_builder.with_ort_config(ort_cfg);
}
let predictor = predictor_builder.build(&args.model_path)?;
if let Some(ref output_dir) = args.output_dir {
fs::create_dir_all(output_dir)?;
}
for (img_idx, image_path) in args.images.iter().enumerate() {
info!(
"Processing image {}/{}: {:?}",
img_idx + 1,
args.images.len(),
image_path
);
let img = match load_image(image_path) {
Ok(img) => img,
Err(e) => {
error!("Failed to load image {:?}: {}", image_path, e);
continue;
}
};
let (width, height) = (img.width(), img.height());
info!("Image size: {}x{}", width, height);
let img_for_vis = img.clone();
let start = Instant::now();
let output = match predictor.predict(vec![img]) {
Ok(output) => output,
Err(e) => {
error!("Layout detection failed for {:?}: {}", image_path, e);
continue;
}
};
let duration = start.elapsed();
info!("Detection completed in {:.2?}", duration);
if !output.elements.is_empty() {
let elements = &output.elements[0];
info!("Detected {} layout elements", elements.len());
let mut type_counts = std::collections::HashMap::new();
for element in elements {
*type_counts.entry(element.element_type.clone()).or_insert(0) += 1;
}
info!("Layout element summary:");
for (element_type, count) in type_counts {
let type_name = format_element_type(&element_type);
info!(" {}: {}", type_name, count);
}
for (idx, element) in elements.iter().enumerate() {
let type_name = format_element_type(&element.element_type);
let bbox = &element.bbox;
if !bbox.points.is_empty() {
let min_x = bbox
.points
.iter()
.map(|p| p.x)
.fold(f32::INFINITY, f32::min);
let min_y = bbox
.points
.iter()
.map(|p| p.y)
.fold(f32::INFINITY, f32::min);
let max_x = bbox
.points
.iter()
.map(|p| p.x)
.fold(f32::NEG_INFINITY, f32::max);
let max_y = bbox
.points
.iter()
.map(|p| p.y)
.fold(f32::NEG_INFINITY, f32::max);
info!(
" [{}] {}: ({:.0},{:.0})-({:.0},{:.0}), score: {:.3}",
idx, type_name, min_x, min_y, max_x, max_y, element.score
);
}
}
if args.dump_json {
let elements_json: Vec<_> = elements
.iter()
.enumerate()
.map(|(idx, element)| {
let bbox = &element.bbox;
json!({
"order": idx + 1,
"label": element.element_type.as_str(),
"score": element.score,
"bbox": [
bbox.x_min().round() as i32,
bbox.y_min().round() as i32,
bbox.x_max().round() as i32,
bbox.y_max().round() as i32
]
})
})
.collect();
let payload = json!({
"image": image_path,
"width": width,
"height": height,
"elements": elements_json
});
println!("{}", serde_json::to_string_pretty(&payload)?);
}
if args.vis {
let output_dir = args
.output_dir
.as_ref()
.ok_or("--output-dir is required when --vis is enabled")?;
let output_filename = image_path
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("result.png");
let output_path = output_dir.join(output_filename);
let layout_items: Vec<LayoutItem> = elements
.iter()
.map(|e| LayoutItem {
bbox: &e.bbox,
element_type: &e.element_type,
score: e.score,
})
.collect();
let vis_img = visualize_layout(&img_for_vis, &layout_items, 2, true);
save_image(&vis_img, &output_path)
.map_err(|e| format!("Failed to save visualization: {}", e))?;
info!("Visualization saved to: {:?}", output_path);
}
} else {
warn!("No layout elements detected in {:?}", image_path);
}
}
Ok(())
}
fn format_element_type(element_type: &str) -> String {
if element_type.is_empty() {
"Unknown".to_string()
} else {
let mut chars = element_type.chars();
match chars.next() {
None => "Unknown".to_string(),
Some(first) => first.to_uppercase().chain(chars).collect(),
}
}
}