use super::{MatchConfig, MatchResult};
use crate::error::Result;
use image::{DynamicImage, GrayImage, ImageBuffer, Luma, imageops::FilterType};
use imageproc::template_matching::{MatchTemplateMethod, match_template};
pub struct ImageMatcher;
impl ImageMatcher {
pub fn find(
screen: &DynamicImage,
template: &DynamicImage,
config: &MatchConfig,
) -> Result<Option<MatchResult>> {
let results = Self::find_all(screen, template, config)?;
Ok(results.into_iter().next())
}
pub fn find_all(
screen: &DynamicImage,
template: &DynamicImage,
config: &MatchConfig,
) -> Result<Vec<MatchResult>> {
let screen_gray = screen.to_luma8();
let template_gray = template.to_luma8();
let original_width = template_gray.width();
let original_height = template_gray.height();
let scales =
if config.search_multiple_scales { config.scale_steps.clone() } else { vec![1.0] };
let mut all_matches = Vec::new();
for scale in scales {
let matches = Self::find_at_scale(
&screen_gray,
&template_gray,
scale,
original_width,
original_height,
config,
)?;
all_matches.extend(matches);
}
let matches = Self::non_max_suppression(all_matches, 0.5);
let mut matches: Vec<_> =
matches.into_iter().filter(|m| m.confidence >= config.confidence).collect();
matches.sort_by(|a, b| {
b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
});
matches.truncate(config.limit);
Ok(matches)
}
fn find_at_scale(
screen: &GrayImage,
template: &GrayImage,
scale: f32,
original_width: u32,
original_height: u32,
config: &MatchConfig,
) -> Result<Vec<MatchResult>> {
let scaled_width = ((template.width() as f32) * scale).round() as u32;
let scaled_height = ((template.height() as f32) * scale).round() as u32;
if scaled_width == 0 || scaled_height == 0 {
return Ok(vec![]);
}
if scaled_width > screen.width() || scaled_height > screen.height() {
return Ok(vec![]);
}
let scaled_template = if (scale - 1.0).abs() < 0.001 {
template.clone()
} else {
image::imageops::resize(template, scaled_width, scaled_height, FilterType::Triangle)
};
let result = match_template(
screen,
&scaled_template,
MatchTemplateMethod::CrossCorrelationNormalized,
);
Self::extract_matches(&result, original_width, original_height, scale, config.confidence)
}
fn extract_matches(
result: &ImageBuffer<Luma<f32>, Vec<f32>>,
template_width: u32,
template_height: u32,
scale: f32,
threshold: f32,
) -> Result<Vec<MatchResult>> {
let mut matches = Vec::new();
for y in 0..result.height() {
for x in 0..result.width() {
let confidence = result.get_pixel(x, y).0[0];
if confidence >= threshold {
matches.push(MatchResult::new(
x,
y,
template_width,
template_height,
confidence,
scale,
));
}
}
}
Ok(matches)
}
fn non_max_suppression(matches: Vec<MatchResult>, iou_threshold: f32) -> Vec<MatchResult> {
if matches.is_empty() {
return matches;
}
let mut sorted = matches;
sorted.sort_by(|a, b| {
b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
});
let mut keep = Vec::new();
while !sorted.is_empty() {
let best = sorted.remove(0);
sorted.retain(|m| best.iou(m) < iou_threshold);
keep.push(best);
}
keep
}
}
#[cfg(test)]
mod tests {
use super::*;
use image::{DynamicImage, Rgb, RgbImage};
fn create_test_image(width: u32, height: u32, color: [u8; 3]) -> DynamicImage {
let mut img = RgbImage::new(width, height);
for pixel in img.pixels_mut() {
*pixel = Rgb(color);
}
DynamicImage::ImageRgb8(img)
}
fn create_image_with_rect(
width: u32,
height: u32,
rect_x: u32,
rect_y: u32,
rect_w: u32,
rect_h: u32,
) -> DynamicImage {
let mut img = RgbImage::new(width, height);
for pixel in img.pixels_mut() {
*pixel = Rgb([255, 255, 255]);
}
for y in rect_y..(rect_y + rect_h).min(height) {
for x in rect_x..(rect_x + rect_w).min(width) {
img.put_pixel(x, y, Rgb([0, 0, 0]));
}
}
DynamicImage::ImageRgb8(img)
}
#[test]
fn test_find_exact_match() {
let screen = create_image_with_rect(200, 200, 50, 50, 30, 30);
let template = create_image_with_rect(30, 30, 0, 0, 30, 30);
let config =
MatchConfig { search_multiple_scales: false, confidence: 0.1, ..Default::default() };
let results = ImageMatcher::find_all(&screen, &template, &config).unwrap();
let best_near_target =
results.iter().find(|r| (r.x as i32 - 50).abs() <= 10 && (r.y as i32 - 50).abs() <= 10);
if !results.is_empty() {
assert!(
best_near_target.is_some() || results.iter().any(|r| r.confidence > 0.3),
"Expected match near (50, 50) or with decent confidence, got {:?}",
results.first()
);
}
}
#[test]
fn test_find_no_match() {
let screen = create_test_image(200, 200, [255, 255, 255]); let template = create_test_image(30, 30, [0, 0, 0]);
let config = MatchConfig { confidence: 0.95, ..Default::default() };
let result = ImageMatcher::find(&screen, &template, &config).unwrap();
assert!(result.is_none() || result.as_ref().map(|r| r.confidence < 0.95).unwrap_or(false));
}
#[test]
fn test_nms() {
let matches = vec![
MatchResult::new(100, 100, 50, 50, 0.9, 1.0),
MatchResult::new(110, 110, 50, 50, 0.8, 1.0), MatchResult::new(200, 200, 50, 50, 0.7, 1.0), ];
let result = ImageMatcher::non_max_suppression(matches, 0.3);
assert_eq!(result.len(), 2); assert_eq!(result[0].x, 100);
assert_eq!(result[1].x, 200);
}
}