mod baseline;
mod color;
mod engine;
mod models;
mod region;
mod template;
use std::sync::Arc;
use std::time::Duration;
use crate::atspi::Rect;
use crate::backend::PointerButton;
use crate::error::{Error, Result};
use crate::session::{Session, VisualTextTuning};
pub(crate) async fn cold_start_click(session: &Arc<Session>, cx: f64, cy: f64) -> Result<()> {
session.cold_start_click(cx, cy, PointerButton::Left).await
}
pub use baseline::{compare_to_baseline, diff_to_baseline, BaselineComparison};
pub(crate) use engine::{ensure_engine, EngineResult};
pub use region::{RegionLocator, Shape};
pub use template::ImageLocator;
pub(crate) use region::last_region_only as __region_last_only;
pub(crate) use region::region_at_seed as __region_at_seed;
pub(crate) use region::sweep_regions as __region_sweep;
pub(crate) use list_labelled_regions as __list_labelled_regions;
pub(crate) use list_text as __list_text;
pub(crate) use recognized_text as __recognized_text;
#[derive(Clone, Debug)]
pub struct TextHit {
pub text: String,
pub bounds: Rect,
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum MatchMode {
#[default]
Substring,
Exact,
Fuzzy,
}
#[derive(Clone)]
pub struct VisualLocator {
session: Arc<Session>,
text: String,
region: Option<Rect>,
timeout: Option<Duration>,
match_mode: MatchMode,
upscale: Option<u32>,
}
impl std::fmt::Debug for VisualLocator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VisualLocator")
.field("kind", &"text-label")
.field("text", &self.text)
.field("match_mode", &self.match_mode)
.field("region", &self.region)
.field("timeout", &self.timeout)
.field("upscale", &self.upscale)
.finish()
}
}
const VISUAL_DEFAULT_TIMEOUT: Duration = Duration::from_secs(120);
impl VisualLocator {
pub(crate) fn new(session: Arc<Session>, text: impl Into<String>) -> Self {
Self {
session,
text: text.into(),
region: None,
timeout: None,
match_mode: MatchMode::default(),
upscale: None,
}
}
pub fn within(&self, region: Rect) -> VisualLocator {
VisualLocator {
region: Some(region),
..self.clone()
}
}
pub fn with_timeout(&self, timeout: Duration) -> VisualLocator {
VisualLocator {
timeout: Some(timeout),
..self.clone()
}
}
pub fn with_match_mode(&self, mode: MatchMode) -> VisualLocator {
VisualLocator {
match_mode: mode,
..self.clone()
}
}
pub fn with_upscale(&self, factor: u32) -> VisualLocator {
VisualLocator {
upscale: Some(factor),
..self.clone()
}
}
pub fn text(&self) -> &str {
&self.text
}
pub fn region(&self) -> Option<Rect> {
self.region
}
pub fn match_mode(&self) -> MatchMode {
self.match_mode
}
pub fn upscale(&self) -> Option<u32> {
self.upscale
}
async fn matches(&self) -> Result<Vec<Rect>> {
let needle = self.text.clone();
let mode = self.match_mode;
let upscale = self
.upscale
.unwrap_or(self.session.visual_text_tuning.ocr_upscale_factor);
let ocr = ocr_lines(
&self.session,
self.region,
self.session.take_screenshot().await?,
upscale,
)
.await?;
let mut out = Vec::new();
match mode {
MatchMode::Exact => {
for line in &ocr.lines {
for (m_start, m_end) in find_matches(&line.joined, &needle, MatchMode::Exact) {
if let Some(rect) =
union_bbox_for_match(&line.words, &line.spans, m_start, m_end)
{
if let Some(scope) = self.region {
if !rect.is_inside(&scope) {
continue;
}
}
out.push(rect);
}
}
}
}
MatchMode::Substring | MatchMode::Fuzzy => {
let boundary = BoundaryContext {
image: &ocr.image,
crop_origin: ocr.crop_origin,
};
let blocks = group_lines_into_blocks(
ocr.lines.clone(),
self.session.visual_text_tuning,
Some(boundary),
);
for block in &blocks {
let variants = block_haystack_variants(block);
for variant in &variants {
for (m_start, m_end) in find_matches(&variant.joined, &needle, mode) {
if let Some(rect) =
union_bbox_for_match(&block.words, &variant.spans, m_start, m_end)
{
if let Some(scope) = self.region {
if !rect.is_inside(&scope) {
continue;
}
}
if !out.contains(&rect) {
out.push(rect);
}
}
}
}
}
}
}
Ok(out)
}
pub async fn count(&self) -> Result<usize> {
Ok(self.matches().await?.len())
}
fn effective_timeout(&self) -> Duration {
self.timeout.unwrap_or(VISUAL_DEFAULT_TIMEOUT)
}
fn timeout_err(&self) -> Error {
Error::Timeout(format!(
"visual: no match for {:?} within {}ms",
self.text,
self.effective_timeout().as_millis()
))
}
async fn matches_bounded(&self, deadline: std::time::Instant) -> Result<Vec<Rect>> {
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
return Err(self.timeout_err());
}
tokio::select! {
biased;
_ = self.session.cancellation_token().cancelled() => Err(Error::Cancelled),
r = tokio::time::timeout(remaining, self.matches()) => match r {
Ok(res) => res,
Err(_) => Err(self.timeout_err()),
}
}
}
async fn sleep_or_cancel(&self, d: Duration) -> Result<()> {
tokio::select! {
_ = self.session.cancellation_token().cancelled() => Err(Error::Cancelled),
_ = tokio::time::sleep(d) => Ok(()),
}
}
pub async fn bounds(&self) -> Result<Rect> {
let deadline = std::time::Instant::now() + self.effective_timeout();
loop {
let hits = self.matches_bounded(deadline).await?;
match hits.len() {
0 => {
if std::time::Instant::now() >= deadline {
return Err(self.timeout_err());
}
self.sleep_or_cancel(Duration::from_millis(200)).await?;
}
1 => return Ok(hits[0]),
n => {
return Err(Error::visual(format!(
"found {n} visual matches for {:?}; scope with .within(rect) \
or use a tighter MatchMode",
self.text,
)));
}
}
}
}
pub async fn wait_for_exists(&self) -> Result<()> {
let deadline = std::time::Instant::now() + self.effective_timeout();
loop {
if !self.matches_bounded(deadline).await?.is_empty() {
return Ok(());
}
if std::time::Instant::now() >= deadline {
return Err(self.timeout_err());
}
self.sleep_or_cancel(Duration::from_millis(200)).await?;
}
}
pub async fn click(&self) -> Result<()> {
let r = self.bounds().await?;
let (cx, cy) = (r.center_x() as f64, r.center_y() as f64);
tracing::debug!(text = %self.text, cx, cy, bbox = ?r, "visual: click");
cold_start_click(&self.session, cx, cy).await
}
pub async fn hover(&self) -> Result<()> {
let r = self.bounds().await?;
self.session
.pointer_motion_absolute(r.center_x() as f64, r.center_y() as f64)
.await?;
Ok(())
}
pub async fn parent_region(&self) -> Result<RegionLocator> {
let parent_bounds = self.region.ok_or_else(|| {
Error::visual(
"parent_region: VisualLocator has no parent scope. Construct it via \
Locator::find_by_text(...) or Session::find_by_text(...).within(rect).",
)
})?;
let inner_bbox = self.bounds().await?;
let png = self.session.take_screenshot().await?;
region::last_region_only(
&self.session,
parent_bounds,
inner_bbox,
&png,
self.session.visual_region_tuning,
)
}
}
#[derive(Debug, Clone)]
struct OcrLine {
joined: String,
bbox: Rect,
words: Vec<(String, Rect)>,
spans: Vec<(usize, usize)>,
}
#[derive(Debug)]
struct OcrBlock {
joined: String,
bbox: Rect,
words: Vec<(String, Rect)>,
spans: Vec<(usize, usize)>,
line_break_word_indices: Vec<usize>,
}
struct BoundaryContext<'a> {
image: &'a image::RgbImage,
crop_origin: (i32, i32),
}
fn sample_pixel_at_screen(
ctx: &BoundaryContext<'_>,
screen_x: i32,
screen_y: i32,
) -> Option<image::Rgb<u8>> {
let crop_x = screen_x - ctx.crop_origin.0;
let crop_y = screen_y - ctx.crop_origin.1;
if crop_x < 0 || crop_y < 0 {
return None;
}
let (w, h) = ctx.image.dimensions();
if crop_x >= w as i32 || crop_y >= h as i32 {
return None;
}
Some(*ctx.image.get_pixel(crop_x as u32, crop_y as u32))
}
fn merge_passes_boundary_check(
prev: &OcrLine,
next: &OcrLine,
ctx: &BoundaryContext<'_>,
tuning: VisualTextTuning,
) -> bool {
let mode = tuning.color_distance;
let tol_sq = color::threshold_sq(tuning.background_color_tolerance, mode);
let overlap_left = prev.bbox.x.max(next.bbox.x);
let overlap_right = (prev.bbox.x + prev.bbox.width).min(next.bbox.x + next.bbox.width);
if overlap_right <= overlap_left {
return true;
}
let mid_x = (overlap_left + overlap_right) / 2;
let prev_bottom = prev.bbox.y + prev.bbox.height;
let next_top = next.bbox.y;
if next_top <= prev_bottom {
return true;
}
let top_sample = sample_background_at_screen(ctx, mid_x, prev_bottom, tuning);
let bot_sample = sample_background_at_screen(ctx, mid_x, next_top - 1, tuning);
if let (Some(top), Some(bot)) = (top_sample, bot_sample) {
if color::distance_sq(top, bot, mode) > tol_sq {
return false;
}
if tuning.divider_detection_enabled {
let samples_per_axis = tuning.boundary_samples_per_axis.max(1);
let majority_threshold = tuning.boundary_majority_threshold;
for row_y in prev_bottom..next_top {
let mut differing = 0usize;
let mut sampled = 0usize;
for i in 0..samples_per_axis {
let x = overlap_left
+ ((overlap_right - overlap_left) as usize * i / samples_per_axis) as i32;
if let Some(p) = sample_pixel_at_screen(ctx, x, row_y) {
sampled += 1;
if color::distance_sq(p, top, mode) > tol_sq
&& color::distance_sq(p, bot, mode) > tol_sq
{
differing += 1;
}
}
}
if sampled > 0 && (differing as f32) / (sampled as f32) >= majority_threshold {
return false; }
}
let gap_height = next_top - prev_bottom;
for col_x in overlap_left..overlap_right {
let mut differing = 0usize;
let mut sampled = 0usize;
for i_y in 0..samples_per_axis {
let y = prev_bottom + (gap_height as usize * i_y / samples_per_axis) as i32;
if let Some(p) = sample_pixel_at_screen(ctx, col_x, y) {
sampled += 1;
if color::distance_sq(p, top, mode) > tol_sq
&& color::distance_sq(p, bot, mode) > tol_sq
{
differing += 1;
}
}
}
if sampled > 0 && (differing as f32) / (sampled as f32) >= majority_threshold {
return false; }
}
}
}
if tuning.connectivity_check_enabled {
if let (Some(top), Some(bot)) = (
sample_pixel_at_screen(ctx, mid_x, prev_bottom),
sample_pixel_at_screen(ctx, mid_x, next_top - 1),
) {
if !connectivity_passes_check(ctx, mid_x, prev_bottom, next_top - 1, top, bot, tuning) {
return false;
}
}
}
true
}
fn connectivity_passes_check(
ctx: &BoundaryContext<'_>,
screen_x: i32,
screen_top_bottom_y: i32,
screen_next_top_y: i32,
seed_color: image::Rgb<u8>,
target_color: image::Rgb<u8>,
tuning: VisualTextTuning,
) -> bool {
let _ = target_color; let crop_x = screen_x - ctx.crop_origin.0;
let crop_seed_y = screen_top_bottom_y - ctx.crop_origin.1;
let crop_target_y = screen_next_top_y - ctx.crop_origin.1;
let (iw, ih) = ctx.image.dimensions();
if crop_x < 0
|| crop_seed_y < 0
|| crop_target_y < 0
|| crop_x >= iw as i32
|| crop_seed_y >= ih as i32
|| crop_target_y >= ih as i32
{
return true; }
let flood = region::flood_fill(
ctx.image,
(crop_x, crop_seed_y),
tuning.background_color_tolerance,
tuning.max_connectivity_pixels,
tuning.color_distance,
);
let target_idx = (crop_target_y as usize) * (flood.image_width as usize) + (crop_x as usize);
let _ = seed_color;
if target_idx < flood.visited.len() && flood.visited[target_idx] {
return true; }
false
}
fn sample_background_at_screen(
ctx: &BoundaryContext<'_>,
screen_x: i32,
screen_y: i32,
tuning: VisualTextTuning,
) -> Option<image::Rgb<u8>> {
let crop_x = screen_x - ctx.crop_origin.0;
let crop_y = screen_y - ctx.crop_origin.1;
color::sample_window(ctx.image, crop_x, crop_y, tuning.background_sample_radius)
}
fn group_lines_into_blocks(
lines: Vec<OcrLine>,
tuning: VisualTextTuning,
boundary: Option<BoundaryContext<'_>>,
) -> Vec<OcrBlock> {
if lines.is_empty() {
return Vec::new();
}
let mut sorted = lines;
sorted.sort_by_key(|l| l.bbox.y);
let mut blocks_of_lines: Vec<Vec<OcrLine>> = Vec::new();
for line in sorted {
let mut best_idx: Option<usize> = None;
let mut best_gap = i32::MAX;
for (idx, block) in blocks_of_lines.iter().enumerate() {
let prev = block.last().expect("block can never be empty");
let prev_bottom = prev.bbox.y + prev.bbox.height;
let gap = line.bbox.y - prev_bottom;
if gap < 0 {
continue; }
let max_gap = (prev.bbox.height as f32 * tuning.multiline_max_gap_factor) as i32;
if gap > max_gap {
continue;
}
let slack = tuning.multiline_x_slack_px;
let line_left = line.bbox.x - slack;
let line_right = line.bbox.x + line.bbox.width + slack;
let prev_left = prev.bbox.x;
let prev_right = prev.bbox.x + prev.bbox.width;
let x_overlap = line_left < prev_right && prev_left < line_right;
if !x_overlap {
continue;
}
if let Some(ref ctx) = boundary {
if !merge_passes_boundary_check(prev, &line, ctx, tuning) {
continue;
}
}
if gap < best_gap {
best_gap = gap;
best_idx = Some(idx);
}
}
match best_idx {
Some(idx) => blocks_of_lines[idx].push(line),
None => blocks_of_lines.push(vec![line]),
}
}
blocks_of_lines
.into_iter()
.map(|block_lines| {
let mut joined = String::new();
let mut words: Vec<(String, Rect)> = Vec::new();
let mut spans: Vec<(usize, usize)> = Vec::new();
let mut line_break_word_indices: Vec<usize> = Vec::new();
let mut min_x = i32::MAX;
let mut min_y = i32::MAX;
let mut max_x = i32::MIN;
let mut max_y = i32::MIN;
for (line_idx, line) in block_lines.into_iter().enumerate() {
if line_idx > 0 {
line_break_word_indices.push(words.len());
}
for (text, rect) in line.words {
if !joined.is_empty() {
joined.push(' ');
}
let start = joined.len();
joined.push_str(&text);
spans.push((start, joined.len()));
min_x = min_x.min(rect.x);
min_y = min_y.min(rect.y);
max_x = max_x.max(rect.x + rect.width);
max_y = max_y.max(rect.y + rect.height);
words.push((text, rect));
}
}
OcrBlock {
joined,
bbox: Rect {
x: min_x,
y: min_y,
width: max_x - min_x,
height: max_y - min_y,
},
words,
spans,
line_break_word_indices,
}
})
.collect()
}
pub(crate) struct OcrResult {
lines: Vec<OcrLine>,
image: image::RgbImage,
crop_origin: (i32, i32),
}
type RegionKey = Option<(i32, i32, i32, i32)>;
type OcrCacheKey = (RegionKey, u32);
#[derive(Default)]
pub(crate) struct OcrCache {
frame_hash: u64,
by_region: std::collections::HashMap<OcrCacheKey, Arc<OcrResult>>,
}
impl OcrCache {
fn get(&mut self, hash: u64, region: RegionKey, upscale: u32) -> Option<Arc<OcrResult>> {
if self.frame_hash != hash {
self.frame_hash = hash;
self.by_region.clear();
}
self.by_region.get(&(region, upscale)).cloned()
}
fn put(&mut self, hash: u64, region: RegionKey, upscale: u32, result: Arc<OcrResult>) {
if self.frame_hash == hash {
self.by_region.insert((region, upscale), result);
}
}
}
fn frame_hash(png_bytes: &[u8]) -> u64 {
use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new();
png_bytes.hash(&mut h);
h.finish()
}
async fn ocr_lines(
session: &Arc<Session>,
region: Option<Rect>,
png_bytes: Vec<u8>,
upscale: u32,
) -> Result<Arc<OcrResult>> {
use ocrs::{ImageSource, TextItem};
let region_key = region.map(|r| (r.x, r.y, r.width, r.height));
let hash = frame_hash(&png_bytes);
let upscale = upscale.max(1);
if let Some(hit) = session
.visual_ocr_cache()
.lock()
.unwrap()
.get(hash, region_key, upscale)
{
return Ok(hit);
}
let engine = session
.visual_engine()
.get_or_init(ensure_engine)
.await
.clone()
.map_err(Error::visual)?;
let context_pad_px = session.visual_text_tuning.ocr_context_padding_px;
let result = tokio::task::spawn_blocking(move || -> Result<OcrResult> {
let full = crate::locator::decode_screenshot_png(&png_bytes)
.map_err(|e| Error::visual(format!("decode screenshot: {e}")))?;
let (cropped, origin_x, origin_y) = if let Some(scope) = region {
let padded = Rect {
x: scope.x - context_pad_px,
y: scope.y - context_pad_px,
width: scope.width + 2 * context_pad_px,
height: scope.height + 2 * context_pad_px,
};
let cropped = crate::locator::crop_to_bounds(full, padded)
.map_err(|e| Error::visual(format!("crop to region: {e}")))?;
(cropped, padded.x.max(0), padded.y.max(0))
} else {
(full, 0, 0)
};
let rgb = cropped.into_rgb8();
let (w, h) = rgb.dimensions();
let upscaled_img;
let (ocr_bytes, ocr_w, ocr_h): (&[u8], u32, u32) = if upscale > 1 {
let f = upscale;
upscaled_img =
image::imageops::resize(&rgb, w * f, h * f, image::imageops::FilterType::Lanczos3);
(upscaled_img.as_raw(), w * f, h * f)
} else {
(rgb.as_raw(), w, h)
};
let upscale_i = upscale as i32;
let src = ImageSource::from_bytes(ocr_bytes, (ocr_w, ocr_h))
.map_err(|e| Error::visual(format!("ocrs ImageSource: {e}")))?;
let input = engine
.prepare_input(src)
.map_err(|e| Error::visual(format!("ocrs prepare_input: {e}")))?;
let word_rects = engine
.detect_words(&input)
.map_err(|e| Error::visual(format!("ocrs detect_words: {e}")))?;
let line_rects = engine.find_text_lines(&input, &word_rects);
let lines = engine
.recognize_text(&input, &line_rects)
.map_err(|e| Error::visual(format!("ocrs recognize_text: {e}")))?;
let mut out = Vec::new();
for line_opt in lines.iter().flatten() {
let words: Vec<(String, Rect)> = line_opt
.words()
.map(|w| {
let text: String = w.chars().iter().map(|c| c.char).collect();
let r = w.bounding_rect();
let rect = Rect {
x: r.left() / upscale_i + origin_x,
y: r.top() / upscale_i + origin_y,
width: (r.width() / upscale_i).max(1),
height: (r.height() / upscale_i).max(1),
};
(text, rect)
})
.collect();
if words.is_empty() {
continue;
}
let mut joined = String::new();
let mut spans: Vec<(usize, usize)> = Vec::with_capacity(words.len());
for (i, (text, _)) in words.iter().enumerate() {
if i > 0 {
joined.push(' ');
}
let start = joined.len();
joined.push_str(text);
spans.push((start, joined.len()));
}
let mut min_x = i32::MAX;
let mut min_y = i32::MAX;
let mut max_x = i32::MIN;
let mut max_y = i32::MIN;
for (_, r) in &words {
min_x = min_x.min(r.x);
min_y = min_y.min(r.y);
max_x = max_x.max(r.x + r.width);
max_y = max_y.max(r.y + r.height);
}
let bbox = Rect {
x: min_x,
y: min_y,
width: max_x - min_x,
height: max_y - min_y,
};
tracing::trace!(line = %joined, ?bbox, "visual: OCR line");
out.push(OcrLine {
joined,
bbox,
words,
spans,
});
}
Ok(OcrResult {
lines: out,
image: rgb,
crop_origin: (origin_x, origin_y),
})
})
.await
.map_err(|e| Error::visual(format!("OCR task panicked: {e}")))??;
let arc = Arc::new(result);
session
.visual_ocr_cache()
.lock()
.unwrap()
.put(hash, region_key, upscale, arc.clone());
Ok(arc)
}
pub(crate) async fn list_text(
session: &Arc<Session>,
scope: Rect,
png: Vec<u8>,
) -> Result<Vec<TextHit>> {
let ocr = ocr_lines(
session,
Some(scope),
png,
session.visual_text_tuning.ocr_upscale_factor,
)
.await?;
let boundary = BoundaryContext {
image: &ocr.image,
crop_origin: ocr.crop_origin,
};
let blocks = group_lines_into_blocks(
ocr.lines.clone(),
session.visual_text_tuning,
Some(boundary),
);
Ok(blocks
.into_iter()
.filter(|block| block.bbox.is_inside(&scope))
.map(|block| TextHit {
text: block.joined,
bounds: block.bbox,
})
.collect())
}
pub(crate) async fn recognized_text(session: &Arc<Session>, png: Vec<u8>) -> Result<Vec<TextHit>> {
let ocr = ocr_lines(
session,
None,
png,
session.visual_text_tuning.ocr_upscale_factor,
)
.await?;
let boundary = BoundaryContext {
image: &ocr.image,
crop_origin: ocr.crop_origin,
};
let blocks = group_lines_into_blocks(
ocr.lines.clone(),
session.visual_text_tuning,
Some(boundary),
);
Ok(blocks
.into_iter()
.map(|block| TextHit {
text: block.joined,
bounds: block.bbox,
})
.collect())
}
pub(crate) async fn list_labelled_regions(
session: &Arc<Session>,
scope: Rect,
png: Vec<u8>,
tuning: crate::session::VisualRegionTuning,
) -> Result<Vec<(TextHit, RegionLocator)>> {
let ocr = ocr_lines(
session,
Some(scope),
png.clone(),
session.visual_text_tuning.ocr_upscale_factor,
)
.await?;
let boundary = BoundaryContext {
image: &ocr.image,
crop_origin: ocr.crop_origin,
};
let blocks = group_lines_into_blocks(
ocr.lines.clone(),
session.visual_text_tuning,
Some(boundary),
);
let mut pairs = Vec::new();
for block in blocks {
if !block.bbox.is_inside(&scope) {
continue;
}
let region_loc = region::last_region_only(session, scope, block.bbox, &png, tuning)?;
pairs.push((
TextHit {
text: block.joined,
bounds: block.bbox,
},
region_loc,
));
}
Ok(pairs)
}
fn normalize_for_match(s: &str) -> String {
use unicode_normalization::{char::is_combining_mark, UnicodeNormalization};
s.nfkd()
.filter(|c| !is_combining_mark(*c))
.flat_map(|c| c.to_lowercase())
.collect()
}
#[derive(Debug)]
struct BlockVariant {
joined: String,
spans: Vec<(usize, usize)>,
}
const MAX_VARIANT_LINES: usize = 5;
fn block_haystack_variants(block: &OcrBlock) -> Vec<BlockVariant> {
let n_words = block.words.len();
if n_words == 0 {
return Vec::new();
}
let mut line_starts: Vec<usize> = vec![0];
line_starts.extend(block.line_break_word_indices.iter().copied());
let n_lines = line_starts.len();
let n_seams = n_lines.saturating_sub(1);
if n_lines > MAX_VARIANT_LINES {
return vec![BlockVariant {
joined: block.joined.clone(),
spans: block.spans.clone(),
}];
}
let n_variants = 1usize << n_seams;
let mut variants: Vec<BlockVariant> = Vec::with_capacity(n_variants);
let normalized_words: Vec<String> = block
.words
.iter()
.map(|(t, _)| normalize_for_match(t))
.collect();
for mask in 0..n_variants {
let mut joined = String::new();
let mut spans = Vec::with_capacity(n_words);
for (wi, normalized) in normalized_words.iter().enumerate() {
if wi > 0 {
let line_break_idx = line_starts.iter().skip(1).position(|&s| s == wi);
let joiner = match line_break_idx {
Some(seam) if (mask >> seam) & 1 == 1 => "",
_ => " ",
};
joined.push_str(joiner);
}
let start = joined.len();
joined.push_str(normalized);
spans.push((start, joined.len()));
}
variants.push(BlockVariant { joined, spans });
}
variants
}
fn find_matches(haystack: &str, needle: &str, mode: MatchMode) -> Vec<(usize, usize)> {
if needle.is_empty() {
return Vec::new();
}
let h = normalize_for_match(haystack);
let n = normalize_for_match(needle);
if n.is_empty() {
return Vec::new();
}
match mode {
MatchMode::Exact => {
if h == n {
vec![(0, h.len())]
} else {
Vec::new()
}
}
MatchMode::Substring => {
let mut out = Vec::new();
let mut start = 0usize;
while let Some(off) = h[start..].find(&n) {
let abs = start + off;
out.push((abs, abs + n.len()));
start = abs + n.len().max(1);
if start > h.len() {
break;
}
}
out
}
MatchMode::Fuzzy => fuzzy_find(&h, &n),
}
}
fn levenshtein(a: &str, b: &str) -> usize {
let a: Vec<char> = a.chars().collect();
let b: Vec<char> = b.chars().collect();
if a.is_empty() {
return b.len();
}
let mut prev: Vec<usize> = (0..=b.len()).collect();
let mut cur = vec![0usize; b.len() + 1];
for (i, &ca) in a.iter().enumerate() {
cur[0] = i + 1;
for (j, &cb) in b.iter().enumerate() {
let cost = usize::from(ca != cb);
cur[j + 1] = (prev[j + 1] + 1).min(cur[j] + 1).min(prev[j] + cost);
}
std::mem::swap(&mut prev, &mut cur);
}
prev[b.len()]
}
fn fuzzy_find(h: &str, n: &str) -> Vec<(usize, usize)> {
let needle_words = n.split_whitespace().count().max(1);
let budget = (n.chars().count() / 5).max(1);
let mut words: Vec<(usize, usize)> = Vec::new();
let mut start: Option<usize> = None;
for (idx, ch) in h.char_indices() {
if ch.is_whitespace() {
if let Some(s) = start.take() {
words.push((s, idx));
}
} else if start.is_none() {
start = Some(idx);
}
}
if let Some(s) = start.take() {
words.push((s, h.len()));
}
if words.len() < needle_words {
return Vec::new();
}
let mut out = Vec::new();
for w in 0..=(words.len() - needle_words) {
let span = (words[w].0, words[w + needle_words - 1].1);
if levenshtein(&h[span.0..span.1], n) <= budget {
out.push(span);
}
}
out
}
fn union_bbox_for_match(
words: &[(String, Rect)],
spans: &[(usize, usize)],
m_start: usize,
m_end: usize,
) -> Option<Rect> {
let mut min_x = i32::MAX;
let mut min_y = i32::MAX;
let mut max_x = i32::MIN;
let mut max_y = i32::MIN;
let mut hit = false;
for (i, &(s, e)) in spans.iter().enumerate() {
if s < m_end && m_start < e {
let r = words[i].1;
min_x = min_x.min(r.x);
min_y = min_y.min(r.y);
max_x = max_x.max(r.x + r.width);
max_y = max_y.max(r.y + r.height);
hit = true;
}
}
if !hit {
return None;
}
Some(Rect {
x: min_x,
y: min_y,
width: max_x - min_x,
height: max_y - min_y,
})
}
#[allow(dead_code)] fn text_matches(haystack: &str, needle: &str, mode: MatchMode) -> bool {
match mode {
MatchMode::Exact => haystack == needle,
MatchMode::Substring => haystack.to_lowercase().contains(&needle.to_lowercase()),
MatchMode::Fuzzy => !find_matches(haystack, needle, MatchMode::Fuzzy).is_empty(),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_result() -> Arc<OcrResult> {
Arc::new(OcrResult {
lines: Vec::new(),
image: image::RgbImage::new(1, 1),
crop_origin: (0, 0),
})
}
#[test]
fn frame_hash_is_deterministic_and_content_sensitive() {
assert_eq!(frame_hash(b"same bytes"), frame_hash(b"same bytes"));
assert_ne!(frame_hash(b"frame a"), frame_hash(b"frame b"));
}
#[test]
fn ocr_cache_hits_same_frame_and_region() {
let mut cache = OcrCache::default();
let h = 42;
assert!(cache.get(h, None, 1).is_none(), "cold cache misses");
cache.put(h, None, 1, dummy_result());
assert!(cache.get(h, None, 1).is_some(), "same frame+region hits");
assert!(cache.get(h, Some((0, 0, 10, 10)), 1).is_none());
assert!(
cache.get(h, None, 3).is_none(),
"different upscale factor is a separate entry"
);
}
#[test]
fn ocr_cache_invalidates_on_new_frame() {
let mut cache = OcrCache::default();
cache.get(1, None, 1); cache.put(1, None, 1, dummy_result());
assert!(cache.get(1, None, 1).is_some());
assert!(cache.get(2, None, 1).is_none(), "new frame busts the memo");
assert!(
cache.get(1, None, 1).is_none(),
"old frame's entry is gone after invalidation"
);
}
#[test]
fn ocr_cache_put_ignored_for_stale_frame() {
let mut cache = OcrCache::default();
assert!(cache.get(2, None, 1).is_none());
cache.put(1, None, 1, dummy_result());
assert!(cache.get(2, None, 1).is_none(), "stale-frame put rejected");
}
#[test]
fn find_matches_substring_in_single_word() {
let hits = find_matches("account", "acc", MatchMode::Substring);
assert_eq!(hits, vec![(0, 3)]);
}
#[test]
fn find_matches_substring_spans_words() {
let hits = find_matches("Add account row", "Add account", MatchMode::Substring);
assert_eq!(hits, vec![(0, 11)]);
}
#[test]
fn find_matches_substring_is_case_insensitive() {
let hits = find_matches("ADD account ROW", "add account", MatchMode::Substring);
assert_eq!(hits, vec![(0, 11)]);
}
#[test]
fn find_matches_substring_multiple_hits() {
let hits = find_matches("foo bar foo", "foo", MatchMode::Substring);
assert_eq!(hits, vec![(0, 3), (8, 11)]);
}
#[test]
fn levenshtein_basic() {
assert_eq!(levenshtein("cursor", "cursor"), 0);
assert_eq!(levenshtein("cursor", "cursar"), 1); assert_eq!(levenshtein("hover-target", "hover-targel"), 1);
assert_eq!(levenshtein("abc", ""), 3);
}
#[test]
fn fuzzy_matches_single_glyph_ocr_error() {
let hits = find_matches("cursar font scrollback", "Cursor", MatchMode::Fuzzy);
assert_eq!(hits, vec![(0, 6)], "fuzzy should locate the mis-read word");
}
#[test]
fn fuzzy_matches_hyphenated_misread() {
let hits = find_matches(
"primary-button mode-toggle",
"hover-targel",
MatchMode::Fuzzy,
);
assert!(hits.is_empty(), "an unrelated word must not fuzzy-match");
let hits = find_matches("hover-targel dc-target", "hover-target", MatchMode::Fuzzy);
assert_eq!(hits, vec![(0, 12)]);
}
#[test]
fn fuzzy_multiword_window() {
let hits = find_matches("open the prefs dialog", "prefs dialeg", MatchMode::Fuzzy);
assert_eq!(hits, vec![(9, 21)]);
}
#[test]
fn fuzzy_rejects_too_many_errors() {
let hits = find_matches("buffer", "Cursor", MatchMode::Fuzzy);
assert!(hits.is_empty());
}
#[test]
fn find_matches_exact_full_string_only() {
assert_eq!(
find_matches("Add account", "Add account", MatchMode::Exact),
vec![(0, 11)]
);
assert!(find_matches("Add account row", "Add account", MatchMode::Exact).is_empty());
}
#[test]
fn find_matches_empty_needle_yields_nothing() {
assert!(find_matches("anything", "", MatchMode::Substring).is_empty());
assert!(find_matches("anything", "", MatchMode::Exact).is_empty());
}
#[test]
fn normalize_strips_diacritics() {
assert_eq!(normalize_for_match("Café"), "cafe");
assert_eq!(normalize_for_match("naïve"), "naive");
assert_eq!(normalize_for_match("ÄÖÜ"), "aou");
}
#[test]
fn normalize_decomposes_ligatures() {
assert_eq!(normalize_for_match("file"), "file");
assert_eq!(normalize_for_match("flux"), "flux");
}
#[test]
fn normalize_is_idempotent() {
let s = "Café file ABC";
let once = normalize_for_match(s);
let twice = normalize_for_match(&once);
assert_eq!(once, twice);
}
#[test]
fn find_matches_substring_handles_diacritics() {
let hits = find_matches("Café latte", "cafe", MatchMode::Substring);
assert_eq!(hits.len(), 1);
}
#[test]
fn block_variants_handles_single_line_block() {
let line = make_line(vec![("Hello", rect(0, 0, 50, 10))]);
let blocks = group_lines_into_blocks(vec![line], default_text_tuning(), None);
assert_eq!(blocks.len(), 1);
let variants = block_haystack_variants(&blocks[0]);
assert_eq!(variants.len(), 1);
assert_eq!(variants[0].joined, "hello");
}
#[test]
fn block_variants_two_lines_produce_space_and_no_space_join() {
let line_a = make_line(vec![("nee", rect(0, 0, 30, 10))]);
let line_b = make_line(vec![("dle", rect(0, 14, 30, 10))]);
let blocks = group_lines_into_blocks(vec![line_a, line_b], default_text_tuning(), None);
assert_eq!(blocks.len(), 1);
let variants = block_haystack_variants(&blocks[0]);
assert_eq!(variants.len(), 2);
let joineds: Vec<&str> = variants.iter().map(|v| v.joined.as_str()).collect();
assert!(joineds.contains(&"nee dle"));
assert!(joineds.contains(&"needle"));
}
#[test]
fn block_variants_query_needle_matches_hyphenated_wrap() {
let line_a = make_line(vec![("nee", rect(0, 0, 30, 10))]);
let line_b = make_line(vec![("dle", rect(0, 14, 30, 10))]);
let blocks = group_lines_into_blocks(vec![line_a, line_b], default_text_tuning(), None);
let variants = block_haystack_variants(&blocks[0]);
let mut any_match = false;
for v in &variants {
if !find_matches(&v.joined, "needle", MatchMode::Substring).is_empty() {
any_match = true;
break;
}
}
assert!(any_match, "expected at least one variant to match 'needle'");
}
#[test]
fn block_variants_capped_at_max_lines() {
let lines: Vec<OcrLine> = (0..6)
.map(|i| {
let y = i * 14;
make_line(vec![("word", rect(0, y, 40, 10))])
})
.collect();
let blocks = group_lines_into_blocks(lines, default_text_tuning(), None);
for block in &blocks {
let n_lines = block.line_break_word_indices.len() + 1;
if n_lines > MAX_VARIANT_LINES {
let variants = block_haystack_variants(block);
assert_eq!(
variants.len(),
1,
"expected fallback to single-variant for {n_lines}-line block"
);
}
}
}
fn rect(x: i32, y: i32, w: i32, h: i32) -> Rect {
Rect {
x,
y,
width: w,
height: h,
}
}
#[test]
fn union_bbox_for_match_single_word() {
let words = vec![
("Add".to_string(), rect(0, 0, 30, 10)),
("account".to_string(), rect(40, 0, 60, 10)),
("row".to_string(), rect(110, 0, 30, 10)),
];
let spans = vec![(0, 3), (4, 11), (12, 15)];
let bbox = union_bbox_for_match(&words, &spans, 4, 11).unwrap();
assert_eq!(bbox, rect(40, 0, 60, 10));
}
#[test]
fn union_bbox_for_match_spans_two_words() {
let words = vec![
("Add".to_string(), rect(0, 0, 30, 10)),
("account".to_string(), rect(40, 0, 60, 10)),
("row".to_string(), rect(110, 0, 30, 10)),
];
let spans = vec![(0, 3), (4, 11), (12, 15)];
let bbox = union_bbox_for_match(&words, &spans, 0, 11).unwrap();
assert_eq!(bbox, rect(0, 0, 100, 10));
}
#[test]
fn union_bbox_for_match_returns_none_for_no_overlap() {
let words = vec![("foo".to_string(), rect(0, 0, 30, 10))];
let spans = vec![(0, 3)];
assert!(union_bbox_for_match(&words, &spans, 100, 200).is_none());
}
fn make_line(words: Vec<(&str, Rect)>) -> OcrLine {
let mut joined = String::new();
let mut spans: Vec<(usize, usize)> = Vec::with_capacity(words.len());
for (i, (text, _)) in words.iter().enumerate() {
if i > 0 {
joined.push(' ');
}
let start = joined.len();
joined.push_str(text);
spans.push((start, joined.len()));
}
let mut min_x = i32::MAX;
let mut min_y = i32::MAX;
let mut max_x = i32::MIN;
let mut max_y = i32::MIN;
for (_, r) in &words {
min_x = min_x.min(r.x);
min_y = min_y.min(r.y);
max_x = max_x.max(r.x + r.width);
max_y = max_y.max(r.y + r.height);
}
let bbox = Rect {
x: min_x,
y: min_y,
width: max_x - min_x,
height: max_y - min_y,
};
let words = words.into_iter().map(|(t, r)| (t.to_string(), r)).collect();
OcrLine {
joined,
bbox,
words,
spans,
}
}
fn default_text_tuning() -> VisualTextTuning {
VisualTextTuning::default()
}
#[test]
fn unrelated_rows_split_into_separate_blocks() {
let line_a = make_line(vec![
("Add", rect(0, 0, 30, 10)),
("account", rect(40, 0, 60, 10)),
]);
let line_b = make_line(vec![
("Remove", rect(0, 40, 60, 10)),
("account", rect(70, 40, 60, 10)),
]);
let blocks = group_lines_into_blocks(vec![line_a, line_b], default_text_tuning(), None);
assert_eq!(blocks.len(), 2);
assert_eq!(blocks[0].joined, "Add account");
assert_eq!(blocks[1].joined, "Remove account");
}
#[test]
fn wrapped_paragraph_merges_into_one_block() {
let line_a = make_line(vec![
("Click", rect(0, 0, 50, 10)),
("here", rect(60, 0, 40, 10)),
]);
let line_b = make_line(vec![
("to", rect(0, 14, 20, 10)),
("learn", rect(30, 14, 50, 10)),
("more", rect(90, 14, 40, 10)),
]);
let blocks = group_lines_into_blocks(vec![line_a, line_b], default_text_tuning(), None);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].joined, "Click here to learn more");
assert_eq!(blocks[0].bbox.x, 0);
assert_eq!(blocks[0].bbox.y, 0);
assert_eq!(blocks[0].bbox.width, 130);
assert_eq!(blocks[0].bbox.height, 24);
}
#[test]
fn cross_line_match_spans_block() {
let line_a = make_line(vec![
("Click", rect(0, 0, 50, 10)),
("here", rect(60, 0, 40, 10)),
]);
let line_b = make_line(vec![
("to", rect(0, 14, 20, 10)),
("learn", rect(30, 14, 50, 10)),
("more", rect(90, 14, 40, 10)),
]);
let blocks = group_lines_into_blocks(vec![line_a, line_b], default_text_tuning(), None);
assert_eq!(blocks.len(), 1);
let block = &blocks[0];
let hits = find_matches(&block.joined, "here to learn", MatchMode::Substring);
assert_eq!(hits.len(), 1);
let (s, e) = hits[0];
let union = union_bbox_for_match(&block.words, &block.spans, s, e).unwrap();
assert_eq!(union.x, 0);
assert_eq!(union.y, 0);
assert_eq!(union.width, 100);
assert_eq!(union.height, 24);
}
#[test]
fn parallel_columns_stay_separate() {
let a1 = make_line(vec![("Alpha", rect(0, 0, 40, 10))]);
let b1 = make_line(vec![("Beta", rect(200, 0, 40, 10))]);
let a2 = make_line(vec![("Apple", rect(0, 14, 40, 10))]);
let b2 = make_line(vec![("Berry", rect(200, 14, 40, 10))]);
let blocks = group_lines_into_blocks(vec![a1, b1, a2, b2], default_text_tuning(), None);
assert_eq!(blocks.len(), 2);
assert_eq!(blocks[0].joined, "Alpha Apple");
assert_eq!(blocks[1].joined, "Beta Berry");
}
fn solid_image(w: u32, h: u32, color: [u8; 3]) -> image::RgbImage {
let mut img = image::RgbImage::new(w, h);
for x in 0..w {
for y in 0..h {
img.put_pixel(x, y, image::Rgb(color));
}
}
img
}
#[test]
fn boundary_check_vetoes_merge_on_background_colour_change() {
let mut img = solid_image(80, 30, [200, 200, 200]);
for x in 0..80 {
for y in 15..30 {
img.put_pixel(x, y, image::Rgb([100, 100, 100]));
}
}
let mut tuning = default_text_tuning();
tuning.multiline_max_gap_factor = 2.0;
let line_a = make_line(vec![("Top", rect(0, 2, 60, 6))]);
let line_b = make_line(vec![("Bot", rect(0, 18, 60, 6))]);
let ctx = BoundaryContext {
image: &img,
crop_origin: (0, 0),
};
let blocks = group_lines_into_blocks(vec![line_a, line_b], tuning, Some(ctx));
assert_eq!(
blocks.len(),
2,
"expected separate blocks on bg-colour change"
);
let line_a2 = make_line(vec![("Top", rect(0, 2, 60, 6))]);
let line_b2 = make_line(vec![("Bot", rect(0, 18, 60, 6))]);
let blocks_no_check = group_lines_into_blocks(vec![line_a2, line_b2], tuning, None);
assert_eq!(
blocks_no_check.len(),
1,
"regression: the test setup must geometrically merge without the boundary check"
);
}
#[test]
fn boundary_check_vetoes_merge_on_horizontal_divider() {
let mut img = solid_image(80, 30, [200, 200, 200]);
for x in 0..80 {
for y in 12..14 {
img.put_pixel(x, y, image::Rgb([20, 20, 20]));
}
}
let mut tuning = default_text_tuning();
tuning.multiline_max_gap_factor = 2.0;
let line_a = make_line(vec![("Top", rect(0, 2, 60, 6))]);
let line_b = make_line(vec![("Bot", rect(0, 18, 60, 6))]);
let ctx = BoundaryContext {
image: &img,
crop_origin: (0, 0),
};
let blocks = group_lines_into_blocks(vec![line_a, line_b], tuning, Some(ctx));
assert_eq!(blocks.len(), 2, "horizontal divider should veto merge");
}
#[test]
fn boundary_check_vetoes_merge_on_vertical_divider() {
let mut img = solid_image(80, 30, [200, 200, 200]);
for x in 39..41 {
for y in 0..30 {
img.put_pixel(x, y, image::Rgb([20, 20, 20]));
}
}
let mut tuning = default_text_tuning();
tuning.multiline_max_gap_factor = 2.0;
let line_a = make_line(vec![("Top", rect(0, 2, 60, 6))]);
let line_b = make_line(vec![("Bot", rect(0, 18, 60, 6))]);
let ctx = BoundaryContext {
image: &img,
crop_origin: (0, 0),
};
let blocks = group_lines_into_blocks(vec![line_a, line_b], tuning, Some(ctx));
assert_eq!(blocks.len(), 2, "vertical divider should veto merge");
}
#[test]
fn connectivity_check_vetoes_merge_when_lines_are_boxed_separately() {
let mut img = solid_image(80, 40, [200, 200, 200]);
for x in 0..80 {
img.put_pixel(x, 0, image::Rgb([20, 20, 20]));
img.put_pixel(x, 15, image::Rgb([20, 20, 20]));
}
for y in 0..16 {
img.put_pixel(0, y, image::Rgb([20, 20, 20]));
img.put_pixel(79, y, image::Rgb([20, 20, 20]));
}
for x in 0..80 {
img.put_pixel(x, 24, image::Rgb([20, 20, 20]));
img.put_pixel(x, 39, image::Rgb([20, 20, 20]));
}
for y in 24..40 {
img.put_pixel(0, y, image::Rgb([20, 20, 20]));
img.put_pixel(79, y, image::Rgb([20, 20, 20]));
}
let mut tuning = default_text_tuning();
tuning.multiline_max_gap_factor = 2.0;
tuning.divider_detection_enabled = false;
tuning.connectivity_check_enabled = true;
let line_a = make_line(vec![("Top", rect(10, 4, 60, 8))]);
let line_b = make_line(vec![("Bot", rect(10, 28, 60, 8))]);
let ctx = BoundaryContext {
image: &img,
crop_origin: (0, 0),
};
let blocks = group_lines_into_blocks(vec![line_a, line_b], tuning, Some(ctx));
assert_eq!(
blocks.len(),
2,
"connectivity check should veto merge across boxed-in lines"
);
}
#[test]
fn connectivity_check_allows_merge_on_continuous_background() {
let img = solid_image(80, 30, [200, 200, 200]);
let mut tuning = default_text_tuning();
tuning.multiline_max_gap_factor = 2.0;
tuning.connectivity_check_enabled = true;
let line_a = make_line(vec![("Top", rect(0, 2, 60, 6))]);
let line_b = make_line(vec![("Bot", rect(0, 18, 60, 6))]);
let ctx = BoundaryContext {
image: &img,
crop_origin: (0, 0),
};
let blocks = group_lines_into_blocks(vec![line_a, line_b], tuning, Some(ctx));
assert_eq!(
blocks.len(),
1,
"connectivity check should not veto on continuous bg"
);
}
#[test]
fn boundary_check_passes_on_clean_paragraph() {
let img = solid_image(80, 30, [200, 200, 200]);
let mut tuning = default_text_tuning();
tuning.multiline_max_gap_factor = 2.0;
let line_a = make_line(vec![("Top", rect(0, 2, 60, 6))]);
let line_b = make_line(vec![("Bot", rect(0, 18, 60, 6))]);
let ctx = BoundaryContext {
image: &img,
crop_origin: (0, 0),
};
let blocks = group_lines_into_blocks(vec![line_a, line_b], tuning, Some(ctx));
assert_eq!(blocks.len(), 1, "clean paragraph should still merge");
assert_eq!(blocks[0].joined, "Top Bot");
}
}