use core_foundation::array::CFArray;
use core_foundation::base::{CFType, TCFType};
use core_foundation::string::CFString;
use pyo3::prelude::*;
use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use crate::accessibility::{attributes, get_attribute, AXUIElementRef};
use crate::element::AXElement;
use crate::error::{AXError, AXResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HealStrategy {
DataTestId,
AriaLabel,
Identifier,
Title,
XPath,
Position,
VisualVLM,
}
#[pyclass]
#[derive(Debug, Clone)]
pub struct HealingConfig {
#[pyo3(get, set)]
pub strategies: Vec<String>,
#[pyo3(get, set)]
pub max_heal_time_ms: u64,
#[pyo3(get, set)]
pub cache_healed: bool,
}
#[pymethods]
impl HealingConfig {
#[new]
#[pyo3(signature = (strategies=None, max_heal_time_ms=100, cache_healed=true))]
fn new(strategies: Option<Vec<String>>, max_heal_time_ms: u64, cache_healed: bool) -> Self {
Self {
strategies: strategies.unwrap_or_else(|| {
vec![
"data_testid".to_string(),
"aria_label".to_string(),
"identifier".to_string(),
"title".to_string(),
"xpath".to_string(),
"position".to_string(),
"visual_vlm".to_string(),
]
}),
max_heal_time_ms,
cache_healed,
}
}
}
impl Default for HealingConfig {
fn default() -> Self {
Self {
strategies: vec![
"data_testid".to_string(),
"aria_label".to_string(),
"identifier".to_string(),
"title".to_string(),
"xpath".to_string(),
"position".to_string(),
"visual_vlm".to_string(),
],
max_heal_time_ms: 100,
cache_healed: true,
}
}
}
static GLOBAL_CONFIG: RwLock<Option<HealingConfig>> = RwLock::new(None);
static HEALING_CACHE: std::sync::LazyLock<RwLock<HashMap<String, ElementQuery>>> =
std::sync::LazyLock::new(|| RwLock::new(HashMap::new()));
pub fn set_global_config(config: HealingConfig) -> PyResult<()> {
let mut global = GLOBAL_CONFIG
.write()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
*global = Some(config);
Ok(())
}
pub fn get_global_config() -> HealingConfig {
GLOBAL_CONFIG
.read()
.ok()
.and_then(|g| g.clone())
.unwrap_or_default()
}
#[derive(Debug, Clone)]
pub struct ElementQuery {
pub original: String,
pub original_id: Option<String>,
pub text_hint: Option<String>,
pub path: Option<String>,
pub position: Option<(f64, f64)>,
pub screenshot: Option<Vec<u8>>,
pub description: Option<String>,
}
pub fn find_with_healing(query: &ElementQuery, root: AXUIElementRef) -> AXResult<AXElement> {
let config = get_global_config();
if config.cache_healed {
if let Ok(cache) = HEALING_CACHE.read() {
if let Some(cached_query) = cache.get(&query.original) {
for strategy_name in &config.strategies {
let strategy = parse_strategy(strategy_name);
if let Some(element) = try_strategy(strategy, cached_query, root) {
return Ok(element);
}
}
}
}
}
let start = Instant::now();
let timeout = Duration::from_millis(config.max_heal_time_ms);
for strategy_name in &config.strategies {
if start.elapsed() >= timeout {
break;
}
let strategy = parse_strategy(strategy_name);
if let Some(element) = try_strategy(strategy, query, root) {
if config.cache_healed {
if let Ok(mut cache) = HEALING_CACHE.write() {
cache.insert(query.original.clone(), query.clone());
}
}
return Ok(element);
}
}
Err(AXError::ElementNotFoundAfterHealing(query.original.clone()))
}
fn parse_strategy(name: &str) -> HealStrategy {
match name.to_lowercase().as_str() {
"data_testid" => HealStrategy::DataTestId,
"aria_label" => HealStrategy::AriaLabel,
"identifier" => HealStrategy::Identifier,
"title" => HealStrategy::Title,
"xpath" => HealStrategy::XPath,
"position" => HealStrategy::Position,
"visual_vlm" => HealStrategy::VisualVLM,
_ => HealStrategy::Title, }
}
fn try_strategy(
strategy: HealStrategy,
query: &ElementQuery,
root: AXUIElementRef,
) -> Option<AXElement> {
match strategy {
HealStrategy::DataTestId => try_by_data_testid(query, root),
HealStrategy::AriaLabel => try_by_aria_label(query, root),
HealStrategy::Identifier => try_by_identifier(query, root),
HealStrategy::Title => try_by_title(query, root),
HealStrategy::XPath => try_by_xpath(query, root),
HealStrategy::Position => try_by_position(query, root),
HealStrategy::VisualVLM => try_by_visual(query, root),
}
}
fn get_string_attr(element: AXUIElementRef, attr: &str) -> Option<String> {
get_attribute(element, attr).ok().and_then(|cf_ref| {
let cf_type = unsafe { CFType::wrap_under_get_rule(cf_ref) };
cf_type.downcast::<CFString>().map(|s| s.to_string())
})
}
fn get_children(element: AXUIElementRef) -> Vec<AXUIElementRef> {
get_attribute(element, attributes::AX_CHILDREN)
.ok()
.and_then(|cf_ref| {
let cf_type = unsafe { CFType::wrap_under_get_rule(cf_ref) };
cf_type.downcast::<CFArray>()
})
.map(|array| {
(0..array.len())
.filter_map(|i| {
array.get(i).map(|item_ref| *item_ref as AXUIElementRef)
})
.collect()
})
.unwrap_or_default()
}
fn get_bounds(element: AXUIElementRef) -> Option<(f64, f64, f64, f64)> {
let _position = get_attribute(element, attributes::AX_POSITION).ok()?;
let _size = get_attribute(element, attributes::AX_SIZE).ok()?;
None
}
fn walk_tree<F>(element: AXUIElementRef, visitor: &mut F, max_depth: usize) -> bool
where
F: FnMut(AXUIElementRef) -> bool,
{
if max_depth == 0 {
return false;
}
if visitor(element) {
return true;
}
for child in get_children(element) {
if walk_tree(child, visitor, max_depth - 1) {
return true;
}
}
false
}
fn fuzzy_match(text: &str, pattern: &str, threshold: f64) -> bool {
let text_lower = text.to_lowercase();
let pattern_lower = pattern.to_lowercase();
if text_lower == pattern_lower {
return true;
}
if text_lower.contains(&pattern_lower) {
return true;
}
let similarity = 1.0
- (levenshtein_distance(&text_lower, &pattern_lower) as f64
/ text_lower.len().max(pattern_lower.len()) as f64);
similarity >= threshold
}
fn levenshtein_distance(s1: &str, s2: &str) -> usize {
let len1 = s1.chars().count();
let len2 = s2.chars().count();
if len1 == 0 {
return len2;
}
if len2 == 0 {
return len1;
}
let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
for (i, row) in matrix.iter_mut().enumerate() {
row[0] = i;
}
for (j, val) in matrix[0].iter_mut().enumerate() {
*val = j;
}
let s1_chars: Vec<char> = s1.chars().collect();
let s2_chars: Vec<char> = s2.chars().collect();
for i in 1..=len1 {
for j in 1..=len2 {
let cost = usize::from(s1_chars[i - 1] != s2_chars[j - 1]);
matrix[i][j] = (matrix[i - 1][j] + 1)
.min(matrix[i][j - 1] + 1)
.min(matrix[i - 1][j - 1] + cost);
}
}
matrix[len1][len2]
}
#[derive(Debug)]
struct XPathSegment {
role: String,
predicates: Vec<(String, String)>,
}
fn parse_xpath(xpath: &str) -> Vec<XPathSegment> {
let mut segments = Vec::new();
for part in xpath.split('/').filter(|s| !s.is_empty()) {
if let Some((role, predicate_str)) = part.split_once('[') {
let role = role.trim().to_string();
let predicate_str = predicate_str.trim_end_matches(']');
let mut predicates = Vec::new();
for pred in predicate_str.split(" and ") {
if let Some((attr, val)) = pred.split_once('=') {
let attr = attr.trim().trim_start_matches('@').to_string();
let val = val.trim().trim_matches('\'').trim_matches('"').to_string();
predicates.push((attr, val));
}
}
segments.push(XPathSegment { role, predicates });
} else {
segments.push(XPathSegment {
role: part.trim().to_string(),
predicates: Vec::new(),
});
}
}
segments
}
fn matches_xpath_segment(element: AXUIElementRef, segment: &XPathSegment) -> bool {
if let Some(role) = get_string_attr(element, attributes::AX_ROLE) {
if role != segment.role {
return false;
}
} else {
return false;
}
for (attr, expected_val) in &segment.predicates {
let attr_name = match attr.as_str() {
"AXTitle" => attributes::AX_TITLE,
"AXIdentifier" => attributes::AX_IDENTIFIER,
"AXLabel" => attributes::AX_LABEL,
"AXDescription" => attributes::AX_DESCRIPTION,
"AXValue" => attributes::AX_VALUE,
_ => attr.as_str(),
};
if let Some(actual_val) = get_string_attr(element, attr_name) {
if actual_val != *expected_val {
return false;
}
} else {
return false;
}
}
true
}
fn try_by_data_testid(query: &ElementQuery, root: AXUIElementRef) -> Option<AXElement> {
let target_id = query.original_id.as_ref().or(query.text_hint.as_ref())?;
let mut found = None;
walk_tree(
root,
&mut |element| {
if let Some(identifier) = get_string_attr(element, attributes::AX_IDENTIFIER) {
if identifier == *target_id {
found = Some(element);
return true;
}
}
false
},
50, );
found.map(AXElement::new)
}
fn try_by_aria_label(query: &ElementQuery, root: AXUIElementRef) -> Option<AXElement> {
let target_label = query.text_hint.as_ref()?;
let mut found = None;
walk_tree(
root,
&mut |element| {
if let Some(label) = get_string_attr(element, attributes::AX_LABEL) {
if label == *target_label {
found = Some(element);
return true;
}
}
if let Some(desc) = get_string_attr(element, attributes::AX_DESCRIPTION) {
if desc == *target_label {
found = Some(element);
return true;
}
}
false
},
50,
);
found.map(AXElement::new)
}
fn try_by_identifier(query: &ElementQuery, root: AXUIElementRef) -> Option<AXElement> {
let target_id = query.original_id.as_ref()?;
let mut found = None;
walk_tree(
root,
&mut |element| {
if let Some(identifier) = get_string_attr(element, attributes::AX_IDENTIFIER) {
if identifier == *target_id {
found = Some(element);
return true;
}
}
false
},
50,
);
found.map(AXElement::new)
}
fn try_by_title(query: &ElementQuery, root: AXUIElementRef) -> Option<AXElement> {
let target_title = query.text_hint.as_ref()?;
let mut found = None;
walk_tree(
root,
&mut |element| {
if let Some(title) = get_string_attr(element, attributes::AX_TITLE) {
if fuzzy_match(&title, target_title, 0.8) {
found = Some(element);
return true;
}
}
false
},
50,
);
found.map(AXElement::new)
}
fn try_by_xpath(query: &ElementQuery, root: AXUIElementRef) -> Option<AXElement> {
let path = query.path.as_ref()?;
let segments = parse_xpath(path);
if segments.is_empty() {
return None;
}
fn search_path(
element: AXUIElementRef,
segments: &[XPathSegment],
current_idx: usize,
) -> Option<AXUIElementRef> {
if current_idx >= segments.len() {
return Some(element);
}
let segment = &segments[current_idx];
if matches_xpath_segment(element, segment) {
if current_idx == segments.len() - 1 {
return Some(element);
}
for child in get_children(element) {
if let Some(found) = search_path(child, segments, current_idx + 1) {
return Some(found);
}
}
}
for child in get_children(element) {
if let Some(found) = search_path(child, segments, current_idx) {
return Some(found);
}
}
None
}
search_path(root, &segments, 0).map(AXElement::new)
}
fn try_by_position(query: &ElementQuery, root: AXUIElementRef) -> Option<AXElement> {
let (target_x, target_y) = query.position?;
let mut closest = None;
let mut closest_dist = f64::MAX;
walk_tree(
root,
&mut |element| {
if let Some((x, y, _w, _h)) = get_bounds(element) {
let dist = ((x - target_x).powi(2) + (y - target_y).powi(2)).sqrt();
if dist < closest_dist {
closest_dist = dist;
closest = Some(element);
}
}
false
},
50,
);
if closest_dist < 50.0 {
closest.map(AXElement::new)
} else {
None
}
}
fn try_by_visual(_query: &ElementQuery, _root: AXUIElementRef) -> Option<AXElement> {
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = HealingConfig::default();
assert_eq!(config.strategies.len(), 7);
assert_eq!(config.max_heal_time_ms, 100);
assert!(config.cache_healed);
}
#[test]
fn test_parse_strategy() {
assert_eq!(parse_strategy("data_testid"), HealStrategy::DataTestId);
assert_eq!(parse_strategy("aria_label"), HealStrategy::AriaLabel);
assert_eq!(parse_strategy("identifier"), HealStrategy::Identifier);
assert_eq!(parse_strategy("title"), HealStrategy::Title);
assert_eq!(parse_strategy("xpath"), HealStrategy::XPath);
assert_eq!(parse_strategy("position"), HealStrategy::Position);
assert_eq!(parse_strategy("visual_vlm"), HealStrategy::VisualVLM);
}
#[test]
fn test_levenshtein_distance() {
assert_eq!(levenshtein_distance("", ""), 0);
assert_eq!(levenshtein_distance("abc", "abc"), 0);
assert_eq!(levenshtein_distance("abc", ""), 3);
assert_eq!(levenshtein_distance("", "abc"), 3);
assert_eq!(levenshtein_distance("kitten", "sitting"), 3);
assert_eq!(levenshtein_distance("saturday", "sunday"), 3);
}
#[test]
fn test_fuzzy_match_exact() {
assert!(fuzzy_match("Save", "Save", 0.8));
assert!(fuzzy_match("save", "SAVE", 0.8));
}
#[test]
fn test_fuzzy_match_contains() {
assert!(fuzzy_match("Save Button", "Save", 0.8));
assert!(fuzzy_match("Click to Save", "Save", 0.8));
}
#[test]
fn test_fuzzy_match_similar() {
assert!(fuzzy_match("Button", "Buton", 0.8)); assert!(fuzzy_match("Click", "Clik", 0.8));
}
#[test]
fn test_fuzzy_match_no_match() {
assert!(!fuzzy_match("Save", "Cancel", 0.8));
assert!(!fuzzy_match("Button", "Window", 0.8));
}
#[test]
fn test_parse_xpath_simple() {
let segments = parse_xpath("//AXWindow/AXButton");
assert_eq!(segments.len(), 2);
assert_eq!(segments[0].role, "AXWindow");
assert_eq!(segments[0].predicates.len(), 0);
assert_eq!(segments[1].role, "AXButton");
assert_eq!(segments[1].predicates.len(), 0);
}
#[test]
fn test_parse_xpath_with_predicates() {
let segments = parse_xpath("//AXWindow/AXButton[@AXTitle='Save']");
assert_eq!(segments.len(), 2);
assert_eq!(segments[1].role, "AXButton");
assert_eq!(segments[1].predicates.len(), 1);
assert_eq!(segments[1].predicates[0].0, "AXTitle");
assert_eq!(segments[1].predicates[0].1, "Save");
}
#[test]
fn test_parse_xpath_multiple_predicates() {
let segments = parse_xpath("//AXButton[@AXTitle='Save' and @AXEnabled='true']");
assert_eq!(segments.len(), 1);
assert_eq!(segments[0].role, "AXButton");
assert_eq!(segments[0].predicates.len(), 2);
assert_eq!(segments[0].predicates[0].0, "AXTitle");
assert_eq!(segments[0].predicates[0].1, "Save");
assert_eq!(segments[0].predicates[1].0, "AXEnabled");
assert_eq!(segments[0].predicates[1].1, "true");
}
#[test]
fn test_parse_xpath_double_quotes() {
let segments = parse_xpath(r#"//AXButton[@AXTitle="Save"]"#);
assert_eq!(segments.len(), 1);
assert_eq!(segments[0].role, "AXButton");
assert_eq!(segments[0].predicates.len(), 1);
assert_eq!(segments[0].predicates[0].1, "Save");
}
#[test]
fn test_parse_xpath_complex() {
let segments =
parse_xpath("//AXWindow[@AXTitle='Editor']/AXGroup/AXButton[@AXTitle='Save']");
assert_eq!(segments.len(), 3);
assert_eq!(segments[0].role, "AXWindow");
assert_eq!(segments[0].predicates[0].0, "AXTitle");
assert_eq!(segments[0].predicates[0].1, "Editor");
assert_eq!(segments[1].role, "AXGroup");
assert_eq!(segments[2].role, "AXButton");
assert_eq!(segments[2].predicates[0].1, "Save");
}
#[test]
fn test_element_query_builder() {
let query = ElementQuery {
original: "button_save".to_string(),
original_id: Some("save_btn".to_string()),
text_hint: Some("Save".to_string()),
path: None,
position: None,
screenshot: None,
description: None,
};
assert_eq!(query.original, "button_save");
assert_eq!(query.original_id, Some("save_btn".to_string()));
assert_eq!(query.text_hint, Some("Save".to_string()));
}
#[test]
fn test_healing_config_custom() {
let config = HealingConfig::new(
Some(vec!["identifier".to_string(), "title".to_string()]),
200,
false,
);
assert_eq!(config.strategies.len(), 2);
assert_eq!(config.strategies[0], "identifier");
assert_eq!(config.strategies[1], "title");
assert_eq!(config.max_heal_time_ms, 200);
assert!(!config.cache_healed);
}
#[test]
fn test_healing_cache_isolation() {
let query1 = ElementQuery {
original: "query1".to_string(),
original_id: Some("id1".to_string()),
text_hint: None,
path: None,
position: None,
screenshot: None,
description: None,
};
let query2 = ElementQuery {
original: "query2".to_string(),
original_id: Some("id2".to_string()),
text_hint: None,
path: None,
position: None,
screenshot: None,
description: None,
};
if let Ok(mut cache) = HEALING_CACHE.write() {
cache.insert(query1.original.clone(), query1.clone());
cache.insert(query2.original.clone(), query2.clone());
}
if let Ok(cache) = HEALING_CACHE.read() {
assert!(cache.contains_key("query1"));
assert!(cache.contains_key("query2"));
}
}
#[test]
fn test_xpath_segment_attribute_mapping() {
let segments = parse_xpath("//AXButton[@AXTitle='Save' and @AXIdentifier='btn1']");
assert_eq!(segments[0].predicates[0].0, "AXTitle");
assert_eq!(segments[0].predicates[1].0, "AXIdentifier");
}
#[test]
fn test_position_distance_calculation() {
let query = ElementQuery {
original: "element_at_pos".to_string(),
original_id: None,
text_hint: None,
path: None,
position: Some((100.0, 200.0)),
screenshot: None,
description: None,
};
assert!(query.position.is_some());
let (x, y) = query.position.unwrap();
assert_eq!(x, 100.0);
assert_eq!(y, 200.0);
}
#[test]
fn test_visual_strategy_requires_data() {
let query_no_data = ElementQuery {
original: "visual_element".to_string(),
original_id: None,
text_hint: None,
path: None,
position: None,
screenshot: None,
description: None,
};
assert!(query_no_data.screenshot.is_none());
assert!(query_no_data.description.is_none());
}
#[test]
fn test_strategy_order_matters() {
let config = HealingConfig::default();
assert_eq!(config.strategies[0], "data_testid");
assert_eq!(config.strategies[1], "aria_label");
assert_eq!(config.strategies[2], "identifier");
assert_eq!(config.strategies[3], "title");
assert_eq!(config.strategies[4], "xpath");
assert_eq!(config.strategies[5], "position");
assert_eq!(config.strategies[6], "visual_vlm");
}
#[test]
fn test_levenshtein_unicode() {
assert_eq!(levenshtein_distance("café", "cafe"), 1);
assert_eq!(levenshtein_distance("hello", "héllo"), 1);
}
#[test]
fn test_xpath_empty_path() {
let segments = parse_xpath("");
assert_eq!(segments.len(), 0);
let segments = parse_xpath("//");
assert_eq!(segments.len(), 0);
}
}