#![allow(clippy::unwrap_used)]
use captcha_engine::CaptchaModel;
use std::env;
use std::fs;
use std::path::Path;
#[allow(clippy::cast_precision_loss)]
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args: Vec<String> = env::args().collect();
let default_model_path = "crates/captcha-engine/assets/captcha.rten";
let model_path = if args.len() >= 2 {
args[1].clone()
} else if Path::new(default_model_path).exists() {
println!("Using default model: {default_model_path}");
default_model_path.to_string()
} else {
eprintln!("Usage: {} <model_path>", args[0]);
eprintln!("Example: cargo run --bin test_model -- training/captcha.rten");
eprintln!("Default model at '{default_model_path}' not found.");
std::process::exit(1);
};
let test_dir = Path::new("test-captcha");
assert!(
test_dir.exists(),
"test-captcha directory not found at {}",
test_dir.display()
);
println!("Loading model from: {model_path}");
let model = CaptchaModel::load(model_path)?;
println!("Model loaded successfully!\n");
let mut images: Vec<_> = fs::read_dir(test_dir)
.unwrap()
.filter_map(Result::ok)
.filter(|entry| {
entry
.path()
.extension()
.map(|ext| ext.to_string_lossy().to_lowercase())
.is_some_and(|ext| ext == "png" || ext == "jpg" || ext == "jpeg")
})
.collect();
images.sort_by_key(fs::DirEntry::path);
let total = images.len();
let mut correct = 0;
let mut results = Vec::new();
println!("Testing {total} captcha images...\n");
println!("{:30} {:15} {:15} Result", "File", "Expected", "Predicted");
println!("{}", "-".repeat(70));
for entry in &images {
let path = entry.path();
let filename = path.file_stem().unwrap().to_string_lossy();
let expected_label = filename
.strip_prefix("captcha-")
.unwrap_or(&filename)
.to_lowercase();
let predicted = match model.predict_file(&path) {
Ok(p) => p,
Err(e) => {
println!(
"{:30} {:15} {:15} ❌",
path.file_name().unwrap().to_string_lossy(),
expected_label,
format!("ERROR: {}", e)
);
results.push((path.clone(), expected_label, "ERROR".to_string(), false));
continue;
}
};
let is_correct = predicted == expected_label;
if is_correct {
correct += 1;
}
let status = if is_correct { "✅" } else { "❌" };
println!(
"{:30} {:15} {:15} {}",
path.file_name().unwrap().to_string_lossy(),
expected_label,
predicted,
status
);
results.push((path.clone(), expected_label, predicted, is_correct));
}
println!("\n{}", "=".repeat(70));
let accuracy = if total > 0 {
(f64::from(correct) / total as f64) * 100.0
} else {
0.0
};
println!("Results: {correct}/{total} correct ({accuracy:.1}% accuracy)");
let failures: Vec<_> = results.iter().filter(|(_, _, _, ok)| !ok).collect();
if !failures.is_empty() {
println!("\n❌ Failed predictions:");
for (path, expected, predicted, _) in failures {
println!(
" {} → expected '{}', got '{}'",
path.file_name().unwrap().to_string_lossy(),
expected,
predicted
);
}
}
if accuracy < 90.0 {
eprintln!("Accuracy {accuracy:.1}% is below threshold 90.0%");
std::process::exit(1);
}
Ok(())
}