ast_grep_core/tree_sitter/
mod.rs

1pub mod traversal;
2
3use crate::node::Root;
4use crate::replacer::Replacer;
5use crate::source::{Content, Doc, Edit, SgNode};
6use crate::{node::KindId, Language, Position};
7use crate::{AstGrep, Matcher};
8use std::borrow::Cow;
9use std::collections::HashMap;
10use thiserror::Error;
11pub use traversal::{TsPre, Visitor};
12pub use tree_sitter::Language as TSLanguage;
13use tree_sitter::{InputEdit, LanguageError, Node, Parser, ParserError, Point, Tree};
14pub use tree_sitter::{Point as TSPoint, Range as TSRange};
15
16/// Represents tree-sitter related error
17#[derive(Debug, Error)]
18pub enum TSParseError {
19  #[error("web-tree-sitter parser is not available")]
20  Parse(#[from] ParserError),
21  #[error("incompatible `Language` is assigned to a `Parser`.")]
22  Language(#[from] LanguageError),
23  /// A general error when tree sitter fails to parse in time. It can be caused by
24  /// the following reasons but tree-sitter does not provide error detail.
25  /// * The timeout set with [Parser::set_timeout_micros] expired
26  /// * The cancellation flag set with [Parser::set_cancellation_flag] was flipped
27  /// * The parser has not yet had a language assigned with [Parser::set_language]
28  #[error("general error when tree-sitter fails to parse.")]
29  TreeUnavailable,
30}
31
32#[inline]
33fn parse_lang(
34  parse_fn: impl Fn(&mut Parser) -> Result<Option<Tree>, ParserError>,
35  ts_lang: TSLanguage,
36) -> Result<Tree, TSParseError> {
37  let mut parser = Parser::new()?;
38  parser.set_language(&ts_lang)?;
39  if let Some(tree) = parse_fn(&mut parser)? {
40    Ok(tree)
41  } else {
42    Err(TSParseError::TreeUnavailable)
43  }
44}
45
46#[derive(Clone)]
47pub struct StrDoc<L: LanguageExt> {
48  pub src: String,
49  pub lang: L,
50  pub tree: Tree,
51}
52
53impl<L: LanguageExt> StrDoc<L> {
54  pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
55    let src = src.to_string();
56    let ts_lang = lang.get_ts_language();
57    let tree = parse_lang(|p| p.parse(src.as_bytes(), None), ts_lang).map_err(|e| e.to_string())?;
58    Ok(Self { src, lang, tree })
59  }
60  pub fn new(src: &str, lang: L) -> Self {
61    Self::try_new(src, lang).expect("Parser tree error")
62  }
63  fn parse(&self, old_tree: Option<&Tree>) -> Result<Tree, TSParseError> {
64    let source = self.get_source();
65    let lang = self.get_lang().get_ts_language();
66    parse_lang(|p| p.parse(source.as_bytes(), old_tree), lang)
67  }
68}
69
70impl<L: LanguageExt> Doc for StrDoc<L> {
71  type Source = String;
72  type Lang = L;
73  type Node<'r> = Node<'r>;
74  fn get_lang(&self) -> &Self::Lang {
75    &self.lang
76  }
77  fn get_source(&self) -> &Self::Source {
78    &self.src
79  }
80  fn do_edit(&mut self, edit: &Edit<Self::Source>) -> Result<(), String> {
81    let source = &mut self.src;
82    perform_edit(&mut self.tree, source, edit);
83    self.tree = self.parse(Some(&self.tree)).map_err(|e| e.to_string())?;
84    Ok(())
85  }
86  fn root_node(&self) -> Node<'_> {
87    self.tree.root_node()
88  }
89  fn get_node_text<'a>(&'a self, node: &Self::Node<'a>) -> Cow<'a, str> {
90    node
91      .utf8_text(self.src.as_bytes())
92      .expect("invalid source text encoding")
93  }
94}
95
96struct NodeWalker<'tree> {
97  cursor: tree_sitter::TreeCursor<'tree>,
98  count: usize,
99}
100
101impl<'tree> Iterator for NodeWalker<'tree> {
102  type Item = Node<'tree>;
103  fn next(&mut self) -> Option<Self::Item> {
104    if self.count == 0 {
105      return None;
106    }
107    let ret = Some(self.cursor.node());
108    self.cursor.goto_next_sibling();
109    self.count -= 1;
110    ret
111  }
112}
113
114impl ExactSizeIterator for NodeWalker<'_> {
115  fn len(&self) -> usize {
116    self.count
117  }
118}
119
120impl<'r> SgNode<'r> for Node<'r> {
121  fn parent(&self) -> Option<Self> {
122    Node::parent(self)
123  }
124  fn ancestors(&self, root: Self) -> impl Iterator<Item = Self> {
125    let mut ancestor = Some(root);
126    let self_id = self.id();
127    std::iter::from_fn(move || {
128      let inner = ancestor.take()?;
129      if inner.id() == self_id {
130        return None;
131      }
132      ancestor = inner.child_with_descendant(self.clone());
133      Some(inner)
134    })
135    // We must iterate up the tree to preserve backwards compatibility
136    .collect::<Vec<_>>()
137    .into_iter()
138    .rev()
139  }
140  fn dfs(&self) -> impl Iterator<Item = Self> {
141    TsPre::new(self)
142  }
143  fn child(&self, nth: usize) -> Option<Self> {
144    // TODO remove cast after migrating to tree-sitter
145    Node::child(self, nth as u32)
146  }
147  fn children(&self) -> impl ExactSizeIterator<Item = Self> {
148    let mut cursor = self.walk();
149    cursor.goto_first_child();
150    NodeWalker {
151      cursor,
152      count: self.child_count() as usize,
153    }
154  }
155  fn child_by_field_id(&self, field_id: u16) -> Option<Self> {
156    Node::child_by_field_id(self, field_id)
157  }
158  fn next(&self) -> Option<Self> {
159    self.next_sibling()
160  }
161  fn prev(&self) -> Option<Self> {
162    self.prev_sibling()
163  }
164  fn next_all(&self) -> impl Iterator<Item = Self> {
165    // if root is none, use self as fallback to return a type-stable Iterator
166    let node = self.parent().unwrap_or_else(|| self.clone());
167    let mut cursor = node.walk();
168    cursor.goto_first_child_for_byte(self.start_byte());
169    std::iter::from_fn(move || {
170      if cursor.goto_next_sibling() {
171        Some(cursor.node())
172      } else {
173        None
174      }
175    })
176  }
177  fn prev_all(&self) -> impl Iterator<Item = Self> {
178    // if root is none, use self as fallback to return a type-stable Iterator
179    let node = self.parent().unwrap_or_else(|| self.clone());
180    let mut cursor = node.walk();
181    cursor.goto_first_child_for_byte(self.start_byte());
182    std::iter::from_fn(move || {
183      if cursor.goto_previous_sibling() {
184        Some(cursor.node())
185      } else {
186        None
187      }
188    })
189  }
190  fn is_named(&self) -> bool {
191    Node::is_named(self)
192  }
193  /// N.B. it is different from is_named && is_leaf
194  /// if a node has no named children.
195  fn is_named_leaf(&self) -> bool {
196    self.named_child_count() == 0
197  }
198  fn is_leaf(&self) -> bool {
199    self.child_count() == 0
200  }
201  fn kind(&self) -> Cow<str> {
202    Node::kind(self)
203  }
204  fn kind_id(&self) -> KindId {
205    Node::kind_id(self)
206  }
207  fn node_id(&self) -> usize {
208    self.id()
209  }
210  fn range(&self) -> std::ops::Range<usize> {
211    (self.start_byte() as usize)..(self.end_byte() as usize)
212  }
213  fn start_pos(&self) -> Position {
214    let pos = self.start_position();
215    let byte = self.start_byte() as usize;
216    Position::new(pos.row() as usize, pos.column() as usize, byte)
217  }
218  fn end_pos(&self) -> Position {
219    let pos = self.end_position();
220    let byte = self.end_byte() as usize;
221    Position::new(pos.row() as usize, pos.column() as usize, byte)
222  }
223  // missing node is a tree-sitter specific concept
224  fn is_missing(&self) -> bool {
225    Node::is_missing(self)
226  }
227  fn is_error(&self) -> bool {
228    Node::is_error(self)
229  }
230
231  fn field(&self, name: &str) -> Option<Self> {
232    self.child_by_field_name(name)
233  }
234  fn field_children(&self, field_id: Option<u16>) -> impl Iterator<Item = Self> {
235    let mut cursor = self.walk();
236    cursor.goto_first_child();
237    // if field_id is not found, iteration is done
238    let mut done = field_id.is_none();
239
240    std::iter::from_fn(move || {
241      if done {
242        return None;
243      }
244      while cursor.field_id() != field_id {
245        if !cursor.goto_next_sibling() {
246          return None;
247        }
248      }
249      let ret = cursor.node();
250      if !cursor.goto_next_sibling() {
251        done = true;
252      }
253      Some(ret)
254    })
255  }
256}
257
258pub fn perform_edit<S: ContentExt>(tree: &mut Tree, input: &mut S, edit: &Edit<S>) -> InputEdit {
259  let edit = input.accept_edit(edit);
260  tree.edit(&edit);
261  edit
262}
263
264/// tree-sitter specific language trait
265pub trait LanguageExt: Language {
266  /// Create an [`AstGrep`] instance for the language
267  fn ast_grep<S: AsRef<str>>(&self, source: S) -> AstGrep<StrDoc<Self>> {
268    AstGrep::new(source, self.clone())
269  }
270
271  /// tree sitter language to parse the source
272  fn get_ts_language(&self) -> TSLanguage;
273
274  fn injectable_languages(&self) -> Option<&'static [&'static str]> {
275    None
276  }
277
278  /// get injected language regions in the root document. e.g. get JavaScripts in HTML
279  /// it will return a list of tuples of (language, regions).
280  /// The first item is the embedded region language, e.g. javascript
281  /// The second item is a list of regions in tree_sitter.
282  /// also see https://tree-sitter.github.io/tree-sitter/using-parsers#multi-language-documents
283  fn extract_injections<L: LanguageExt>(
284    &self,
285    _root: crate::Node<StrDoc<L>>,
286  ) -> HashMap<String, Vec<TSRange>> {
287    HashMap::new()
288  }
289}
290
291fn position_for_offset(input: &[u8], offset: usize) -> Point {
292  debug_assert!(offset <= input.len());
293  let (mut row, mut col) = (0, 0);
294  for c in &input[0..offset] {
295    if *c as char == '\n' {
296      row += 1;
297      col = 0;
298    } else {
299      col += 1;
300    }
301  }
302  Point::new(row, col)
303}
304
305impl<L: LanguageExt> AstGrep<StrDoc<L>> {
306  pub fn new<S: AsRef<str>>(src: S, lang: L) -> Self {
307    Root::str(src.as_ref(), lang)
308  }
309
310  pub fn source(&self) -> &str {
311    self.doc.get_source().as_str()
312  }
313
314  pub fn generate(self) -> String {
315    self.doc.src
316  }
317}
318
319pub trait ContentExt: Content {
320  fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit;
321}
322impl ContentExt for String {
323  fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit {
324    let start_byte = edit.position;
325    let old_end_byte = edit.position + edit.deleted_length;
326    let new_end_byte = edit.position + edit.inserted_text.len();
327    let input = unsafe { self.as_mut_vec() };
328    let start_position = position_for_offset(input, start_byte);
329    let old_end_position = position_for_offset(input, old_end_byte);
330    input.splice(start_byte..old_end_byte, edit.inserted_text.clone());
331    let new_end_position = position_for_offset(input, new_end_byte);
332    InputEdit::new(
333      start_byte as u32,
334      old_end_byte as u32,
335      new_end_byte as u32,
336      &start_position,
337      &old_end_position,
338      &new_end_position,
339    )
340  }
341}
342
343impl<L: LanguageExt> Root<StrDoc<L>> {
344  pub fn str(src: &str, lang: L) -> Self {
345    Self::try_new(src, lang).expect("should parse")
346  }
347  pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
348    let doc = StrDoc::try_new(src, lang)?;
349    Ok(Self { doc })
350  }
351  pub fn get_text(&self) -> &str {
352    &self.doc.src
353  }
354
355  pub fn get_injections<F: Fn(&str) -> Option<L>>(&self, get_lang: F) -> Vec<Self> {
356    let root = self.root();
357    let range = self.lang().extract_injections(root);
358    let roots = range
359      .into_iter()
360      .filter_map(|(lang, ranges)| {
361        let lang = get_lang(&lang)?;
362        let source = self.doc.get_source();
363        let mut parser = tree_sitter::Parser::new().ok()?;
364        parser.set_included_ranges(&ranges).ok()?;
365        parser.set_language(&lang.get_ts_language()).ok()?;
366        let tree = parser.parse(source, None).ok()?;
367        tree.map(|t| Self {
368          doc: StrDoc {
369            src: self.doc.src.clone(),
370            lang,
371            tree: t,
372          },
373        })
374      })
375      .collect();
376    roots
377  }
378}
379
380pub struct DisplayContext<'r> {
381  /// content for the matched node
382  pub matched: Cow<'r, str>,
383  /// content before the matched node
384  pub leading: &'r str,
385  /// content after the matched node
386  pub trailing: &'r str,
387  /// zero-based start line of the context
388  pub start_line: usize,
389}
390
391/// these methods are only for `StrDoc`
392impl<'r, L: LanguageExt> crate::Node<'r, StrDoc<L>> {
393  #[doc(hidden)]
394  pub fn display_context(&self, before: usize, after: usize) -> DisplayContext<'r> {
395    let source = self.root.doc.get_source().as_str();
396    let bytes = source.as_bytes();
397    let start = self.inner.start_byte() as usize;
398    let end = self.inner.end_byte() as usize;
399    let (mut leading, mut trailing) = (start, end);
400    let mut lines_before = before + 1;
401    while leading > 0 {
402      if bytes[leading - 1] == b'\n' {
403        lines_before -= 1;
404        if lines_before == 0 {
405          break;
406        }
407      }
408      leading -= 1;
409    }
410    let mut lines_after = after + 1;
411    // tree-sitter will append line ending to source so trailing can be out of bound
412    trailing = trailing.min(bytes.len());
413    while trailing < bytes.len() {
414      if bytes[trailing] == b'\n' {
415        lines_after -= 1;
416        if lines_after == 0 {
417          break;
418        }
419      }
420      trailing += 1;
421    }
422    // lines_before means we matched all context, offset is `before` itself
423    let offset = if lines_before == 0 {
424      before
425    } else {
426      // otherwise, there are fewer than `before` line in src, compute the actual line
427      before + 1 - lines_before
428    };
429    DisplayContext {
430      matched: self.text(),
431      leading: &source[leading..start],
432      trailing: &source[end..trailing],
433      start_line: self.start_pos().line() - offset,
434    }
435  }
436
437  pub fn replace_all<M: Matcher, R: Replacer<StrDoc<L>>>(
438    &self,
439    matcher: M,
440    replacer: R,
441  ) -> Vec<Edit<String>> {
442    // TODO: support nested matches like Some(Some(1)) with pattern Some($A)
443    Visitor::new(&matcher)
444      .reentrant(false)
445      .visit(self.clone())
446      .map(|matched| matched.make_edit(&matcher, &replacer))
447      .collect()
448  }
449}
450
451#[cfg(test)]
452mod test {
453  use super::*;
454  use crate::language::Tsx;
455  use tree_sitter::Point;
456
457  fn parse(src: &str) -> Result<Tree, TSParseError> {
458    parse_lang(|p| p.parse(src, None), Tsx.get_ts_language())
459  }
460
461  #[test]
462  fn test_tree_sitter() -> Result<(), TSParseError> {
463    let tree = parse("var a = 1234")?;
464    let root_node = tree.root_node();
465    assert_eq!(root_node.kind(), "program");
466    assert_eq!(root_node.start_position().column(), 0);
467    assert_eq!(root_node.end_position().column(), 12);
468    assert_eq!(
469      root_node.to_sexp(),
470      "(program (variable_declaration (variable_declarator name: (identifier) value: (number))))"
471    );
472    Ok(())
473  }
474
475  #[test]
476  fn test_object_literal() -> Result<(), TSParseError> {
477    let tree = parse("{a: $X}")?;
478    let root_node = tree.root_node();
479    // wow this is not label. technically it is wrong but practically it is better LOL
480    assert_eq!(root_node.to_sexp(), "(program (expression_statement (object (pair key: (property_identifier) value: (identifier)))))");
481    Ok(())
482  }
483
484  #[test]
485  fn test_string() -> Result<(), TSParseError> {
486    let tree = parse("'$A'")?;
487    let root_node = tree.root_node();
488    assert_eq!(
489      root_node.to_sexp(),
490      "(program (expression_statement (string (string_fragment))))"
491    );
492    Ok(())
493  }
494
495  #[test]
496  fn test_row_col() -> Result<(), TSParseError> {
497    let tree = parse("😄")?;
498    let root = tree.root_node();
499    assert_eq!(root.start_position(), Point::new(0, 0));
500    // NOTE: Point in tree-sitter is counted in bytes instead of char
501    assert_eq!(root.end_position(), Point::new(0, 4));
502    Ok(())
503  }
504
505  #[test]
506  fn test_edit() -> Result<(), TSParseError> {
507    let mut src = "a + b".to_string();
508    let mut tree = parse(&src)?;
509    let _ = perform_edit(
510      &mut tree,
511      &mut src,
512      &Edit {
513        position: 1,
514        deleted_length: 0,
515        inserted_text: " * b".into(),
516      },
517    );
518    let tree2 = parse_lang(|p| p.parse(&src, Some(&tree)), Tsx.get_ts_language())?;
519    assert_eq!(
520      tree.root_node().to_sexp(),
521      "(program (expression_statement (binary_expression left: (identifier) right: (identifier))))"
522    );
523    assert_eq!(tree2.root_node().to_sexp(), "(program (expression_statement (binary_expression left: (binary_expression left: (identifier) right: (identifier)) right: (identifier))))");
524    Ok(())
525  }
526}