use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::thread;
use std::time::Instant;
use anyhow::Context;
use clap::Parser;
use quasivision::config::Config;
use quasivision::element;
use quasivision::export;
use quasivision::object_detector;
use quasivision::pipeline;
use quasivision::text_detection;
#[derive(Parser, Debug)]
#[command(name = "quasivision", version, about)]
struct Args {
#[arg(short, long)]
input: String,
#[arg(short, long, default_value = "output")]
output: String,
#[arg(long, default_value_t = 4)]
gradient: u8,
#[arg(long, default_value_t = 55)]
min_area: u32,
#[arg(long, action = clap::ArgAction::Set, default_value_t = false)]
paragraph: bool,
#[arg(long, action = clap::ArgAction::Set, default_value_t = true)]
remove_bar: bool,
#[arg(long, action = clap::ArgAction::Set, default_value_t = true)]
sub_component: bool,
#[arg(long, action = clap::ArgAction::Set, default_value_t = true)]
ocr: bool,
#[arg(long, action = clap::ArgAction::Set, default_value_t = false)]
synthesize_text: bool,
#[arg(long, action = clap::ArgAction::Set, default_value_t = false)]
recursive: bool,
#[arg(long, default_value = "png,jpg,jpeg,jfif")]
extensions: String,
#[arg(long, default_value_t = 8)]
line_thickness: u32,
#[arg(long, default_value_t = 0.95)]
line_min_length: f64,
#[arg(long, default_value_t = 0.7)]
rec_evenness: f64,
#[arg(long, default_value_t = 0.25)]
rec_dent: f64,
#[arg(long, default_value_t = 0.08)]
rec_corner_skip: f64,
#[arg(long, default_value_t = 0.15)]
block_side: f64,
#[arg(long, default_value_t = 5)]
block_grad: u8,
#[arg(long, default_value_t = 0.08)]
text_max_h: f64,
#[arg(long, default_value_t = 10)]
text_gap: u32,
#[arg(long, action = clap::ArgAction::Set, default_value_t = true)]
icon_classify: bool,
#[arg(long, action = clap::ArgAction::Set, default_value_t = true)]
object_detect: bool,
#[arg(long, default_value_t = 0.2)]
detect_conf: f32,
#[arg(long, default_value = "resources")]
models_dir: String,
}
struct RunOptions {
cfg: Config,
is_paragraph: bool,
is_remove_bar: bool,
enable_sub_component: bool,
enable_ocr: bool,
enable_synthesize: bool,
enable_icon_classify: bool,
enable_object_detect: bool,
detect_conf: f32,
models_dir: String,
}
impl From<&Args> for RunOptions {
fn from(args: &Args) -> Self {
let cfg = Config {
gradient_threshold: args.gradient,
obj_min_area: args.min_area,
rec_min_evenness: args.rec_evenness,
rec_max_dent_ratio: args.rec_dent,
rec_corner_skip_ratio: args.rec_corner_skip,
line_thickness: args.line_thickness,
line_min_length_ratio: args.line_min_length,
text_max_word_gap: args.text_gap,
text_max_height: args.text_max_h,
block_side_length: args.block_side,
block_gradient_threshold: args.block_grad,
..Config::default()
};
Self {
cfg,
is_paragraph: args.paragraph,
is_remove_bar: args.remove_bar,
enable_sub_component: args.sub_component,
enable_ocr: args.ocr,
enable_synthesize: args.synthesize_text,
enable_icon_classify: args.icon_classify,
enable_object_detect: args.object_detect,
detect_conf: args.detect_conf,
models_dir: args.models_dir.clone(),
}
}
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let opts = RunOptions::from(&args);
let input_path = Path::new(&args.input);
if input_path.is_dir() {
println!("[Batch Mode] Processing directory: {}", args.input);
let extensions: Vec<&str> = args.extensions.split(',').map(|s| s.trim()).collect();
let entries = collect_images(input_path, &extensions, args.recursive)?;
if entries.is_empty() {
println!("[Batch] No images found in: {}", args.input);
return Ok(());
}
println!("[Batch] Found {} images to process", entries.len());
let mut success_count = 0;
let mut fail_count = 0;
for entry in &entries {
let entry_str = entry.as_str();
println!("\n{} Processing: {}", "=".repeat(60), entry_str);
match run_pipeline(entry_str, &args.output, &opts) {
Ok(_) => success_count += 1,
Err(e) => {
eprintln!("[ERROR] Failed to process {}: {}", entry_str, e);
fail_count += 1;
}
}
}
println!("\n{} Batch Complete {}", "=".repeat(25), "=".repeat(25));
println!(
" Total: {}, Success: {}, Failed: {}",
entries.len(),
success_count,
fail_count
);
} else if input_path.is_file() {
run_pipeline(&args.input, &args.output, &opts)?;
} else {
eprintln!("[ERROR] Input path does not exist: {}", args.input);
}
Ok(())
}
fn collect_images(dir: &Path, extensions: &[&str], recursive: bool) -> anyhow::Result<Vec<String>> {
let mut images = Vec::new();
for entry in
fs::read_dir(dir).with_context(|| format!("Failed to read directory: {}", dir.display()))?
{
let entry = entry?;
let path = entry.path();
if path.is_file() {
if let Some(ext) = path.extension() {
let ext_lower = ext.to_str().unwrap_or("").to_lowercase();
if extensions.contains(&ext_lower.as_str()) {
images.push(path.to_string_lossy().to_string());
}
}
} else if recursive && path.is_dir() {
let sub_images = collect_images(&path, extensions, true)?;
images.extend(sub_images);
}
}
images.sort();
Ok(images)
}
fn run_pipeline(img_path: &str, output_root: &str, opts: &RunOptions) -> anyhow::Result<()> {
let t0 = Instant::now();
let models_dir = opts.models_dir.trim_end_matches('/');
let pipeline_cfg = pipeline::PipelineConfig {
ui_config: opts.cfg.clone(),
models_dir: opts.models_dir.clone(),
paragraph: opts.is_paragraph,
remove_bar: opts.is_remove_bar,
sub_component: opts.enable_sub_component,
synthesize_text: opts.enable_synthesize,
detect_model_path: format!("{}/object-detection/yoloe-26n-seg.onnx", models_dir),
detect_labels_path: format!("{}/object-detection/yoloe-26n_classes.txt", models_dir),
detect_conf: opts.detect_conf,
};
if let Err(e) = quasivision::init_models(models_dir) {
eprintln!(" [Warning] Model init failed (partial functionality):\n {e}");
}
let t_step = Instant::now();
println!("[Step 1/8] Reading image: {}", img_path);
let (img, _gray) = pipeline_cfg
.read_image(img_path)
.with_context(|| format!("Failed to read image: {}", img_path))?;
let img_shape = (img.height(), img.width());
println!(
" → Image size: {} x {} ({:.1}ms)",
img_shape.1,
img_shape.0,
t_step.elapsed().as_secs_f64() * 1000.0
);
let ocr_handle = if opts.enable_ocr {
let img_for_ocr = img.clone();
println!(" → OCR thread spawned (running in background)");
Some(thread::spawn(move || {
text_detection::detect_text(&img_for_ocr)
}))
} else {
None
};
let object_detect_handle = if opts.enable_object_detect {
let img_for_detect = img.clone();
let model = pipeline_cfg.detect_model_path.clone();
let labels = pipeline_cfg.detect_labels_path.clone();
let conf = opts.detect_conf;
println!(" → Object detection thread spawned (running in background)");
Some(thread::spawn(move || {
object_detector::run_object_detection(&img_for_detect, &model, &labels, conf)
}))
} else {
None
};
let t_step = Instant::now();
println!("[Step 2-4/8] Component detection & classification...");
let comps = pipeline_cfg.detect_components(&img)?;
let mut class_counts: HashMap<String, u32> = HashMap::new();
for c in &comps {
*class_counts.entry(c.category.clone()).or_insert(0) += 1;
}
for (k, v) in &class_counts {
println!(" → {}: {}", k, v);
}
println!(
" → Total: {} components ({:.1}ms)",
comps.len(),
t_step.elapsed().as_secs_f64() * 1000.0
);
let t_step = Instant::now();
println!("[Step 5/8] Text detection...");
let text_result = if let Some(handle) = ocr_handle {
let result = handle.join().expect("OCR thread panicked");
println!(
" → OCR result ready (waited {:.1}ms)",
t_step.elapsed().as_secs_f64() * 1000.0
);
result
} else {
println!(" → OCR disabled by user");
text_detection::TextResult { texts: Vec::new() }
};
println!(" → Found {} text elements", text_result.texts.len());
let t_step = Instant::now();
println!("[Step 6/8] Merging components and texts...");
let mut elements = pipeline_cfg.merge(&img, &comps, &text_result)?;
println!(
" → Final: {} elements ({:.1}ms)",
elements.len(),
t_step.elapsed().as_secs_f64() * 1000.0
);
let t_sub = Instant::now();
element::compute_prominence(&mut elements);
let prominent_count = elements
.iter()
.filter(|e| e.prominence.map_or(false, |p| p >= 0.5))
.count();
println!(
" → 6b2/8. Prominence computed: {} prominent (≥0.5), {} total ({:.1}ms)",
prominent_count,
elements.len(),
t_sub.elapsed().as_secs_f64() * 1000.0
);
if opts.enable_icon_classify {
let t_sub = Instant::now();
let icon_count = elements.iter().filter(|e| e.class == "Icon").count();
if icon_count > 0 {
match pipeline_cfg.classify_icons(&img, &mut elements) {
Ok(_) => {
println!(
" → 6c/8. Icon classification done ({:.1}ms)",
t_sub.elapsed().as_secs_f64() * 1000.0
);
}
Err(e) => {
eprintln!(
" → 6c/8. Failed to initialize IconClassifier: {} (skipping)",
e
);
}
}
} else {
println!(
" → 6c/8. No icons to classify ({:.1}ms)",
t_sub.elapsed().as_secs_f64() * 1000.0
);
}
}
let object_detections = if let Some(handle) = object_detect_handle {
let t_sub = Instant::now();
println!(" → 6d/8. Waiting for object detection...");
let detections = handle.join().expect("Object detection thread panicked");
println!(
" → 6d/8. Object detection: {} objects found ({:.1}ms)",
detections.len(),
t_sub.elapsed().as_secs_f64() * 1000.0
);
detections
} else {
Vec::new()
};
let t_step = Instant::now();
println!("[Step 8/8] Exporting results...");
let img_name = Path::new(img_path)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("output");
let output_dir = Path::new(output_root).join(img_name);
fs::create_dir_all(&output_dir)?;
{
let json_path = output_dir.join("elements.tree.json");
export::save_tree_json(&elements, img_shape, json_path.to_str().unwrap())?;
let txt_path = output_dir.join("elements.tree.txt");
export::save_tree_text(&elements, img_shape, txt_path.to_str().unwrap())?;
}
let detection_roots = object_detector::build_detection_tree(&object_detections);
if !detection_roots.is_empty() {
let det_tree_path = output_dir.join("objects.tree.json");
export::save_detection_tree_json(
&detection_roots,
img_shape,
det_tree_path.to_str().unwrap(),
)?;
let det_txt_path = output_dir.join("objects.tree.txt");
export::save_detection_tree_text(
&detection_roots,
img_shape,
det_txt_path.to_str().unwrap(),
)?;
} else {
let det_txt_path = output_dir.join("objects.tree.txt");
export::save_detection_tree_text(
&detection_roots,
img_shape,
det_txt_path.to_str().unwrap(),
)?;
}
let vis_path = output_dir.join("visualization.jpg");
export::save_visualization(&img, &elements, vis_path.to_str().unwrap())?;
if !detection_roots.is_empty() {
let det_vis_path = output_dir.join("objects.jpg");
export::save_object_detection_visualization(
&img,
&detection_roots,
det_vis_path.to_str().unwrap(),
)?;
}
println!(
" → Export done ({:.1}ms)",
t_step.elapsed().as_secs_f64() * 1000.0
);
let mut compo_count = 0;
let mut text_count = 0;
let mut block_count = 0;
let mut btn_count = 0;
let mut img_count = 0;
let mut icon_count = 0;
for e in &elements {
match e.class.as_str() {
"Text" => text_count += 1,
"Block" => block_count += 1,
"Button" => btn_count += 1,
"Image" => img_count += 1,
"Icon" => icon_count += 1,
_ => compo_count += 1,
}
}
println!();
println!("=== Result Summary ===");
println!(" Blocks: {}", block_count);
println!(" Buttons: {}", btn_count);
println!(" Icons: {}", icon_count);
println!(" Images: {}", img_count);
println!(" Texts: {}", text_count);
println!(" Components: {}", compo_count);
println!(" Total: {}", elements.len());
if !object_detections.is_empty() {
println!(" Objects detected: {}", object_detections.len());
}
println!(" Output: {}", output_dir.display());
println!(" ⏱ Total time: {:.1}s", t0.elapsed().as_secs_f64());
Ok(())
}