accessibility/
lib.rs

1pub mod action;
2pub mod attribute;
3pub mod ui_element;
4mod util;
5
6use accessibility_sys::{error_string, AXError};
7use core_foundation::{
8    array::CFArray,
9    base::CFTypeID,
10    base::{CFCopyTypeIDDescription, TCFType},
11    string::CFString,
12};
13use std::{
14    cell::{Cell, RefCell},
15    thread,
16    time::{Duration, Instant},
17};
18use thiserror::Error as TError;
19
20pub use action::*;
21pub use attribute::*;
22pub use ui_element::*;
23
24#[non_exhaustive]
25#[derive(Debug, TError)]
26pub enum Error {
27    #[error("element not found")]
28    NotFound,
29    #[error(
30        "expected attribute type {} but got {}",
31        type_name(*expected),
32        type_name(*received),
33    )]
34    UnexpectedType {
35        expected: CFTypeID,
36        received: CFTypeID,
37    },
38    #[error("accessibility error {}", error_string(*.0))]
39    Ax(AXError),
40}
41
42fn type_name(type_id: CFTypeID) -> CFString {
43    unsafe { CFString::wrap_under_create_rule(CFCopyTypeIDDescription(type_id)) }
44}
45
46pub trait TreeVisitor {
47    fn enter_element(&self, element: &AXUIElement) -> TreeWalkerFlow;
48    fn exit_element(&self, element: &AXUIElement);
49}
50
51pub struct TreeWalker {
52    attr_children: AXAttribute<CFArray<AXUIElement>>,
53}
54
55#[derive(Copy, Clone, PartialEq, Eq)]
56pub enum TreeWalkerFlow {
57    Continue,
58    SkipSubtree,
59    Exit,
60}
61
62impl Default for TreeWalker {
63    fn default() -> Self {
64        Self {
65            attr_children: AXAttribute::children(),
66        }
67    }
68}
69
70impl TreeWalker {
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    pub fn walk(&self, root: &AXUIElement, visitor: &dyn TreeVisitor) {
76        let _ = self.walk_one(root, visitor);
77    }
78
79    fn walk_one(&self, root: &AXUIElement, visitor: &dyn TreeVisitor) -> TreeWalkerFlow {
80        let mut flow = visitor.enter_element(root);
81
82        if flow == TreeWalkerFlow::Continue {
83            if let Ok(children) = root.attribute(&self.attr_children) {
84                for child in children.into_iter() {
85                    let child_flow = self.walk_one(&child, visitor);
86
87                    if child_flow == TreeWalkerFlow::Exit {
88                        flow = child_flow;
89                        break;
90                    }
91                }
92            }
93        }
94
95        visitor.exit_element(root);
96        flow
97    }
98}
99
100pub struct ElementFinder {
101    root: AXUIElement,
102    implicit_wait: Option<Duration>,
103    predicate: Box<dyn Fn(&AXUIElement) -> bool>,
104    depth: Cell<usize>,
105    cached: RefCell<Option<AXUIElement>>,
106}
107
108impl ElementFinder {
109    pub fn new<F>(root: &AXUIElement, predicate: F, implicit_wait: Option<Duration>) -> Self
110    where
111        F: 'static + Fn(&AXUIElement) -> bool,
112    {
113        Self {
114            root: root.clone(),
115            predicate: Box::new(predicate),
116            implicit_wait,
117            depth: Cell::new(0),
118            cached: RefCell::new(None),
119        }
120    }
121
122    pub fn find(&self) -> Result<AXUIElement, Error> {
123        if let Some(result) = &*self.cached.borrow() {
124            return Ok(result.clone());
125        }
126
127        let mut deadline = Instant::now();
128        let walker = TreeWalker::new();
129
130        if let Some(implicit_wait) = &self.implicit_wait {
131            deadline += *implicit_wait;
132        }
133
134        loop {
135            if let Some(result) = &*self.cached.borrow() {
136                return Ok(result.clone());
137            }
138
139            walker.walk(&self.root, self);
140            let now = Instant::now();
141
142            if now >= deadline {
143                return Err(Error::NotFound);
144            } else {
145                let time_left = deadline.saturating_duration_since(now);
146                thread::sleep(std::cmp::min(time_left, Duration::from_millis(250)));
147            }
148        }
149    }
150
151    pub fn reset(&self) {
152        self.cached.replace(None);
153    }
154
155    pub fn attribute<T: TCFType>(&self, attribute: &AXAttribute<T>) -> Result<T, Error> {
156        self.find()?.attribute(attribute)
157    }
158
159    pub fn set_attribute<T: TCFType>(
160        &self,
161        attribute: &AXAttribute<T>,
162        value: impl Into<T>,
163    ) -> Result<(), Error> {
164        self.find()?.set_attribute(attribute, value)
165    }
166
167    pub fn perform_action(&self, name: &CFString) -> Result<(), Error> {
168        self.find()?.perform_action(name)
169    }
170}
171
172const MAX_DEPTH: usize = 100;
173
174impl TreeVisitor for ElementFinder {
175    fn enter_element(&self, element: &AXUIElement) -> TreeWalkerFlow {
176        self.depth.set(self.depth.get() + 1);
177
178        if (self.predicate)(element) {
179            self.cached.replace(Some(element.clone()));
180            return TreeWalkerFlow::Exit;
181        }
182
183        if self.depth.get() > MAX_DEPTH {
184            TreeWalkerFlow::SkipSubtree
185        } else {
186            TreeWalkerFlow::Continue
187        }
188    }
189
190    fn exit_element(&self, _element: &AXUIElement) {
191        self.depth.set(self.depth.get() - 1)
192    }
193}