ast_grep_core/tree_sitter/
mod.rs1pub mod traversal;
2
3use crate::node::Root;
4use crate::replacer::Replacer;
5use crate::source::{Content, Doc, Edit, SgNode};
6use crate::{node::KindId, Language, Position};
7use crate::{AstGrep, Matcher};
8use std::borrow::Cow;
9use std::num::NonZero;
10use thiserror::Error;
11pub use traversal::{TsPre, Visitor};
12pub use tree_sitter::Language as TSLanguage;
13use tree_sitter::{InputEdit, LanguageError, Node, Parser, Point, Tree};
14pub use tree_sitter::{Point as TSPoint, Range as TSRange};
15
16#[derive(Debug, Error)]
18pub enum TSParseError {
19 #[error("incompatible `Language` is assigned to a `Parser`.")]
20 Language(#[from] LanguageError),
21 #[error("general error when tree-sitter fails to parse.")]
27 TreeUnavailable,
28}
29
30#[inline]
31fn parse_lang(
32 parse_fn: impl Fn(&mut Parser) -> Option<Tree>,
33 ts_lang: TSLanguage,
34) -> Result<Tree, TSParseError> {
35 let mut parser = Parser::new();
36 parser.set_language(&ts_lang)?;
37 if let Some(tree) = parse_fn(&mut parser) {
38 Ok(tree)
39 } else {
40 Err(TSParseError::TreeUnavailable)
41 }
42}
43
44#[derive(Clone)]
45pub struct StrDoc<L: LanguageExt> {
46 pub src: String,
47 pub lang: L,
48 pub tree: Tree,
49}
50
51impl<L: LanguageExt> StrDoc<L> {
52 pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
53 let src = src.to_string();
54 let ts_lang = lang.get_ts_language();
55 let tree = parse_lang(|p| p.parse(src.as_bytes(), None), ts_lang).map_err(|e| e.to_string())?;
56 Ok(Self { src, lang, tree })
57 }
58 pub fn new(src: &str, lang: L) -> Self {
59 Self::try_new(src, lang).expect("Parser tree error")
60 }
61 fn parse(&self, old_tree: Option<&Tree>) -> Result<Tree, TSParseError> {
62 let source = self.get_source();
63 let lang = self.get_lang().get_ts_language();
64 parse_lang(|p| p.parse(source.as_bytes(), old_tree), lang)
65 }
66}
67
68impl<L: LanguageExt> Doc for StrDoc<L> {
69 type Source = String;
70 type Lang = L;
71 type Node<'r> = Node<'r>;
72 fn get_lang(&self) -> &Self::Lang {
73 &self.lang
74 }
75 fn get_source(&self) -> &Self::Source {
76 &self.src
77 }
78 fn do_edit(&mut self, edit: &Edit<Self::Source>) -> Result<(), String> {
79 let source = &mut self.src;
80 perform_edit(&mut self.tree, source, edit);
81 self.tree = self.parse(Some(&self.tree)).map_err(|e| e.to_string())?;
82 Ok(())
83 }
84 fn root_node(&self) -> Node<'_> {
85 self.tree.root_node()
86 }
87 fn get_node_text<'a>(&'a self, node: &Self::Node<'a>) -> Cow<'a, str> {
88 Cow::Borrowed(
89 node
90 .utf8_text(self.src.as_bytes())
91 .expect("invalid source text encoding"),
92 )
93 }
94}
95
96struct NodeWalker<'tree> {
97 cursor: tree_sitter::TreeCursor<'tree>,
98 count: usize,
99}
100
101impl<'tree> Iterator for NodeWalker<'tree> {
102 type Item = Node<'tree>;
103 fn next(&mut self) -> Option<Self::Item> {
104 if self.count == 0 {
105 return None;
106 }
107 let ret = Some(self.cursor.node());
108 self.cursor.goto_next_sibling();
109 self.count -= 1;
110 ret
111 }
112}
113
114impl ExactSizeIterator for NodeWalker<'_> {
115 fn len(&self) -> usize {
116 self.count
117 }
118}
119
120impl<'r> SgNode<'r> for Node<'r> {
121 fn parent(&self) -> Option<Self> {
122 Node::parent(self)
123 }
124 fn ancestors(&self, root: Self) -> impl Iterator<Item = Self> {
125 let mut ancestor = Some(root);
126 let self_id = self.id();
127 std::iter::from_fn(move || {
128 let inner = ancestor.take()?;
129 if inner.id() == self_id {
130 return None;
131 }
132 ancestor = inner.child_with_descendant(*self);
133 Some(inner)
134 })
135 .collect::<Vec<_>>()
137 .into_iter()
138 .rev()
139 }
140 fn dfs(&self) -> impl Iterator<Item = Self> {
141 TsPre::new(self)
142 }
143 fn child(&self, nth: usize) -> Option<Self> {
144 Node::child(self, nth as u32)
145 }
146 fn children(&self) -> impl ExactSizeIterator<Item = Self> {
147 let mut cursor = self.walk();
148 cursor.goto_first_child();
149 NodeWalker {
150 cursor,
151 count: self.child_count(),
152 }
153 }
154 fn child_by_field_id(&self, field_id: u16) -> Option<Self> {
155 Node::child_by_field_id(self, field_id)
156 }
157 fn next(&self) -> Option<Self> {
158 self.next_sibling()
159 }
160 fn prev(&self) -> Option<Self> {
161 self.prev_sibling()
162 }
163 fn next_all(&self) -> impl Iterator<Item = Self> {
164 let node = self.parent().unwrap_or(*self);
166 let mut cursor = node.walk();
167 cursor.goto_first_child_for_byte(self.start_byte());
168 std::iter::from_fn(move || {
169 if cursor.goto_next_sibling() {
170 Some(cursor.node())
171 } else {
172 None
173 }
174 })
175 }
176 fn prev_all(&self) -> impl Iterator<Item = Self> {
177 let node = self.parent().unwrap_or(*self);
179 let mut cursor = node.walk();
180 cursor.goto_first_child_for_byte(self.start_byte());
181 std::iter::from_fn(move || {
182 if cursor.goto_previous_sibling() {
183 Some(cursor.node())
184 } else {
185 None
186 }
187 })
188 }
189 fn is_named(&self) -> bool {
190 Node::is_named(self)
191 }
192 fn is_named_leaf(&self) -> bool {
195 self.named_child_count() == 0
196 }
197 fn is_leaf(&self) -> bool {
198 self.child_count() == 0
199 }
200 fn kind(&self) -> Cow<'_, str> {
201 Cow::Borrowed(Node::kind(self))
202 }
203 fn kind_id(&self) -> KindId {
204 Node::kind_id(self)
205 }
206 fn node_id(&self) -> usize {
207 self.id()
208 }
209 fn range(&self) -> std::ops::Range<usize> {
210 self.start_byte()..self.end_byte()
211 }
212 fn start_pos(&self) -> Position {
213 let pos = self.start_position();
214 let byte = self.start_byte();
215 Position::new(pos.row, pos.column, byte)
216 }
217 fn end_pos(&self) -> Position {
218 let pos = self.end_position();
219 let byte = self.end_byte();
220 Position::new(pos.row, pos.column, byte)
221 }
222 fn is_missing(&self) -> bool {
224 Node::is_missing(self)
225 }
226 fn is_error(&self) -> bool {
227 Node::is_error(self)
228 }
229
230 fn field(&self, name: &str) -> Option<Self> {
231 self.child_by_field_name(name)
232 }
233 fn field_children(&self, field_id: Option<u16>) -> impl Iterator<Item = Self> {
234 let field_id = field_id.and_then(NonZero::new);
235 let mut cursor = self.walk();
236 let has_children = cursor.goto_first_child();
237 let mut done = field_id.is_none() || !has_children;
239
240 std::iter::from_fn(move || {
241 if done {
242 return None;
243 }
244 while cursor.field_id() != field_id {
245 if !cursor.goto_next_sibling() {
246 return None;
247 }
248 }
249 let ret = cursor.node();
250 if !cursor.goto_next_sibling() {
251 done = true;
252 }
253 Some(ret)
254 })
255 }
256}
257
258pub fn perform_edit<S: ContentExt>(tree: &mut Tree, input: &mut S, edit: &Edit<S>) -> InputEdit {
259 let edit = input.accept_edit(edit);
260 tree.edit(&edit);
261 edit
262}
263
264pub trait LanguageExt: Language {
266 fn ast_grep<S: AsRef<str>>(&self, source: S) -> AstGrep<StrDoc<Self>> {
268 AstGrep::new(source, self.clone())
269 }
270
271 fn get_ts_language(&self) -> TSLanguage;
273
274 fn injectable_languages(&self) -> Option<&'static [&'static str]> {
275 None
276 }
277
278 fn extract_injections<L: LanguageExt>(
283 &self,
284 _root: crate::Node<StrDoc<L>>,
285 ) -> Vec<(String, Vec<TSRange>)> {
286 Vec::new()
287 }
288}
289
290fn position_for_offset(input: &[u8], offset: usize) -> Point {
291 debug_assert!(offset <= input.len());
292 let (mut row, mut col) = (0, 0);
293 for c in &input[0..offset] {
294 if *c as char == '\n' {
295 row += 1;
296 col = 0;
297 } else {
298 col += 1;
299 }
300 }
301 Point::new(row, col)
302}
303
304impl<L: LanguageExt> AstGrep<StrDoc<L>> {
305 pub fn new<S: AsRef<str>>(src: S, lang: L) -> Self {
306 Root::str(src.as_ref(), lang)
307 }
308
309 pub fn source(&self) -> &str {
310 self.doc.get_source().as_str()
311 }
312
313 pub fn generate(self) -> String {
314 self.doc.src
315 }
316}
317
318pub trait ContentExt: Content {
319 fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit;
320}
321impl ContentExt for String {
322 fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit {
323 let start_byte = edit.position;
324 let old_end_byte = edit.position + edit.deleted_length;
325 let new_end_byte = edit.position + edit.inserted_text.len();
326 let input = unsafe { self.as_mut_vec() };
327 let start_position = position_for_offset(input, start_byte);
328 let old_end_position = position_for_offset(input, old_end_byte);
329 input.splice(start_byte..old_end_byte, edit.inserted_text.clone());
330 let new_end_position = position_for_offset(input, new_end_byte);
331 InputEdit {
332 start_byte,
333 old_end_byte,
334 new_end_byte,
335 start_position,
336 old_end_position,
337 new_end_position,
338 }
339 }
340}
341
342impl<L: LanguageExt> Root<StrDoc<L>> {
343 pub fn str(src: &str, lang: L) -> Self {
344 Self::try_new(src, lang).expect("should parse")
345 }
346 pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
347 let doc = StrDoc::try_new(src, lang)?;
348 Ok(Self { doc })
349 }
350 pub fn get_text(&self) -> &str {
351 &self.doc.src
352 }
353
354 pub fn get_injections<F: Fn(&str) -> Option<L>>(&self, get_lang: F) -> Vec<Self> {
355 let root = self.root();
356 let range = self.lang().extract_injections(root);
357 let roots = range
358 .into_iter()
359 .filter_map(|(lang_str, ranges)| {
360 let lang = get_lang(&lang_str)?;
361 let source = self.doc.get_source();
362 let mut parser = Parser::new();
363 parser.set_included_ranges(&ranges).ok()?;
364 parser.set_language(&lang.get_ts_language()).ok()?;
365 let tree = parser.parse(source, None)?;
366 Some(Self {
367 doc: StrDoc {
368 src: self.doc.src.clone(),
369 lang,
370 tree,
371 },
372 })
373 })
374 .collect();
375 roots
376 }
377}
378
379pub struct DisplayContext<'r> {
380 pub matched: Cow<'r, str>,
382 pub leading: &'r str,
384 pub trailing: &'r str,
386 pub start_line: usize,
388}
389
390impl<'r, L: LanguageExt> crate::Node<'r, StrDoc<L>> {
392 #[doc(hidden)]
393 pub fn display_context(&self, before: usize, after: usize) -> DisplayContext<'r> {
394 let source = self.root.doc.get_source().as_str();
395 let bytes = source.as_bytes();
396 let start = self.inner.start_byte();
397 let end = self.inner.end_byte();
398 let (mut leading, mut trailing) = (start, end);
399 let mut lines_before = before + 1;
400 while leading > 0 {
401 if bytes[leading - 1] == b'\n' {
402 lines_before -= 1;
403 if lines_before == 0 {
404 break;
405 }
406 }
407 leading -= 1;
408 }
409 let mut lines_after = after + 1;
410 trailing = trailing.min(bytes.len());
412 while trailing < bytes.len() {
413 if bytes[trailing] == b'\n' {
414 lines_after -= 1;
415 if lines_after == 0 {
416 break;
417 }
418 }
419 trailing += 1;
420 }
421 let offset = if lines_before == 0 {
423 before
424 } else {
425 before + 1 - lines_before
427 };
428 DisplayContext {
429 matched: self.text(),
430 leading: &source[leading..start],
431 trailing: &source[end..trailing],
432 start_line: self.start_pos().line() - offset,
433 }
434 }
435
436 pub fn replace_all<M: Matcher, R: Replacer<StrDoc<L>>>(
437 &self,
438 matcher: M,
439 replacer: R,
440 ) -> Vec<Edit<String>> {
441 Visitor::new(&matcher)
443 .reentrant(false)
444 .visit(self.clone())
445 .map(|matched| matched.make_edit(&matcher, &replacer))
446 .collect()
447 }
448}
449
450#[cfg(test)]
451mod test {
452 use super::*;
453 use crate::language::Tsx;
454 use tree_sitter::Point;
455
456 fn parse(src: &str) -> Result<Tree, TSParseError> {
457 parse_lang(|p| p.parse(src, None), Tsx.get_ts_language())
458 }
459
460 #[test]
461 fn test_tree_sitter() -> Result<(), TSParseError> {
462 let tree = parse("var a = 1234")?;
463 let root_node = tree.root_node();
464 assert_eq!(root_node.kind(), "program");
465 assert_eq!(root_node.start_position().column, 0);
466 assert_eq!(root_node.end_position().column, 12);
467 assert_eq!(
468 root_node.to_sexp(),
469 "(program (variable_declaration (variable_declarator name: (identifier) value: (number))))"
470 );
471 Ok(())
472 }
473
474 #[test]
475 fn test_object_literal() -> Result<(), TSParseError> {
476 let tree = parse("{a: $X}")?;
477 let root_node = tree.root_node();
478 assert_eq!(root_node.to_sexp(), "(program (expression_statement (object (pair key: (property_identifier) value: (identifier)))))");
480 Ok(())
481 }
482
483 #[test]
484 fn test_string() -> Result<(), TSParseError> {
485 let tree = parse("'$A'")?;
486 let root_node = tree.root_node();
487 assert_eq!(
488 root_node.to_sexp(),
489 "(program (expression_statement (string (string_fragment))))"
490 );
491 Ok(())
492 }
493
494 #[test]
495 fn test_row_col() -> Result<(), TSParseError> {
496 let tree = parse("😄")?;
497 let root = tree.root_node();
498 assert_eq!(root.start_position(), Point::new(0, 0));
499 assert_eq!(root.end_position(), Point::new(0, 4));
501 Ok(())
502 }
503
504 #[test]
505 fn test_edit() -> Result<(), TSParseError> {
506 let mut src = "a + b".to_string();
507 let mut tree = parse(&src)?;
508 let _ = perform_edit(
509 &mut tree,
510 &mut src,
511 &Edit {
512 position: 1,
513 deleted_length: 0,
514 inserted_text: " * b".into(),
515 },
516 );
517 let tree2 = parse_lang(|p| p.parse(&src, Some(&tree)), Tsx.get_ts_language())?;
518 assert_eq!(
519 tree.root_node().to_sexp(),
520 "(program (expression_statement (binary_expression left: (identifier) right: (identifier))))"
521 );
522 assert_eq!(tree2.root_node().to_sexp(), "(program (expression_statement (binary_expression left: (binary_expression left: (identifier) right: (identifier)) right: (identifier))))");
523 Ok(())
524 }
525}