Skip to main content

ast_grep_core/tree_sitter/
mod.rs

1pub mod traversal;
2
3use crate::node::Root;
4use crate::replacer::Replacer;
5use crate::source::{Content, Doc, Edit, SgNode};
6use crate::{node::KindId, Language, Position};
7use crate::{AstGrep, Matcher};
8use std::borrow::Cow;
9use std::num::NonZero;
10use thiserror::Error;
11pub use traversal::{TsPre, Visitor};
12pub use tree_sitter::Language as TSLanguage;
13use tree_sitter::{InputEdit, LanguageError, Node, Parser, Point, Tree};
14pub use tree_sitter::{Point as TSPoint, Range as TSRange};
15
16/// Represents tree-sitter related error
17#[derive(Debug, Error)]
18pub enum TSParseError {
19  #[error("incompatible `Language` is assigned to a `Parser`.")]
20  Language(#[from] LanguageError),
21  /// A general error when tree sitter fails to parse in time. It can be caused by
22  /// the following reasons but tree-sitter does not provide error detail.
23  /// * The timeout set with [Parser::set_timeout_micros] expired
24  /// * The cancellation flag set with [Parser::set_cancellation_flag] was flipped
25  /// * The parser has not yet had a language assigned with [Parser::set_language]
26  #[error("general error when tree-sitter fails to parse.")]
27  TreeUnavailable,
28}
29
30#[inline]
31fn parse_lang(
32  parse_fn: impl Fn(&mut Parser) -> Option<Tree>,
33  ts_lang: TSLanguage,
34) -> Result<Tree, TSParseError> {
35  let mut parser = Parser::new();
36  parser.set_language(&ts_lang)?;
37  if let Some(tree) = parse_fn(&mut parser) {
38    Ok(tree)
39  } else {
40    Err(TSParseError::TreeUnavailable)
41  }
42}
43
44#[derive(Clone)]
45pub struct StrDoc<L: LanguageExt> {
46  pub src: String,
47  pub lang: L,
48  pub tree: Tree,
49}
50
51impl<L: LanguageExt> StrDoc<L> {
52  pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
53    let src = src.to_string();
54    let ts_lang = lang.get_ts_language();
55    let tree = parse_lang(|p| p.parse(src.as_bytes(), None), ts_lang).map_err(|e| e.to_string())?;
56    Ok(Self { src, lang, tree })
57  }
58  pub fn new(src: &str, lang: L) -> Self {
59    Self::try_new(src, lang).expect("Parser tree error")
60  }
61  fn parse(&self, old_tree: Option<&Tree>) -> Result<Tree, TSParseError> {
62    let source = self.get_source();
63    let lang = self.get_lang().get_ts_language();
64    parse_lang(|p| p.parse(source.as_bytes(), old_tree), lang)
65  }
66}
67
68impl<L: LanguageExt> Doc for StrDoc<L> {
69  type Source = String;
70  type Lang = L;
71  type Node<'r> = Node<'r>;
72  fn get_lang(&self) -> &Self::Lang {
73    &self.lang
74  }
75  fn get_source(&self) -> &Self::Source {
76    &self.src
77  }
78  fn do_edit(&mut self, edit: &Edit<Self::Source>) -> Result<(), String> {
79    let source = &mut self.src;
80    perform_edit(&mut self.tree, source, edit);
81    self.tree = self.parse(Some(&self.tree)).map_err(|e| e.to_string())?;
82    Ok(())
83  }
84  fn root_node(&self) -> Node<'_> {
85    self.tree.root_node()
86  }
87  fn get_node_text<'a>(&'a self, node: &Self::Node<'a>) -> Cow<'a, str> {
88    Cow::Borrowed(
89      node
90        .utf8_text(self.src.as_bytes())
91        .expect("invalid source text encoding"),
92    )
93  }
94}
95
96struct NodeWalker<'tree> {
97  cursor: tree_sitter::TreeCursor<'tree>,
98  count: usize,
99}
100
101impl<'tree> Iterator for NodeWalker<'tree> {
102  type Item = Node<'tree>;
103  fn next(&mut self) -> Option<Self::Item> {
104    if self.count == 0 {
105      return None;
106    }
107    let ret = Some(self.cursor.node());
108    self.cursor.goto_next_sibling();
109    self.count -= 1;
110    ret
111  }
112}
113
114impl ExactSizeIterator for NodeWalker<'_> {
115  fn len(&self) -> usize {
116    self.count
117  }
118}
119
120impl<'r> SgNode<'r> for Node<'r> {
121  fn parent(&self) -> Option<Self> {
122    Node::parent(self)
123  }
124  fn ancestors(&self, root: Self) -> impl Iterator<Item = Self> {
125    let mut ancestor = Some(root);
126    let self_id = self.id();
127    std::iter::from_fn(move || {
128      let inner = ancestor.take()?;
129      if inner.id() == self_id {
130        return None;
131      }
132      ancestor = inner.child_with_descendant(*self);
133      Some(inner)
134    })
135    // We must iterate up the tree to preserve backwards compatibility
136    .collect::<Vec<_>>()
137    .into_iter()
138    .rev()
139  }
140  fn dfs(&self) -> impl Iterator<Item = Self> {
141    TsPre::new(self)
142  }
143  fn child(&self, nth: usize) -> Option<Self> {
144    Node::child(self, nth as u32)
145  }
146  fn children(&self) -> impl ExactSizeIterator<Item = Self> {
147    let mut cursor = self.walk();
148    cursor.goto_first_child();
149    NodeWalker {
150      cursor,
151      count: self.child_count(),
152    }
153  }
154  fn child_by_field_id(&self, field_id: u16) -> Option<Self> {
155    Node::child_by_field_id(self, field_id)
156  }
157  fn next(&self) -> Option<Self> {
158    self.next_sibling()
159  }
160  fn prev(&self) -> Option<Self> {
161    self.prev_sibling()
162  }
163  fn next_all(&self) -> impl Iterator<Item = Self> {
164    // if root is none, use self as fallback to return a type-stable Iterator
165    let node = self.parent().unwrap_or(*self);
166    let mut cursor = node.walk();
167    cursor.goto_first_child_for_byte(self.start_byte());
168    std::iter::from_fn(move || {
169      if cursor.goto_next_sibling() {
170        Some(cursor.node())
171      } else {
172        None
173      }
174    })
175  }
176  fn prev_all(&self) -> impl Iterator<Item = Self> {
177    // if root is none, use self as fallback to return a type-stable Iterator
178    let node = self.parent().unwrap_or(*self);
179    let mut cursor = node.walk();
180    cursor.goto_first_child_for_byte(self.start_byte());
181    std::iter::from_fn(move || {
182      if cursor.goto_previous_sibling() {
183        Some(cursor.node())
184      } else {
185        None
186      }
187    })
188  }
189  fn is_named(&self) -> bool {
190    Node::is_named(self)
191  }
192  /// N.B. it is different from is_named && is_leaf
193  /// if a node has no named children.
194  fn is_named_leaf(&self) -> bool {
195    self.named_child_count() == 0
196  }
197  fn is_leaf(&self) -> bool {
198    self.child_count() == 0
199  }
200  fn kind(&self) -> Cow<'_, str> {
201    Cow::Borrowed(Node::kind(self))
202  }
203  fn kind_id(&self) -> KindId {
204    Node::kind_id(self)
205  }
206  fn node_id(&self) -> usize {
207    self.id()
208  }
209  fn range(&self) -> std::ops::Range<usize> {
210    self.start_byte()..self.end_byte()
211  }
212  fn start_pos(&self) -> Position {
213    let pos = self.start_position();
214    let byte = self.start_byte();
215    Position::new(pos.row, pos.column, byte)
216  }
217  fn end_pos(&self) -> Position {
218    let pos = self.end_position();
219    let byte = self.end_byte();
220    Position::new(pos.row, pos.column, byte)
221  }
222  // missing node is a tree-sitter specific concept
223  fn is_missing(&self) -> bool {
224    Node::is_missing(self)
225  }
226  fn is_error(&self) -> bool {
227    Node::is_error(self)
228  }
229
230  fn field(&self, name: &str) -> Option<Self> {
231    self.child_by_field_name(name)
232  }
233  fn field_children(&self, field_id: Option<u16>) -> impl Iterator<Item = Self> {
234    let field_id = field_id.and_then(NonZero::new);
235    let mut cursor = self.walk();
236    let has_children = cursor.goto_first_child();
237    // if field_id is not found, iteration is done
238    let mut done = field_id.is_none() || !has_children;
239
240    std::iter::from_fn(move || {
241      if done {
242        return None;
243      }
244      while cursor.field_id() != field_id {
245        if !cursor.goto_next_sibling() {
246          return None;
247        }
248      }
249      let ret = cursor.node();
250      if !cursor.goto_next_sibling() {
251        done = true;
252      }
253      Some(ret)
254    })
255  }
256}
257
258pub fn perform_edit<S: ContentExt>(tree: &mut Tree, input: &mut S, edit: &Edit<S>) -> InputEdit {
259  let edit = input.accept_edit(edit);
260  tree.edit(&edit);
261  edit
262}
263
264/// tree-sitter specific language trait
265pub trait LanguageExt: Language {
266  /// Create an [`AstGrep`] instance for the language
267  fn ast_grep<S: AsRef<str>>(&self, source: S) -> AstGrep<StrDoc<Self>> {
268    AstGrep::new(source, self.clone())
269  }
270
271  /// tree sitter language to parse the source
272  fn get_ts_language(&self) -> TSLanguage;
273
274  fn injectable_languages(&self) -> Option<&'static [&'static str]> {
275    None
276  }
277
278  /// Get injected language regions in the root document. e.g. get JavaScripts in HTML.
279  /// Each entry is parsed as an **independent** tree-sitter document.
280  /// Multiple entries for the same language produce separate parse trees.
281  /// Also see <https://tree-sitter.github.io/tree-sitter/using-parsers#multi-language-documents>
282  fn extract_injections<L: LanguageExt>(
283    &self,
284    _root: crate::Node<StrDoc<L>>,
285  ) -> Vec<(String, Vec<TSRange>)> {
286    Vec::new()
287  }
288}
289
290fn position_for_offset(input: &[u8], offset: usize) -> Point {
291  debug_assert!(offset <= input.len());
292  let (mut row, mut col) = (0, 0);
293  for c in &input[0..offset] {
294    if *c as char == '\n' {
295      row += 1;
296      col = 0;
297    } else {
298      col += 1;
299    }
300  }
301  Point::new(row, col)
302}
303
304impl<L: LanguageExt> AstGrep<StrDoc<L>> {
305  pub fn new<S: AsRef<str>>(src: S, lang: L) -> Self {
306    Root::str(src.as_ref(), lang)
307  }
308
309  pub fn source(&self) -> &str {
310    self.doc.get_source().as_str()
311  }
312
313  pub fn generate(self) -> String {
314    self.doc.src
315  }
316}
317
318pub trait ContentExt: Content {
319  fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit;
320}
321impl ContentExt for String {
322  fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit {
323    let start_byte = edit.position;
324    let old_end_byte = edit.position + edit.deleted_length;
325    let new_end_byte = edit.position + edit.inserted_text.len();
326    let input = unsafe { self.as_mut_vec() };
327    let start_position = position_for_offset(input, start_byte);
328    let old_end_position = position_for_offset(input, old_end_byte);
329    input.splice(start_byte..old_end_byte, edit.inserted_text.clone());
330    let new_end_position = position_for_offset(input, new_end_byte);
331    InputEdit {
332      start_byte,
333      old_end_byte,
334      new_end_byte,
335      start_position,
336      old_end_position,
337      new_end_position,
338    }
339  }
340}
341
342impl<L: LanguageExt> Root<StrDoc<L>> {
343  pub fn str(src: &str, lang: L) -> Self {
344    Self::try_new(src, lang).expect("should parse")
345  }
346  pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
347    let doc = StrDoc::try_new(src, lang)?;
348    Ok(Self { doc })
349  }
350  pub fn get_text(&self) -> &str {
351    &self.doc.src
352  }
353
354  pub fn get_injections<F: Fn(&str) -> Option<L>>(&self, get_lang: F) -> Vec<Self> {
355    let root = self.root();
356    let range = self.lang().extract_injections(root);
357    let roots = range
358      .into_iter()
359      .filter_map(|(lang_str, ranges)| {
360        let lang = get_lang(&lang_str)?;
361        let source = self.doc.get_source();
362        let mut parser = Parser::new();
363        parser.set_included_ranges(&ranges).ok()?;
364        parser.set_language(&lang.get_ts_language()).ok()?;
365        let tree = parser.parse(source, None)?;
366        Some(Self {
367          doc: StrDoc {
368            src: self.doc.src.clone(),
369            lang,
370            tree,
371          },
372        })
373      })
374      .collect();
375    roots
376  }
377}
378
379pub struct DisplayContext<'r> {
380  /// content for the matched node
381  pub matched: Cow<'r, str>,
382  /// content before the matched node
383  pub leading: &'r str,
384  /// content after the matched node
385  pub trailing: &'r str,
386  /// zero-based start line of the context
387  pub start_line: usize,
388}
389
390/// these methods are only for `StrDoc`
391impl<'r, L: LanguageExt> crate::Node<'r, StrDoc<L>> {
392  #[doc(hidden)]
393  pub fn display_context(&self, before: usize, after: usize) -> DisplayContext<'r> {
394    let source = self.root.doc.get_source().as_str();
395    let bytes = source.as_bytes();
396    let start = self.inner.start_byte();
397    let end = self.inner.end_byte();
398    let (mut leading, mut trailing) = (start, end);
399    let mut lines_before = before + 1;
400    while leading > 0 {
401      if bytes[leading - 1] == b'\n' {
402        lines_before -= 1;
403        if lines_before == 0 {
404          break;
405        }
406      }
407      leading -= 1;
408    }
409    let mut lines_after = after + 1;
410    // tree-sitter will append line ending to source so trailing can be out of bound
411    trailing = trailing.min(bytes.len());
412    while trailing < bytes.len() {
413      if bytes[trailing] == b'\n' {
414        lines_after -= 1;
415        if lines_after == 0 {
416          break;
417        }
418      }
419      trailing += 1;
420    }
421    // lines_before means we matched all context, offset is `before` itself
422    let offset = if lines_before == 0 {
423      before
424    } else {
425      // otherwise, there are fewer than `before` line in src, compute the actual line
426      before + 1 - lines_before
427    };
428    DisplayContext {
429      matched: self.text(),
430      leading: &source[leading..start],
431      trailing: &source[end..trailing],
432      start_line: self.start_pos().line() - offset,
433    }
434  }
435
436  pub fn replace_all<M: Matcher, R: Replacer<StrDoc<L>>>(
437    &self,
438    matcher: M,
439    replacer: R,
440  ) -> Vec<Edit<String>> {
441    // TODO: support nested matches like Some(Some(1)) with pattern Some($A)
442    Visitor::new(&matcher)
443      .reentrant(false)
444      .visit(self.clone())
445      .map(|matched| matched.make_edit(&matcher, &replacer))
446      .collect()
447  }
448}
449
450#[cfg(test)]
451mod test {
452  use super::*;
453  use crate::language::Tsx;
454  use tree_sitter::Point;
455
456  fn parse(src: &str) -> Result<Tree, TSParseError> {
457    parse_lang(|p| p.parse(src, None), Tsx.get_ts_language())
458  }
459
460  #[test]
461  fn test_tree_sitter() -> Result<(), TSParseError> {
462    let tree = parse("var a = 1234")?;
463    let root_node = tree.root_node();
464    assert_eq!(root_node.kind(), "program");
465    assert_eq!(root_node.start_position().column, 0);
466    assert_eq!(root_node.end_position().column, 12);
467    assert_eq!(
468      root_node.to_sexp(),
469      "(program (variable_declaration (variable_declarator name: (identifier) value: (number))))"
470    );
471    Ok(())
472  }
473
474  #[test]
475  fn test_object_literal() -> Result<(), TSParseError> {
476    let tree = parse("{a: $X}")?;
477    let root_node = tree.root_node();
478    // wow this is not label. technically it is wrong but practically it is better LOL
479    assert_eq!(root_node.to_sexp(), "(program (expression_statement (object (pair key: (property_identifier) value: (identifier)))))");
480    Ok(())
481  }
482
483  #[test]
484  fn test_string() -> Result<(), TSParseError> {
485    let tree = parse("'$A'")?;
486    let root_node = tree.root_node();
487    assert_eq!(
488      root_node.to_sexp(),
489      "(program (expression_statement (string (string_fragment))))"
490    );
491    Ok(())
492  }
493
494  #[test]
495  fn test_row_col() -> Result<(), TSParseError> {
496    let tree = parse("😄")?;
497    let root = tree.root_node();
498    assert_eq!(root.start_position(), Point::new(0, 0));
499    // NOTE: Point in tree-sitter is counted in bytes instead of char
500    assert_eq!(root.end_position(), Point::new(0, 4));
501    Ok(())
502  }
503
504  #[test]
505  fn test_edit() -> Result<(), TSParseError> {
506    let mut src = "a + b".to_string();
507    let mut tree = parse(&src)?;
508    let _ = perform_edit(
509      &mut tree,
510      &mut src,
511      &Edit {
512        position: 1,
513        deleted_length: 0,
514        inserted_text: " * b".into(),
515      },
516    );
517    let tree2 = parse_lang(|p| p.parse(&src, Some(&tree)), Tsx.get_ts_language())?;
518    assert_eq!(
519      tree.root_node().to_sexp(),
520      "(program (expression_statement (binary_expression left: (identifier) right: (identifier))))"
521    );
522    assert_eq!(tree2.root_node().to_sexp(), "(program (expression_statement (binary_expression left: (binary_expression left: (identifier) right: (identifier)) right: (identifier))))");
523    Ok(())
524  }
525}