1use super::*;
2
3pub trait NodeExt {
4 fn find(&self, selector: &str) -> Option<Node<'_>>;
5 fn find_all(&self, selector: &str) -> Vec<Node<'_>>;
6 fn get_function(&self, document: &Document) -> Option<Function>;
7 fn get_parent(&self, kind: &str) -> Option<Node<'_>>;
8 fn get_range(&self, document: &Document) -> lsp::Range;
9 fn get_recipe(&self, document: &Document) -> Option<Recipe>;
10 fn has_any_parent(&self, kinds: &[&str]) -> bool;
11 fn siblings(&self) -> impl Iterator<Item = Node<'_>>;
12}
13
14fn collect_nodes_by_kind<'a>(node: Node<'a>, kind: &str) -> Vec<Node<'a>> {
15 let self_match = if node.kind() == kind {
16 vec![node]
17 } else {
18 Vec::new()
19 };
20
21 let children_matches = (0..node.child_count())
22 .filter_map(|i| child_at(&node, i))
23 .flat_map(|child| collect_nodes_by_kind(child, kind))
24 .collect::<Vec<_>>();
25
26 [self_match, children_matches].concat()
27}
28
29fn collect_descendants_by_kind<'a>(
30 node: Node<'a>,
31 kind: &str,
32) -> Vec<Node<'a>> {
33 (0..node.child_count())
34 .filter_map(|i| child_at(&node, i))
35 .flat_map(|child| {
36 let self_match = if child.kind() == kind {
37 vec![child]
38 } else {
39 Vec::new()
40 };
41
42 let descendants = collect_descendants_by_kind(child, kind);
43
44 [self_match, descendants].concat()
45 })
46 .collect()
47}
48
49fn child_at<'a>(node: &Node<'a>, index: usize) -> Option<Node<'a>> {
50 index.try_into().ok().and_then(|index| node.child(index))
51}
52
53impl NodeExt for Node<'_> {
54 fn find(&self, selector: &str) -> Option<Node<'_>> {
55 self.find_all(selector).into_iter().next()
56 }
57
58 fn find_all(&self, selector: &str) -> Vec<Node<'_>> {
59 if selector.contains(',') {
60 return selector
61 .split(',')
62 .map(str::trim)
63 .flat_map(|sub_selector| self.find_all(sub_selector))
64 .collect();
65 }
66
67 if let Some(position_str) = selector.strip_prefix('@') {
68 return position_str
69 .parse::<usize>()
70 .ok()
71 .and_then(|position| child_at(self, position))
72 .map_or_else(Vec::new, |child| vec![child]);
73 }
74
75 if let Some(rest) = selector.strip_prefix('^') {
76 if rest.contains('[') && rest.ends_with(']') {
77 let parts: Vec<&str> = rest.split('[').collect();
78
79 if parts.len() == 2 {
80 let (kind, index_str) = (parts[0], &parts[1][..parts[1].len() - 1]);
81
82 if let Ok(index) = index_str.parse::<usize>() {
83 let direct_children = (0..self.child_count())
84 .filter_map(|i| child_at(self, i))
85 .filter(|child| child.kind() == kind)
86 .collect::<Vec<_>>();
87
88 return direct_children
89 .get(index)
90 .copied()
91 .map_or_else(Vec::new, |node| vec![node]);
92 }
93 }
94 }
95
96 return (0..self.child_count())
97 .filter_map(|i| child_at(self, i))
98 .filter(|child| child.kind() == rest)
99 .collect();
100 }
101
102 if selector.contains('[') && selector.ends_with(']') {
103 let parts: Vec<&str> = selector.split('[').collect();
104
105 if parts.len() == 2 {
106 let (kind, index_str) = (parts[0], &parts[1][..parts[1].len() - 1]);
107
108 if let Ok(index) = index_str.parse::<usize>() {
109 let all_of_kind = self.find_all(kind);
110 return all_of_kind
111 .get(index)
112 .copied()
113 .map_or_else(Vec::new, |node| vec![node]);
114 }
115 }
116 }
117
118 if selector.contains(" > ") {
119 let parts: Vec<&str> = selector.split(" > ").collect();
120
121 return parts.iter().skip(1).fold(
122 self.find_all(parts[0]),
123 |parents, &child_kind| {
124 parents
125 .iter()
126 .flat_map(|parent| {
127 (0..parent.child_count())
128 .filter_map(|i| child_at(parent, i))
129 .filter(|child| child.kind() == child_kind)
130 .collect::<Vec<_>>()
131 })
132 .collect()
133 },
134 );
135 }
136
137 if selector.contains(' ') {
138 let parts: Vec<&str> = selector.split_whitespace().collect();
139
140 return parts.iter().skip(1).fold(
141 self.find_all(parts[0]),
142 |ancestors, &descendant_kind| {
143 ancestors
144 .iter()
145 .flat_map(|&ancestor| {
146 collect_descendants_by_kind(ancestor, descendant_kind)
147 })
148 .collect()
149 },
150 );
151 }
152
153 collect_nodes_by_kind(*self, selector)
154 }
155
156 fn get_function(&self, document: &Document) -> Option<Function> {
157 let function_node = self.get_parent("function_definition")?;
158
159 document.find_function(
160 &document.get_node_text(&function_node.child_by_field_name("name")?),
161 )
162 }
163
164 fn get_parent(&self, kind: &str) -> Option<Node<'_>> {
165 let mut current = *self;
166
167 while let Some(parent) = current.parent() {
168 if parent.kind() == kind {
169 return Some(parent);
170 }
171
172 current = parent;
173 }
174
175 None
176 }
177
178 fn get_range(&self, document: &Document) -> lsp::Range {
179 lsp::Range {
180 start: self.start_position().position(document),
181 end: self.end_position().position(document),
182 }
183 }
184
185 fn get_recipe(&self, document: &Document) -> Option<Recipe> {
186 let recipe_node = self.get_parent("recipe")?;
187
188 document.find_recipe(
189 &document.get_node_text(&recipe_node.find("recipe_header > identifier")?),
190 )
191 }
192
193 fn has_any_parent(&self, kinds: &[&str]) -> bool {
194 kinds.iter().any(|kind| self.get_parent(kind).is_some())
195 }
196
197 fn siblings(&self) -> impl Iterator<Item = Node<'_>> {
198 successors(self.next_sibling(), Node::next_sibling)
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use {super::*, indoc::indoc, pretty_assertions::assert_eq};
205
206 #[test]
207 fn find_basic_kind() {
208 let document = Document::from(indoc! {
209 "
210 foo:
211 echo \"foo\"
212
213 bar:
214 echo \"bar\"
215 "
216 });
217
218 let root = document.tree.as_ref().unwrap().root_node();
219
220 let recipes = root.find_all("recipe");
221
222 assert_eq!(recipes.len(), 2);
223
224 let recipe_texts = recipes
225 .iter()
226 .map(|recipe| document.get_node_text(recipe).trim().to_string())
227 .collect::<Vec<_>>();
228
229 assert_eq!(
230 recipe_texts,
231 vec![
232 "foo:\n echo \"foo\"".to_string(),
233 "bar:\n echo \"bar\"".to_string()
234 ]
235 );
236
237 let first_recipe = root.find("recipe").unwrap();
238
239 assert_eq!(
240 document.get_node_text(&first_recipe).trim(),
241 "foo:\n echo \"foo\""
242 );
243 }
244
245 #[test]
246 fn find_indexed_nodes() {
247 let document = Document::from(indoc! {
248 "
249 foo:
250 echo \"foo\"
251
252 bar:
253 echo \"bar\"
254
255 baz:
256 echo \"baz\"
257 "
258 });
259
260 let root = document.tree.as_ref().unwrap().root_node();
261
262 let selectors = ["recipe[0]", "recipe[1]", "recipe[2]"];
263
264 let recipe_texts = selectors
265 .iter()
266 .map(|selector| {
267 document
268 .get_node_text(&root.find(selector).unwrap())
269 .trim()
270 .to_string()
271 })
272 .collect::<Vec<_>>();
273
274 assert_eq!(
275 recipe_texts,
276 vec![
277 "foo:\n echo \"foo\"".to_string(),
278 "bar:\n echo \"bar\"".to_string(),
279 "baz:\n echo \"baz\"".to_string()
280 ]
281 );
282
283 assert!(root.find("recipe[10]").is_none());
284 }
285
286 #[test]
287 fn find_direct_child() {
288 let document = Document::from(indoc! {
289 "
290 foo:
291 echo \"foo\"
292
293 bar arg1 arg2:
294 echo \"bar\"
295 "
296 });
297
298 let root = document.tree.as_ref().unwrap().root_node();
299
300 let identifiers = root.find_all("recipe_header > identifier");
301
302 let identifier_texts = identifiers
303 .iter()
304 .map(|node| document.get_node_text(node))
305 .collect::<Vec<_>>();
306
307 assert_eq!(identifier_texts, vec!["foo".to_string(), "bar".to_string()]);
308
309 let second_recipe = root.find("recipe[1]").unwrap();
310
311 let recipe_header = second_recipe.find("recipe_header").unwrap();
312
313 let parameters = recipe_header.find_all("parameters > parameter");
314
315 let parameter_texts = parameters
316 .iter()
317 .map(|node| document.get_node_text(node))
318 .collect::<Vec<_>>();
319
320 assert_eq!(
321 parameter_texts,
322 vec!["arg1".to_string(), "arg2".to_string()]
323 );
324 }
325
326 #[test]
327 fn find_descendant() {
328 let document = Document::from(indoc! {
329 "
330 foo:
331 echo \"foo\"
332
333 bar arg1 arg2:
334 echo \"{{ arch() }}\"
335 "
336 });
337
338 let root = document.tree.as_ref().unwrap().root_node();
339
340 let identifier_texts = root
341 .find_all("identifier")
342 .iter()
343 .map(|node| document.get_node_text(node))
344 .collect::<Vec<_>>();
345
346 assert_eq!(
347 identifier_texts,
348 vec![
349 "foo".to_string(),
350 "bar".to_string(),
351 "arg1".to_string(),
352 "arg2".to_string(),
353 "arch".to_string()
354 ]
355 );
356
357 let recipe_identifier_texts = root
358 .find_all("recipe identifier")
359 .iter()
360 .map(|node| document.get_node_text(node))
361 .collect::<Vec<_>>();
362
363 assert_eq!(recipe_identifier_texts, identifier_texts);
364
365 let function_call_texts = root
366 .find_all("recipe function_call")
367 .iter()
368 .map(|node| document.get_node_text(node).trim().to_string())
369 .collect::<Vec<_>>();
370
371 assert_eq!(function_call_texts, vec!["arch()".to_string()]);
372
373 let function_identifier_texts = root
374 .find_all("function_call identifier")
375 .iter()
376 .map(|node| document.get_node_text(node))
377 .collect::<Vec<_>>();
378
379 assert_eq!(function_identifier_texts, vec!["arch".to_string()]);
380 }
381
382 #[test]
383 fn find_union() {
384 let document = Document::from(indoc! {
385 "
386 foo := \"value\"
387
388 foo:
389 echo \"foo\"
390
391 bar:
392 echo \"bar\"
393 "
394 });
395
396 let root = document.tree.as_ref().unwrap().root_node();
397
398 let recipes_and_assignments = root.find_all("recipe, assignment");
399
400 let kinds = recipes_and_assignments
401 .iter()
402 .map(Node::kind)
403 .collect::<Vec<_>>();
404
405 assert_eq!(kinds, ["recipe", "recipe", "assignment"]);
406
407 let node_texts = recipes_and_assignments
408 .iter()
409 .map(|node| document.get_node_text(node).trim().to_string())
410 .collect::<Vec<_>>();
411
412 assert_eq!(
413 node_texts,
414 vec![
415 "foo:\n echo \"foo\"".to_string(),
416 "bar:\n echo \"bar\"".to_string(),
417 "foo := \"value\"".to_string()
418 ]
419 );
420
421 let identifier_texts = root
422 .find_all("recipe_header > identifier, function_call > identifier")
423 .iter()
424 .map(|node| document.get_node_text(node))
425 .collect::<Vec<_>>();
426
427 assert_eq!(identifier_texts, vec!["foo".to_string(), "bar".to_string()]);
428 }
429
430 #[test]
431 fn find_direct_child_marker() {
432 let document = Document::from(indoc! {
433 "
434 foo:
435 echo \"foo\"
436
437 bar arg1 arg2:
438 echo \"{{ arch() }}\"
439 "
440 });
441
442 let root = document.tree.as_ref().unwrap().root_node();
443
444 let second_recipe = root.find("recipe[1]").unwrap();
445
446 let recipe_header = second_recipe.find("recipe_header").unwrap();
447 let parameters_node = recipe_header.find("parameters").unwrap();
448
449 let direct_parameters = parameters_node.find_all("^parameter");
450
451 assert_eq!(direct_parameters.len(), 2);
452
453 let parameter_texts = direct_parameters
454 .iter()
455 .map(|node| document.get_node_text(node))
456 .collect::<Vec<_>>();
457
458 assert_eq!(
459 parameter_texts,
460 vec!["arg1".to_string(), "arg2".to_string()]
461 );
462 }
463
464 #[test]
465 fn find_nonexistent() {
466 let document = Document::from(indoc! {
467 "
468 foo:
469 echo \"foo\"
470 "
471 });
472
473 let tree = document.tree.as_ref().unwrap();
474 let root = tree.root_node();
475
476 let nonexistent = root.find("nonexistent_kind");
477 assert!(nonexistent.is_none());
478
479 let empty_results = root.find_all("nonexistent_kind");
480 assert_eq!(empty_results.len(), 0);
481
482 let no_function_calls = root.find_all("function_call");
483 assert_eq!(no_function_calls.len(), 0);
484 }
485
486 #[test]
487 fn find_nth_occurrence() {
488 let document = Document::from(indoc! {
489 "
490 alias foo := bar
491 "
492 });
493
494 let root = document.tree.as_ref().unwrap().root_node();
495
496 let alias = root.find("alias").unwrap();
497
498 let first_identifier = alias.find("identifier[0]").unwrap();
499 let second_identifier = alias.find("identifier[1]").unwrap();
500
501 assert_eq!(document.get_node_text(&first_identifier), "foo");
502 assert_eq!(document.get_node_text(&second_identifier), "bar");
503 }
504
505 #[test]
506 fn find_nested_child() {
507 let document = Document::from(indoc! {
508 "
509 foo: (bar baz):
510 echo foo
511 "
512 });
513
514 let root = document.tree.as_ref().unwrap().root_node();
515
516 let identifier =
517 root.find("dependency_expression > expression > value > identifier");
518
519 let identifier = identifier.unwrap();
520
521 assert_eq!(document.get_node_text(&identifier), "baz");
522 }
523}