1use crate::dom::element::{AriaChild, AriaNode};
2use crate::error::{BrowserError, Result};
3use headless_chrome::Tab;
4use std::sync::Arc;
5
6#[derive(Debug, Clone)]
9pub struct DomTree {
10 pub root: AriaNode,
12
13 pub selectors: Vec<String>,
15
16 pub iframe_indices: Vec<usize>,
18}
19
20#[derive(Debug, serde::Deserialize)]
22struct SnapshotResponse {
23 root: AriaNode,
24 selectors: Vec<String>,
25 #[serde(rename = "iframeIndices")]
26 iframe_indices: Vec<usize>,
27}
28
29impl DomTree {
30 pub fn new(root: AriaNode) -> Self {
32 let mut tree = Self {
33 root,
34 selectors: Vec::new(),
35 iframe_indices: Vec::new(),
36 };
37 tree.rebuild_maps();
38 tree
39 }
40
41 pub fn from_tab(tab: &Arc<Tab>) -> Result<Self> {
43 Self::from_tab_with_prefix(tab, "")
44 }
45
46 pub fn from_tab_with_prefix(tab: &Arc<Tab>, _ref_prefix: &str) -> Result<Self> {
48 let js_code = include_str!("extract_dom.js");
51
52 let result = tab.evaluate(js_code, false).map_err(|e| {
54 BrowserError::DomParseFailed(format!("Failed to execute DOM extraction script: {}", e))
55 })?;
56
57 let json_value = result.value.ok_or_else(|| {
59 BrowserError::DomParseFailed("No value returned from DOM extraction".to_string())
60 })?;
61
62 let json_str: String = serde_json::from_value(json_value).map_err(|e| {
64 BrowserError::DomParseFailed(format!("Failed to get JSON string: {}", e))
65 })?;
66
67 let response: SnapshotResponse = serde_json::from_str(&json_str).map_err(|e| {
69 BrowserError::DomParseFailed(format!("Failed to parse snapshot JSON: {}", e))
70 })?;
71
72 Ok(Self {
73 root: response.root,
74 selectors: response.selectors,
75 iframe_indices: response.iframe_indices,
76 })
77 }
78
79 fn rebuild_maps(&mut self) {
83 self.iframe_indices.clear();
84
85 let max_index = self.find_max_index(&self.root.clone());
87
88 if let Some(max_idx) = max_index {
90 if self.selectors.len() <= max_idx {
91 self.selectors.resize(max_idx + 1, String::new());
92 }
93 }
94
95 let root = self.root.clone();
97 self.collect_iframe_indices(&root);
98 }
99
100 fn find_max_index(&self, node: &AriaNode) -> Option<usize> {
101 let mut max = node.index;
102
103 for child in &node.children {
104 if let AriaChild::Node(child_node) = child {
105 if let Some(child_max) = self.find_max_index(child_node) {
106 max = match max {
107 Some(current) => Some(current.max(child_max)),
108 None => Some(child_max),
109 };
110 }
111 }
112 }
113
114 max
115 }
116
117 fn collect_iframe_indices(&mut self, node: &AriaNode) {
118 if let Some(index) = node.index {
119 if node.role == "iframe" {
120 self.iframe_indices.push(index);
121 }
122 }
123
124 for child in &node.children {
125 if let AriaChild::Node(child_node) = child {
126 self.collect_iframe_indices(child_node);
127 }
128 }
129 }
130
131 pub fn get_selector(&self, index: usize) -> Option<&String> {
133 self.selectors.get(index).filter(|s| !s.is_empty())
134 }
135
136 pub fn interactive_indices(&self) -> Vec<usize> {
138 let mut indices = Vec::new();
139 self.collect_indices(&self.root, &mut indices);
140 indices.sort();
141 indices
142 }
143
144 fn collect_indices(&self, node: &AriaNode, indices: &mut Vec<usize>) {
145 if let Some(index) = node.index {
146 indices.push(index);
147 }
148 for child in &node.children {
149 if let AriaChild::Node(child_node) = child {
150 self.collect_indices(child_node, indices);
151 }
152 }
153 }
154
155 pub fn count_nodes(&self) -> usize {
157 self.root.count_nodes()
158 }
159
160 pub fn count_interactive(&self) -> usize {
162 self.root.count_interactive()
163 }
164
165 pub fn find_node_by_index(&self, index: usize) -> Option<&AriaNode> {
167 self.root.find_by_index(index)
168 }
169
170 pub fn find_node_by_index_mut(&mut self, index: usize) -> Option<&mut AriaNode> {
172 self.root.find_by_index_mut(index)
173 }
174
175 pub fn get_iframe_indices(&self) -> &[usize] {
177 &self.iframe_indices
178 }
179
180 pub fn to_json(&self) -> Result<String> {
182 serde_json::to_string_pretty(&self.root).map_err(|e| {
183 BrowserError::DomParseFailed(format!("Failed to serialize DOM to JSON: {}", e))
184 })
185 }
186
187 pub fn inject_iframe_content(&mut self, iframe_index: usize, iframe_snapshot: DomTree) {
190 if let Some(iframe_node) = self.find_node_by_index_mut(iframe_index) {
191 iframe_node.children = iframe_snapshot.root.children;
193
194 let offset = self.selectors.len();
196 for selector in iframe_snapshot.selectors {
197 if !selector.is_empty() {
198 self.selectors.push(selector);
199 }
200 }
201
202 for idx in iframe_snapshot.iframe_indices {
204 self.iframe_indices.push(idx + offset);
205 }
206 }
207 }
208
209 pub fn assemble_with_iframes<F>(mut self, mut get_iframe_snapshot: F) -> Self
212 where
213 F: FnMut(usize) -> Option<DomTree>,
214 {
215 let iframe_indices = self.iframe_indices.clone();
216
217 for iframe_index in iframe_indices {
218 if let Some(iframe_snapshot) = get_iframe_snapshot(iframe_index) {
219 self.inject_iframe_content(iframe_index, iframe_snapshot);
220 }
221 }
222
223 self
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 fn create_test_tree() -> AriaNode {
232 let mut root = AriaNode::fragment();
233
234 root.children.push(AriaChild::Node(Box::new(
235 AriaNode::new("button", "Click me")
236 .with_index(0)
237 .with_box(true, Some("pointer".to_string())),
238 )));
239
240 root.children.push(AriaChild::Node(Box::new(
241 AriaNode::new("link", "Go to page")
242 .with_index(1)
243 .with_box(true, None),
244 )));
245
246 root.children.push(AriaChild::Node(Box::new(
247 AriaNode::new("paragraph", "").with_child(AriaChild::Text("Some text".to_string())),
248 )));
249
250 root
251 }
252
253 #[test]
254 fn test_find_node_by_index() {
255 let root = create_test_tree();
256 let tree = DomTree::new(root);
257
258 let button = tree.find_node_by_index(0);
259 assert!(button.is_some());
260 assert_eq!(button.unwrap().role, "button");
261 assert_eq!(button.unwrap().name, "Click me");
262
263 let not_found = tree.find_node_by_index(999);
264 assert!(not_found.is_none());
265 }
266
267 #[test]
268 fn test_count_nodes() {
269 let root = create_test_tree();
270 let tree = DomTree::new(root);
271
272 assert_eq!(tree.count_nodes(), 4);
274 }
275
276 #[test]
277 fn test_interactive_indices() {
278 let root = create_test_tree();
279 let tree = DomTree::new(root);
280
281 let indices = tree.interactive_indices();
282 assert_eq!(indices.len(), 2);
283 assert!(indices.contains(&0));
284 assert!(indices.contains(&1));
285 }
286
287 #[test]
288 fn test_inject_iframe_content() {
289 let mut main_tree = AriaNode::fragment();
290 main_tree.children.push(AriaChild::Node(Box::new(
291 AriaNode::new("iframe", "").with_index(0),
292 )));
293
294 let mut iframe_tree = AriaNode::fragment();
295 iframe_tree.children.push(AriaChild::Node(Box::new(
296 AriaNode::new("button", "Inside iframe").with_index(0),
297 )));
298
299 let mut main = DomTree::new(main_tree);
300 let iframe = DomTree::new(iframe_tree);
301
302 main.inject_iframe_content(0, iframe);
303
304 let iframe_node = main.find_node_by_index(0).unwrap();
306 assert_eq!(iframe_node.children.len(), 1);
307
308 match &iframe_node.children[0] {
309 AriaChild::Node(n) => {
310 assert_eq!(n.role, "button");
311 assert_eq!(n.name, "Inside iframe");
312 }
313 _ => panic!("Expected node child"),
314 }
315 }
316}