mod utils;
use clap::Parser;
use oar_ocr::predictors::DocumentRectificationPredictor;
use oar_ocr::utils::load_image;
use std::path::PathBuf;
use std::time::Instant;
use tracing::{error, info};
use utils::device_config::parse_device_config;
use utils::visualization::{load_system_font, save_rgb_image};
#[derive(Parser)]
#[command(name = "document_rectification")]
#[command(about = "Document Rectification Example - corrects distortions in document images")]
struct Args {
#[arg(short, long)]
model_path: PathBuf,
#[arg(required = true)]
images: Vec<PathBuf>,
#[arg(short, long)]
output_dir: PathBuf,
#[arg(long)]
vis: bool,
#[arg(long, default_value = "cpu")]
device: String,
#[arg(long, default_value = "0")]
input_height: usize,
#[arg(long, default_value = "0")]
input_width: usize,
#[arg(short, long)]
verbose: bool,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
utils::init_tracing();
let args = Args::parse();
info!("Document Rectification Example");
if !args.model_path.exists() {
error!("Model file not found: {}", args.model_path.display());
return Err("Model file not found".into());
}
let existing_images: Vec<PathBuf> = args
.images
.iter()
.filter(|path| {
let exists = path.exists();
if !exists {
error!("Image file not found: {}", path.display());
}
exists
})
.cloned()
.collect();
if existing_images.is_empty() {
error!("No valid image files found");
return Err("No valid image files found".into());
}
info!("Using device: {}", args.device);
let ort_config = parse_device_config(&args.device)?.unwrap_or_default();
if ort_config.execution_providers.is_some() {
info!("CUDA execution provider configured successfully");
}
if args.verbose {
info!("Building rectification predictor...");
info!(" Model: {}", args.model_path.display());
if args.input_height > 0 && args.input_width > 0 {
info!(
" Input shape override: [3, {}, {}]",
args.input_height, args.input_width
);
} else {
info!(" Input shape override: dynamic (use original image size)");
}
}
let predictor = DocumentRectificationPredictor::builder()
.with_ort_config(ort_config)
.build(&args.model_path)?;
info!("Rectification predictor built successfully");
info!("Processing {} images...", existing_images.len());
let mut images = Vec::new();
for image_path in &existing_images {
match load_image(image_path) {
Ok(rgb_img) => {
if args.verbose {
info!(
"Loaded image: {} ({}x{})",
image_path.display(),
rgb_img.width(),
rgb_img.height()
);
}
images.push(rgb_img);
}
Err(e) => {
error!("Failed to load image {}: {}", image_path.display(), e);
continue;
}
}
}
if images.is_empty() {
error!("No images could be loaded for processing");
return Err("No images could be loaded".into());
}
info!("Running document rectification...");
let start = Instant::now();
let output = predictor.predict(images.clone())?;
let duration = start.elapsed();
info!(
"Rectification completed in {:.2}ms ({:.2}ms per image)",
duration.as_secs_f64() * 1000.0,
duration.as_secs_f64() * 1000.0 / existing_images.len() as f64
);
info!("\n=== Rectification Results ===");
for (idx, (image_path, original_img, rectified_img)) in existing_images
.iter()
.zip(images.iter())
.zip(output.images.iter())
.map(|((path, orig), rect)| (path, orig, rect))
.enumerate()
{
info!("\nImage {}: {}", idx + 1, image_path.display());
info!(
" Original size: {}x{}",
original_img.width(),
original_img.height()
);
info!(
" Rectified size: {}x{}",
rectified_img.width(),
rectified_img.height()
);
}
std::fs::create_dir_all(&args.output_dir)?;
info!(
"\nSaving rectified images to: {}",
args.output_dir.display()
);
for (image_path, original_img, rectified_img) in existing_images
.iter()
.zip(images.iter())
.zip(output.images.iter())
.map(|((path, orig), rect)| (path, orig, rect))
{
let output_filename = image_path
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("unknown.jpg");
let rectified_path = args.output_dir.join(output_filename);
save_rgb_image(rectified_img, &rectified_path)
.map_err(|e| format!("Failed to save rectified image: {}", e))?;
info!(" Saved rectified: {}", rectified_path.display());
if args.vis {
let input_filename = image_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown");
let comparison_filename = format!("{}_comparison.jpg", input_filename);
let comparison_path = args.output_dir.join(&comparison_filename);
let comparison_img = create_comparison_image(original_img, rectified_img);
save_rgb_image(&comparison_img, &comparison_path)
.map_err(|e| format!("Failed to save comparison image: {}", e))?;
info!(" Saved comparison: {}", comparison_path.display());
}
}
info!("\nRectification complete!");
Ok(())
}
fn create_comparison_image(
original: &image::RgbImage,
rectified: &image::RgbImage,
) -> image::RgbImage {
use image::{Rgb, RgbImage};
use imageproc::drawing::draw_text_mut;
let max_height = original.height().max(rectified.height());
let total_width = original.width() + rectified.width() + 20; let padding = 10;
let mut output = RgbImage::from_pixel(total_width, max_height, Rgb([255, 255, 255]));
for y in 0..original.height() {
for x in 0..original.width() {
output.put_pixel(x, y, *original.get_pixel(x, y));
}
}
let x_offset = original.width() + padding * 2;
for y in 0..rectified.height() {
for x in 0..rectified.width() {
output.put_pixel(x + x_offset, y, *rectified.get_pixel(x, y));
}
}
if let Some(font) = load_system_font() {
let text_color = Rgb([0u8, 0u8, 0u8]);
draw_text_mut(&mut output, text_color, 10, 10, 24.0, &font, "Original");
draw_text_mut(
&mut output,
text_color,
(x_offset + 10) as i32,
10,
24.0,
&font,
"Rectified",
);
}
output
}