#![allow(clippy::useless_conversion)]
use pyo3::prelude::*;
use std::process::Command;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::accessibility::{
self, attributes, create_application_element, get_attribute, AXUIElementRef,
};
use crate::element::AXElement;
use crate::error::{AXError, AXResult};
use crate::sync::SyncEngine;
#[pyclass]
pub struct AXApp {
pub(crate) pid: i32,
pub(crate) bundle_id: Option<String>,
pub(crate) name: Option<String>,
pub(crate) element: AXUIElementRef,
sync_engine: Arc<SyncEngine>,
}
impl std::fmt::Debug for AXApp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AXApp")
.field("pid", &self.pid)
.field("bundle_id", &self.bundle_id)
.field("name", &self.name)
.field("element", &self.element)
.field("sync_mode", &self.sync_engine.mode())
.finish()
}
}
unsafe impl Send for AXApp {}
unsafe impl Sync for AXApp {}
#[pymethods]
impl AXApp {
#[getter]
fn pid(&self) -> i32 {
self.pid
}
#[getter]
fn bundle_id(&self) -> Option<String> {
self.bundle_id.clone()
}
fn is_running(&self) -> bool {
std::fs::metadata(format!("/proc/{}", self.pid)).is_ok()
|| Command::new("kill")
.args(["-0", &self.pid.to_string()])
.output()
.map(|o| o.status.success())
.unwrap_or(false)
}
#[pyo3(signature = (query, timeout_ms=None))]
fn find(&self, query: &str, timeout_ms: Option<u64>) -> PyResult<AXElement> {
let timeout = timeout_ms.map(Duration::from_millis);
self.find_element(query, timeout)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
}
#[pyo3(signature = (role, title=None, identifier=None, label=None))]
fn find_by_role(
&self,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
label: Option<&str>,
) -> PyResult<AXElement> {
self.find_element_by_role(role, title, identifier, label)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
}
#[pyo3(signature = (query, timeout_ms=5000))]
fn wait_for_element(&self, query: &str, timeout_ms: u64) -> PyResult<AXElement> {
self.find_element(query, Some(Duration::from_millis(timeout_ms)))
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
}
#[pyo3(signature = (timeout_ms=5000))]
fn wait_for_idle(&self, timeout_ms: u64) -> bool {
self.sync_engine
.wait_for_idle(Duration::from_millis(timeout_ms))
}
fn is_idle(&self) -> bool {
self.sync_engine.is_idle()
}
fn screenshot(&self) -> PyResult<Vec<u8>> {
self.capture_screenshot()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
}
fn windows(&self) -> PyResult<Vec<AXElement>> {
self.get_windows()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
}
fn main_window(&self) -> PyResult<AXElement> {
self.get_main_window()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
}
fn terminate(&self) -> PyResult<()> {
Command::new("kill")
.arg(self.pid.to_string())
.output()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(())
}
}
impl AXApp {
pub fn connect(
name: Option<&str>,
bundle_id: Option<&str>,
pid: Option<u32>,
) -> PyResult<Self> {
let resolved_pid = if let Some(p) = pid {
p as i32
} else if let Some(bid) = bundle_id {
Self::pid_from_bundle_id(bid)?
} else if let Some(n) = name {
Self::pid_from_name(n)?
} else {
return Err(pyo3::exceptions::PyValueError::new_err(
"Must provide name, bundle_id, or pid",
));
};
let element = create_application_element(resolved_pid)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let sync_engine = Arc::new(SyncEngine::new(resolved_pid, element));
Ok(Self {
pid: resolved_pid,
bundle_id: bundle_id.map(String::from),
name: name.map(String::from),
element,
sync_engine,
})
}
fn pid_from_bundle_id(bundle_id: &str) -> PyResult<i32> {
let output = Command::new("osascript")
.args([
"-e",
&format!(
"tell application \"System Events\" to unix id of (processes whose bundle identifier is \"{bundle_id}\")"
),
])
.output()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let stdout = String::from_utf8_lossy(&output.stdout);
let pid_str = stdout.trim();
if pid_str.is_empty() || pid_str == "missing value" {
return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
"Application not found: {bundle_id}"
)));
}
pid_str
.parse::<i32>()
.map_err(|_| pyo3::exceptions::PyRuntimeError::new_err("Failed to parse PID"))
}
fn pid_from_name(name: &str) -> PyResult<i32> {
let output = Command::new("pgrep")
.args(["-x", name])
.output()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let stdout = String::from_utf8_lossy(&output.stdout);
let pid_str = stdout.lines().next().unwrap_or("").trim();
if pid_str.is_empty() {
return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
"Application not found: {name}"
)));
}
pid_str
.parse::<i32>()
.map_err(|_| pyo3::exceptions::PyRuntimeError::new_err("Failed to parse PID"))
}
fn find_element(&self, query: &str, timeout: Option<Duration>) -> AXResult<AXElement> {
let start = Instant::now();
let timeout = timeout.unwrap_or(Duration::from_millis(100));
loop {
match self.search_element(query) {
Ok(element) => return Ok(element),
Err(e) if start.elapsed() >= timeout => {
return Err(AXError::ElementNotFound(query.to_string()));
}
Err(_) => {
std::thread::sleep(Duration::from_millis(50));
}
}
}
}
fn search_element(&self, query: &str) -> AXResult<AXElement> {
let criteria = SearchCriteria::parse(query)?;
let result = self.breadth_first_search(&criteria)?;
Ok(result)
}
fn find_element_by_role(
&self,
role: &str,
title: Option<&str>,
identifier: Option<&str>,
label: Option<&str>,
) -> AXResult<AXElement> {
let criteria = SearchCriteria {
role: Some(role.to_string()),
title: title.map(String::from),
identifier: identifier.map(String::from),
label: label.map(String::from),
};
self.breadth_first_search(&criteria)
}
fn capture_screenshot(&self) -> AXResult<Vec<u8>> {
let temp_path = format!("/tmp/axterminator_screenshot_{}.png", self.pid);
let output = Command::new("screencapture")
.args(["-l", &self.window_id()?, "-o", "-x", &temp_path])
.output()
.map_err(|e| AXError::SystemError(e.to_string()))?;
if !output.status.success() {
return Err(AXError::SystemError("Screenshot failed".into()));
}
let data = std::fs::read(&temp_path).map_err(|e| AXError::SystemError(e.to_string()))?;
let _ = std::fs::remove_file(&temp_path);
Ok(data)
}
fn window_id(&self) -> AXResult<String> {
let output = Command::new("osascript")
.args([
"-e",
&format!(
"tell application \"System Events\" to id of window 1 of (processes whose unix id is {})",
self.pid
),
])
.output()
.map_err(|e| AXError::SystemError(e.to_string()))?;
let stdout = String::from_utf8_lossy(&output.stdout);
Ok(stdout.trim().to_string())
}
fn get_windows(&self) -> AXResult<Vec<AXElement>> {
let windows_ref = get_attribute(self.element, attributes::AX_WINDOWS)?;
let windows = cf_array_to_vec(windows_ref)
.ok_or_else(|| AXError::SystemError("Failed to get windows array".into()))?;
accessibility::release_cf(windows_ref);
Ok(windows.into_iter().map(AXElement::new).collect())
}
fn get_main_window(&self) -> AXResult<AXElement> {
let main_window_ref = get_attribute(self.element, attributes::AX_MAIN_WINDOW)?;
Ok(AXElement::new(main_window_ref as AXUIElementRef))
}
fn breadth_first_search(&self, criteria: &SearchCriteria) -> AXResult<AXElement> {
use core_foundation::base::CFTypeRef;
use std::collections::VecDeque;
if self.element_matches(self.element, criteria) {
let _ = accessibility::retain_cf(self.element as CFTypeRef);
return Ok(AXElement::new(self.element));
}
let mut queue: VecDeque<(AXUIElementRef, bool)> = VecDeque::new();
if let Ok(children_ref) = get_attribute(self.element, attributes::AX_CHILDREN) {
if let Some(children) = cf_array_to_vec(children_ref) {
for child in children {
queue.push_back((child, false)); }
}
accessibility::release_cf(children_ref);
}
while let Some((current, _is_root)) = queue.pop_front() {
if self.element_matches(current, criteria) {
for (elem, _) in queue {
accessibility::release_cf(elem as CFTypeRef);
}
return Ok(AXElement::new(current));
}
if let Ok(children_ref) = get_attribute(current, attributes::AX_CHILDREN) {
if let Some(children) = cf_array_to_vec(children_ref) {
for child in children {
queue.push_back((child, false));
}
}
accessibility::release_cf(children_ref);
}
accessibility::release_cf(current as CFTypeRef);
}
Err(AXError::ElementNotFound(format!("{criteria:?}")))
}
fn element_matches(&self, element: AXUIElementRef, criteria: &SearchCriteria) -> bool {
if let Some(required_role) = &criteria.role {
if let Ok(role_ref) = get_attribute(element, attributes::AX_ROLE) {
let matches = cf_string_to_string(role_ref).is_some_and(|r| &r == required_role);
accessibility::release_cf(role_ref);
if !matches {
return false;
}
} else {
return false;
}
}
if let Some(required_title) = &criteria.title {
if let Ok(title_ref) = get_attribute(element, attributes::AX_TITLE) {
let matches =
cf_string_to_string(title_ref).is_some_and(|t| t.contains(required_title));
accessibility::release_cf(title_ref);
if !matches {
return false;
}
} else {
return false;
}
}
if let Some(required_id) = &criteria.identifier {
if let Ok(id_ref) = get_attribute(element, attributes::AX_IDENTIFIER) {
let matches = cf_string_to_string(id_ref).is_some_and(|i| &i == required_id);
accessibility::release_cf(id_ref);
if !matches {
return false;
}
} else {
return false;
}
}
if let Some(required_label) = &criteria.label {
if let Ok(label_ref) = get_attribute(element, attributes::AX_LABEL) {
let matches =
cf_string_to_string(label_ref).is_some_and(|l| l.contains(required_label));
accessibility::release_cf(label_ref);
if !matches {
return false;
}
} else {
return false;
}
}
true
}
}
impl Drop for AXApp {
fn drop(&mut self) {
accessibility::release_cf(self.element.cast());
}
}
#[derive(Debug, Clone)]
struct SearchCriteria {
role: Option<String>,
title: Option<String>,
identifier: Option<String>,
label: Option<String>,
}
impl SearchCriteria {
fn parse(query: &str) -> AXResult<Self> {
let query = query.trim();
if query.starts_with("//") {
return Self::parse_xpath(query);
}
if query.contains(':') {
return Self::parse_key_value(query);
}
Ok(Self {
role: None,
title: Some(query.to_string()),
identifier: Some(query.to_string()),
label: Some(query.to_string()),
})
}
fn parse_xpath(query: &str) -> AXResult<Self> {
let mut criteria = Self {
role: None,
title: None,
identifier: None,
label: None,
};
if let Some(role_end) = query.find('[').or(Some(query.len())) {
let role = query[2..role_end].trim();
if !role.is_empty() {
criteria.role = Some(role.to_string());
}
}
for attr_match in query.match_indices("[@") {
let start = attr_match.0 + 2;
if let Some(end) = query[start..].find(']') {
let attr_str = &query[start..start + end];
if let Some((key, value)) = attr_str.split_once('=') {
let key = key.trim();
let value = value.trim().trim_matches(|c| c == '\'' || c == '"');
match key {
"AXTitle" => criteria.title = Some(value.to_string()),
"AXIdentifier" => criteria.identifier = Some(value.to_string()),
"AXLabel" => criteria.label = Some(value.to_string()),
_ => {}
}
}
}
}
Ok(criteria)
}
fn parse_key_value(query: &str) -> AXResult<Self> {
let mut criteria = Self {
role: None,
title: None,
identifier: None,
label: None,
};
for part in query.split_whitespace() {
if let Some((key, value)) = part.split_once(':') {
match key.trim() {
"role" => criteria.role = Some(value.trim().to_string()),
"title" => criteria.title = Some(value.trim().to_string()),
"identifier" | "id" => criteria.identifier = Some(value.trim().to_string()),
"label" => criteria.label = Some(value.trim().to_string()),
_ => return Err(AXError::InvalidQuery(format!("Unknown key: {key}"))),
}
}
}
Ok(criteria)
}
}
fn cf_string_to_string(cf_ref: core_foundation::base::CFTypeRef) -> Option<String> {
use core_foundation::base::TCFType;
use core_foundation::string::CFString;
if cf_ref.is_null() {
return None;
}
unsafe {
let cf_string = CFString::wrap_under_get_rule(cf_ref.cast());
Some(cf_string.to_string())
}
}
fn cf_array_to_vec(cf_ref: core_foundation::base::CFTypeRef) -> Option<Vec<AXUIElementRef>> {
use core_foundation::array::CFArray;
use core_foundation::base::{CFType, CFTypeRef, TCFType};
if cf_ref.is_null() {
return None;
}
unsafe {
let cf_array: CFArray<CFType> = CFArray::wrap_under_get_rule(cf_ref.cast());
let count = cf_array.len();
let mut result = Vec::with_capacity(count as usize);
for i in 0..count {
if let Some(element_ref) = cf_array.get(i) {
let element_ptr = element_ref.as_concrete_TypeRef() as AXUIElementRef;
if !element_ptr.is_null() {
let _ = accessibility::retain_cf(element_ptr as CFTypeRef);
result.push(element_ptr);
}
}
}
Some(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_criteria_parse_simple_text() {
let query = "Save";
let criteria = SearchCriteria::parse(query).unwrap();
assert_eq!(criteria.role, None);
assert_eq!(criteria.title, Some("Save".to_string()));
assert_eq!(criteria.identifier, Some("Save".to_string()));
assert_eq!(criteria.label, Some("Save".to_string()));
}
#[test]
fn test_search_criteria_parse_role_only() {
let query = "role:AXButton";
let criteria = SearchCriteria::parse(query).unwrap();
assert_eq!(criteria.role, Some("AXButton".to_string()));
assert_eq!(criteria.title, None);
assert_eq!(criteria.identifier, None);
assert_eq!(criteria.label, None);
}
#[test]
fn test_search_criteria_parse_combined() {
let query = "role:AXButton title:Save";
let criteria = SearchCriteria::parse(query).unwrap();
assert_eq!(criteria.role, Some("AXButton".to_string()));
assert_eq!(criteria.title, Some("Save".to_string()));
assert_eq!(criteria.identifier, None);
assert_eq!(criteria.label, None);
}
#[test]
fn test_search_criteria_parse_xpath_role_only() {
let query = "//AXButton";
let criteria = SearchCriteria::parse(query).unwrap();
assert_eq!(criteria.role, Some("AXButton".to_string()));
assert_eq!(criteria.title, None);
}
#[test]
fn test_search_criteria_parse_xpath_with_title() {
let query = "//AXButton[@AXTitle='Save']";
let criteria = SearchCriteria::parse(query).unwrap();
assert_eq!(criteria.role, Some("AXButton".to_string()));
assert_eq!(criteria.title, Some("Save".to_string()));
}
#[test]
fn test_search_criteria_parse_xpath_multiple_attributes() {
let query = "//AXButton[@AXTitle='Save'][@AXIdentifier='save_btn']";
let criteria = SearchCriteria::parse(query).unwrap();
assert_eq!(criteria.role, Some("AXButton".to_string()));
assert_eq!(criteria.title, Some("Save".to_string()));
assert_eq!(criteria.identifier, Some("save_btn".to_string()));
}
#[test]
fn test_search_criteria_parse_identifier_alias() {
let query = "role:AXButton id:save_btn";
let criteria = SearchCriteria::parse(query).unwrap();
assert_eq!(criteria.identifier, Some("save_btn".to_string()));
}
#[test]
fn test_search_criteria_parse_invalid_key() {
let query = "role:AXButton invalid:value";
let result = SearchCriteria::parse(query);
assert!(result.is_err());
match result {
Err(AXError::InvalidQuery(msg)) => assert!(msg.contains("invalid")),
_ => panic!("Expected InvalidQuery error"),
}
}
#[test]
fn test_cf_string_conversion_null_safety() {
let null_ref: core_foundation::base::CFTypeRef = std::ptr::null();
let result = cf_string_to_string(null_ref);
assert!(result.is_none());
}
#[test]
fn test_cf_array_conversion_null_safety() {
let null_ref: core_foundation::base::CFTypeRef = std::ptr::null();
let result = cf_array_to_vec(null_ref);
assert!(result.is_none());
}
}