use std::sync::Arc;
use std::time::Duration;
use image::DynamicImage;
use imageproc::template_matching::{match_template, MatchTemplateMethod};
use crate::atspi::Rect;
use crate::backend::PointerButton;
use crate::error::{Error, Result};
use crate::session::Session;
const DEFAULT_THRESHOLD: f32 = 0.85;
#[derive(Clone)]
pub struct ImageLocator {
pub(crate) session: Arc<Session>,
template_rgb: Arc<image::RgbImage>,
region: Option<Rect>,
threshold: f32,
timeout: Option<Duration>,
}
impl std::fmt::Debug for ImageLocator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let (tw, th) = self.template_rgb.dimensions();
f.debug_struct("ImageLocator")
.field("kind", &"image-template")
.field("template_size", &format!("{tw}x{th}"))
.field("threshold", &self.threshold)
.field("region", &self.region)
.field("timeout", &self.timeout)
.finish()
}
}
impl ImageLocator {
pub(crate) fn new(
session: Arc<Session>,
png_bytes: &[u8],
region: Option<Rect>,
) -> Result<Self> {
let img = image::load_from_memory(png_bytes)
.map_err(|e| Error::visual(format!("decode template image: {e}")))?;
let rgb = img.into_rgb8();
let (w, h) = rgb.dimensions();
if w == 0 || h == 0 {
return Err(Error::visual("template image is empty"));
}
Ok(Self {
session,
template_rgb: Arc::new(rgb),
region,
threshold: DEFAULT_THRESHOLD,
timeout: None,
})
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn within(mut self, region: Rect) -> Self {
self.region = Some(region);
self
}
pub async fn matches(&self) -> Result<Vec<Rect>> {
let png = self.session.take_screenshot().await?;
let region = self.region;
let template = self.template_rgb.clone();
let threshold = self.threshold;
tokio::task::spawn_blocking(move || -> Result<Vec<Rect>> {
let full = crate::locator::decode_screenshot_png(&png)
.map_err(|e| Error::visual(format!("decode screenshot: {e}")))?;
let (haystack_rgb, origin_x, origin_y) = if let Some(scope) = region {
let cropped = crate::locator::crop_to_bounds(full, scope)
.map_err(|e| Error::visual(format!("crop to region: {e}")))?;
(cropped.into_rgb8(), scope.x.max(0), scope.y.max(0))
} else {
(full.into_rgb8(), 0, 0)
};
find_template_matches(&haystack_rgb, &template, threshold, origin_x, origin_y)
})
.await
.map_err(|e| Error::visual(format!("template-matching task panicked: {e}")))?
}
pub async fn count(&self) -> Result<usize> {
Ok(self.matches().await?.len())
}
pub async fn bounds(&self) -> Result<Rect> {
let deadline = std::time::Instant::now()
+ self
.timeout
.unwrap_or_else(|| self.session.default_timeout());
loop {
let hits = self.matches().await?;
match hits.len() {
0 => {
if std::time::Instant::now() >= deadline {
return Err(Error::ElementNotFound {
xpath: format!("image-template (threshold={})", self.threshold),
});
}
tokio::time::sleep(Duration::from_millis(200)).await;
}
1 => return Ok(hits[0]),
n => {
return Err(Error::visual(format!(
"found {n} image-template matches at threshold {}; \
scope with .within(rect) or raise the threshold to disambiguate",
self.threshold,
)));
}
}
}
}
pub async fn wait_for_visible(&self) -> Result<()> {
let deadline = std::time::Instant::now()
+ self
.timeout
.unwrap_or_else(|| self.session.default_timeout());
loop {
if !self.matches().await?.is_empty() {
return Ok(());
}
if std::time::Instant::now() >= deadline {
return Err(Error::ElementNotFound {
xpath: format!("image-template (threshold={})", self.threshold),
});
}
tokio::time::sleep(Duration::from_millis(200)).await;
}
}
pub async fn click(&self) -> Result<()> {
let r = self.bounds().await?;
let (cx, cy) = (r.center_x() as f64, r.center_y() as f64);
tracing::debug!(?r, cx, cy, "image: click");
super::cold_start_click(&self.session, cx, cy).await
}
pub async fn hover(&self) -> Result<()> {
let r = self.bounds().await?;
self.session
.pointer_motion_absolute(r.center_x() as f64, r.center_y() as f64)
.await?;
Ok(())
}
}
fn find_template_matches(
haystack_rgb: &image::RgbImage,
template_rgb: &image::RgbImage,
threshold: f32,
origin_x: i32,
origin_y: i32,
) -> Result<Vec<Rect>> {
let (hw, hh) = haystack_rgb.dimensions();
let (tw, th) = template_rgb.dimensions();
if tw == 0 || th == 0 {
return Err(Error::visual("template has zero width or height"));
}
if tw > hw || th > hh {
return Err(Error::visual(format!(
"template ({tw}x{th}) larger than search area ({hw}x{hh})"
)));
}
let haystack_gray = DynamicImage::ImageRgb8(haystack_rgb.clone()).into_luma8();
let template_gray = DynamicImage::ImageRgb8(template_rgb.clone()).into_luma8();
let result = match_template(
&haystack_gray,
&template_gray,
MatchTemplateMethod::CrossCorrelationNormalized,
);
let (rw, rh) = result.dimensions();
let mut peaks: Vec<(f32, u32, u32)> = Vec::new();
for y in 0..rh {
for x in 0..rw {
let score = result.get_pixel(x, y)[0];
if score >= threshold {
peaks.push((score, x, y));
}
}
}
peaks.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let min_dim = tw.min(th) as i32;
let nms_radius = (min_dim / 2).max(1);
let mut accepted: Vec<(u32, u32)> = Vec::new();
let mut out: Vec<Rect> = Vec::new();
for (_, x, y) in peaks {
let (xi, yi) = (x as i32, y as i32);
if accepted.iter().any(|&(ax, ay)| {
(ax as i32 - xi).abs() <= nms_radius && (ay as i32 - yi).abs() <= nms_radius
}) {
continue;
}
accepted.push((x, y));
out.push(Rect {
x: xi + origin_x,
y: yi + origin_y,
width: tw as i32,
height: th as i32,
});
}
Ok(out)
}
#[allow(unused_imports)]
use PointerButton as _;
#[cfg(test)]
mod tests {
use super::*;
use image::{Rgb, RgbImage};
fn solid(w: u32, h: u32, color: [u8; 3]) -> RgbImage {
let mut img = RgbImage::new(w, h);
for x in 0..w {
for y in 0..h {
img.put_pixel(x, y, Rgb(color));
}
}
img
}
fn embed(haystack_w: u32, haystack_h: u32, template: &RgbImage, x: u32, y: u32) -> RgbImage {
let mut hay = RgbImage::new(haystack_w, haystack_h);
for py in 0..haystack_h {
for px in 0..haystack_w {
let v = ((px.wrapping_mul(73) ^ py.wrapping_mul(31)) & 0xff) as u8;
hay.put_pixel(px, py, Rgb([v, v.wrapping_add(40), v.wrapping_add(80)]));
}
}
let (tw, th) = template.dimensions();
for dy in 0..th {
for dx in 0..tw {
hay.put_pixel(x + dx, y + dy, *template.get_pixel(dx, dy));
}
}
hay
}
#[test]
fn finds_exact_template_at_known_position() {
let mut template = solid(20, 10, [50, 100, 200]);
for x in 0..20 {
template.put_pixel(x, 4, Rgb([255, 255, 255]));
}
let haystack = embed(200, 100, &template, 73, 41);
let hits = find_template_matches(&haystack, &template, 0.95, 0, 0).expect("matching ok");
assert_eq!(hits.len(), 1, "expected exactly one peak above threshold");
let r = hits[0];
assert_eq!(r.x, 73);
assert_eq!(r.y, 41);
assert_eq!(r.width, 20);
assert_eq!(r.height, 10);
}
#[test]
fn returns_empty_when_template_not_present() {
let mut template = solid(20, 10, [50, 100, 200]);
for x in 0..20 {
template.put_pixel(x, 4, Rgb([255, 255, 255]));
}
let haystack = solid(200, 100, [10, 10, 10]);
let hits = find_template_matches(&haystack, &template, 0.95, 0, 0).expect("matching ok");
assert!(hits.is_empty(), "expected no peaks; got {hits:?}");
}
#[test]
fn screen_coord_translation_via_origin() {
let mut template = solid(8, 8, [200, 50, 50]);
for i in 0..8 {
template.put_pixel(i, i, Rgb([255, 255, 255]));
}
let haystack = embed(100, 100, &template, 20, 30);
let hits =
find_template_matches(&haystack, &template, 0.95, 500, 600).expect("matching ok");
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].x, 520);
assert_eq!(hits[0].y, 630);
}
#[test]
fn errors_when_template_larger_than_haystack() {
let template = solid(40, 40, [50, 100, 200]);
let haystack = solid(20, 20, [50, 100, 200]);
let err = find_template_matches(&haystack, &template, 0.5, 0, 0).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("larger than search area"),
"unexpected error: {msg}"
);
}
#[test]
fn nms_suppresses_neighbouring_peaks() {
let mut template = solid(20, 20, [80, 80, 80]);
for i in 0..20 {
template.put_pixel(i, 10, Rgb([255, 255, 255]));
}
let haystack = embed(200, 200, &template, 50, 50);
let hits = find_template_matches(&haystack, &template, 0.95, 0, 0).expect("matching ok");
assert_eq!(
hits.len(),
1,
"NMS should collapse near-peaks; got {hits:?}"
);
assert_eq!(hits[0].x, 50);
assert_eq!(hits[0].y, 50);
}
}