use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, Instant};
use atspi::connection::AccessibilityConnection;
use crate::atspi as atspi_client;
use crate::atspi::ElementInfo;
use crate::error::{Error, Result};
use crate::session::Session;
const INITIAL_POLL_DELAY: Duration = Duration::from_millis(50);
const MAX_POLL_DELAY: Duration = Duration::from_millis(500);
#[derive(Clone)]
pub struct Locator {
session: Arc<Session>,
xpath: String,
timeout: Option<Duration>,
}
impl Locator {
pub(crate) fn new(session: Arc<Session>, xpath: String) -> Self {
Self {
session,
xpath,
timeout: None,
}
}
pub fn xpath(&self) -> &str {
&self.xpath
}
pub fn with_timeout(&self, timeout: Duration) -> Locator {
Locator {
session: self.session.clone(),
xpath: self.xpath.clone(),
timeout: Some(timeout),
}
}
pub fn locate(&self, sub: &str) -> Locator {
let trimmed = sub.trim();
let new_xpath = if trimmed.starts_with('/') {
trimmed.to_string()
} else {
format!("({})//{}", self.xpath, trimmed)
};
self.with_xpath(new_xpath)
}
pub fn nth(&self, n: usize) -> Locator {
self.with_xpath(format!("({})[{}]", self.xpath, n + 1))
}
pub fn first(&self) -> Locator {
self.nth(0)
}
pub fn last(&self) -> Locator {
self.with_xpath(format!("({})[last()]", self.xpath))
}
pub fn parent(&self) -> Locator {
self.with_xpath(format!("({})/..", self.xpath))
}
fn with_xpath(&self, xpath: String) -> Locator {
Locator {
session: self.session.clone(),
xpath,
timeout: self.timeout,
}
}
pub async fn count(&self) -> Result<usize> {
Ok(self.resolve_all_once().await?.len())
}
pub async fn all(&self) -> Result<Vec<Locator>> {
let n = self.count().await?;
Ok((0..n).map(|i| self.nth(i)).collect())
}
pub async fn inspect_all(&self) -> Result<Vec<ElementInfo>> {
let a11y = self.a11y()?;
let xml =
atspi_client::snapshot_tree(a11y, &self.session.app_bus_name, &self.session.app_path)
.await?;
atspi_client::evaluate_xpath_detailed(&xml, &self.xpath)
}
pub async fn name(&self) -> Result<Option<String>> {
Ok(self.wait_for_existing().await?.name)
}
pub async fn role(&self) -> Result<String> {
let info = self.wait_for_existing().await?;
Ok(info.role_raw.unwrap_or(info.role))
}
pub async fn attribute(&self, key: &str) -> Result<Option<String>> {
Ok(self.wait_for_existing().await?.attributes.remove(key))
}
pub async fn attributes(&self) -> Result<HashMap<String, String>> {
Ok(self.wait_for_existing().await?.attributes)
}
pub async fn is_showing(&self) -> Result<bool> {
self.has_state("showing").await
}
pub async fn is_enabled(&self) -> Result<bool> {
let info = self.wait_for_existing().await?;
Ok(is_enabled_in(&info.states))
}
pub async fn text(&self) -> Result<String> {
let info = self.wait_for_existing().await?;
let a11y = self.a11y()?;
let (bus, path) = info.ref_;
atspi_client::read_text_on(a11y, &self.xpath, &bus, &path).await
}
pub async fn click(&self) -> Result<()> {
let info = self.wait_for_actionable().await?;
let (bus, path) = info.ref_;
let a11y = self.a11y()?;
atspi_client::do_action_on(a11y, &self.xpath, &bus, &path).await
}
pub async fn set_text(&self, text: &str) -> Result<()> {
let info = self.wait_for_actionable().await?;
let (bus, path) = info.ref_;
let a11y = self.a11y()?;
atspi_client::set_text_on(a11y, &self.xpath, &bus, &path, text).await
}
pub async fn focus(&self) -> Result<()> {
let info = self.wait_for_focusable().await?;
let (bus, path) = info.ref_;
let a11y = self.a11y()?;
atspi_client::grab_focus_on(a11y, &self.xpath, &bus, &path).await
}
pub async fn wait_for_visible(&self) -> Result<()> {
let xpath = self.xpath.clone();
poll_with_retry(self.effective_timeout(), &xpath, || async {
let info = self.resolve_once_info().await?;
if info.states.iter().any(|s| s == "showing") {
Ok(Some(()))
} else {
Ok(None)
}
})
.await
}
pub async fn wait_for_hidden(&self) -> Result<()> {
let xpath = self.xpath.clone();
poll_with_retry(self.effective_timeout(), &xpath, || async {
match self.resolve_once_info().await {
Ok(info) => {
if info.states.iter().any(|s| s == "showing") {
Ok(None) } else {
Ok(Some(()))
}
}
Err(Error::ElementNotFound { .. }) => Ok(Some(())), Err(e) => Err(e),
}
})
.await
}
pub async fn wait_for_enabled(&self) -> Result<()> {
let xpath = self.xpath.clone();
poll_with_retry(self.effective_timeout(), &xpath, || async {
let info = self.resolve_once_info().await?;
if is_enabled_in(&info.states) {
Ok(Some(()))
} else {
Ok(None)
}
})
.await
}
pub async fn wait_for_count(&self, n: usize) -> Result<()> {
let xpath = self.xpath.clone();
poll_with_retry(self.effective_timeout(), &xpath, || async {
let hits = self.resolve_all_once().await?;
if hits.len() == n {
Ok(Some(()))
} else {
Ok(None)
}
})
.await
}
pub async fn wait_for_text<F>(&self, pred: F) -> Result<String>
where
F: Fn(&str) -> bool,
{
let xpath = self.xpath.clone();
poll_with_retry(self.effective_timeout(), &xpath, || async {
let info = self.resolve_once_info().await?;
let a11y = self.a11y()?;
let (bus, path) = info.ref_;
let text = atspi_client::read_text_on(a11y, &self.xpath, &bus, &path).await?;
if pred(&text) {
Ok(Some(text))
} else {
Ok(None)
}
})
.await
}
async fn has_state(&self, state: &str) -> Result<bool> {
Ok(self
.wait_for_existing()
.await?
.states
.iter()
.any(|s| s == state))
}
fn a11y(&self) -> Result<&AccessibilityConnection> {
self.session
.a11y_connection
.as_ref()
.ok_or_else(|| Error::Atspi("session has no AT-SPI connection".into()))
}
fn effective_timeout(&self) -> Duration {
self.timeout
.unwrap_or_else(|| self.session.default_timeout())
}
async fn snapshot(&self) -> Result<String> {
let a11y = self.a11y()?;
atspi_client::snapshot_tree(a11y, &self.session.app_bus_name, &self.session.app_path).await
}
async fn resolve_all_once(&self) -> Result<Vec<(String, String)>> {
let xml = self.snapshot().await?;
atspi_client::evaluate_xpath(&xml, &self.xpath)
}
async fn resolve_once_info(&self) -> Result<ElementInfo> {
let xml = self.snapshot().await?;
let mut hits = atspi_client::evaluate_xpath_detailed(&xml, &self.xpath)?;
select_exactly_one(&self.xpath, hits.len())?;
Ok(hits.pop().unwrap())
}
async fn wait_for_existing(&self) -> Result<ElementInfo> {
let xpath = self.xpath.clone();
poll_with_retry(self.effective_timeout(), &xpath, || async {
Ok(Some(self.resolve_once_info().await?))
})
.await
}
async fn wait_for_actionable(&self) -> Result<ElementInfo> {
let xpath = self.xpath.clone();
poll_with_retry(self.effective_timeout(), &xpath, || async {
let info = self.resolve_once_info().await?;
let showing = info.states.iter().any(|s| s == "showing");
if showing && is_enabled_in(&info.states) {
Ok(Some(info))
} else {
Ok(None)
}
})
.await
}
async fn wait_for_focusable(&self) -> Result<ElementInfo> {
let xpath = self.xpath.clone();
poll_with_retry(self.effective_timeout(), &xpath, || async {
let info = self.resolve_once_info().await?;
let showing = info.states.iter().any(|s| s == "showing");
let focusable = info.states.iter().any(|s| s == "focusable");
if showing && focusable {
Ok(Some(info))
} else {
Ok(None)
}
})
.await
}
}
fn is_enabled_in(states: &[String]) -> bool {
states.iter().any(|s| s == "enabled" || s == "sensitive")
}
fn select_exactly_one(xpath: &str, count: usize) -> Result<()> {
match count {
0 => Err(Error::ElementNotFound {
xpath: xpath.to_string(),
}),
1 => Ok(()),
n => Err(Error::AmbiguousSelector {
xpath: xpath.to_string(),
count: n,
}),
}
}
pub(crate) async fn poll_with_retry<T, F, Fut>(
timeout: Duration,
xpath: &str,
mut f: F,
) -> Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<Option<T>>>,
{
let deadline = Instant::now() + timeout;
let mut delay = INITIAL_POLL_DELAY;
#[allow(unused_assignments)]
let mut last_err: Option<Error> = None;
let mut attempts: u32 = 0;
loop {
attempts += 1;
match f().await {
Ok(Some(v)) => return Ok(v),
Ok(None) => {
last_err = None;
}
Err(e) if is_retriable(&e) => {
last_err = Some(e);
}
Err(e) => return Err(e),
}
if Instant::now() >= deadline {
return Err(last_err.unwrap_or_else(|| {
Error::Timeout(format!(
"wait for '{xpath}' timed out after {attempts} attempt(s) \
({}ms budget)",
timeout.as_millis()
))
}));
}
tokio::time::sleep(delay).await;
delay = (delay * 2).min(MAX_POLL_DELAY);
}
}
fn is_retriable(e: &Error) -> bool {
matches!(
e,
Error::ElementNotFound { .. } | Error::ElementStale { .. }
)
}
#[cfg(test)]
mod tests {
use super::{is_retriable, poll_with_retry, select_exactly_one, Error};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
fn compose_locate(outer: &str, sub: &str) -> String {
let trimmed = sub.trim();
if trimmed.starts_with('/') {
trimmed.to_string()
} else {
format!("({outer})//{trimmed}")
}
}
fn compose_nth(outer: &str, n: usize) -> String {
format!("({outer})[{}]", n + 1)
}
fn compose_parent(outer: &str) -> String {
format!("({outer})/..")
}
#[test]
fn locate_relative_scopes() {
assert_eq!(
compose_locate("//Dialog[@name='X']", "PushButton"),
"(//Dialog[@name='X'])//PushButton"
);
}
#[test]
fn locate_absolute_replaces() {
assert_eq!(compose_locate("//Dialog", "//Menu"), "//Menu");
}
#[test]
fn nth_is_one_indexed_in_xpath() {
assert_eq!(compose_nth("//PushButton", 0), "(//PushButton)[1]");
assert_eq!(compose_nth("//PushButton", 4), "(//PushButton)[5]");
}
#[test]
fn parent_appends_dot_dot() {
assert_eq!(
compose_parent("//PushButton[@name='OK']"),
"(//PushButton[@name='OK'])/.."
);
}
#[test]
fn select_exactly_one_zero_is_not_found() {
let err = select_exactly_one("//Missing", 0).unwrap_err();
assert!(matches!(err, Error::ElementNotFound { .. }));
assert!(err.to_string().contains("//Missing"));
}
#[test]
fn select_exactly_one_one_is_ok() {
assert!(select_exactly_one("//PushButton[@name='OK']", 1).is_ok());
}
#[test]
fn select_exactly_one_many_is_ambiguous_with_count() {
let err = select_exactly_one("//PushButton", 7).unwrap_err();
match err {
Error::AmbiguousSelector { count, xpath } => {
assert_eq!(count, 7);
assert_eq!(xpath, "//PushButton");
}
other => panic!("expected AmbiguousSelector, got {other:?}"),
}
}
use std::path::{Path, PathBuf};
use async_trait::async_trait;
use crate::backend::{CaptureBackend, CompositorRuntime, InputBackend, PipeWireStream};
use crate::error::Result as WdResult;
use crate::session::Session;
struct StubCompositor;
#[async_trait]
impl CompositorRuntime for StubCompositor {
async fn start(&mut self, _resolution: Option<&str>) -> WdResult<()> {
Ok(())
}
async fn stop(&mut self) -> WdResult<()> {
Ok(())
}
fn id(&self) -> &str {
"stub"
}
fn wayland_display(&self) -> &str {
"wayland-stub"
}
fn runtime_dir(&self) -> &Path {
Path::new("/tmp")
}
}
struct StubInput;
#[async_trait]
impl InputBackend for StubInput {
async fn press_keysym(&self, _keysym: u32) -> WdResult<()> {
Ok(())
}
async fn key_down(&self, _keysym: u32) -> WdResult<()> {
Ok(())
}
async fn key_up(&self, _keysym: u32) -> WdResult<()> {
Ok(())
}
async fn pointer_motion_relative(&self, _dx: f64, _dy: f64) -> WdResult<()> {
Ok(())
}
async fn pointer_button(&self, _button: u32) -> WdResult<()> {
Ok(())
}
}
struct StubCapture;
#[async_trait]
impl CaptureBackend for StubCapture {
async fn start_stream(&self) -> WdResult<PipeWireStream> {
unimplemented!("not used in composition tests")
}
async fn stop_stream(&self, _stream: PipeWireStream) -> WdResult<()> {
Ok(())
}
fn pipewire_socket(&self) -> PathBuf {
PathBuf::from("/tmp/stub")
}
}
fn test_session() -> Arc<Session> {
Arc::new(Session::new_for_test(
"stub".into(),
"app".into(),
Box::new(StubInput),
Box::new(StubCapture),
Box::new(StubCompositor),
))
}
#[tokio::test]
async fn session_locate_carries_xpath_verbatim() {
let s = test_session();
let loc = s.locate("//PushButton[@name='OK']");
assert_eq!(loc.xpath(), "//PushButton[@name='OK']");
}
#[tokio::test]
async fn session_root_locator_uses_wildcard() {
let s = test_session();
assert_eq!(s.root().xpath(), "/*");
}
#[tokio::test]
async fn session_find_by_id_composes_xpath() {
let s = test_session();
assert_eq!(s.find_by_id("submit").xpath(), "//*[@id='submit']");
}
#[tokio::test]
async fn session_find_by_name_composes_xpath() {
let s = test_session();
assert_eq!(s.find_by_name("OK").xpath(), "//*[@name='OK']");
}
#[tokio::test]
async fn session_find_by_role_name_composes_xpath() {
let s = test_session();
assert_eq!(
s.find_by_role_name("PushButton", "OK").xpath(),
"//PushButton[@name='OK']"
);
}
#[tokio::test]
async fn locator_locate_appends_descendant_when_relative() {
let s = test_session();
let dialog = s.locate("//Dialog[@name='Confirm']");
let inner = dialog.locate("PushButton");
assert_eq!(inner.xpath(), "(//Dialog[@name='Confirm'])//PushButton");
}
#[tokio::test]
async fn locator_locate_absolute_replaces_scope() {
let s = test_session();
let dialog = s.locate("//Dialog");
assert_eq!(dialog.locate("//Menu").xpath(), "//Menu");
}
#[tokio::test]
async fn locator_nth_wraps_with_one_indexed_predicate() {
let s = test_session();
let loc = s.locate("//PushButton").nth(2);
assert_eq!(loc.xpath(), "(//PushButton)[3]");
}
#[tokio::test]
async fn locator_first_is_nth_zero() {
let s = test_session();
let loc = s.locate("//PushButton").first();
assert_eq!(loc.xpath(), "(//PushButton)[1]");
}
#[tokio::test]
async fn locator_last_uses_last_function() {
let s = test_session();
let loc = s.locate("//PushButton").last();
assert_eq!(loc.xpath(), "(//PushButton)[last()]");
}
#[tokio::test]
async fn locator_parent_appends_dot_dot() {
let s = test_session();
let loc = s.locate("//PushButton[@name='OK']").parent();
assert_eq!(loc.xpath(), "(//PushButton[@name='OK'])/..");
}
#[tokio::test]
async fn locator_composition_chains() {
let s = test_session();
let loc = s
.locate("//Dialog[@name='Confirm']")
.locate("PushButton")
.nth(1);
assert_eq!(loc.xpath(), "((//Dialog[@name='Confirm'])//PushButton)[2]");
}
#[tokio::test]
async fn locator_clone_preserves_xpath() {
let s = test_session();
let loc = s.locate("//PushButton");
let cloned = loc.clone();
assert_eq!(cloned.xpath(), "//PushButton");
}
#[tokio::test]
async fn locator_click_on_session_without_a11y_errors_cleanly() {
let s = test_session();
let err = s.locate("//PushButton").click().await.unwrap_err();
assert!(matches!(err, Error::Atspi(_)));
assert!(err.to_string().contains("no AT-SPI connection"));
}
#[tokio::test]
async fn session_dump_tree_without_a11y_errors_cleanly() {
let s = test_session();
let err = s.dump_tree().await.unwrap_err();
assert!(matches!(err, Error::Atspi(_)));
assert!(err.to_string().contains("no AT-SPI connection"));
}
#[tokio::test]
async fn with_timeout_overrides_session_default() {
let s = test_session();
let base = s.locate("//PushButton");
let quick = base.with_timeout(Duration::from_millis(100));
assert_eq!(quick.xpath(), base.xpath());
}
#[tokio::test]
async fn poll_returns_value_on_first_try() {
let result: Result<i32, Error> =
poll_with_retry(Duration::from_secs(5), "x", || async { Ok(Some(42)) }).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn poll_succeeds_after_retries() {
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_cloned = attempts.clone();
let result: Result<&'static str, Error> =
poll_with_retry(Duration::from_secs(5), "x", move || {
let a = attempts_cloned.clone();
async move {
let n = a.fetch_add(1, Ordering::SeqCst);
if n < 2 {
Err(Error::ElementNotFound { xpath: "x".into() })
} else {
Ok(Some("found"))
}
}
})
.await;
assert_eq!(result.unwrap(), "found");
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn poll_surfaces_last_retriable_error_on_timeout() {
let result: Result<&'static str, Error> =
poll_with_retry(Duration::from_millis(50), "//Missing", || async {
Err::<Option<&'static str>, _>(Error::ElementNotFound {
xpath: "//Missing".into(),
})
})
.await;
let err = result.unwrap_err();
assert!(
matches!(err, Error::ElementNotFound { .. }),
"expected ElementNotFound, got {err}"
);
}
#[tokio::test]
async fn poll_returns_timeout_when_predicate_keeps_saying_none() {
let result: Result<i32, Error> =
poll_with_retry(Duration::from_millis(50), "//Pending", || async {
Ok::<Option<i32>, Error>(None)
})
.await;
let err = result.unwrap_err();
match err {
Error::Timeout(msg) => assert!(
msg.contains("//Pending"),
"timeout message should include the xpath: {msg}"
),
other => panic!("expected Timeout, got {other:?}"),
}
}
#[tokio::test]
async fn poll_bails_immediately_on_non_retriable_error() {
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_cloned = attempts.clone();
let result: Result<&'static str, Error> =
poll_with_retry(Duration::from_secs(5), "//Bad", move || {
let a = attempts_cloned.clone();
async move {
a.fetch_add(1, Ordering::SeqCst);
Err(Error::InvalidSelector {
xpath: "//Bad".into(),
reason: "oops".into(),
})
}
})
.await;
let err = result.unwrap_err();
assert!(matches!(err, Error::InvalidSelector { .. }));
assert_eq!(attempts.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn poll_ambiguous_selector_is_not_retriable() {
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_cloned = attempts.clone();
let result: Result<&'static str, Error> =
poll_with_retry(Duration::from_secs(5), "//PushButton", move || {
let a = attempts_cloned.clone();
async move {
a.fetch_add(1, Ordering::SeqCst);
Err(Error::AmbiguousSelector {
xpath: "//PushButton".into(),
count: 3,
})
}
})
.await;
assert!(matches!(
result.unwrap_err(),
Error::AmbiguousSelector { count: 3, .. }
));
assert_eq!(attempts.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn poll_zero_timeout_is_single_shot() {
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_cloned = attempts.clone();
let start = std::time::Instant::now();
let _: Result<i32, Error> = poll_with_retry(Duration::ZERO, "//X", move || {
let a = attempts_cloned.clone();
async move {
a.fetch_add(1, Ordering::SeqCst);
Err(Error::ElementNotFound {
xpath: "//X".into(),
})
}
})
.await;
assert_eq!(attempts.load(Ordering::SeqCst), 1);
assert!(
start.elapsed() < Duration::from_millis(100),
"zero-timeout poll should not sleep, took {:?}",
start.elapsed()
);
}
#[test]
fn is_retriable_matches_expected_errors() {
assert!(is_retriable(&Error::ElementNotFound { xpath: "x".into() }));
assert!(is_retriable(&Error::ElementStale {
xpath: "x".into(),
bus: "b".into(),
path: "/p".into(),
}));
assert!(!is_retriable(&Error::AmbiguousSelector {
xpath: "x".into(),
count: 2,
}));
assert!(!is_retriable(&Error::InvalidSelector {
xpath: "x".into(),
reason: "r".into(),
}));
assert!(!is_retriable(&Error::Atspi("boom".into())));
assert!(!is_retriable(&Error::Timeout("nope".into())));
}
}