mod utils;
use clap::Parser;
use oar_ocr::domain::tasks::{TextDetectionConfig, TextRecognitionConfig};
use oar_ocr::oarocr::OAROCRBuilder;
use oar_ocr::processors::LimitType;
use oar_ocr::utils::load_image;
use std::path::PathBuf;
use std::time::Instant;
use tracing::{error, info, warn};
use utils::device_config::parse_device_config;
use utils::visualization::{VisualizationConfig, create_ocr_visualization};
#[derive(Parser)]
#[command(name = "ocr")]
#[command(about = "Run the high-level OCR pipeline (detection + recognition)")]
struct Args {
#[arg(long = "det-model")]
det_model: PathBuf,
#[arg(long = "rec-model")]
rec_model: PathBuf,
#[arg(long = "dict-path")]
dict_path: PathBuf,
#[arg(required = true)]
images: Vec<PathBuf>,
#[arg(long = "document-image-orientation-model")]
document_image_orientation_model: Option<PathBuf>,
#[arg(long = "text-line-orientation-model")]
text_line_orientation_model: Option<PathBuf>,
#[arg(long)]
rectification_model: Option<PathBuf>,
#[arg(long, default_value_t = false)]
return_word_box: bool,
#[arg(long, default_value = "cpu")]
device: String,
#[arg(long, default_value_t = 0.3)]
det_score_thresh: f32,
#[arg(long, default_value_t = 0.6)]
det_box_thresh: f32,
#[arg(long, default_value_t = 1.5)]
det_unclip: f32,
#[arg(long, default_value_t = 1000)]
det_max_candidates: usize,
#[arg(long)]
det_limit_side_len: Option<u32>,
#[arg(long)]
det_limit_type: Option<String>,
#[arg(long)]
det_max_side_len: Option<u32>,
#[arg(long, default_value_t = 0.0)]
rec_score_thresh: f32,
#[arg(long, default_value_t = 100)]
rec_max_text_length: usize,
#[arg(long)]
image_batch_size: Option<usize>,
#[arg(long)]
region_batch_size: Option<usize>,
#[arg(short = 'o', long = "output-dir")]
output_dir: Option<PathBuf>,
#[arg(long)]
vis: bool,
#[arg(long = "vis-font-path")]
vis_font_path: Option<PathBuf>,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
utils::init_tracing();
let args = Args::parse();
info!("Running OCR pipeline example");
if !args.det_model.exists() {
error!("Detection model not found: {}", args.det_model.display());
return Err("Detection model not found".into());
}
if !args.rec_model.exists() {
error!("Recognition model not found: {}", args.rec_model.display());
return Err("Recognition model not found".into());
}
if !args.dict_path.exists() {
error!("Dictionary file not found: {}", args.dict_path.display());
return Err("Dictionary file not found".into());
}
let existing_images: Vec<PathBuf> = args
.images
.iter()
.filter(|path| {
let exists = path.exists();
if !exists {
error!("Image not found: {}", path.display());
}
exists
})
.cloned()
.collect();
if existing_images.is_empty() {
return Err("No valid input images provided".into());
}
let ort_config = parse_device_config(&args.device)?;
let det_config = TextDetectionConfig {
score_threshold: args.det_score_thresh,
box_threshold: args.det_box_thresh,
unclip_ratio: args.det_unclip,
max_candidates: args.det_max_candidates,
limit_side_len: args.det_limit_side_len,
limit_type: args
.det_limit_type
.as_deref()
.map(|s| match s.to_lowercase().as_str() {
"min" => Ok(LimitType::Min),
"max" => Ok(LimitType::Max),
"resize_long" | "resizelong" | "resize-long" => Ok(LimitType::ResizeLong),
other => Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid --det-limit-type value: '{}'", other),
)),
})
.transpose()?,
max_side_len: args.det_max_side_len,
};
let rec_config = TextRecognitionConfig {
score_threshold: args.rec_score_thresh,
max_text_length: args.rec_max_text_length,
};
let mut builder = OAROCRBuilder::new(&args.det_model, &args.rec_model, &args.dict_path)
.text_detection_config(det_config)
.text_recognition_config(rec_config)
.return_word_box(args.return_word_box);
if let Some(config) = ort_config.clone() {
builder = builder.ort_session(config);
}
if let Some(size) = args.image_batch_size {
builder = builder.image_batch_size(size);
}
if let Some(size) = args.region_batch_size {
builder = builder.region_batch_size(size);
}
let build_start = Instant::now();
let ocr = builder.build()?;
info!(
"OCR pipeline built in {:.2}ms",
build_start.elapsed().as_secs_f64() * 1000.0
);
let mut images = Vec::new();
for path in &existing_images {
match load_image(path) {
Ok(img) => {
info!(
"Loaded image {} ({}x{})",
path.display(),
img.width(),
img.height()
);
images.push(img);
}
Err(err) => warn!("Failed to load {}: {}", path.display(), err),
}
}
if images.is_empty() {
return Err("No images could be loaded".into());
}
let start = Instant::now();
let results = ocr.predict(images)?;
info!(
"OCR completed in {:.2}ms",
start.elapsed().as_secs_f64() * 1000.0
);
info!("\n=== OCR Results ===");
for (idx, (path, result)) in existing_images.iter().zip(results.iter()).enumerate() {
info!("\nImage {}: {}", idx + 1, path.display());
if let Some(angle) = result.orientation_angle {
info!(" Overall image orientation: {} degrees", angle);
}
info!(" {} text regions", result.text_regions.len());
for (region_idx, region) in result.text_regions.iter().enumerate() {
let bbox = ®ion.bounding_box;
let text = region
.text
.as_ref()
.map(|t| t.to_string())
.unwrap_or_else(|| "<no text>".to_string());
let score = region.confidence.unwrap_or(0.0) * 100.0;
let line_orientation = region
.orientation_angle
.map_or("N/A".to_string(), |a| format!("{:.1}°", a));
info!(
" [{}] \"{}\" ({:.1}%) at [{:.1},{:.1}] - [{:.1},{:.1}] (Line Orientation: {})",
region_idx + 1,
text,
score,
bbox.x_min(),
bbox.y_min(),
bbox.x_max(),
bbox.y_max(),
line_orientation
);
if let Some(word_boxes) = ®ion.word_boxes {
info!(" {} word boxes", word_boxes.len());
if let Some(full_text) = region.text.as_ref() {
let chars: Vec<char> = full_text.chars().collect();
for (i, word_bbox) in word_boxes.iter().enumerate() {
if let Some(char_content) = chars.get(i) {
info!(
" Word Box {}: '{}' at [{:.1},{:.1}] - [{:.1},{:.1}]",
i + 1,
char_content,
word_bbox.x_min(),
word_bbox.y_min(),
word_bbox.x_max(),
word_bbox.y_max()
);
}
}
}
}
}
}
if args.vis {
let output_dir = args
.output_dir
.as_ref()
.ok_or("--output-dir is required when --vis is enabled")?;
std::fs::create_dir_all(output_dir)?;
let vis_config = if let Some(font_path) = args.vis_font_path {
VisualizationConfig::with_font_path(&font_path).unwrap_or_else(|err| {
warn!(
"Failed to load font {}: {}. Falling back to system font.",
font_path.display(),
err
);
VisualizationConfig::with_system_font()
})
} else {
VisualizationConfig::with_system_font()
};
info!("\nSaving visualizations to: {}", output_dir.display());
for (path, result) in existing_images.iter().zip(results.iter()) {
let vis_img = create_ocr_visualization(result, &vis_config)?;
let filename = path
.file_name()
.and_then(|s| s.to_str())
.map(|s| s.to_string())
.unwrap_or_else(|| "visualization.jpg".to_string());
let output_path = output_dir.join(filename);
vis_img.save(&output_path)?;
info!(" Saved: {}", output_path.display());
}
}
Ok(())
}