use crate::dom::element::{AriaChild, AriaNode};
use crate::error::{BrowserError, Result};
use headless_chrome::Tab;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct DomTree {
pub root: AriaNode,
pub selectors: Vec<String>,
pub iframe_indices: Vec<usize>,
}
#[derive(Debug, serde::Deserialize)]
struct SnapshotResponse {
root: AriaNode,
selectors: Vec<String>,
#[serde(rename = "iframeIndices")]
iframe_indices: Vec<usize>,
}
impl DomTree {
pub fn new(root: AriaNode) -> Self {
let mut tree = Self {
root,
selectors: Vec::new(),
iframe_indices: Vec::new(),
};
tree.rebuild_maps();
tree
}
pub fn from_tab(tab: &Arc<Tab>) -> Result<Self> {
Self::from_tab_with_prefix(tab, "")
}
pub fn from_tab_with_prefix(tab: &Arc<Tab>, _ref_prefix: &str) -> Result<Self> {
let js_code = include_str!("extract_dom.js");
let result = tab.evaluate(js_code, false).map_err(|e| {
BrowserError::DomParseFailed(format!("Failed to execute DOM extraction script: {}", e))
})?;
let json_value = result.value.ok_or_else(|| {
BrowserError::DomParseFailed("No value returned from DOM extraction".to_string())
})?;
let json_str: String = serde_json::from_value(json_value).map_err(|e| {
BrowserError::DomParseFailed(format!("Failed to get JSON string: {}", e))
})?;
let response: SnapshotResponse = serde_json::from_str(&json_str).map_err(|e| {
BrowserError::DomParseFailed(format!("Failed to parse snapshot JSON: {}", e))
})?;
Ok(Self {
root: response.root,
selectors: response.selectors,
iframe_indices: response.iframe_indices,
})
}
fn rebuild_maps(&mut self) {
self.iframe_indices.clear();
let max_index = self.find_max_index(&self.root.clone());
if let Some(max_idx) = max_index {
if self.selectors.len() <= max_idx {
self.selectors.resize(max_idx + 1, String::new());
}
}
let root = self.root.clone();
self.collect_iframe_indices(&root);
}
fn find_max_index(&self, node: &AriaNode) -> Option<usize> {
let mut max = node.index;
for child in &node.children {
if let AriaChild::Node(child_node) = child {
if let Some(child_max) = self.find_max_index(child_node) {
max = match max {
Some(current) => Some(current.max(child_max)),
None => Some(child_max),
};
}
}
}
max
}
fn collect_iframe_indices(&mut self, node: &AriaNode) {
if let Some(index) = node.index {
if node.role == "iframe" {
self.iframe_indices.push(index);
}
}
for child in &node.children {
if let AriaChild::Node(child_node) = child {
self.collect_iframe_indices(child_node);
}
}
}
pub fn get_selector(&self, index: usize) -> Option<&String> {
self.selectors.get(index).filter(|s| !s.is_empty())
}
pub fn interactive_indices(&self) -> Vec<usize> {
let mut indices = Vec::new();
self.collect_indices(&self.root, &mut indices);
indices.sort();
indices
}
fn collect_indices(&self, node: &AriaNode, indices: &mut Vec<usize>) {
if let Some(index) = node.index {
indices.push(index);
}
for child in &node.children {
if let AriaChild::Node(child_node) = child {
self.collect_indices(child_node, indices);
}
}
}
pub fn count_nodes(&self) -> usize {
self.root.count_nodes()
}
pub fn count_interactive(&self) -> usize {
self.root.count_interactive()
}
pub fn find_node_by_index(&self, index: usize) -> Option<&AriaNode> {
self.root.find_by_index(index)
}
pub fn find_node_by_index_mut(&mut self, index: usize) -> Option<&mut AriaNode> {
self.root.find_by_index_mut(index)
}
pub fn get_iframe_indices(&self) -> &[usize] {
&self.iframe_indices
}
pub fn to_json(&self) -> Result<String> {
serde_json::to_string_pretty(&self.root).map_err(|e| {
BrowserError::DomParseFailed(format!("Failed to serialize DOM to JSON: {}", e))
})
}
pub fn inject_iframe_content(&mut self, iframe_index: usize, iframe_snapshot: DomTree) {
if let Some(iframe_node) = self.find_node_by_index_mut(iframe_index) {
iframe_node.children = iframe_snapshot.root.children;
let offset = self.selectors.len();
for selector in iframe_snapshot.selectors {
if !selector.is_empty() {
self.selectors.push(selector);
}
}
for idx in iframe_snapshot.iframe_indices {
self.iframe_indices.push(idx + offset);
}
}
}
pub fn assemble_with_iframes<F>(mut self, mut get_iframe_snapshot: F) -> Self
where
F: FnMut(usize) -> Option<DomTree>,
{
let iframe_indices = self.iframe_indices.clone();
for iframe_index in iframe_indices {
if let Some(iframe_snapshot) = get_iframe_snapshot(iframe_index) {
self.inject_iframe_content(iframe_index, iframe_snapshot);
}
}
self
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_tree() -> AriaNode {
let mut root = AriaNode::fragment();
root.children.push(AriaChild::Node(Box::new(
AriaNode::new("button", "Click me")
.with_index(0)
.with_box(true, Some("pointer".to_string())),
)));
root.children.push(AriaChild::Node(Box::new(
AriaNode::new("link", "Go to page")
.with_index(1)
.with_box(true, None),
)));
root.children.push(AriaChild::Node(Box::new(
AriaNode::new("paragraph", "").with_child(AriaChild::Text("Some text".to_string())),
)));
root
}
#[test]
fn test_find_node_by_index() {
let root = create_test_tree();
let tree = DomTree::new(root);
let button = tree.find_node_by_index(0);
assert!(button.is_some());
assert_eq!(button.unwrap().role, "button");
assert_eq!(button.unwrap().name, "Click me");
let not_found = tree.find_node_by_index(999);
assert!(not_found.is_none());
}
#[test]
fn test_count_nodes() {
let root = create_test_tree();
let tree = DomTree::new(root);
assert_eq!(tree.count_nodes(), 4);
}
#[test]
fn test_interactive_indices() {
let root = create_test_tree();
let tree = DomTree::new(root);
let indices = tree.interactive_indices();
assert_eq!(indices.len(), 2);
assert!(indices.contains(&0));
assert!(indices.contains(&1));
}
#[test]
fn test_inject_iframe_content() {
let mut main_tree = AriaNode::fragment();
main_tree.children.push(AriaChild::Node(Box::new(
AriaNode::new("iframe", "").with_index(0),
)));
let mut iframe_tree = AriaNode::fragment();
iframe_tree.children.push(AriaChild::Node(Box::new(
AriaNode::new("button", "Inside iframe").with_index(0),
)));
let mut main = DomTree::new(main_tree);
let iframe = DomTree::new(iframe_tree);
main.inject_iframe_content(0, iframe);
let iframe_node = main.find_node_by_index(0).unwrap();
assert_eq!(iframe_node.children.len(), 1);
match &iframe_node.children[0] {
AriaChild::Node(n) => {
assert_eq!(n.role, "button");
assert_eq!(n.name, "Inside iframe");
}
_ => panic!("Expected node child"),
}
}
}