mod utils;
use clap::Parser;
use oar_ocr::predictors::FormulaRecognitionPredictor;
use oar_ocr::utils::load_image;
use std::path::PathBuf;
use std::time::Instant;
use tracing::{error, info, warn};
use utils::device_config::parse_device_config;
use utils::visualization::{ClassificationVisConfig, save_rgb_image, visualize_classification};
#[derive(Parser)]
#[command(name = "formula_recognition")]
#[command(about = "Formula Recognition Example - recognizes mathematical formulas in images")]
struct Args {
#[arg(short, long)]
model_path: PathBuf,
#[arg(short, long)]
tokenizer_path: PathBuf,
#[arg(required = true)]
images: Vec<PathBuf>,
#[arg(short, long)]
output_dir: Option<PathBuf>,
#[arg(long)]
vis: bool,
#[arg(long, default_value = "cpu")]
device: String,
#[arg(long, default_value = "0.0")]
score_thresh: f32,
#[arg(long, default_value = "0")]
target_width: u32,
#[arg(long, default_value = "0")]
target_height: u32,
#[arg(long, default_value = "1536")]
max_length: usize,
#[arg(long, default_value = "FormulaRecognition")]
model_name: String,
#[arg(short, long)]
verbose: bool,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
utils::init_tracing();
let args = Args::parse();
info!("Formula Recognition Example");
if !args.model_path.exists() {
error!("Model file not found: {}", args.model_path.display());
return Err("Model file not found".into());
}
let existing_images: Vec<PathBuf> = args
.images
.iter()
.filter(|path| {
let exists = path.exists();
if !exists {
error!("Image file not found: {}", path.display());
}
exists
})
.cloned()
.collect();
if existing_images.is_empty() {
error!("No valid image files found");
return Err("No valid image files found".into());
}
info!("Using device: {}", args.device);
let ort_config = parse_device_config(&args.device)?.unwrap_or_default();
if ort_config.execution_providers.is_some() {
info!("CUDA execution provider configured successfully");
}
if args.verbose {
info!("Formula Recognition Configuration:");
info!(" Score threshold: {}", args.score_thresh);
info!(" Max formula length: {}", args.max_length);
if args.target_width > 0 && args.target_height > 0 {
info!(
" Target size override: {}x{}",
args.target_width, args.target_height
);
} else {
info!(" Target size: auto-detect from model input");
}
}
if args.verbose {
info!("Building formula recognition predictor...");
info!(" Model: {}", args.model_path.display());
info!(" Tokenizer: {}", args.tokenizer_path.display());
}
let predictor = FormulaRecognitionPredictor::builder()
.score_threshold(args.score_thresh)
.model_name(&args.model_name)
.tokenizer_path(&args.tokenizer_path)
.with_ort_config(ort_config)
.build(&args.model_path)?;
info!("Formula recognition predictor built successfully");
info!("Processing {} images...", existing_images.len());
let mut images = Vec::new();
for image_path in &existing_images {
match load_image(image_path) {
Ok(rgb_img) => {
if args.verbose {
info!(
"Loaded image: {} ({}x{})",
image_path.display(),
rgb_img.width(),
rgb_img.height()
);
}
images.push(rgb_img);
}
Err(e) => {
error!("Failed to load image {}: {}", image_path.display(), e);
continue;
}
}
}
if images.is_empty() {
error!("No images could be loaded for processing");
return Err("No images could be loaded".into());
}
info!("Running formula recognition...");
let start = Instant::now();
let output = predictor.predict(images.clone())?;
let duration = start.elapsed();
info!(
"Recognition completed in {:.2}ms",
duration.as_secs_f64() * 1000.0
);
info!("\n=== Formula Recognition Results ===");
for (idx, (image_path, formula, score)) in existing_images
.iter()
.zip(output.formulas.iter())
.zip(output.scores.iter())
.map(|((path, formula), score)| (path, formula, score))
.enumerate()
{
info!("\nImage {}: {}", idx + 1, image_path.display());
if formula.is_empty() {
warn!(" No formula recognized (below threshold or invalid)");
} else {
info!(" LaTeX: {}", formula);
if let Some(s) = score {
info!(" Confidence: {:.2}%", s * 100.0);
}
}
}
if args.vis {
let output_dir = args
.output_dir
.as_ref()
.ok_or("--output-dir is required when --vis is enabled")?;
std::fs::create_dir_all(output_dir)?;
info!("\nSaving visualizations to: {}", output_dir.display());
let vis_config = ClassificationVisConfig::default();
for (image_path, rgb_img, formula, score) in existing_images
.iter()
.zip(images.iter())
.zip(output.formulas.iter())
.zip(output.scores.iter())
.map(|(((path, img), formula), score)| (path, img, formula, score))
{
if !formula.is_empty() {
let output_filename = image_path
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("unknown.jpg");
let output_path = output_dir.join(output_filename);
let display_formula = if let Some((idx, _)) = formula.char_indices().nth(50) {
format!("{}...", &formula[..idx])
} else {
formula.clone()
};
let confidence = score.unwrap_or(1.0);
let visualized = visualize_classification(
rgb_img,
&display_formula,
confidence,
"LaTeX",
&vis_config,
);
save_rgb_image(&visualized, &output_path)
.map_err(|e| format!("Failed to save visualization: {}", e))?;
info!(" Saved: {}", output_path.display());
} else {
warn!(
" Skipping visualization for {} (no formula recognized)",
image_path.display()
);
}
}
}
Ok(())
}