browser_use/dom/
tree.rs

1use crate::dom::element::{AriaChild, AriaNode};
2use crate::error::{BrowserError, Result};
3use headless_chrome::Tab;
4use std::sync::Arc;
5
6/// Represents the ARIA snapshot of a web page
7/// Based on Playwright's AriaSnapshot structure
8#[derive(Debug, Clone)]
9pub struct DomTree {
10    /// Root AriaNode (usually a fragment)
11    pub root: AriaNode,
12
13    /// Array of CSS selectors indexed by element index
14    pub selectors: Vec<String>,
15
16    /// List of iframe indices (for multi-frame snapshots)
17    pub iframe_indices: Vec<usize>,
18}
19
20/// Snapshot extraction response from JavaScript
21#[derive(Debug, serde::Deserialize)]
22struct SnapshotResponse {
23    root: AriaNode,
24    selectors: Vec<String>,
25    #[serde(rename = "iframeIndices")]
26    iframe_indices: Vec<usize>,
27}
28
29impl DomTree {
30    /// Create a new DomTree from an AriaNode
31    pub fn new(root: AriaNode) -> Self {
32        let mut tree = Self {
33            root,
34            selectors: Vec::new(),
35            iframe_indices: Vec::new(),
36        };
37        tree.rebuild_maps();
38        tree
39    }
40
41    /// Build DOM tree from a browser tab
42    pub fn from_tab(tab: &Arc<Tab>) -> Result<Self> {
43        Self::from_tab_with_prefix(tab, "")
44    }
45
46    /// Build DOM tree from a browser tab with a ref prefix (for iframe handling)
47    pub fn from_tab_with_prefix(tab: &Arc<Tab>, _ref_prefix: &str) -> Result<Self> {
48        // Note: ref_prefix is deprecated but kept for API compatibility
49        // JavaScript code to extract ARIA snapshot
50        let js_code = include_str!("extract_dom.js");
51
52        // Execute JavaScript to extract DOM
53        let result = tab.evaluate(js_code, false).map_err(|e| {
54            BrowserError::DomParseFailed(format!("Failed to execute DOM extraction script: {}", e))
55        })?;
56
57        // Get the JSON string value
58        let json_value = result.value.ok_or_else(|| {
59            BrowserError::DomParseFailed("No value returned from DOM extraction".to_string())
60        })?;
61
62        // The JavaScript returns a JSON string, so we need to parse it as a string first
63        let json_str: String = serde_json::from_value(json_value).map_err(|e| {
64            BrowserError::DomParseFailed(format!("Failed to get JSON string: {}", e))
65        })?;
66
67        // Then parse the JSON string into SnapshotResponse
68        let response: SnapshotResponse = serde_json::from_str(&json_str).map_err(|e| {
69            BrowserError::DomParseFailed(format!("Failed to parse snapshot JSON: {}", e))
70        })?;
71
72        Ok(Self {
73            root: response.root,
74            selectors: response.selectors,
75            iframe_indices: response.iframe_indices,
76        })
77    }
78
79    /// Rebuild the selectors array by traversing the tree
80    /// Note: This only resizes the array based on indices found.
81    /// Actual selectors are populated from JavaScript extraction.
82    fn rebuild_maps(&mut self) {
83        self.iframe_indices.clear();
84
85        // Find the maximum index in the tree
86        let max_index = self.find_max_index(&self.root.clone());
87
88        // Resize selectors array if needed
89        if let Some(max_idx) = max_index {
90            if self.selectors.len() <= max_idx {
91                self.selectors.resize(max_idx + 1, String::new());
92            }
93        }
94
95        // Collect iframe indices
96        let root = self.root.clone();
97        self.collect_iframe_indices(&root);
98    }
99
100    fn find_max_index(&self, node: &AriaNode) -> Option<usize> {
101        let mut max = node.index;
102
103        for child in &node.children {
104            if let AriaChild::Node(child_node) = child {
105                if let Some(child_max) = self.find_max_index(child_node) {
106                    max = match max {
107                        Some(current) => Some(current.max(child_max)),
108                        None => Some(child_max),
109                    };
110                }
111            }
112        }
113
114        max
115    }
116
117    fn collect_iframe_indices(&mut self, node: &AriaNode) {
118        if let Some(index) = node.index {
119            if node.role == "iframe" {
120                self.iframe_indices.push(index);
121            }
122        }
123
124        for child in &node.children {
125            if let AriaChild::Node(child_node) = child {
126                self.collect_iframe_indices(child_node);
127            }
128        }
129    }
130
131    /// Get CSS selector for a given index
132    pub fn get_selector(&self, index: usize) -> Option<&String> {
133        self.selectors.get(index).filter(|s| !s.is_empty())
134    }
135
136    /// Get all interactive element indices
137    pub fn interactive_indices(&self) -> Vec<usize> {
138        let mut indices = Vec::new();
139        self.collect_indices(&self.root, &mut indices);
140        indices.sort();
141        indices
142    }
143
144    fn collect_indices(&self, node: &AriaNode, indices: &mut Vec<usize>) {
145        if let Some(index) = node.index {
146            indices.push(index);
147        }
148        for child in &node.children {
149            if let AriaChild::Node(child_node) = child {
150                self.collect_indices(child_node, indices);
151            }
152        }
153    }
154
155    /// Count total nodes in the tree
156    pub fn count_nodes(&self) -> usize {
157        self.root.count_nodes()
158    }
159
160    /// Count interactive elements (elements with indices)
161    pub fn count_interactive(&self) -> usize {
162        self.root.count_interactive()
163    }
164
165    /// Find node by index
166    pub fn find_node_by_index(&self, index: usize) -> Option<&AriaNode> {
167        self.root.find_by_index(index)
168    }
169
170    /// Find node by index (mutable)
171    pub fn find_node_by_index_mut(&mut self, index: usize) -> Option<&mut AriaNode> {
172        self.root.find_by_index_mut(index)
173    }
174
175    /// Get all iframe indices for multi-frame snapshot handling
176    pub fn get_iframe_indices(&self) -> &[usize] {
177        &self.iframe_indices
178    }
179
180    /// Convert the DOM tree to JSON
181    pub fn to_json(&self) -> Result<String> {
182        serde_json::to_string_pretty(&self.root).map_err(|e| {
183            BrowserError::DomParseFailed(format!("Failed to serialize DOM to JSON: {}", e))
184        })
185    }
186
187    /// Replace an iframe node's children with content from another snapshot
188    /// Used for multi-frame snapshot assembly
189    pub fn inject_iframe_content(&mut self, iframe_index: usize, iframe_snapshot: DomTree) {
190        if let Some(iframe_node) = self.find_node_by_index_mut(iframe_index) {
191            // Replace iframe's children with the snapshot's root children
192            iframe_node.children = iframe_snapshot.root.children;
193
194            // Merge selectors (offset by current length)
195            let offset = self.selectors.len();
196            for selector in iframe_snapshot.selectors {
197                if !selector.is_empty() {
198                    self.selectors.push(selector);
199                }
200            }
201
202            // Update iframe indices with offset
203            for idx in iframe_snapshot.iframe_indices {
204                self.iframe_indices.push(idx + offset);
205            }
206        }
207    }
208
209    /// Create a snapshot with multiple frames assembled
210    /// Takes a function that can retrieve snapshots for iframe elements
211    pub fn assemble_with_iframes<F>(mut self, mut get_iframe_snapshot: F) -> Self
212    where
213        F: FnMut(usize) -> Option<DomTree>,
214    {
215        let iframe_indices = self.iframe_indices.clone();
216
217        for iframe_index in iframe_indices {
218            if let Some(iframe_snapshot) = get_iframe_snapshot(iframe_index) {
219                self.inject_iframe_content(iframe_index, iframe_snapshot);
220            }
221        }
222
223        self
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    fn create_test_tree() -> AriaNode {
232        let mut root = AriaNode::fragment();
233
234        root.children.push(AriaChild::Node(Box::new(
235            AriaNode::new("button", "Click me")
236                .with_index(0)
237                .with_box(true, Some("pointer".to_string())),
238        )));
239
240        root.children.push(AriaChild::Node(Box::new(
241            AriaNode::new("link", "Go to page")
242                .with_index(1)
243                .with_box(true, None),
244        )));
245
246        root.children.push(AriaChild::Node(Box::new(
247            AriaNode::new("paragraph", "").with_child(AriaChild::Text("Some text".to_string())),
248        )));
249
250        root
251    }
252
253    #[test]
254    fn test_find_node_by_index() {
255        let root = create_test_tree();
256        let tree = DomTree::new(root);
257
258        let button = tree.find_node_by_index(0);
259        assert!(button.is_some());
260        assert_eq!(button.unwrap().role, "button");
261        assert_eq!(button.unwrap().name, "Click me");
262
263        let not_found = tree.find_node_by_index(999);
264        assert!(not_found.is_none());
265    }
266
267    #[test]
268    fn test_count_nodes() {
269        let root = create_test_tree();
270        let tree = DomTree::new(root);
271
272        // fragment + button + link + paragraph = 4
273        assert_eq!(tree.count_nodes(), 4);
274    }
275
276    #[test]
277    fn test_interactive_indices() {
278        let root = create_test_tree();
279        let tree = DomTree::new(root);
280
281        let indices = tree.interactive_indices();
282        assert_eq!(indices.len(), 2);
283        assert!(indices.contains(&0));
284        assert!(indices.contains(&1));
285    }
286
287    #[test]
288    fn test_inject_iframe_content() {
289        let mut main_tree = AriaNode::fragment();
290        main_tree.children.push(AriaChild::Node(Box::new(
291            AriaNode::new("iframe", "").with_index(0),
292        )));
293
294        let mut iframe_tree = AriaNode::fragment();
295        iframe_tree.children.push(AriaChild::Node(Box::new(
296            AriaNode::new("button", "Inside iframe").with_index(0),
297        )));
298
299        let mut main = DomTree::new(main_tree);
300        let iframe = DomTree::new(iframe_tree);
301
302        main.inject_iframe_content(0, iframe);
303
304        // Check that iframe now has the button as a child
305        let iframe_node = main.find_node_by_index(0).unwrap();
306        assert_eq!(iframe_node.children.len(), 1);
307
308        match &iframe_node.children[0] {
309            AriaChild::Node(n) => {
310                assert_eq!(n.role, "button");
311                assert_eq!(n.name, "Inside iframe");
312            }
313            _ => panic!("Expected node child"),
314        }
315    }
316}