use std::cmp::Ordering;
use crate::intent::{SceneGraph, SceneNode};
use crate::intent_matching::{score_node, tokenise, MatchContext};
#[derive(Debug, Default)]
pub struct SemanticFinder;
#[derive(Debug, Clone)]
pub struct FindQuery {
pub description: String,
pub context: Option<String>,
}
impl FindQuery {
#[must_use]
pub fn new(description: impl Into<String>) -> Self {
Self {
description: description.into(),
context: None,
}
}
#[must_use]
pub fn with_context(mut self, context: impl Into<String>) -> Self {
self.context = Some(context.into());
self
}
fn full_text(&self) -> String {
match &self.context {
Some(ctx) => format!("{} {}", self.description, ctx),
None => self.description.clone(),
}
}
}
#[derive(Debug, Default)]
pub struct FindResult {
pub matches: Vec<ElementMatch>,
}
#[derive(Debug, Clone)]
pub struct ElementMatch {
pub role: String,
pub label: String,
pub score: f64,
pub reasoning: String,
pub bounds: Option<(f64, f64, f64, f64)>,
}
impl SemanticFinder {
#[must_use]
pub fn find(&self, scene: &SceneGraph, query: &FindQuery) -> FindResult {
let full = query.full_text();
let ctx = MatchContext::from_query(&full);
let hints = QueryHints::parse(&full);
let mut scored: Vec<(f64, String, &SceneNode)> = scene
.iter()
.filter_map(|node| self.score_candidate(node, &ctx, &hints, scene))
.collect();
scored.sort_by(|(a, _, _), (b, _, _)| b.partial_cmp(a).unwrap_or(Ordering::Equal));
scored.truncate(MAX_RESULTS);
FindResult {
matches: scored
.into_iter()
.map(|(score, reasoning, node)| build_match(node, score, reasoning))
.collect(),
}
}
fn score_candidate<'a>(
&self,
node: &'a SceneNode,
ctx: &MatchContext,
hints: &QueryHints,
scene: &SceneGraph,
) -> Option<(f64, String, &'a SceneNode)> {
let (base, reason) = score_node(node, ctx, scene);
let adjusted = base + hints.position_bonus(node) + hints.size_bonus(node);
let clamped = adjusted.clamp(0.0_f64, 1.0_f64);
(clamped >= MIN_SCORE).then_some((clamped, reason, node))
}
}
const MAX_RESULTS: usize = 20;
const MIN_SCORE: f64 = 0.05;
const POSITION_BONUS: f64 = 0.08;
const SIZE_BONUS: f64 = 0.05;
#[derive(Debug, Default)]
struct QueryHints {
position: Option<PositionHint>,
size: Option<SizeHint>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PositionHint {
Top,
Bottom,
Left,
Right,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SizeHint {
Large,
Small,
}
impl QueryHints {
fn parse(query: &str) -> Self {
let tokens = tokenise(query);
Self {
position: Self::parse_position(&tokens),
size: Self::parse_size(&tokens),
}
}
fn parse_position(tokens: &[String]) -> Option<PositionHint> {
tokens.iter().find_map(|t| match t.as_str() {
"top" | "upper" => Some(PositionHint::Top),
"bottom" | "lower" => Some(PositionHint::Bottom),
"left" => Some(PositionHint::Left),
"right" => Some(PositionHint::Right),
_ => None,
})
}
fn parse_size(tokens: &[String]) -> Option<SizeHint> {
tokens.iter().find_map(|t| match t.as_str() {
"large" | "big" | "wide" => Some(SizeHint::Large),
"small" | "tiny" | "mini" => Some(SizeHint::Small),
_ => None,
})
}
fn position_bonus(&self, node: &SceneNode) -> f64 {
let Some(hint) = self.position else {
return 0.0;
};
let Some((cx, cy)) = node.center() else {
return 0.0;
};
let matches = match hint {
PositionHint::Top => cy < 300.0,
PositionHint::Bottom => cy > 600.0,
PositionHint::Left => cx < 400.0,
PositionHint::Right => cx > 800.0,
};
if matches {
POSITION_BONUS
} else {
0.0
}
}
fn size_bonus(&self, node: &SceneNode) -> f64 {
let Some(hint) = self.size else {
return 0.0;
};
let Some((_, _, w, h)) = node.bounds else {
return 0.0;
};
let area = w * h;
let matches = match hint {
SizeHint::Large => area > 4_000.0,
SizeHint::Small => area < 600.0,
};
if matches {
SIZE_BONUS
} else {
0.0
}
}
}
fn build_match(node: &SceneNode, score: f64, reasoning: String) -> ElementMatch {
let label = node
.text_labels()
.first()
.copied()
.unwrap_or("<no label>")
.to_string();
ElementMatch {
role: node.role.clone().unwrap_or_else(|| "AXUnknown".into()),
label,
score,
reasoning,
bounds: node.bounds,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::intent::{build_scene_from_nodes, NodeId, SceneNode};
fn button(id: usize, title: &str, bounds: (f64, f64, f64, f64)) -> SceneNode {
SceneNode {
id: NodeId(id),
parent: None,
children: vec![],
role: Some("AXButton".into()),
title: Some(title.into()),
label: None,
value: None,
description: None,
identifier: None,
bounds: Some(bounds),
enabled: true,
depth: 1,
}
}
fn field(id: usize, label: &str) -> SceneNode {
SceneNode {
id: NodeId(id),
parent: None,
children: vec![],
role: Some("AXTextField".into()),
title: None,
label: Some(label.into()),
value: None,
description: None,
identifier: None,
bounds: Some((0.0, 500.0, 200.0, 25.0)),
enabled: true,
depth: 1,
}
}
#[test]
fn find_query_new_sets_description_clears_context() {
let q = FindQuery::new("submit button");
assert_eq!(q.description, "submit button");
assert!(q.context.is_none());
}
#[test]
fn find_query_with_context_appended_to_full_text() {
let q = FindQuery::new("close button").with_context("in the dialog");
let text = q.full_text();
assert!(text.contains("close button"));
assert!(text.contains("in the dialog"));
}
#[test]
fn find_query_full_text_without_context_equals_description() {
let q = FindQuery::new("search bar");
assert_eq!(q.full_text(), "search bar");
}
#[test]
fn query_hints_parses_top_position() {
let h = QueryHints::parse("the button at the top");
assert_eq!(h.position, Some(PositionHint::Top));
}
#[test]
fn query_hints_parses_large_size() {
let h = QueryHints::parse("click the large ok button");
assert_eq!(h.size, Some(SizeHint::Large));
}
#[test]
fn query_hints_no_hints_when_absent() {
let h = QueryHints::parse("submit form");
assert!(h.position.is_none());
assert!(h.size.is_none());
}
#[test]
fn query_hints_position_bonus_matches_top_node() {
let node = button(0, "Close", (10.0, 5.0, 80.0, 30.0)); let hints = QueryHints {
position: Some(PositionHint::Top),
size: None,
};
assert!(hints.position_bonus(&node) > 0.0);
}
#[test]
fn query_hints_position_bonus_no_match_returns_zero() {
let node = button(0, "Footer", (10.0, 685.0, 200.0, 30.0));
let hints = QueryHints {
position: Some(PositionHint::Top),
size: None,
};
assert_eq!(hints.position_bonus(&node), 0.0);
}
#[test]
fn query_hints_size_bonus_large_element_matches() {
let node = button(0, "Banner", (0.0, 0.0, 100.0, 80.0));
let hints = QueryHints {
position: None,
size: Some(SizeHint::Large),
};
assert!(hints.size_bonus(&node) > 0.0);
}
#[test]
fn find_empty_scene_returns_empty_result() {
let finder = SemanticFinder;
let scene = SceneGraph::empty();
let query = FindQuery::new("submit button");
let result = finder.find(&scene, &query);
assert!(result.matches.is_empty());
}
#[test]
fn find_returns_at_most_twenty_matches() {
let nodes: Vec<SceneNode> = (0..25)
.map(|i| button(i, "Submit", (0.0, f64::from(i as u32) * 40.0, 100.0, 30.0)))
.collect();
let scene = build_scene_from_nodes(nodes);
let finder = SemanticFinder;
let query = FindQuery::new("submit");
let result = finder.find(&scene, &query);
assert!(result.matches.len() <= MAX_RESULTS);
}
#[test]
fn find_ranks_exact_title_match_first() {
let scene = build_scene_from_nodes(vec![
button(0, "Search", (0.0, 0.0, 100.0, 30.0)),
button(1, "Cancel", (0.0, 40.0, 100.0, 30.0)),
]);
let finder = SemanticFinder;
let query = FindQuery::new("search");
let result = finder.find(&scene, &query);
assert!(!result.matches.is_empty());
assert_eq!(result.matches[0].label, "Search");
}
#[test]
fn find_results_sorted_descending_by_score() {
let scene = build_scene_from_nodes(vec![
button(0, "Submit", (0.0, 0.0, 100.0, 30.0)),
button(1, "Cancel", (0.0, 40.0, 100.0, 30.0)),
button(2, "Submit Form", (0.0, 80.0, 100.0, 30.0)),
]);
let finder = SemanticFinder;
let query = FindQuery::new("submit button");
let result = finder.find(&scene, &query);
for pair in result.matches.windows(2) {
assert!(pair[0].score >= pair[1].score);
}
}
#[test]
fn find_all_scores_within_unit_interval() {
let scene = build_scene_from_nodes(vec![
button(0, "OK", (0.0, 0.0, 60.0, 25.0)),
field(1, "Email address"),
button(2, "Cancel", (0.0, 40.0, 60.0, 25.0)),
]);
let finder = SemanticFinder;
let query = FindQuery::new("email input field");
let result = finder.find(&scene, &query);
for m in &result.matches {
assert!(
(0.0..=1.0).contains(&m.score),
"score {} out of range",
m.score
);
}
}
#[test]
fn find_with_role_hint_prefers_text_field_over_button() {
let scene = build_scene_from_nodes(vec![
button(0, "Email", (0.0, 0.0, 100.0, 30.0)),
field(1, "Email"),
]);
let finder = SemanticFinder;
let query = FindQuery::new("email input field");
let result = finder.find(&scene, &query);
assert!(!result.matches.is_empty());
assert_eq!(result.matches[0].role, "AXTextField");
}
#[test]
fn find_with_position_hint_boosts_top_button() {
let scene = build_scene_from_nodes(vec![
button(0, "Close", (10.0, 5.0, 80.0, 30.0)), button(1, "Close", (10.0, 700.0, 80.0, 30.0)), ]);
let finder = SemanticFinder;
let query = FindQuery::new("close button at the top");
let result = finder.find(&scene, &query);
assert!(result.matches.len() >= 2);
let top_idx = result
.matches
.iter()
.position(|m| m.bounds == Some((10.0, 5.0, 80.0, 30.0)));
let bot_idx = result
.matches
.iter()
.position(|m| m.bounds == Some((10.0, 700.0, 80.0, 30.0)));
if let (Some(ti), Some(bi)) = (top_idx, bot_idx) {
assert!(
ti <= bi,
"top button should rank no worse than bottom button"
);
}
}
#[test]
fn find_reasoning_non_empty_for_every_match() {
let scene = build_scene_from_nodes(vec![button(0, "Save", (0.0, 0.0, 80.0, 28.0))]);
let finder = SemanticFinder;
let query = FindQuery::new("save");
let result = finder.find(&scene, &query);
for m in &result.matches {
assert!(!m.reasoning.is_empty(), "reasoning must not be blank");
}
}
}