ast_grep_core/tree_sitter/
mod.rs

1pub mod traversal;
2
3use crate::node::Root;
4use crate::replacer::Replacer;
5use crate::source::{Content, Doc, Edit, SgNode};
6use crate::{node::KindId, Language, Position};
7use crate::{AstGrep, Matcher};
8use std::borrow::Cow;
9use std::collections::HashMap;
10use std::num::NonZero;
11use thiserror::Error;
12pub use traversal::{TsPre, Visitor};
13pub use tree_sitter::Language as TSLanguage;
14use tree_sitter::{InputEdit, LanguageError, Node, Parser, Point, Tree};
15pub use tree_sitter::{Point as TSPoint, Range as TSRange};
16
17/// Represents tree-sitter related error
18#[derive(Debug, Error)]
19pub enum TSParseError {
20  #[error("incompatible `Language` is assigned to a `Parser`.")]
21  Language(#[from] LanguageError),
22  /// A general error when tree sitter fails to parse in time. It can be caused by
23  /// the following reasons but tree-sitter does not provide error detail.
24  /// * The timeout set with [Parser::set_timeout_micros] expired
25  /// * The cancellation flag set with [Parser::set_cancellation_flag] was flipped
26  /// * The parser has not yet had a language assigned with [Parser::set_language]
27  #[error("general error when tree-sitter fails to parse.")]
28  TreeUnavailable,
29}
30
31#[inline]
32fn parse_lang(
33  parse_fn: impl Fn(&mut Parser) -> Option<Tree>,
34  ts_lang: TSLanguage,
35) -> Result<Tree, TSParseError> {
36  let mut parser = Parser::new();
37  parser.set_language(&ts_lang)?;
38  if let Some(tree) = parse_fn(&mut parser) {
39    Ok(tree)
40  } else {
41    Err(TSParseError::TreeUnavailable)
42  }
43}
44
45#[derive(Clone)]
46pub struct StrDoc<L: LanguageExt> {
47  pub src: String,
48  pub lang: L,
49  pub tree: Tree,
50}
51
52impl<L: LanguageExt> StrDoc<L> {
53  pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
54    let src = src.to_string();
55    let ts_lang = lang.get_ts_language();
56    let tree = parse_lang(|p| p.parse(src.as_bytes(), None), ts_lang).map_err(|e| e.to_string())?;
57    Ok(Self { src, lang, tree })
58  }
59  pub fn new(src: &str, lang: L) -> Self {
60    Self::try_new(src, lang).expect("Parser tree error")
61  }
62  fn parse(&self, old_tree: Option<&Tree>) -> Result<Tree, TSParseError> {
63    let source = self.get_source();
64    let lang = self.get_lang().get_ts_language();
65    parse_lang(|p| p.parse(source.as_bytes(), old_tree), lang)
66  }
67}
68
69impl<L: LanguageExt> Doc for StrDoc<L> {
70  type Source = String;
71  type Lang = L;
72  type Node<'r> = Node<'r>;
73  fn get_lang(&self) -> &Self::Lang {
74    &self.lang
75  }
76  fn get_source(&self) -> &Self::Source {
77    &self.src
78  }
79  fn do_edit(&mut self, edit: &Edit<Self::Source>) -> Result<(), String> {
80    let source = &mut self.src;
81    perform_edit(&mut self.tree, source, edit);
82    self.tree = self.parse(Some(&self.tree)).map_err(|e| e.to_string())?;
83    Ok(())
84  }
85  fn root_node(&self) -> Node<'_> {
86    self.tree.root_node()
87  }
88  fn get_node_text<'a>(&'a self, node: &Self::Node<'a>) -> Cow<'a, str> {
89    Cow::Borrowed(
90      node
91        .utf8_text(self.src.as_bytes())
92        .expect("invalid source text encoding"),
93    )
94  }
95}
96
97struct NodeWalker<'tree> {
98  cursor: tree_sitter::TreeCursor<'tree>,
99  count: usize,
100}
101
102impl<'tree> Iterator for NodeWalker<'tree> {
103  type Item = Node<'tree>;
104  fn next(&mut self) -> Option<Self::Item> {
105    if self.count == 0 {
106      return None;
107    }
108    let ret = Some(self.cursor.node());
109    self.cursor.goto_next_sibling();
110    self.count -= 1;
111    ret
112  }
113}
114
115impl ExactSizeIterator for NodeWalker<'_> {
116  fn len(&self) -> usize {
117    self.count
118  }
119}
120
121impl<'r> SgNode<'r> for Node<'r> {
122  fn parent(&self) -> Option<Self> {
123    Node::parent(self)
124  }
125  fn ancestors(&self, root: Self) -> impl Iterator<Item = Self> {
126    let mut ancestor = Some(root);
127    let self_id = self.id();
128    std::iter::from_fn(move || {
129      let inner = ancestor.take()?;
130      if inner.id() == self_id {
131        return None;
132      }
133      ancestor = inner.child_with_descendant(*self);
134      Some(inner)
135    })
136    // We must iterate up the tree to preserve backwards compatibility
137    .collect::<Vec<_>>()
138    .into_iter()
139    .rev()
140  }
141  fn dfs(&self) -> impl Iterator<Item = Self> {
142    TsPre::new(self)
143  }
144  fn child(&self, nth: usize) -> Option<Self> {
145    // TODO remove cast after migrating to tree-sitter
146    Node::child(self, nth)
147  }
148  fn children(&self) -> impl ExactSizeIterator<Item = Self> {
149    let mut cursor = self.walk();
150    cursor.goto_first_child();
151    NodeWalker {
152      cursor,
153      count: self.child_count(),
154    }
155  }
156  fn child_by_field_id(&self, field_id: u16) -> Option<Self> {
157    Node::child_by_field_id(self, field_id)
158  }
159  fn next(&self) -> Option<Self> {
160    self.next_sibling()
161  }
162  fn prev(&self) -> Option<Self> {
163    self.prev_sibling()
164  }
165  fn next_all(&self) -> impl Iterator<Item = Self> {
166    // if root is none, use self as fallback to return a type-stable Iterator
167    let node = self.parent().unwrap_or(*self);
168    let mut cursor = node.walk();
169    cursor.goto_first_child_for_byte(self.start_byte());
170    std::iter::from_fn(move || {
171      if cursor.goto_next_sibling() {
172        Some(cursor.node())
173      } else {
174        None
175      }
176    })
177  }
178  fn prev_all(&self) -> impl Iterator<Item = Self> {
179    // if root is none, use self as fallback to return a type-stable Iterator
180    let node = self.parent().unwrap_or(*self);
181    let mut cursor = node.walk();
182    cursor.goto_first_child_for_byte(self.start_byte());
183    std::iter::from_fn(move || {
184      if cursor.goto_previous_sibling() {
185        Some(cursor.node())
186      } else {
187        None
188      }
189    })
190  }
191  fn is_named(&self) -> bool {
192    Node::is_named(self)
193  }
194  /// N.B. it is different from is_named && is_leaf
195  /// if a node has no named children.
196  fn is_named_leaf(&self) -> bool {
197    self.named_child_count() == 0
198  }
199  fn is_leaf(&self) -> bool {
200    self.child_count() == 0
201  }
202  fn kind(&self) -> Cow<'_, str> {
203    Cow::Borrowed(Node::kind(self))
204  }
205  fn kind_id(&self) -> KindId {
206    Node::kind_id(self)
207  }
208  fn node_id(&self) -> usize {
209    self.id()
210  }
211  fn range(&self) -> std::ops::Range<usize> {
212    self.start_byte()..self.end_byte()
213  }
214  fn start_pos(&self) -> Position {
215    let pos = self.start_position();
216    let byte = self.start_byte();
217    Position::new(pos.row, pos.column, byte)
218  }
219  fn end_pos(&self) -> Position {
220    let pos = self.end_position();
221    let byte = self.end_byte();
222    Position::new(pos.row, pos.column, byte)
223  }
224  // missing node is a tree-sitter specific concept
225  fn is_missing(&self) -> bool {
226    Node::is_missing(self)
227  }
228  fn is_error(&self) -> bool {
229    Node::is_error(self)
230  }
231
232  fn field(&self, name: &str) -> Option<Self> {
233    self.child_by_field_name(name)
234  }
235  fn field_children(&self, field_id: Option<u16>) -> impl Iterator<Item = Self> {
236    let field_id = field_id.and_then(NonZero::new);
237    let mut cursor = self.walk();
238    cursor.goto_first_child();
239    // if field_id is not found, iteration is done
240    let mut done = field_id.is_none();
241
242    std::iter::from_fn(move || {
243      if done {
244        return None;
245      }
246      while cursor.field_id() != field_id {
247        if !cursor.goto_next_sibling() {
248          return None;
249        }
250      }
251      let ret = cursor.node();
252      if !cursor.goto_next_sibling() {
253        done = true;
254      }
255      Some(ret)
256    })
257  }
258}
259
260pub fn perform_edit<S: ContentExt>(tree: &mut Tree, input: &mut S, edit: &Edit<S>) -> InputEdit {
261  let edit = input.accept_edit(edit);
262  tree.edit(&edit);
263  edit
264}
265
266/// tree-sitter specific language trait
267pub trait LanguageExt: Language {
268  /// Create an [`AstGrep`] instance for the language
269  fn ast_grep<S: AsRef<str>>(&self, source: S) -> AstGrep<StrDoc<Self>> {
270    AstGrep::new(source, self.clone())
271  }
272
273  /// tree sitter language to parse the source
274  fn get_ts_language(&self) -> TSLanguage;
275
276  fn injectable_languages(&self) -> Option<&'static [&'static str]> {
277    None
278  }
279
280  /// get injected language regions in the root document. e.g. get JavaScripts in HTML
281  /// it will return a list of tuples of (language, regions).
282  /// The first item is the embedded region language, e.g. javascript
283  /// The second item is a list of regions in tree_sitter.
284  /// also see https://tree-sitter.github.io/tree-sitter/using-parsers#multi-language-documents
285  fn extract_injections<L: LanguageExt>(
286    &self,
287    _root: crate::Node<StrDoc<L>>,
288  ) -> HashMap<String, Vec<TSRange>> {
289    HashMap::new()
290  }
291}
292
293fn position_for_offset(input: &[u8], offset: usize) -> Point {
294  debug_assert!(offset <= input.len());
295  let (mut row, mut col) = (0, 0);
296  for c in &input[0..offset] {
297    if *c as char == '\n' {
298      row += 1;
299      col = 0;
300    } else {
301      col += 1;
302    }
303  }
304  Point::new(row, col)
305}
306
307impl<L: LanguageExt> AstGrep<StrDoc<L>> {
308  pub fn new<S: AsRef<str>>(src: S, lang: L) -> Self {
309    Root::str(src.as_ref(), lang)
310  }
311
312  pub fn source(&self) -> &str {
313    self.doc.get_source().as_str()
314  }
315
316  pub fn generate(self) -> String {
317    self.doc.src
318  }
319}
320
321pub trait ContentExt: Content {
322  fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit;
323}
324impl ContentExt for String {
325  fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit {
326    let start_byte = edit.position;
327    let old_end_byte = edit.position + edit.deleted_length;
328    let new_end_byte = edit.position + edit.inserted_text.len();
329    let input = unsafe { self.as_mut_vec() };
330    let start_position = position_for_offset(input, start_byte);
331    let old_end_position = position_for_offset(input, old_end_byte);
332    input.splice(start_byte..old_end_byte, edit.inserted_text.clone());
333    let new_end_position = position_for_offset(input, new_end_byte);
334    InputEdit {
335      start_byte,
336      old_end_byte,
337      new_end_byte,
338      start_position,
339      old_end_position,
340      new_end_position,
341    }
342  }
343}
344
345impl<L: LanguageExt> Root<StrDoc<L>> {
346  pub fn str(src: &str, lang: L) -> Self {
347    Self::try_new(src, lang).expect("should parse")
348  }
349  pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
350    let doc = StrDoc::try_new(src, lang)?;
351    Ok(Self { doc })
352  }
353  pub fn get_text(&self) -> &str {
354    &self.doc.src
355  }
356
357  pub fn get_injections<F: Fn(&str) -> Option<L>>(&self, get_lang: F) -> Vec<Self> {
358    let root = self.root();
359    let range = self.lang().extract_injections(root);
360    let roots = range
361      .into_iter()
362      .filter_map(|(lang, ranges)| {
363        let lang = get_lang(&lang)?;
364        let source = self.doc.get_source();
365        let mut parser = Parser::new();
366        parser.set_included_ranges(&ranges).ok()?;
367        parser.set_language(&lang.get_ts_language()).ok()?;
368        let tree = parser.parse(source, None)?;
369        Some(Self {
370          doc: StrDoc {
371            src: self.doc.src.clone(),
372            lang,
373            tree,
374          },
375        })
376      })
377      .collect();
378    roots
379  }
380}
381
382pub struct DisplayContext<'r> {
383  /// content for the matched node
384  pub matched: Cow<'r, str>,
385  /// content before the matched node
386  pub leading: &'r str,
387  /// content after the matched node
388  pub trailing: &'r str,
389  /// zero-based start line of the context
390  pub start_line: usize,
391}
392
393/// these methods are only for `StrDoc`
394impl<'r, L: LanguageExt> crate::Node<'r, StrDoc<L>> {
395  #[doc(hidden)]
396  pub fn display_context(&self, before: usize, after: usize) -> DisplayContext<'r> {
397    let source = self.root.doc.get_source().as_str();
398    let bytes = source.as_bytes();
399    let start = self.inner.start_byte();
400    let end = self.inner.end_byte();
401    let (mut leading, mut trailing) = (start, end);
402    let mut lines_before = before + 1;
403    while leading > 0 {
404      if bytes[leading - 1] == b'\n' {
405        lines_before -= 1;
406        if lines_before == 0 {
407          break;
408        }
409      }
410      leading -= 1;
411    }
412    let mut lines_after = after + 1;
413    // tree-sitter will append line ending to source so trailing can be out of bound
414    trailing = trailing.min(bytes.len());
415    while trailing < bytes.len() {
416      if bytes[trailing] == b'\n' {
417        lines_after -= 1;
418        if lines_after == 0 {
419          break;
420        }
421      }
422      trailing += 1;
423    }
424    // lines_before means we matched all context, offset is `before` itself
425    let offset = if lines_before == 0 {
426      before
427    } else {
428      // otherwise, there are fewer than `before` line in src, compute the actual line
429      before + 1 - lines_before
430    };
431    DisplayContext {
432      matched: self.text(),
433      leading: &source[leading..start],
434      trailing: &source[end..trailing],
435      start_line: self.start_pos().line() - offset,
436    }
437  }
438
439  pub fn replace_all<M: Matcher, R: Replacer<StrDoc<L>>>(
440    &self,
441    matcher: M,
442    replacer: R,
443  ) -> Vec<Edit<String>> {
444    // TODO: support nested matches like Some(Some(1)) with pattern Some($A)
445    Visitor::new(&matcher)
446      .reentrant(false)
447      .visit(self.clone())
448      .map(|matched| matched.make_edit(&matcher, &replacer))
449      .collect()
450  }
451}
452
453#[cfg(test)]
454mod test {
455  use super::*;
456  use crate::language::Tsx;
457  use tree_sitter::Point;
458
459  fn parse(src: &str) -> Result<Tree, TSParseError> {
460    parse_lang(|p| p.parse(src, None), Tsx.get_ts_language())
461  }
462
463  #[test]
464  fn test_tree_sitter() -> Result<(), TSParseError> {
465    let tree = parse("var a = 1234")?;
466    let root_node = tree.root_node();
467    assert_eq!(root_node.kind(), "program");
468    assert_eq!(root_node.start_position().column, 0);
469    assert_eq!(root_node.end_position().column, 12);
470    assert_eq!(
471      root_node.to_sexp(),
472      "(program (variable_declaration (variable_declarator name: (identifier) value: (number))))"
473    );
474    Ok(())
475  }
476
477  #[test]
478  fn test_object_literal() -> Result<(), TSParseError> {
479    let tree = parse("{a: $X}")?;
480    let root_node = tree.root_node();
481    // wow this is not label. technically it is wrong but practically it is better LOL
482    assert_eq!(root_node.to_sexp(), "(program (expression_statement (object (pair key: (property_identifier) value: (identifier)))))");
483    Ok(())
484  }
485
486  #[test]
487  fn test_string() -> Result<(), TSParseError> {
488    let tree = parse("'$A'")?;
489    let root_node = tree.root_node();
490    assert_eq!(
491      root_node.to_sexp(),
492      "(program (expression_statement (string (string_fragment))))"
493    );
494    Ok(())
495  }
496
497  #[test]
498  fn test_row_col() -> Result<(), TSParseError> {
499    let tree = parse("😄")?;
500    let root = tree.root_node();
501    assert_eq!(root.start_position(), Point::new(0, 0));
502    // NOTE: Point in tree-sitter is counted in bytes instead of char
503    assert_eq!(root.end_position(), Point::new(0, 4));
504    Ok(())
505  }
506
507  #[test]
508  fn test_edit() -> Result<(), TSParseError> {
509    let mut src = "a + b".to_string();
510    let mut tree = parse(&src)?;
511    let _ = perform_edit(
512      &mut tree,
513      &mut src,
514      &Edit {
515        position: 1,
516        deleted_length: 0,
517        inserted_text: " * b".into(),
518      },
519    );
520    let tree2 = parse_lang(|p| p.parse(&src, Some(&tree)), Tsx.get_ts_language())?;
521    assert_eq!(
522      tree.root_node().to_sexp(),
523      "(program (expression_statement (binary_expression left: (identifier) right: (identifier))))"
524    );
525    assert_eq!(tree2.root_node().to_sexp(), "(program (expression_statement (binary_expression left: (binary_expression left: (identifier) right: (identifier)) right: (identifier))))");
526    Ok(())
527  }
528}