Skip to main content

just_lsp/
node_ext.rs

1use super::*;
2
3pub trait NodeExt {
4  fn find(&self, selector: &str) -> Option<Node<'_>>;
5  fn find_all(&self, selector: &str) -> Vec<Node<'_>>;
6  fn get_function(&self, document: &Document) -> Option<Function>;
7  fn get_parent(&self, kind: &str) -> Option<Node<'_>>;
8  fn get_range(&self, document: &Document) -> lsp::Range;
9  fn get_recipe(&self, document: &Document) -> Option<Recipe>;
10  fn has_any_parent(&self, kinds: &[&str]) -> bool;
11  fn siblings(&self) -> impl Iterator<Item = Node<'_>>;
12}
13
14fn collect_nodes_by_kind<'a>(node: Node<'a>, kind: &str) -> Vec<Node<'a>> {
15  let self_match = if node.kind() == kind {
16    vec![node]
17  } else {
18    Vec::new()
19  };
20
21  let children_matches = (0..node.child_count())
22    .filter_map(|i| child_at(&node, i))
23    .flat_map(|child| collect_nodes_by_kind(child, kind))
24    .collect::<Vec<_>>();
25
26  [self_match, children_matches].concat()
27}
28
29fn collect_descendants_by_kind<'a>(
30  node: Node<'a>,
31  kind: &str,
32) -> Vec<Node<'a>> {
33  (0..node.child_count())
34    .filter_map(|i| child_at(&node, i))
35    .flat_map(|child| {
36      let self_match = if child.kind() == kind {
37        vec![child]
38      } else {
39        Vec::new()
40      };
41
42      let descendants = collect_descendants_by_kind(child, kind);
43
44      [self_match, descendants].concat()
45    })
46    .collect()
47}
48
49fn child_at<'a>(node: &Node<'a>, index: usize) -> Option<Node<'a>> {
50  index.try_into().ok().and_then(|index| node.child(index))
51}
52
53impl NodeExt for Node<'_> {
54  fn find(&self, selector: &str) -> Option<Node<'_>> {
55    self.find_all(selector).into_iter().next()
56  }
57
58  fn find_all(&self, selector: &str) -> Vec<Node<'_>> {
59    if selector.contains(',') {
60      return selector
61        .split(',')
62        .map(str::trim)
63        .flat_map(|sub_selector| self.find_all(sub_selector))
64        .collect();
65    }
66
67    if let Some(position_str) = selector.strip_prefix('@') {
68      return position_str
69        .parse::<usize>()
70        .ok()
71        .and_then(|position| child_at(self, position))
72        .map_or_else(Vec::new, |child| vec![child]);
73    }
74
75    if let Some(rest) = selector.strip_prefix('^') {
76      if rest.contains('[') && rest.ends_with(']') {
77        let parts: Vec<&str> = rest.split('[').collect();
78
79        if parts.len() == 2 {
80          let (kind, index_str) = (parts[0], &parts[1][..parts[1].len() - 1]);
81
82          if let Ok(index) = index_str.parse::<usize>() {
83            let direct_children = (0..self.child_count())
84              .filter_map(|i| child_at(self, i))
85              .filter(|child| child.kind() == kind)
86              .collect::<Vec<_>>();
87
88            return direct_children
89              .get(index)
90              .copied()
91              .map_or_else(Vec::new, |node| vec![node]);
92          }
93        }
94      }
95
96      return (0..self.child_count())
97        .filter_map(|i| child_at(self, i))
98        .filter(|child| child.kind() == rest)
99        .collect();
100    }
101
102    if selector.contains('[') && selector.ends_with(']') {
103      let parts: Vec<&str> = selector.split('[').collect();
104
105      if parts.len() == 2 {
106        let (kind, index_str) = (parts[0], &parts[1][..parts[1].len() - 1]);
107
108        if let Ok(index) = index_str.parse::<usize>() {
109          let all_of_kind = self.find_all(kind);
110          return all_of_kind
111            .get(index)
112            .copied()
113            .map_or_else(Vec::new, |node| vec![node]);
114        }
115      }
116    }
117
118    if selector.contains(" > ") {
119      let parts: Vec<&str> = selector.split(" > ").collect();
120
121      return parts.iter().skip(1).fold(
122        self.find_all(parts[0]),
123        |parents, &child_kind| {
124          parents
125            .iter()
126            .flat_map(|parent| {
127              (0..parent.child_count())
128                .filter_map(|i| child_at(parent, i))
129                .filter(|child| child.kind() == child_kind)
130                .collect::<Vec<_>>()
131            })
132            .collect()
133        },
134      );
135    }
136
137    if selector.contains(' ') {
138      let parts: Vec<&str> = selector.split_whitespace().collect();
139
140      return parts.iter().skip(1).fold(
141        self.find_all(parts[0]),
142        |ancestors, &descendant_kind| {
143          ancestors
144            .iter()
145            .flat_map(|&ancestor| {
146              collect_descendants_by_kind(ancestor, descendant_kind)
147            })
148            .collect()
149        },
150      );
151    }
152
153    collect_nodes_by_kind(*self, selector)
154  }
155
156  fn get_function(&self, document: &Document) -> Option<Function> {
157    let function_node = self.get_parent("function_definition")?;
158
159    document.find_function(
160      &document.get_node_text(&function_node.child_by_field_name("name")?),
161    )
162  }
163
164  fn get_parent(&self, kind: &str) -> Option<Node<'_>> {
165    let mut current = *self;
166
167    while let Some(parent) = current.parent() {
168      if parent.kind() == kind {
169        return Some(parent);
170      }
171
172      current = parent;
173    }
174
175    None
176  }
177
178  fn get_range(&self, document: &Document) -> lsp::Range {
179    lsp::Range {
180      start: self.start_position().position(document),
181      end: self.end_position().position(document),
182    }
183  }
184
185  fn get_recipe(&self, document: &Document) -> Option<Recipe> {
186    let recipe_node = self.get_parent("recipe")?;
187
188    document.find_recipe(
189      &document.get_node_text(&recipe_node.find("recipe_header > identifier")?),
190    )
191  }
192
193  fn has_any_parent(&self, kinds: &[&str]) -> bool {
194    kinds.iter().any(|kind| self.get_parent(kind).is_some())
195  }
196
197  fn siblings(&self) -> impl Iterator<Item = Node<'_>> {
198    successors(self.next_sibling(), Node::next_sibling)
199  }
200}
201
202#[cfg(test)]
203mod tests {
204  use {super::*, indoc::indoc, pretty_assertions::assert_eq};
205
206  #[test]
207  fn find_basic_kind() {
208    let document = Document::from(indoc! {
209      "
210      foo:
211        echo \"foo\"
212
213      bar:
214        echo \"bar\"
215      "
216    });
217
218    let root = document.tree.as_ref().unwrap().root_node();
219
220    let recipes = root.find_all("recipe");
221
222    assert_eq!(recipes.len(), 2);
223
224    let recipe_texts = recipes
225      .iter()
226      .map(|recipe| document.get_node_text(recipe).trim().to_string())
227      .collect::<Vec<_>>();
228
229    assert_eq!(
230      recipe_texts,
231      vec![
232        "foo:\n  echo \"foo\"".to_string(),
233        "bar:\n  echo \"bar\"".to_string()
234      ]
235    );
236
237    let first_recipe = root.find("recipe").unwrap();
238
239    assert_eq!(
240      document.get_node_text(&first_recipe).trim(),
241      "foo:\n  echo \"foo\""
242    );
243  }
244
245  #[test]
246  fn find_indexed_nodes() {
247    let document = Document::from(indoc! {
248      "
249      foo:
250        echo \"foo\"
251
252      bar:
253        echo \"bar\"
254
255      baz:
256        echo \"baz\"
257      "
258    });
259
260    let root = document.tree.as_ref().unwrap().root_node();
261
262    let selectors = ["recipe[0]", "recipe[1]", "recipe[2]"];
263
264    let recipe_texts = selectors
265      .iter()
266      .map(|selector| {
267        document
268          .get_node_text(&root.find(selector).unwrap())
269          .trim()
270          .to_string()
271      })
272      .collect::<Vec<_>>();
273
274    assert_eq!(
275      recipe_texts,
276      vec![
277        "foo:\n  echo \"foo\"".to_string(),
278        "bar:\n  echo \"bar\"".to_string(),
279        "baz:\n  echo \"baz\"".to_string()
280      ]
281    );
282
283    assert!(root.find("recipe[10]").is_none());
284  }
285
286  #[test]
287  fn find_direct_child() {
288    let document = Document::from(indoc! {
289      "
290      foo:
291        echo \"foo\"
292
293      bar arg1 arg2:
294        echo \"bar\"
295      "
296    });
297
298    let root = document.tree.as_ref().unwrap().root_node();
299
300    let identifiers = root.find_all("recipe_header > identifier");
301
302    let identifier_texts = identifiers
303      .iter()
304      .map(|node| document.get_node_text(node))
305      .collect::<Vec<_>>();
306
307    assert_eq!(identifier_texts, vec!["foo".to_string(), "bar".to_string()]);
308
309    let second_recipe = root.find("recipe[1]").unwrap();
310
311    let recipe_header = second_recipe.find("recipe_header").unwrap();
312
313    let parameters = recipe_header.find_all("parameters > parameter");
314
315    let parameter_texts = parameters
316      .iter()
317      .map(|node| document.get_node_text(node))
318      .collect::<Vec<_>>();
319
320    assert_eq!(
321      parameter_texts,
322      vec!["arg1".to_string(), "arg2".to_string()]
323    );
324  }
325
326  #[test]
327  fn find_descendant() {
328    let document = Document::from(indoc! {
329      "
330      foo:
331        echo \"foo\"
332
333      bar arg1 arg2:
334        echo \"{{ arch() }}\"
335      "
336    });
337
338    let root = document.tree.as_ref().unwrap().root_node();
339
340    let identifier_texts = root
341      .find_all("identifier")
342      .iter()
343      .map(|node| document.get_node_text(node))
344      .collect::<Vec<_>>();
345
346    assert_eq!(
347      identifier_texts,
348      vec![
349        "foo".to_string(),
350        "bar".to_string(),
351        "arg1".to_string(),
352        "arg2".to_string(),
353        "arch".to_string()
354      ]
355    );
356
357    let recipe_identifier_texts = root
358      .find_all("recipe identifier")
359      .iter()
360      .map(|node| document.get_node_text(node))
361      .collect::<Vec<_>>();
362
363    assert_eq!(recipe_identifier_texts, identifier_texts);
364
365    let function_call_texts = root
366      .find_all("recipe function_call")
367      .iter()
368      .map(|node| document.get_node_text(node).trim().to_string())
369      .collect::<Vec<_>>();
370
371    assert_eq!(function_call_texts, vec!["arch()".to_string()]);
372
373    let function_identifier_texts = root
374      .find_all("function_call identifier")
375      .iter()
376      .map(|node| document.get_node_text(node))
377      .collect::<Vec<_>>();
378
379    assert_eq!(function_identifier_texts, vec!["arch".to_string()]);
380  }
381
382  #[test]
383  fn find_union() {
384    let document = Document::from(indoc! {
385      "
386      foo := \"value\"
387
388      foo:
389        echo \"foo\"
390
391      bar:
392        echo \"bar\"
393      "
394    });
395
396    let root = document.tree.as_ref().unwrap().root_node();
397
398    let recipes_and_assignments = root.find_all("recipe, assignment");
399
400    let kinds = recipes_and_assignments
401      .iter()
402      .map(Node::kind)
403      .collect::<Vec<_>>();
404
405    assert_eq!(kinds, ["recipe", "recipe", "assignment"]);
406
407    let node_texts = recipes_and_assignments
408      .iter()
409      .map(|node| document.get_node_text(node).trim().to_string())
410      .collect::<Vec<_>>();
411
412    assert_eq!(
413      node_texts,
414      vec![
415        "foo:\n  echo \"foo\"".to_string(),
416        "bar:\n  echo \"bar\"".to_string(),
417        "foo := \"value\"".to_string()
418      ]
419    );
420
421    let identifier_texts = root
422      .find_all("recipe_header > identifier, function_call > identifier")
423      .iter()
424      .map(|node| document.get_node_text(node))
425      .collect::<Vec<_>>();
426
427    assert_eq!(identifier_texts, vec!["foo".to_string(), "bar".to_string()]);
428  }
429
430  #[test]
431  fn find_direct_child_marker() {
432    let document = Document::from(indoc! {
433      "
434      foo:
435        echo \"foo\"
436
437      bar arg1 arg2:
438        echo \"{{ arch() }}\"
439      "
440    });
441
442    let root = document.tree.as_ref().unwrap().root_node();
443
444    let second_recipe = root.find("recipe[1]").unwrap();
445
446    let recipe_header = second_recipe.find("recipe_header").unwrap();
447    let parameters_node = recipe_header.find("parameters").unwrap();
448
449    let direct_parameters = parameters_node.find_all("^parameter");
450
451    assert_eq!(direct_parameters.len(), 2);
452
453    let parameter_texts = direct_parameters
454      .iter()
455      .map(|node| document.get_node_text(node))
456      .collect::<Vec<_>>();
457
458    assert_eq!(
459      parameter_texts,
460      vec!["arg1".to_string(), "arg2".to_string()]
461    );
462  }
463
464  #[test]
465  fn find_nonexistent() {
466    let document = Document::from(indoc! {
467      "
468      foo:
469        echo \"foo\"
470      "
471    });
472
473    let tree = document.tree.as_ref().unwrap();
474    let root = tree.root_node();
475
476    let nonexistent = root.find("nonexistent_kind");
477    assert!(nonexistent.is_none());
478
479    let empty_results = root.find_all("nonexistent_kind");
480    assert_eq!(empty_results.len(), 0);
481
482    let no_function_calls = root.find_all("function_call");
483    assert_eq!(no_function_calls.len(), 0);
484  }
485
486  #[test]
487  fn find_nth_occurrence() {
488    let document = Document::from(indoc! {
489      "
490      alias foo := bar
491      "
492    });
493
494    let root = document.tree.as_ref().unwrap().root_node();
495
496    let alias = root.find("alias").unwrap();
497
498    let first_identifier = alias.find("identifier[0]").unwrap();
499    let second_identifier = alias.find("identifier[1]").unwrap();
500
501    assert_eq!(document.get_node_text(&first_identifier), "foo");
502    assert_eq!(document.get_node_text(&second_identifier), "bar");
503  }
504
505  #[test]
506  fn find_nested_child() {
507    let document = Document::from(indoc! {
508      "
509      foo: (bar baz):
510        echo foo
511      "
512    });
513
514    let root = document.tree.as_ref().unwrap().root_node();
515
516    let identifier =
517      root.find("dependency_expression > expression > value > identifier");
518
519    let identifier = identifier.unwrap();
520
521    assert_eq!(document.get_node_text(&identifier), "baz");
522  }
523}