browser_use/dom/
tree.rs

1use crate::dom::element::ElementNode;
2use crate::dom::selector_map::{ElementSelector, SelectorMap};
3use crate::error::{BrowserError, Result};
4use headless_chrome::Tab;
5use std::sync::Arc;
6
7/// Represents the DOM tree of a web page
8#[derive(Debug, Clone)]
9pub struct DomTree {
10    /// Root element of the DOM tree
11    pub root: ElementNode,
12
13    /// Map of indices to element selectors
14    pub selector_map: SelectorMap,
15}
16
17impl DomTree {
18    /// Create a new empty DomTree
19    pub fn new(root: ElementNode) -> Self {
20        Self {
21            root,
22            selector_map: SelectorMap::new(),
23        }
24    }
25
26    /// Build DOM tree from a browser tab
27    pub fn from_tab(tab: &Arc<Tab>) -> Result<Self> {
28        // JavaScript code to extract simplified DOM structure
29        // This returns a JSON string
30        let js_code = include_str!("extract_dom.js");
31
32        // Execute JavaScript to extract DOM
33        let result = tab.evaluate(js_code, false).map_err(|e| {
34            BrowserError::DomParseFailed(format!("Failed to execute DOM extraction script: {}", e))
35        })?;
36
37        // Get the JSON string value
38        let json_value = result.value.ok_or_else(|| {
39            BrowserError::DomParseFailed("No value returned from DOM extraction".to_string())
40        })?;
41
42        // The JavaScript returns a JSON string, so we need to parse it as a string first
43        let json_str: String = serde_json::from_value(json_value).map_err(|e| {
44            BrowserError::DomParseFailed(format!("Failed to get JSON string: {}", e))
45        })?;
46
47        // Then parse the JSON string into ElementNode
48        let root: ElementNode = serde_json::from_str(&json_str).map_err(|e| {
49            BrowserError::DomParseFailed(format!("Failed to parse DOM JSON: {}", e))
50        })?;
51
52        let mut tree = Self::new(root);
53        tree.build_selector_map();
54
55        Ok(tree)
56    }
57
58    /// Build the selector map by traversing the DOM tree
59    fn build_selector_map(&mut self) {
60        self.selector_map.clear();
61        let mut index_counter = 0;
62        Self::traverse_and_index_static(
63            &mut self.root,
64            "body",
65            &mut self.selector_map,
66            &mut index_counter,
67        );
68    }
69
70    /// Static method to recursively traverse and index elements
71    fn traverse_and_index_static(
72        node: &mut ElementNode,
73        css_path: &str,
74        selector_map: &mut SelectorMap,
75        _index_counter: &mut usize,
76    ) {
77        // Compute interactivity for this node
78        node.compute_interactivity();
79
80        // If the element is interactive, assign it an index
81        if node.is_interactive && node.is_visible {
82            let selector = Self::build_selector_static(node, css_path);
83            let index = selector_map.register(selector);
84            node.index = Some(index);
85        }
86
87        // Recursively process children
88        for (i, child) in node.children.iter_mut().enumerate() {
89            let child_path = format!("{} > {}:nth-child({})", css_path, child.tag_name, i + 1);
90            Self::traverse_and_index_static(child, &child_path, selector_map, _index_counter);
91        }
92    }
93
94    /// Build an ElementSelector for a given node (static version)
95    fn build_selector_static(node: &ElementNode, css_path: &str) -> ElementSelector {
96        // Prefer ID selector if available
97        let css_selector = if let Some(id) = &node.id() {
98            format!("#{}", id)
99        } else if let Some(class) = node.get_attribute("class") {
100            format!(
101                "{}.{}",
102                node.tag_name,
103                class.split_whitespace().next().unwrap_or("")
104            )
105        } else {
106            css_path.to_string()
107        };
108
109        let mut selector = ElementSelector::new(css_selector, &node.tag_name);
110
111        if let Some(id) = node.id() {
112            selector = selector.with_id(id);
113        }
114
115        if let Some(text) = &node.text_content {
116            // Truncate text for display
117            let truncated = if text.len() > 50 {
118                format!("{}...", &text[..47])
119            } else {
120                text.clone()
121            };
122            selector = selector.with_text(truncated);
123        }
124
125        selector
126    }
127
128    /// Simplify the DOM tree by removing unnecessary elements
129    pub fn simplify(&mut self) {
130        self.root.simplify();
131        self.build_selector_map(); // Rebuild map after simplification
132    }
133
134    /// Convert the DOM tree to JSON
135    pub fn to_json(&self) -> Result<String> {
136        serde_json::to_string_pretty(&self.root).map_err(|e| {
137            BrowserError::DomParseFailed(format!("Failed to serialize DOM to JSON: {}", e))
138        })
139    }
140
141    /// Get element selector by index
142    pub fn get_selector(&self, index: usize) -> Option<&ElementSelector> {
143        self.selector_map.get(index)
144    }
145
146    /// Get all interactive element indices
147    pub fn interactive_indices(&self) -> Vec<usize> {
148        self.selector_map.indices().copied().collect()
149    }
150
151    /// Count total elements in the tree
152    pub fn count_elements(&self) -> usize {
153        self.count_elements_recursive(&self.root)
154    }
155
156    fn count_elements_recursive(&self, node: &ElementNode) -> usize {
157        1 + node
158            .children
159            .iter()
160            .map(|c| self.count_elements_recursive(c))
161            .sum::<usize>()
162    }
163
164    /// Count interactive elements
165    pub fn count_interactive(&self) -> usize {
166        self.selector_map.len()
167    }
168
169    /// Find element node by index (traverse the tree)
170    pub fn find_node_by_index(&self, index: usize) -> Option<&ElementNode> {
171        self.find_node_by_index_recursive(&self.root, index)
172    }
173
174    fn find_node_by_index_recursive<'a>(
175        &self,
176        node: &'a ElementNode,
177        target_index: usize,
178    ) -> Option<&'a ElementNode> {
179        if node.index == Some(target_index) {
180            return Some(node);
181        }
182
183        for child in &node.children {
184            if let Some(found) = self.find_node_by_index_recursive(child, target_index) {
185                return Some(found);
186            }
187        }
188
189        None
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    fn create_test_tree() -> ElementNode {
198        let mut root = ElementNode::new("body");
199
200        let mut header = ElementNode::new("header");
201        let mut nav_button = ElementNode::new("button");
202        nav_button.add_attribute("id", "nav-btn");
203        nav_button.text_content = Some("Menu".to_string());
204        nav_button.is_visible = true;
205        header.add_child(nav_button);
206
207        let mut main = ElementNode::new("main");
208        let mut link = ElementNode::new("a");
209        link.add_attribute("href", "/page");
210        link.text_content = Some("Click here".to_string());
211        link.is_visible = true;
212        main.add_child(link);
213
214        let mut div = ElementNode::new("div");
215        div.add_attribute("class", "content");
216        div.text_content = Some("Some text".to_string());
217        main.add_child(div);
218
219        root.add_child(header);
220        root.add_child(main);
221
222        root
223    }
224
225    #[test]
226    fn test_dom_tree_creation() {
227        let root = create_test_tree();
228        let tree = DomTree::new(root);
229
230        assert_eq!(tree.root.tag_name, "body");
231        assert_eq!(tree.root.children.len(), 2);
232    }
233
234    #[test]
235    fn test_build_selector_map() {
236        let root = create_test_tree();
237        let mut tree = DomTree::new(root);
238        tree.build_selector_map();
239
240        // Should have 2 interactive elements: button and link
241        assert_eq!(tree.count_interactive(), 2);
242    }
243
244    #[test]
245    fn test_find_node_by_index() {
246        let root = create_test_tree();
247        let mut tree = DomTree::new(root);
248        tree.build_selector_map();
249
250        let indices = tree.interactive_indices();
251        assert!(!indices.is_empty());
252
253        for &index in &indices {
254            let node = tree.find_node_by_index(index);
255            assert!(node.is_some());
256            assert_eq!(node.unwrap().index, Some(index));
257        }
258    }
259
260    #[test]
261    fn test_count_elements() {
262        let root = create_test_tree();
263        let tree = DomTree::new(root);
264
265        // body > header > button, body > main > link, div
266        // Total: body(1) + header(1) + button(1) + main(1) + link(1) + div(1) = 6
267        assert_eq!(tree.count_elements(), 6);
268    }
269
270    #[test]
271    fn test_simplify() {
272        let mut root = ElementNode::new("body");
273        root.add_child(ElementNode::new("p").with_text("Content"));
274        root.add_child(ElementNode::new("script").with_text("alert('test')"));
275        root.add_child(ElementNode::new("style").with_text(".test {}"));
276
277        let mut tree = DomTree::new(root);
278        tree.simplify();
279
280        assert_eq!(tree.root.children.len(), 1);
281        assert!(tree.root.children[0].is_tag("p"));
282    }
283
284    #[test]
285    fn test_to_json() {
286        let mut root = ElementNode::new("div");
287        root.add_attribute("id", "container");
288        root.add_child(ElementNode::new("span").with_text("Hello"));
289
290        let tree = DomTree::new(root);
291        let json = tree.to_json().unwrap();
292
293        assert!(json.contains("\"tag_name\": \"div\""));
294        assert!(json.contains("\"id\": \"container\""));
295        assert!(json.contains("\"span\""));
296        assert!(json.contains("Hello"));
297    }
298
299    #[test]
300    fn test_get_selector() {
301        let root = create_test_tree();
302        let mut tree = DomTree::new(root);
303        tree.build_selector_map();
304
305        let indices = tree.interactive_indices();
306        for &index in &indices {
307            let selector = tree.get_selector(index);
308            assert!(selector.is_some());
309        }
310    }
311}