ast_grep_core/tree_sitter/
mod.rs

1pub mod traversal;
2
3use crate::node::Root;
4use crate::replacer::Replacer;
5use crate::source::{Content, Doc, Edit, SgNode};
6use crate::{node::KindId, Language, Position};
7use crate::{AstGrep, Matcher};
8use std::borrow::Cow;
9use std::collections::HashMap;
10use std::num::NonZero;
11use thiserror::Error;
12pub use traversal::{TsPre, Visitor};
13pub use tree_sitter::Language as TSLanguage;
14use tree_sitter::{InputEdit, LanguageError, Node, Parser, Point, Tree};
15pub use tree_sitter::{Point as TSPoint, Range as TSRange};
16
17/// Represents tree-sitter related error
18#[derive(Debug, Error)]
19pub enum TSParseError {
20  #[error("incompatible `Language` is assigned to a `Parser`.")]
21  Language(#[from] LanguageError),
22  /// A general error when tree sitter fails to parse in time. It can be caused by
23  /// the following reasons but tree-sitter does not provide error detail.
24  /// * The timeout set with [Parser::set_timeout_micros] expired
25  /// * The cancellation flag set with [Parser::set_cancellation_flag] was flipped
26  /// * The parser has not yet had a language assigned with [Parser::set_language]
27  #[error("general error when tree-sitter fails to parse.")]
28  TreeUnavailable,
29}
30
31#[inline]
32fn parse_lang(
33  parse_fn: impl Fn(&mut Parser) -> Option<Tree>,
34  ts_lang: TSLanguage,
35) -> Result<Tree, TSParseError> {
36  let mut parser = Parser::new();
37  parser.set_language(&ts_lang)?;
38  if let Some(tree) = parse_fn(&mut parser) {
39    Ok(tree)
40  } else {
41    Err(TSParseError::TreeUnavailable)
42  }
43}
44
45#[derive(Clone)]
46pub struct StrDoc<L: LanguageExt> {
47  pub src: String,
48  pub lang: L,
49  pub tree: Tree,
50}
51
52impl<L: LanguageExt> StrDoc<L> {
53  pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
54    let src = src.to_string();
55    let ts_lang = lang.get_ts_language();
56    let tree = parse_lang(|p| p.parse(src.as_bytes(), None), ts_lang).map_err(|e| e.to_string())?;
57    Ok(Self { src, lang, tree })
58  }
59  pub fn new(src: &str, lang: L) -> Self {
60    Self::try_new(src, lang).expect("Parser tree error")
61  }
62  fn parse(&self, old_tree: Option<&Tree>) -> Result<Tree, TSParseError> {
63    let source = self.get_source();
64    let lang = self.get_lang().get_ts_language();
65    parse_lang(|p| p.parse(source.as_bytes(), old_tree), lang)
66  }
67}
68
69impl<L: LanguageExt> Doc for StrDoc<L> {
70  type Source = String;
71  type Lang = L;
72  type Node<'r> = Node<'r>;
73  fn get_lang(&self) -> &Self::Lang {
74    &self.lang
75  }
76  fn get_source(&self) -> &Self::Source {
77    &self.src
78  }
79  fn do_edit(&mut self, edit: &Edit<Self::Source>) -> Result<(), String> {
80    let source = &mut self.src;
81    perform_edit(&mut self.tree, source, edit);
82    self.tree = self.parse(Some(&self.tree)).map_err(|e| e.to_string())?;
83    Ok(())
84  }
85  fn root_node(&self) -> Node<'_> {
86    self.tree.root_node()
87  }
88  fn get_node_text<'a>(&'a self, node: &Self::Node<'a>) -> Cow<'a, str> {
89    Cow::Borrowed(
90      node
91        .utf8_text(self.src.as_bytes())
92        .expect("invalid source text encoding"),
93    )
94  }
95}
96
97struct NodeWalker<'tree> {
98  cursor: tree_sitter::TreeCursor<'tree>,
99  count: usize,
100}
101
102impl<'tree> Iterator for NodeWalker<'tree> {
103  type Item = Node<'tree>;
104  fn next(&mut self) -> Option<Self::Item> {
105    if self.count == 0 {
106      return None;
107    }
108    let ret = Some(self.cursor.node());
109    self.cursor.goto_next_sibling();
110    self.count -= 1;
111    ret
112  }
113}
114
115impl ExactSizeIterator for NodeWalker<'_> {
116  fn len(&self) -> usize {
117    self.count
118  }
119}
120
121impl<'r> SgNode<'r> for Node<'r> {
122  fn parent(&self) -> Option<Self> {
123    Node::parent(self)
124  }
125  fn ancestors(&self, root: Self) -> impl Iterator<Item = Self> {
126    let mut ancestor = Some(root);
127    let self_id = self.id();
128    std::iter::from_fn(move || {
129      let inner = ancestor.take()?;
130      if inner.id() == self_id {
131        return None;
132      }
133      ancestor = inner.child_with_descendant(*self);
134      Some(inner)
135    })
136    // We must iterate up the tree to preserve backwards compatibility
137    .collect::<Vec<_>>()
138    .into_iter()
139    .rev()
140  }
141  fn dfs(&self) -> impl Iterator<Item = Self> {
142    TsPre::new(self)
143  }
144  fn child(&self, nth: usize) -> Option<Self> {
145    Node::child(self, nth as u32)
146  }
147  fn children(&self) -> impl ExactSizeIterator<Item = Self> {
148    let mut cursor = self.walk();
149    cursor.goto_first_child();
150    NodeWalker {
151      cursor,
152      count: self.child_count(),
153    }
154  }
155  fn child_by_field_id(&self, field_id: u16) -> Option<Self> {
156    Node::child_by_field_id(self, field_id)
157  }
158  fn next(&self) -> Option<Self> {
159    self.next_sibling()
160  }
161  fn prev(&self) -> Option<Self> {
162    self.prev_sibling()
163  }
164  fn next_all(&self) -> impl Iterator<Item = Self> {
165    // if root is none, use self as fallback to return a type-stable Iterator
166    let node = self.parent().unwrap_or(*self);
167    let mut cursor = node.walk();
168    cursor.goto_first_child_for_byte(self.start_byte());
169    std::iter::from_fn(move || {
170      if cursor.goto_next_sibling() {
171        Some(cursor.node())
172      } else {
173        None
174      }
175    })
176  }
177  fn prev_all(&self) -> impl Iterator<Item = Self> {
178    // if root is none, use self as fallback to return a type-stable Iterator
179    let node = self.parent().unwrap_or(*self);
180    let mut cursor = node.walk();
181    cursor.goto_first_child_for_byte(self.start_byte());
182    std::iter::from_fn(move || {
183      if cursor.goto_previous_sibling() {
184        Some(cursor.node())
185      } else {
186        None
187      }
188    })
189  }
190  fn is_named(&self) -> bool {
191    Node::is_named(self)
192  }
193  /// N.B. it is different from is_named && is_leaf
194  /// if a node has no named children.
195  fn is_named_leaf(&self) -> bool {
196    self.named_child_count() == 0
197  }
198  fn is_leaf(&self) -> bool {
199    self.child_count() == 0
200  }
201  fn kind(&self) -> Cow<'_, str> {
202    Cow::Borrowed(Node::kind(self))
203  }
204  fn kind_id(&self) -> KindId {
205    Node::kind_id(self)
206  }
207  fn node_id(&self) -> usize {
208    self.id()
209  }
210  fn range(&self) -> std::ops::Range<usize> {
211    self.start_byte()..self.end_byte()
212  }
213  fn start_pos(&self) -> Position {
214    let pos = self.start_position();
215    let byte = self.start_byte();
216    Position::new(pos.row, pos.column, byte)
217  }
218  fn end_pos(&self) -> Position {
219    let pos = self.end_position();
220    let byte = self.end_byte();
221    Position::new(pos.row, pos.column, byte)
222  }
223  // missing node is a tree-sitter specific concept
224  fn is_missing(&self) -> bool {
225    Node::is_missing(self)
226  }
227  fn is_error(&self) -> bool {
228    Node::is_error(self)
229  }
230
231  fn field(&self, name: &str) -> Option<Self> {
232    self.child_by_field_name(name)
233  }
234  fn field_children(&self, field_id: Option<u16>) -> impl Iterator<Item = Self> {
235    let field_id = field_id.and_then(NonZero::new);
236    let mut cursor = self.walk();
237    cursor.goto_first_child();
238    // if field_id is not found, iteration is done
239    let mut done = field_id.is_none();
240
241    std::iter::from_fn(move || {
242      if done {
243        return None;
244      }
245      while cursor.field_id() != field_id {
246        if !cursor.goto_next_sibling() {
247          return None;
248        }
249      }
250      let ret = cursor.node();
251      if !cursor.goto_next_sibling() {
252        done = true;
253      }
254      Some(ret)
255    })
256  }
257}
258
259pub fn perform_edit<S: ContentExt>(tree: &mut Tree, input: &mut S, edit: &Edit<S>) -> InputEdit {
260  let edit = input.accept_edit(edit);
261  tree.edit(&edit);
262  edit
263}
264
265/// tree-sitter specific language trait
266pub trait LanguageExt: Language {
267  /// Create an [`AstGrep`] instance for the language
268  fn ast_grep<S: AsRef<str>>(&self, source: S) -> AstGrep<StrDoc<Self>> {
269    AstGrep::new(source, self.clone())
270  }
271
272  /// tree sitter language to parse the source
273  fn get_ts_language(&self) -> TSLanguage;
274
275  fn injectable_languages(&self) -> Option<&'static [&'static str]> {
276    None
277  }
278
279  /// get injected language regions in the root document. e.g. get JavaScripts in HTML
280  /// it will return a list of tuples of (language, regions).
281  /// The first item is the embedded region language, e.g. javascript
282  /// The second item is a list of regions in tree_sitter.
283  /// also see https://tree-sitter.github.io/tree-sitter/using-parsers#multi-language-documents
284  fn extract_injections<L: LanguageExt>(
285    &self,
286    _root: crate::Node<StrDoc<L>>,
287  ) -> HashMap<String, Vec<TSRange>> {
288    HashMap::new()
289  }
290}
291
292fn position_for_offset(input: &[u8], offset: usize) -> Point {
293  debug_assert!(offset <= input.len());
294  let (mut row, mut col) = (0, 0);
295  for c in &input[0..offset] {
296    if *c as char == '\n' {
297      row += 1;
298      col = 0;
299    } else {
300      col += 1;
301    }
302  }
303  Point::new(row, col)
304}
305
306impl<L: LanguageExt> AstGrep<StrDoc<L>> {
307  pub fn new<S: AsRef<str>>(src: S, lang: L) -> Self {
308    Root::str(src.as_ref(), lang)
309  }
310
311  pub fn source(&self) -> &str {
312    self.doc.get_source().as_str()
313  }
314
315  pub fn generate(self) -> String {
316    self.doc.src
317  }
318}
319
320pub trait ContentExt: Content {
321  fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit;
322}
323impl ContentExt for String {
324  fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit {
325    let start_byte = edit.position;
326    let old_end_byte = edit.position + edit.deleted_length;
327    let new_end_byte = edit.position + edit.inserted_text.len();
328    let input = unsafe { self.as_mut_vec() };
329    let start_position = position_for_offset(input, start_byte);
330    let old_end_position = position_for_offset(input, old_end_byte);
331    input.splice(start_byte..old_end_byte, edit.inserted_text.clone());
332    let new_end_position = position_for_offset(input, new_end_byte);
333    InputEdit {
334      start_byte,
335      old_end_byte,
336      new_end_byte,
337      start_position,
338      old_end_position,
339      new_end_position,
340    }
341  }
342}
343
344impl<L: LanguageExt> Root<StrDoc<L>> {
345  pub fn str(src: &str, lang: L) -> Self {
346    Self::try_new(src, lang).expect("should parse")
347  }
348  pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
349    let doc = StrDoc::try_new(src, lang)?;
350    Ok(Self { doc })
351  }
352  pub fn get_text(&self) -> &str {
353    &self.doc.src
354  }
355
356  pub fn get_injections<F: Fn(&str) -> Option<L>>(&self, get_lang: F) -> Vec<Self> {
357    let root = self.root();
358    let range = self.lang().extract_injections(root);
359    let roots = range
360      .into_iter()
361      .filter_map(|(lang, ranges)| {
362        let lang = get_lang(&lang)?;
363        let source = self.doc.get_source();
364        let mut parser = Parser::new();
365        parser.set_included_ranges(&ranges).ok()?;
366        parser.set_language(&lang.get_ts_language()).ok()?;
367        let tree = parser.parse(source, None)?;
368        Some(Self {
369          doc: StrDoc {
370            src: self.doc.src.clone(),
371            lang,
372            tree,
373          },
374        })
375      })
376      .collect();
377    roots
378  }
379}
380
381pub struct DisplayContext<'r> {
382  /// content for the matched node
383  pub matched: Cow<'r, str>,
384  /// content before the matched node
385  pub leading: &'r str,
386  /// content after the matched node
387  pub trailing: &'r str,
388  /// zero-based start line of the context
389  pub start_line: usize,
390}
391
392/// these methods are only for `StrDoc`
393impl<'r, L: LanguageExt> crate::Node<'r, StrDoc<L>> {
394  #[doc(hidden)]
395  pub fn display_context(&self, before: usize, after: usize) -> DisplayContext<'r> {
396    let source = self.root.doc.get_source().as_str();
397    let bytes = source.as_bytes();
398    let start = self.inner.start_byte();
399    let end = self.inner.end_byte();
400    let (mut leading, mut trailing) = (start, end);
401    let mut lines_before = before + 1;
402    while leading > 0 {
403      if bytes[leading - 1] == b'\n' {
404        lines_before -= 1;
405        if lines_before == 0 {
406          break;
407        }
408      }
409      leading -= 1;
410    }
411    let mut lines_after = after + 1;
412    // tree-sitter will append line ending to source so trailing can be out of bound
413    trailing = trailing.min(bytes.len());
414    while trailing < bytes.len() {
415      if bytes[trailing] == b'\n' {
416        lines_after -= 1;
417        if lines_after == 0 {
418          break;
419        }
420      }
421      trailing += 1;
422    }
423    // lines_before means we matched all context, offset is `before` itself
424    let offset = if lines_before == 0 {
425      before
426    } else {
427      // otherwise, there are fewer than `before` line in src, compute the actual line
428      before + 1 - lines_before
429    };
430    DisplayContext {
431      matched: self.text(),
432      leading: &source[leading..start],
433      trailing: &source[end..trailing],
434      start_line: self.start_pos().line() - offset,
435    }
436  }
437
438  pub fn replace_all<M: Matcher, R: Replacer<StrDoc<L>>>(
439    &self,
440    matcher: M,
441    replacer: R,
442  ) -> Vec<Edit<String>> {
443    // TODO: support nested matches like Some(Some(1)) with pattern Some($A)
444    Visitor::new(&matcher)
445      .reentrant(false)
446      .visit(self.clone())
447      .map(|matched| matched.make_edit(&matcher, &replacer))
448      .collect()
449  }
450}
451
452#[cfg(test)]
453mod test {
454  use super::*;
455  use crate::language::Tsx;
456  use tree_sitter::Point;
457
458  fn parse(src: &str) -> Result<Tree, TSParseError> {
459    parse_lang(|p| p.parse(src, None), Tsx.get_ts_language())
460  }
461
462  #[test]
463  fn test_tree_sitter() -> Result<(), TSParseError> {
464    let tree = parse("var a = 1234")?;
465    let root_node = tree.root_node();
466    assert_eq!(root_node.kind(), "program");
467    assert_eq!(root_node.start_position().column, 0);
468    assert_eq!(root_node.end_position().column, 12);
469    assert_eq!(
470      root_node.to_sexp(),
471      "(program (variable_declaration (variable_declarator name: (identifier) value: (number))))"
472    );
473    Ok(())
474  }
475
476  #[test]
477  fn test_object_literal() -> Result<(), TSParseError> {
478    let tree = parse("{a: $X}")?;
479    let root_node = tree.root_node();
480    // wow this is not label. technically it is wrong but practically it is better LOL
481    assert_eq!(root_node.to_sexp(), "(program (expression_statement (object (pair key: (property_identifier) value: (identifier)))))");
482    Ok(())
483  }
484
485  #[test]
486  fn test_string() -> Result<(), TSParseError> {
487    let tree = parse("'$A'")?;
488    let root_node = tree.root_node();
489    assert_eq!(
490      root_node.to_sexp(),
491      "(program (expression_statement (string (string_fragment))))"
492    );
493    Ok(())
494  }
495
496  #[test]
497  fn test_row_col() -> Result<(), TSParseError> {
498    let tree = parse("😄")?;
499    let root = tree.root_node();
500    assert_eq!(root.start_position(), Point::new(0, 0));
501    // NOTE: Point in tree-sitter is counted in bytes instead of char
502    assert_eq!(root.end_position(), Point::new(0, 4));
503    Ok(())
504  }
505
506  #[test]
507  fn test_edit() -> Result<(), TSParseError> {
508    let mut src = "a + b".to_string();
509    let mut tree = parse(&src)?;
510    let _ = perform_edit(
511      &mut tree,
512      &mut src,
513      &Edit {
514        position: 1,
515        deleted_length: 0,
516        inserted_text: " * b".into(),
517      },
518    );
519    let tree2 = parse_lang(|p| p.parse(&src, Some(&tree)), Tsx.get_ts_language())?;
520    assert_eq!(
521      tree.root_node().to_sexp(),
522      "(program (expression_statement (binary_expression left: (identifier) right: (identifier))))"
523    );
524    assert_eq!(tree2.root_node().to_sexp(), "(program (expression_statement (binary_expression left: (binary_expression left: (identifier) right: (identifier)) right: (identifier))))");
525    Ok(())
526  }
527}