ast_grep_core/tree_sitter/
mod.rs1pub mod traversal;
2
3use crate::node::Root;
4use crate::replacer::Replacer;
5use crate::source::{Content, Doc, Edit, SgNode};
6use crate::{node::KindId, Language, Position};
7use crate::{AstGrep, Matcher};
8use std::borrow::Cow;
9use std::collections::HashMap;
10use std::num::NonZero;
11use thiserror::Error;
12pub use traversal::{TsPre, Visitor};
13pub use tree_sitter::Language as TSLanguage;
14use tree_sitter::{InputEdit, LanguageError, Node, Parser, Point, Tree};
15pub use tree_sitter::{Point as TSPoint, Range as TSRange};
16
17#[derive(Debug, Error)]
19pub enum TSParseError {
20 #[error("incompatible `Language` is assigned to a `Parser`.")]
21 Language(#[from] LanguageError),
22 #[error("general error when tree-sitter fails to parse.")]
28 TreeUnavailable,
29}
30
31#[inline]
32fn parse_lang(
33 parse_fn: impl Fn(&mut Parser) -> Option<Tree>,
34 ts_lang: TSLanguage,
35) -> Result<Tree, TSParseError> {
36 let mut parser = Parser::new();
37 parser.set_language(&ts_lang)?;
38 if let Some(tree) = parse_fn(&mut parser) {
39 Ok(tree)
40 } else {
41 Err(TSParseError::TreeUnavailable)
42 }
43}
44
45#[derive(Clone)]
46pub struct StrDoc<L: LanguageExt> {
47 pub src: String,
48 pub lang: L,
49 pub tree: Tree,
50}
51
52impl<L: LanguageExt> StrDoc<L> {
53 pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
54 let src = src.to_string();
55 let ts_lang = lang.get_ts_language();
56 let tree = parse_lang(|p| p.parse(src.as_bytes(), None), ts_lang).map_err(|e| e.to_string())?;
57 Ok(Self { src, lang, tree })
58 }
59 pub fn new(src: &str, lang: L) -> Self {
60 Self::try_new(src, lang).expect("Parser tree error")
61 }
62 fn parse(&self, old_tree: Option<&Tree>) -> Result<Tree, TSParseError> {
63 let source = self.get_source();
64 let lang = self.get_lang().get_ts_language();
65 parse_lang(|p| p.parse(source.as_bytes(), old_tree), lang)
66 }
67}
68
69impl<L: LanguageExt> Doc for StrDoc<L> {
70 type Source = String;
71 type Lang = L;
72 type Node<'r> = Node<'r>;
73 fn get_lang(&self) -> &Self::Lang {
74 &self.lang
75 }
76 fn get_source(&self) -> &Self::Source {
77 &self.src
78 }
79 fn do_edit(&mut self, edit: &Edit<Self::Source>) -> Result<(), String> {
80 let source = &mut self.src;
81 perform_edit(&mut self.tree, source, edit);
82 self.tree = self.parse(Some(&self.tree)).map_err(|e| e.to_string())?;
83 Ok(())
84 }
85 fn root_node(&self) -> Node<'_> {
86 self.tree.root_node()
87 }
88 fn get_node_text<'a>(&'a self, node: &Self::Node<'a>) -> Cow<'a, str> {
89 Cow::Borrowed(
90 node
91 .utf8_text(self.src.as_bytes())
92 .expect("invalid source text encoding"),
93 )
94 }
95}
96
97struct NodeWalker<'tree> {
98 cursor: tree_sitter::TreeCursor<'tree>,
99 count: usize,
100}
101
102impl<'tree> Iterator for NodeWalker<'tree> {
103 type Item = Node<'tree>;
104 fn next(&mut self) -> Option<Self::Item> {
105 if self.count == 0 {
106 return None;
107 }
108 let ret = Some(self.cursor.node());
109 self.cursor.goto_next_sibling();
110 self.count -= 1;
111 ret
112 }
113}
114
115impl ExactSizeIterator for NodeWalker<'_> {
116 fn len(&self) -> usize {
117 self.count
118 }
119}
120
121impl<'r> SgNode<'r> for Node<'r> {
122 fn parent(&self) -> Option<Self> {
123 Node::parent(self)
124 }
125 fn ancestors(&self, root: Self) -> impl Iterator<Item = Self> {
126 let mut ancestor = Some(root);
127 let self_id = self.id();
128 std::iter::from_fn(move || {
129 let inner = ancestor.take()?;
130 if inner.id() == self_id {
131 return None;
132 }
133 ancestor = inner.child_with_descendant(*self);
134 Some(inner)
135 })
136 .collect::<Vec<_>>()
138 .into_iter()
139 .rev()
140 }
141 fn dfs(&self) -> impl Iterator<Item = Self> {
142 TsPre::new(self)
143 }
144 fn child(&self, nth: usize) -> Option<Self> {
145 Node::child(self, nth)
147 }
148 fn children(&self) -> impl ExactSizeIterator<Item = Self> {
149 let mut cursor = self.walk();
150 cursor.goto_first_child();
151 NodeWalker {
152 cursor,
153 count: self.child_count(),
154 }
155 }
156 fn child_by_field_id(&self, field_id: u16) -> Option<Self> {
157 Node::child_by_field_id(self, field_id)
158 }
159 fn next(&self) -> Option<Self> {
160 self.next_sibling()
161 }
162 fn prev(&self) -> Option<Self> {
163 self.prev_sibling()
164 }
165 fn next_all(&self) -> impl Iterator<Item = Self> {
166 let node = self.parent().unwrap_or(*self);
168 let mut cursor = node.walk();
169 cursor.goto_first_child_for_byte(self.start_byte());
170 std::iter::from_fn(move || {
171 if cursor.goto_next_sibling() {
172 Some(cursor.node())
173 } else {
174 None
175 }
176 })
177 }
178 fn prev_all(&self) -> impl Iterator<Item = Self> {
179 let node = self.parent().unwrap_or(*self);
181 let mut cursor = node.walk();
182 cursor.goto_first_child_for_byte(self.start_byte());
183 std::iter::from_fn(move || {
184 if cursor.goto_previous_sibling() {
185 Some(cursor.node())
186 } else {
187 None
188 }
189 })
190 }
191 fn is_named(&self) -> bool {
192 Node::is_named(self)
193 }
194 fn is_named_leaf(&self) -> bool {
197 self.named_child_count() == 0
198 }
199 fn is_leaf(&self) -> bool {
200 self.child_count() == 0
201 }
202 fn kind(&self) -> Cow<'_, str> {
203 Cow::Borrowed(Node::kind(self))
204 }
205 fn kind_id(&self) -> KindId {
206 Node::kind_id(self)
207 }
208 fn node_id(&self) -> usize {
209 self.id()
210 }
211 fn range(&self) -> std::ops::Range<usize> {
212 self.start_byte()..self.end_byte()
213 }
214 fn start_pos(&self) -> Position {
215 let pos = self.start_position();
216 let byte = self.start_byte();
217 Position::new(pos.row, pos.column, byte)
218 }
219 fn end_pos(&self) -> Position {
220 let pos = self.end_position();
221 let byte = self.end_byte();
222 Position::new(pos.row, pos.column, byte)
223 }
224 fn is_missing(&self) -> bool {
226 Node::is_missing(self)
227 }
228 fn is_error(&self) -> bool {
229 Node::is_error(self)
230 }
231
232 fn field(&self, name: &str) -> Option<Self> {
233 self.child_by_field_name(name)
234 }
235 fn field_children(&self, field_id: Option<u16>) -> impl Iterator<Item = Self> {
236 let field_id = field_id.and_then(NonZero::new);
237 let mut cursor = self.walk();
238 cursor.goto_first_child();
239 let mut done = field_id.is_none();
241
242 std::iter::from_fn(move || {
243 if done {
244 return None;
245 }
246 while cursor.field_id() != field_id {
247 if !cursor.goto_next_sibling() {
248 return None;
249 }
250 }
251 let ret = cursor.node();
252 if !cursor.goto_next_sibling() {
253 done = true;
254 }
255 Some(ret)
256 })
257 }
258}
259
260pub fn perform_edit<S: ContentExt>(tree: &mut Tree, input: &mut S, edit: &Edit<S>) -> InputEdit {
261 let edit = input.accept_edit(edit);
262 tree.edit(&edit);
263 edit
264}
265
266pub trait LanguageExt: Language {
268 fn ast_grep<S: AsRef<str>>(&self, source: S) -> AstGrep<StrDoc<Self>> {
270 AstGrep::new(source, self.clone())
271 }
272
273 fn get_ts_language(&self) -> TSLanguage;
275
276 fn injectable_languages(&self) -> Option<&'static [&'static str]> {
277 None
278 }
279
280 fn extract_injections<L: LanguageExt>(
286 &self,
287 _root: crate::Node<StrDoc<L>>,
288 ) -> HashMap<String, Vec<TSRange>> {
289 HashMap::new()
290 }
291}
292
293fn position_for_offset(input: &[u8], offset: usize) -> Point {
294 debug_assert!(offset <= input.len());
295 let (mut row, mut col) = (0, 0);
296 for c in &input[0..offset] {
297 if *c as char == '\n' {
298 row += 1;
299 col = 0;
300 } else {
301 col += 1;
302 }
303 }
304 Point::new(row, col)
305}
306
307impl<L: LanguageExt> AstGrep<StrDoc<L>> {
308 pub fn new<S: AsRef<str>>(src: S, lang: L) -> Self {
309 Root::str(src.as_ref(), lang)
310 }
311
312 pub fn source(&self) -> &str {
313 self.doc.get_source().as_str()
314 }
315
316 pub fn generate(self) -> String {
317 self.doc.src
318 }
319}
320
321pub trait ContentExt: Content {
322 fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit;
323}
324impl ContentExt for String {
325 fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit {
326 let start_byte = edit.position;
327 let old_end_byte = edit.position + edit.deleted_length;
328 let new_end_byte = edit.position + edit.inserted_text.len();
329 let input = unsafe { self.as_mut_vec() };
330 let start_position = position_for_offset(input, start_byte);
331 let old_end_position = position_for_offset(input, old_end_byte);
332 input.splice(start_byte..old_end_byte, edit.inserted_text.clone());
333 let new_end_position = position_for_offset(input, new_end_byte);
334 InputEdit {
335 start_byte,
336 old_end_byte,
337 new_end_byte,
338 start_position,
339 old_end_position,
340 new_end_position,
341 }
342 }
343}
344
345impl<L: LanguageExt> Root<StrDoc<L>> {
346 pub fn str(src: &str, lang: L) -> Self {
347 Self::try_new(src, lang).expect("should parse")
348 }
349 pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
350 let doc = StrDoc::try_new(src, lang)?;
351 Ok(Self { doc })
352 }
353 pub fn get_text(&self) -> &str {
354 &self.doc.src
355 }
356
357 pub fn get_injections<F: Fn(&str) -> Option<L>>(&self, get_lang: F) -> Vec<Self> {
358 let root = self.root();
359 let range = self.lang().extract_injections(root);
360 let roots = range
361 .into_iter()
362 .filter_map(|(lang, ranges)| {
363 let lang = get_lang(&lang)?;
364 let source = self.doc.get_source();
365 let mut parser = Parser::new();
366 parser.set_included_ranges(&ranges).ok()?;
367 parser.set_language(&lang.get_ts_language()).ok()?;
368 let tree = parser.parse(source, None)?;
369 Some(Self {
370 doc: StrDoc {
371 src: self.doc.src.clone(),
372 lang,
373 tree,
374 },
375 })
376 })
377 .collect();
378 roots
379 }
380}
381
382pub struct DisplayContext<'r> {
383 pub matched: Cow<'r, str>,
385 pub leading: &'r str,
387 pub trailing: &'r str,
389 pub start_line: usize,
391}
392
393impl<'r, L: LanguageExt> crate::Node<'r, StrDoc<L>> {
395 #[doc(hidden)]
396 pub fn display_context(&self, before: usize, after: usize) -> DisplayContext<'r> {
397 let source = self.root.doc.get_source().as_str();
398 let bytes = source.as_bytes();
399 let start = self.inner.start_byte();
400 let end = self.inner.end_byte();
401 let (mut leading, mut trailing) = (start, end);
402 let mut lines_before = before + 1;
403 while leading > 0 {
404 if bytes[leading - 1] == b'\n' {
405 lines_before -= 1;
406 if lines_before == 0 {
407 break;
408 }
409 }
410 leading -= 1;
411 }
412 let mut lines_after = after + 1;
413 trailing = trailing.min(bytes.len());
415 while trailing < bytes.len() {
416 if bytes[trailing] == b'\n' {
417 lines_after -= 1;
418 if lines_after == 0 {
419 break;
420 }
421 }
422 trailing += 1;
423 }
424 let offset = if lines_before == 0 {
426 before
427 } else {
428 before + 1 - lines_before
430 };
431 DisplayContext {
432 matched: self.text(),
433 leading: &source[leading..start],
434 trailing: &source[end..trailing],
435 start_line: self.start_pos().line() - offset,
436 }
437 }
438
439 pub fn replace_all<M: Matcher, R: Replacer<StrDoc<L>>>(
440 &self,
441 matcher: M,
442 replacer: R,
443 ) -> Vec<Edit<String>> {
444 Visitor::new(&matcher)
446 .reentrant(false)
447 .visit(self.clone())
448 .map(|matched| matched.make_edit(&matcher, &replacer))
449 .collect()
450 }
451}
452
453#[cfg(test)]
454mod test {
455 use super::*;
456 use crate::language::Tsx;
457 use tree_sitter::Point;
458
459 fn parse(src: &str) -> Result<Tree, TSParseError> {
460 parse_lang(|p| p.parse(src, None), Tsx.get_ts_language())
461 }
462
463 #[test]
464 fn test_tree_sitter() -> Result<(), TSParseError> {
465 let tree = parse("var a = 1234")?;
466 let root_node = tree.root_node();
467 assert_eq!(root_node.kind(), "program");
468 assert_eq!(root_node.start_position().column, 0);
469 assert_eq!(root_node.end_position().column, 12);
470 assert_eq!(
471 root_node.to_sexp(),
472 "(program (variable_declaration (variable_declarator name: (identifier) value: (number))))"
473 );
474 Ok(())
475 }
476
477 #[test]
478 fn test_object_literal() -> Result<(), TSParseError> {
479 let tree = parse("{a: $X}")?;
480 let root_node = tree.root_node();
481 assert_eq!(root_node.to_sexp(), "(program (expression_statement (object (pair key: (property_identifier) value: (identifier)))))");
483 Ok(())
484 }
485
486 #[test]
487 fn test_string() -> Result<(), TSParseError> {
488 let tree = parse("'$A'")?;
489 let root_node = tree.root_node();
490 assert_eq!(
491 root_node.to_sexp(),
492 "(program (expression_statement (string (string_fragment))))"
493 );
494 Ok(())
495 }
496
497 #[test]
498 fn test_row_col() -> Result<(), TSParseError> {
499 let tree = parse("😄")?;
500 let root = tree.root_node();
501 assert_eq!(root.start_position(), Point::new(0, 0));
502 assert_eq!(root.end_position(), Point::new(0, 4));
504 Ok(())
505 }
506
507 #[test]
508 fn test_edit() -> Result<(), TSParseError> {
509 let mut src = "a + b".to_string();
510 let mut tree = parse(&src)?;
511 let _ = perform_edit(
512 &mut tree,
513 &mut src,
514 &Edit {
515 position: 1,
516 deleted_length: 0,
517 inserted_text: " * b".into(),
518 },
519 );
520 let tree2 = parse_lang(|p| p.parse(&src, Some(&tree)), Tsx.get_ts_language())?;
521 assert_eq!(
522 tree.root_node().to_sexp(),
523 "(program (expression_statement (binary_expression left: (identifier) right: (identifier))))"
524 );
525 assert_eq!(tree2.root_node().to_sexp(), "(program (expression_statement (binary_expression left: (binary_expression left: (identifier) right: (identifier)) right: (identifier))))");
526 Ok(())
527 }
528}