use crate::dom::views::{
DEFAULT_INCLUDE_ATTRIBUTES, DOMInteractedElement, EnhancedDOMTreeNode, NodeType,
SerializedDOMState,
};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SimplifiedNode {
pub original_node: EnhancedDOMTreeNode,
pub children: Vec<SimplifiedNode>,
pub should_display: bool,
pub is_interactive: bool,
pub interactive_index: Option<u32>,
}
impl SimplifiedNode {
pub fn new(node: EnhancedDOMTreeNode) -> Self {
Self {
original_node: node,
children: Vec::new(),
should_display: true,
is_interactive: false,
interactive_index: None,
}
}
}
pub struct DOMTreeSerializer {
root_node: EnhancedDOMTreeNode,
interactive_counter: u32,
selector_map: HashMap<u32, DOMInteractedElement>,
}
impl DOMTreeSerializer {
pub fn new(root_node: EnhancedDOMTreeNode) -> Self {
Self {
root_node,
interactive_counter: 1,
selector_map: HashMap::new(),
}
}
pub fn serialize_accessible_elements(mut self) -> (SerializedDOMState, HashMap<String, f64>) {
self.interactive_counter = 1;
self.selector_map.clear();
let simplified_tree = self._create_simplified_tree(&self.root_node);
let mut simplified_tree_mut = simplified_tree;
self._assign_interactive_indices(&mut simplified_tree_mut);
let simplified_tree = simplified_tree_mut;
let serialized_string =
Self::serialize_tree(&simplified_tree, DEFAULT_INCLUDE_ATTRIBUTES, 0);
let serialized_state = SerializedDOMState {
html: None,
text: Some(serialized_string.clone()),
markdown: Some(serialized_string),
elements: vec![],
selector_map: self.selector_map,
};
(serialized_state, HashMap::new())
}
fn _create_simplified_tree(&self, node: &EnhancedDOMTreeNode) -> SimplifiedNode {
let mut simplified = SimplifiedNode::new(node.clone());
simplified.should_display = self._should_display_node(node);
if let Some(ref children) = node.children_nodes {
for child in children {
let child_simplified = self._create_simplified_tree(child);
simplified.children.push(child_simplified);
}
}
if let Some(ref shadow_roots) = node.shadow_roots {
for shadow_root in shadow_roots {
let shadow_simplified = self._create_simplified_tree(shadow_root);
simplified.children.push(shadow_simplified);
}
}
if let Some(ref content_doc) = node.content_document {
let doc_simplified = self._create_simplified_tree(content_doc);
simplified.children.push(doc_simplified);
}
simplified
}
fn _should_display_node(&self, node: &EnhancedDOMTreeNode) -> bool {
if let Some(attrs) = node.attributes.get("disabled") {
if attrs.as_str() == "true" || attrs.as_str() == "disabled" {
return false;
}
}
if let Some(ref snapshot) = node.snapshot_node {
if let Some(ref styles) = snapshot.computed_styles {
if let Some(display) = styles.get("display") {
if display == "none" {
return false;
}
}
if let Some(visibility) = styles.get("visibility") {
if visibility == "hidden" {
return false;
}
}
}
}
let tag = node.tag_name();
if matches!(
tag.as_str(),
"script" | "style" | "head" | "meta" | "link" | "title"
) {
return false;
}
true
}
fn _assign_interactive_indices(&mut self, simplified: &mut SimplifiedNode) {
if !simplified.should_display {
for child in &mut simplified.children {
self._assign_interactive_indices(child);
}
return;
}
let node = &simplified.original_node;
let is_clickable = node
.snapshot_node
.as_ref()
.and_then(|s| s.is_clickable)
.unwrap_or(false)
|| self._is_interactive_element(node);
if is_clickable {
let index = self.interactive_counter;
self.interactive_counter += 1;
simplified.is_interactive = true;
simplified.interactive_index = Some(index);
let interacted = DOMInteractedElement {
index,
backend_node_id: Some(node.backend_node_id as u32),
tag: node.tag_name(),
text: self._get_element_text(node),
attributes: node.attributes.clone(),
selector: Some(self.generate_xpath_selector(node)),
};
self.selector_map.insert(index, interacted);
}
for child in &mut simplified.children {
self._assign_interactive_indices(child);
}
}
fn _is_interactive_element(&self, node: &EnhancedDOMTreeNode) -> bool {
let tag = node.tag_name();
matches!(
tag.as_str(),
"a" | "button" | "input" | "select" | "textarea" | "label"
) || node
.attributes
.get("role")
.map(|r| {
matches!(
r.as_str(),
"button" | "link" | "menuitem" | "tab" | "option"
)
})
.unwrap_or(false)
}
fn _get_element_text(&self, node: &EnhancedDOMTreeNode) -> Option<String> {
if let Some(label) = node.attributes.get("aria-label") {
if !label.is_empty() {
return Some(label.clone());
}
}
if let Some(value) = node.attributes.get("value") {
if !value.is_empty() {
return Some(value.clone());
}
}
if let Some(placeholder) = node.attributes.get("placeholder") {
if !placeholder.is_empty() {
return Some(placeholder.clone());
}
}
if node.node_type == NodeType::TextNode && !node.node_value.trim().is_empty() {
return Some(node.node_value.trim().to_string());
}
None
}
pub fn serialize_tree(
node: &SimplifiedNode,
include_attributes: &[&str],
depth: usize,
) -> String {
if !node.should_display {
return Self::_serialize_children(node, include_attributes, depth);
}
let mut formatted_text = Vec::new();
let depth_str = "\t".repeat(depth);
let next_depth = depth + 1;
match node.original_node.node_type {
NodeType::ElementNode => {
let tag = node.original_node.tag_name();
let mut parts = vec![tag.clone()];
let attrs_str =
Self::_build_attributes_string(&node.original_node, include_attributes);
if !attrs_str.is_empty() {
parts.push(attrs_str);
}
if let Some(index) = node.interactive_index {
parts.push(format!("[{index}]"));
}
formatted_text.push(format!("{}{}", depth_str, parts.join(" ")));
for child in &node.children {
let child_text = Self::serialize_tree(child, include_attributes, next_depth);
if !child_text.trim().is_empty() {
formatted_text.push(child_text);
}
}
}
NodeType::TextNode => {
let text = node.original_node.node_value.trim();
if !text.is_empty() && text.len() > 1 {
formatted_text.push(format!("{depth_str}{text}"));
}
}
_ => {
for child in &node.children {
let child_text = Self::serialize_tree(child, include_attributes, next_depth);
if !child_text.trim().is_empty() {
formatted_text.push(child_text);
}
}
}
}
formatted_text.join("\n")
}
fn _serialize_children(
node: &SimplifiedNode,
include_attributes: &[&str],
depth: usize,
) -> String {
let mut parts = Vec::new();
for child in &node.children {
let child_text = Self::serialize_tree(child, include_attributes, depth);
if !child_text.trim().is_empty() {
parts.push(child_text);
}
}
parts.join("\n")
}
fn _build_attributes_string(node: &EnhancedDOMTreeNode, include_attributes: &[&str]) -> String {
let mut attrs = Vec::new();
for attr_name in include_attributes {
if let Some(value) = node.attributes.get(*attr_name) {
if !value.is_empty() {
attrs.push(format!("{attr_name}=\"{value}\""));
}
}
}
attrs.join(" ")
}
fn _find_interacted_element(
&self,
node: &EnhancedDOMTreeNode,
) -> Option<&DOMInteractedElement> {
self.selector_map
.values()
.find(|elem| elem.backend_node_id == Some(node.backend_node_id as u32))
}
pub fn generate_xpath_selector(&self, node: &EnhancedDOMTreeNode) -> String {
let tag_name = node.tag_name();
if let Some(id) = node.attributes.get("id") {
if !id.is_empty() {
return format!("//*[@id={}]", Self::_escape_xpath_string(id));
}
}
if tag_name == "input" || tag_name == "select" || tag_name == "textarea" {
if let Some(name) = node.attributes.get("name") {
if !name.is_empty() {
return format!("//{}[@name={}]", tag_name, Self::_escape_xpath_string(name));
}
}
}
for attr in &["data-testid", "data-cy", "data-test", "data-qa"] {
if let Some(value) = node.attributes.get(*attr) {
if !value.is_empty() {
return format!("//*[@{}={}]", attr, Self::_escape_xpath_string(value));
}
}
}
if let Some(position) = self._get_node_position(node) {
format!("//{}[{}]", tag_name, position)
} else {
format!("//{}", tag_name)
}
}
fn _get_node_position(&self, node: &EnhancedDOMTreeNode) -> Option<usize> {
if let Some(parent) = &node.parent_node {
if let Some(children) = &parent.children_nodes {
let tag_name = node.tag_name();
let position = children
.iter()
.enumerate()
.filter(|(_, sibling)| sibling.tag_name() == tag_name)
.position(|(_, sibling)| sibling.backend_node_id == node.backend_node_id);
return position.map(|p| p + 1);
}
}
None
}
fn _escape_xpath_string(s: &str) -> String {
if s.contains("'") {
if s.contains("\"") {
let parts: Vec<String> = s.split('\'').map(|part| format!("'{}'", part)).collect();
format!("concat({})", parts.join(", \",\", "))
} else {
format!("\"{}\"", s)
}
} else {
format!("'{}'", s)
}
}
}