use std::io::{Read, Write};
use std::net::TcpStream;
use std::sync::atomic::{AtomicU32, Ordering};
use serde::Deserialize;
use serde_json::{json, Value};
use tracing::{debug, info};
use tungstenite::{connect as ws_connect, stream::MaybeTlsStream, Message, WebSocket};
use crate::error::{AXError, AXResult};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Rect {
pub x: f64,
pub y: f64,
pub width: f64,
pub height: f64,
}
impl Rect {
#[must_use]
pub fn new(x: f64, y: f64, width: f64, height: f64) -> Self {
Self {
x,
y,
width,
height,
}
}
#[must_use]
pub fn center(&self) -> (f64, f64) {
(self.x + self.width * 0.5, self.y + self.height * 0.5)
}
#[must_use]
pub fn area(&self) -> f64 {
self.width * self.height
}
}
#[derive(Debug, Clone)]
pub struct ElectronElement {
pub node_id: i64,
pub tag: String,
pub classes: Vec<String>,
pub text: String,
pub bounds: Option<Rect>,
}
pub struct ElectronConnection {
socket: WebSocket<MaybeTlsStream<TcpStream>>,
debug_port: u16,
next_id: AtomicU32,
}
impl std::fmt::Debug for ElectronConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ElectronConnection")
.field("debug_port", &self.debug_port)
.finish()
}
}
impl ElectronConnection {
pub fn connect(port: u16) -> AXResult<Self> {
if !probe_cdp_port(port) {
return Err(AXError::AppNotFound(format!(
"No CDP endpoint on port {port}"
)));
}
let ws_url = format!("ws://127.0.0.1:{port}/json");
debug!(port, "Connecting to Electron CDP");
let (socket, _) = ws_connect(&ws_url)
.map_err(|e| AXError::SystemError(format!("CDP WebSocket failed: {e}")))?;
info!(port, "Electron CDP connection established");
Ok(Self {
socket,
debug_port: port,
next_id: AtomicU32::new(1),
})
}
#[must_use]
pub fn port(&self) -> u16 {
self.debug_port
}
pub fn list_targets(&self) -> AXResult<Vec<CdpTarget>> {
let raw = http_get(self.debug_port, "/json")
.map_err(|e| AXError::SystemError(format!("CDP /json failed: {e}")))?;
serde_json::from_str::<Vec<CdpTarget>>(&raw)
.map_err(|e| AXError::SystemError(format!("CDP target parse failed: {e}")))
}
}
impl ElectronConnection {
pub fn execute(&mut self, method: &str, params: Value) -> AXResult<Value> {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let msg = json!({ "id": id, "method": method, "params": params });
debug!(method, %id, "CDP request");
self.socket
.send(Message::Text(msg.to_string().into()))
.map_err(|e| AXError::SystemError(format!("CDP send: {e}")))?;
loop {
let frame = self
.socket
.read()
.map_err(|e| AXError::SystemError(format!("CDP read: {e}")))?;
if let Message::Text(text) = frame {
let resp: CdpResponse = serde_json::from_str(&text)
.map_err(|e| AXError::SystemError(format!("CDP parse: {e}")))?;
if resp.id == Some(id) {
return match resp.error {
Some(err) => Err(AXError::ActionFailed(format!("CDP: {}", err.message))),
None => Ok(resp.result.unwrap_or(Value::Null)),
};
}
}
}
}
}
impl ElectronConnection {
pub fn query_selector(&mut self, selector: &str) -> AXResult<Vec<ElectronElement>> {
let doc = self.execute("DOM.getDocument", json!({ "depth": 0 }))?;
let root_id = doc["root"]["nodeId"]
.as_i64()
.ok_or_else(|| AXError::SystemError("No root nodeId".into()))?;
let result = self.execute(
"DOM.querySelectorAll",
json!({ "nodeId": root_id, "selector": selector }),
)?;
let node_ids = match result["nodeIds"].as_array() {
Some(ids) => ids.clone(),
None => return Ok(vec![]),
};
node_ids
.into_iter()
.filter_map(|v| v.as_i64())
.map(|nid| self.enrich_element(nid))
.collect()
}
pub fn get_accessibility_tree(&mut self) -> AXResult<Vec<ElectronElement>> {
let result = self.execute("Accessibility.getFullAXTree", json!({}))?;
let nodes = result["nodes"]
.as_array()
.ok_or_else(|| AXError::SystemError("AX tree missing 'nodes'".into()))?;
let elements = nodes.iter().map(ax_node_to_element).collect();
Ok(elements)
}
pub fn evaluate_js(&mut self, expr: &str) -> AXResult<String> {
let result = self.execute(
"Runtime.evaluate",
json!({
"expression": expr,
"returnByValue": true,
"awaitPromise": false,
}),
)?;
if let Some(exc) = result.get("exceptionDetails") {
let msg = exc["exception"]["description"]
.as_str()
.or_else(|| exc["text"].as_str())
.unwrap_or("JavaScript exception");
return Err(AXError::ActionFailed(msg.into()));
}
let value = &result["result"]["value"];
Ok(match value {
Value::String(s) => s.clone(),
Value::Null => "null".into(),
other => other.to_string(),
})
}
pub fn click_element(&mut self, element: &ElectronElement) -> AXResult<()> {
let bounds = element
.bounds
.ok_or_else(|| AXError::ActionFailed("Element has no bounds for click".into()))?;
let (x, y) = bounds.center();
self.dispatch_click(x, y)
}
pub fn type_text(&mut self, text: &str) -> AXResult<()> {
self.execute("Input.insertText", json!({ "text": text }))?;
Ok(())
}
fn dispatch_click(&mut self, x: f64, y: f64) -> AXResult<()> {
let down = json!({
"type": "mousePressed",
"x": x, "y": y,
"button": "left",
"clickCount": 1,
});
let up = json!({
"type": "mouseReleased",
"x": x, "y": y,
"button": "left",
"clickCount": 1,
});
self.execute("Input.dispatchMouseEvent", down)?;
self.execute("Input.dispatchMouseEvent", up)?;
Ok(())
}
fn enrich_element(&mut self, node_id: i64) -> AXResult<ElectronElement> {
let desc = self.execute("DOM.describeNode", json!({ "nodeId": node_id, "depth": 0 }))?;
let tag = desc["node"]["localName"]
.as_str()
.unwrap_or("unknown")
.to_lowercase();
let classes = parse_class_list(desc["node"]["attributes"].as_array());
let text = self
.evaluate_js(&format!(
"(function(){{var n=document.querySelectorAll('*')[{node_id}];return n?n.innerText:''}})()"
))
.unwrap_or_default();
let bounds = self.get_box_model(node_id).ok();
Ok(ElectronElement {
node_id,
tag,
classes,
text,
bounds,
})
}
fn get_box_model(&mut self, node_id: i64) -> AXResult<Rect> {
let result = self.execute("DOM.getBoxModel", json!({ "nodeId": node_id }))?;
let content = result["model"]["content"]
.as_array()
.ok_or_else(|| AXError::SystemError("No content box".into()))?;
let x0 = content.first().and_then(Value::as_f64).unwrap_or(0.0);
let y0 = content.get(1).and_then(Value::as_f64).unwrap_or(0.0);
let x1 = content.get(4).and_then(Value::as_f64).unwrap_or(0.0);
let y1 = content.get(5).and_then(Value::as_f64).unwrap_or(0.0);
Ok(Rect::new(x0, y0, x1 - x0, y1 - y0))
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct CdpTarget {
pub id: String,
pub title: String,
#[serde(rename = "type")]
pub target_type: String,
#[serde(rename = "webSocketDebuggerUrl")]
pub ws_url: Option<String>,
}
#[derive(Debug, Deserialize)]
struct CdpResponse {
id: Option<u32>,
result: Option<Value>,
error: Option<CdpError>,
}
#[derive(Debug, Deserialize)]
struct CdpError {
#[allow(dead_code)]
code: i32,
message: String,
}
#[must_use]
pub fn probe_cdp_port(port: u16) -> bool {
let Ok(mut stream) = TcpStream::connect(format!("127.0.0.1:{port}")) else {
return false;
};
let req = format!(
"GET /json/version HTTP/1.1\r\nHost: 127.0.0.1:{port}\r\nConnection: close\r\n\r\n"
);
if stream.write_all(req.as_bytes()).is_err() {
return false;
}
let mut buf = [0u8; 512];
let Ok(n) = stream.read(&mut buf) else {
return false;
};
let resp = String::from_utf8_lossy(&buf[..n]);
resp.contains("Browser") || resp.contains("webSocketDebuggerUrl")
}
fn http_get(port: u16, path: &str) -> std::io::Result<String> {
let mut stream = TcpStream::connect(format!("127.0.0.1:{port}"))?;
let req = format!("GET {path} HTTP/1.1\r\nHost: 127.0.0.1:{port}\r\nConnection: close\r\n\r\n");
stream.write_all(req.as_bytes())?;
let mut body = String::new();
stream.read_to_string(&mut body)?;
if let Some(pos) = body.find("\r\n\r\n") {
Ok(body[pos + 4..].to_string())
} else {
Ok(body)
}
}
fn parse_class_list(attrs: Option<&Vec<Value>>) -> Vec<String> {
let Some(attrs) = attrs else {
return vec![];
};
attrs
.chunks(2)
.find(|pair| pair.first().and_then(Value::as_str) == Some("class"))
.and_then(|pair| pair.get(1)?.as_str())
.map(|classes| classes.split_whitespace().map(str::to_string).collect())
.unwrap_or_default()
}
fn ax_node_to_element(node: &Value) -> ElectronElement {
let node_id = node["nodeId"].as_i64().unwrap_or(0);
let tag = node["role"]["value"]
.as_str()
.unwrap_or("unknown")
.to_lowercase();
let text = node["name"]["value"].as_str().unwrap_or("").to_string();
ElectronElement {
node_id,
tag,
classes: vec![],
text,
bounds: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn rect_center_midpoint_of_bounds() {
let r = Rect::new(10.0, 20.0, 80.0, 40.0);
let (cx, cy) = r.center();
assert_eq!(cx, 50.0);
assert_eq!(cy, 40.0);
}
#[test]
fn rect_area_width_times_height() {
let r = Rect::new(0.0, 0.0, 100.0, 50.0);
assert_eq!(r.area(), 5_000.0);
}
#[test]
fn rect_zero_size_area_is_zero() {
let r = Rect::new(5.0, 5.0, 0.0, 0.0);
assert_eq!(r.area(), 0.0);
}
#[test]
fn probe_cdp_port_closed_port_returns_false() {
assert!(!probe_cdp_port(65_000));
}
#[test]
fn probe_cdp_port_invalid_high_port_returns_false() {
assert!(!probe_cdp_port(65_535));
}
#[test]
fn parse_class_list_extracts_classes() {
let attrs = vec![
json!("id"),
json!("myId"),
json!("class"),
json!("btn primary"),
];
let classes = parse_class_list(Some(&attrs));
assert_eq!(classes, vec!["btn", "primary"]);
}
#[test]
fn parse_class_list_no_class_attr_returns_empty() {
let attrs = vec![json!("id"), json!("myId")];
assert!(parse_class_list(Some(&attrs)).is_empty());
}
#[test]
fn parse_class_list_none_returns_empty() {
assert!(parse_class_list(None).is_empty());
}
#[test]
fn ax_node_to_element_maps_role_and_name() {
let node = json!({
"nodeId": 42,
"role": { "value": "Button" },
"name": { "value": "Submit" }
});
let elem = ax_node_to_element(&node);
assert_eq!(elem.node_id, 42);
assert_eq!(elem.tag, "button");
assert_eq!(elem.text, "Submit");
assert!(elem.classes.is_empty());
assert!(elem.bounds.is_none());
}
#[test]
fn ax_node_to_element_missing_fields_use_defaults() {
let node = json!({});
let elem = ax_node_to_element(&node);
assert_eq!(elem.node_id, 0);
assert_eq!(elem.tag, "unknown");
assert_eq!(elem.text, "");
}
#[test]
fn http_get_strips_http_headers() {
let raw = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n[{\"id\":\"1\"}]";
let split_pos = raw.find("\r\n\r\n").expect("must have separator");
let body = &raw[split_pos + 4..];
assert_eq!(body, "[{\"id\":\"1\"}]");
}
#[test]
fn connect_to_closed_port_returns_error() {
let result = ElectronConnection::connect(65_001);
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("No CDP endpoint") || msg.contains("CDP"),
"unexpected error message: {msg}"
);
}
#[test]
fn cdp_target_deserialises_from_json() {
let json_str = r#"{
"id": "abc-123",
"title": "VS Code",
"type": "page",
"webSocketDebuggerUrl": "ws://127.0.0.1:9222/devtools/page/abc-123"
}"#;
let target: CdpTarget = serde_json::from_str(json_str).unwrap();
assert_eq!(target.id, "abc-123");
assert_eq!(target.title, "VS Code");
assert_eq!(target.target_type, "page");
assert!(target.ws_url.is_some());
}
#[test]
fn cdp_target_deserialises_without_ws_url() {
let json_str = r#"{
"id": "worker-1",
"title": "service worker",
"type": "service_worker"
}"#;
let target: CdpTarget = serde_json::from_str(json_str).unwrap();
assert_eq!(target.target_type, "service_worker");
assert!(target.ws_url.is_none());
}
}