ast_grep_core/tree_sitter/
mod.rs1pub mod traversal;
2
3use crate::node::Root;
4use crate::replacer::Replacer;
5use crate::source::{Content, Doc, Edit, SgNode};
6use crate::{node::KindId, Language, Position};
7use crate::{AstGrep, Matcher};
8use std::borrow::Cow;
9use std::collections::HashMap;
10use std::num::NonZero;
11use thiserror::Error;
12pub use traversal::{TsPre, Visitor};
13pub use tree_sitter::Language as TSLanguage;
14use tree_sitter::{InputEdit, LanguageError, Node, Parser, Point, Tree};
15pub use tree_sitter::{Point as TSPoint, Range as TSRange};
16
17#[derive(Debug, Error)]
19pub enum TSParseError {
20 #[error("incompatible `Language` is assigned to a `Parser`.")]
21 Language(#[from] LanguageError),
22 #[error("general error when tree-sitter fails to parse.")]
28 TreeUnavailable,
29}
30
31#[inline]
32fn parse_lang(
33 parse_fn: impl Fn(&mut Parser) -> Option<Tree>,
34 ts_lang: TSLanguage,
35) -> Result<Tree, TSParseError> {
36 let mut parser = Parser::new();
37 parser.set_language(&ts_lang)?;
38 if let Some(tree) = parse_fn(&mut parser) {
39 Ok(tree)
40 } else {
41 Err(TSParseError::TreeUnavailable)
42 }
43}
44
45#[derive(Clone)]
46pub struct StrDoc<L: LanguageExt> {
47 pub src: String,
48 pub lang: L,
49 pub tree: Tree,
50}
51
52impl<L: LanguageExt> StrDoc<L> {
53 pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
54 let src = src.to_string();
55 let ts_lang = lang.get_ts_language();
56 let tree = parse_lang(|p| p.parse(src.as_bytes(), None), ts_lang).map_err(|e| e.to_string())?;
57 Ok(Self { src, lang, tree })
58 }
59 pub fn new(src: &str, lang: L) -> Self {
60 Self::try_new(src, lang).expect("Parser tree error")
61 }
62 fn parse(&self, old_tree: Option<&Tree>) -> Result<Tree, TSParseError> {
63 let source = self.get_source();
64 let lang = self.get_lang().get_ts_language();
65 parse_lang(|p| p.parse(source.as_bytes(), old_tree), lang)
66 }
67}
68
69impl<L: LanguageExt> Doc for StrDoc<L> {
70 type Source = String;
71 type Lang = L;
72 type Node<'r> = Node<'r>;
73 fn get_lang(&self) -> &Self::Lang {
74 &self.lang
75 }
76 fn get_source(&self) -> &Self::Source {
77 &self.src
78 }
79 fn do_edit(&mut self, edit: &Edit<Self::Source>) -> Result<(), String> {
80 let source = &mut self.src;
81 perform_edit(&mut self.tree, source, edit);
82 self.tree = self.parse(Some(&self.tree)).map_err(|e| e.to_string())?;
83 Ok(())
84 }
85 fn root_node(&self) -> Node<'_> {
86 self.tree.root_node()
87 }
88 fn get_node_text<'a>(&'a self, node: &Self::Node<'a>) -> Cow<'a, str> {
89 Cow::Borrowed(
90 node
91 .utf8_text(self.src.as_bytes())
92 .expect("invalid source text encoding"),
93 )
94 }
95}
96
97struct NodeWalker<'tree> {
98 cursor: tree_sitter::TreeCursor<'tree>,
99 count: usize,
100}
101
102impl<'tree> Iterator for NodeWalker<'tree> {
103 type Item = Node<'tree>;
104 fn next(&mut self) -> Option<Self::Item> {
105 if self.count == 0 {
106 return None;
107 }
108 let ret = Some(self.cursor.node());
109 self.cursor.goto_next_sibling();
110 self.count -= 1;
111 ret
112 }
113}
114
115impl ExactSizeIterator for NodeWalker<'_> {
116 fn len(&self) -> usize {
117 self.count
118 }
119}
120
121impl<'r> SgNode<'r> for Node<'r> {
122 fn parent(&self) -> Option<Self> {
123 Node::parent(self)
124 }
125 fn ancestors(&self, root: Self) -> impl Iterator<Item = Self> {
126 let mut ancestor = Some(root);
127 let self_id = self.id();
128 std::iter::from_fn(move || {
129 let inner = ancestor.take()?;
130 if inner.id() == self_id {
131 return None;
132 }
133 ancestor = inner.child_with_descendant(*self);
134 Some(inner)
135 })
136 .collect::<Vec<_>>()
138 .into_iter()
139 .rev()
140 }
141 fn dfs(&self) -> impl Iterator<Item = Self> {
142 TsPre::new(self)
143 }
144 fn child(&self, nth: usize) -> Option<Self> {
145 Node::child(self, nth as u32)
146 }
147 fn children(&self) -> impl ExactSizeIterator<Item = Self> {
148 let mut cursor = self.walk();
149 cursor.goto_first_child();
150 NodeWalker {
151 cursor,
152 count: self.child_count(),
153 }
154 }
155 fn child_by_field_id(&self, field_id: u16) -> Option<Self> {
156 Node::child_by_field_id(self, field_id)
157 }
158 fn next(&self) -> Option<Self> {
159 self.next_sibling()
160 }
161 fn prev(&self) -> Option<Self> {
162 self.prev_sibling()
163 }
164 fn next_all(&self) -> impl Iterator<Item = Self> {
165 let node = self.parent().unwrap_or(*self);
167 let mut cursor = node.walk();
168 cursor.goto_first_child_for_byte(self.start_byte());
169 std::iter::from_fn(move || {
170 if cursor.goto_next_sibling() {
171 Some(cursor.node())
172 } else {
173 None
174 }
175 })
176 }
177 fn prev_all(&self) -> impl Iterator<Item = Self> {
178 let node = self.parent().unwrap_or(*self);
180 let mut cursor = node.walk();
181 cursor.goto_first_child_for_byte(self.start_byte());
182 std::iter::from_fn(move || {
183 if cursor.goto_previous_sibling() {
184 Some(cursor.node())
185 } else {
186 None
187 }
188 })
189 }
190 fn is_named(&self) -> bool {
191 Node::is_named(self)
192 }
193 fn is_named_leaf(&self) -> bool {
196 self.named_child_count() == 0
197 }
198 fn is_leaf(&self) -> bool {
199 self.child_count() == 0
200 }
201 fn kind(&self) -> Cow<'_, str> {
202 Cow::Borrowed(Node::kind(self))
203 }
204 fn kind_id(&self) -> KindId {
205 Node::kind_id(self)
206 }
207 fn node_id(&self) -> usize {
208 self.id()
209 }
210 fn range(&self) -> std::ops::Range<usize> {
211 self.start_byte()..self.end_byte()
212 }
213 fn start_pos(&self) -> Position {
214 let pos = self.start_position();
215 let byte = self.start_byte();
216 Position::new(pos.row, pos.column, byte)
217 }
218 fn end_pos(&self) -> Position {
219 let pos = self.end_position();
220 let byte = self.end_byte();
221 Position::new(pos.row, pos.column, byte)
222 }
223 fn is_missing(&self) -> bool {
225 Node::is_missing(self)
226 }
227 fn is_error(&self) -> bool {
228 Node::is_error(self)
229 }
230
231 fn field(&self, name: &str) -> Option<Self> {
232 self.child_by_field_name(name)
233 }
234 fn field_children(&self, field_id: Option<u16>) -> impl Iterator<Item = Self> {
235 let field_id = field_id.and_then(NonZero::new);
236 let mut cursor = self.walk();
237 cursor.goto_first_child();
238 let mut done = field_id.is_none();
240
241 std::iter::from_fn(move || {
242 if done {
243 return None;
244 }
245 while cursor.field_id() != field_id {
246 if !cursor.goto_next_sibling() {
247 return None;
248 }
249 }
250 let ret = cursor.node();
251 if !cursor.goto_next_sibling() {
252 done = true;
253 }
254 Some(ret)
255 })
256 }
257}
258
259pub fn perform_edit<S: ContentExt>(tree: &mut Tree, input: &mut S, edit: &Edit<S>) -> InputEdit {
260 let edit = input.accept_edit(edit);
261 tree.edit(&edit);
262 edit
263}
264
265pub trait LanguageExt: Language {
267 fn ast_grep<S: AsRef<str>>(&self, source: S) -> AstGrep<StrDoc<Self>> {
269 AstGrep::new(source, self.clone())
270 }
271
272 fn get_ts_language(&self) -> TSLanguage;
274
275 fn injectable_languages(&self) -> Option<&'static [&'static str]> {
276 None
277 }
278
279 fn extract_injections<L: LanguageExt>(
285 &self,
286 _root: crate::Node<StrDoc<L>>,
287 ) -> HashMap<String, Vec<TSRange>> {
288 HashMap::new()
289 }
290}
291
292fn position_for_offset(input: &[u8], offset: usize) -> Point {
293 debug_assert!(offset <= input.len());
294 let (mut row, mut col) = (0, 0);
295 for c in &input[0..offset] {
296 if *c as char == '\n' {
297 row += 1;
298 col = 0;
299 } else {
300 col += 1;
301 }
302 }
303 Point::new(row, col)
304}
305
306impl<L: LanguageExt> AstGrep<StrDoc<L>> {
307 pub fn new<S: AsRef<str>>(src: S, lang: L) -> Self {
308 Root::str(src.as_ref(), lang)
309 }
310
311 pub fn source(&self) -> &str {
312 self.doc.get_source().as_str()
313 }
314
315 pub fn generate(self) -> String {
316 self.doc.src
317 }
318}
319
320pub trait ContentExt: Content {
321 fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit;
322}
323impl ContentExt for String {
324 fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit {
325 let start_byte = edit.position;
326 let old_end_byte = edit.position + edit.deleted_length;
327 let new_end_byte = edit.position + edit.inserted_text.len();
328 let input = unsafe { self.as_mut_vec() };
329 let start_position = position_for_offset(input, start_byte);
330 let old_end_position = position_for_offset(input, old_end_byte);
331 input.splice(start_byte..old_end_byte, edit.inserted_text.clone());
332 let new_end_position = position_for_offset(input, new_end_byte);
333 InputEdit {
334 start_byte,
335 old_end_byte,
336 new_end_byte,
337 start_position,
338 old_end_position,
339 new_end_position,
340 }
341 }
342}
343
344impl<L: LanguageExt> Root<StrDoc<L>> {
345 pub fn str(src: &str, lang: L) -> Self {
346 Self::try_new(src, lang).expect("should parse")
347 }
348 pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
349 let doc = StrDoc::try_new(src, lang)?;
350 Ok(Self { doc })
351 }
352 pub fn get_text(&self) -> &str {
353 &self.doc.src
354 }
355
356 pub fn get_injections<F: Fn(&str) -> Option<L>>(&self, get_lang: F) -> Vec<Self> {
357 let root = self.root();
358 let range = self.lang().extract_injections(root);
359 let roots = range
360 .into_iter()
361 .filter_map(|(lang, ranges)| {
362 let lang = get_lang(&lang)?;
363 let source = self.doc.get_source();
364 let mut parser = Parser::new();
365 parser.set_included_ranges(&ranges).ok()?;
366 parser.set_language(&lang.get_ts_language()).ok()?;
367 let tree = parser.parse(source, None)?;
368 Some(Self {
369 doc: StrDoc {
370 src: self.doc.src.clone(),
371 lang,
372 tree,
373 },
374 })
375 })
376 .collect();
377 roots
378 }
379}
380
381pub struct DisplayContext<'r> {
382 pub matched: Cow<'r, str>,
384 pub leading: &'r str,
386 pub trailing: &'r str,
388 pub start_line: usize,
390}
391
392impl<'r, L: LanguageExt> crate::Node<'r, StrDoc<L>> {
394 #[doc(hidden)]
395 pub fn display_context(&self, before: usize, after: usize) -> DisplayContext<'r> {
396 let source = self.root.doc.get_source().as_str();
397 let bytes = source.as_bytes();
398 let start = self.inner.start_byte();
399 let end = self.inner.end_byte();
400 let (mut leading, mut trailing) = (start, end);
401 let mut lines_before = before + 1;
402 while leading > 0 {
403 if bytes[leading - 1] == b'\n' {
404 lines_before -= 1;
405 if lines_before == 0 {
406 break;
407 }
408 }
409 leading -= 1;
410 }
411 let mut lines_after = after + 1;
412 trailing = trailing.min(bytes.len());
414 while trailing < bytes.len() {
415 if bytes[trailing] == b'\n' {
416 lines_after -= 1;
417 if lines_after == 0 {
418 break;
419 }
420 }
421 trailing += 1;
422 }
423 let offset = if lines_before == 0 {
425 before
426 } else {
427 before + 1 - lines_before
429 };
430 DisplayContext {
431 matched: self.text(),
432 leading: &source[leading..start],
433 trailing: &source[end..trailing],
434 start_line: self.start_pos().line() - offset,
435 }
436 }
437
438 pub fn replace_all<M: Matcher, R: Replacer<StrDoc<L>>>(
439 &self,
440 matcher: M,
441 replacer: R,
442 ) -> Vec<Edit<String>> {
443 Visitor::new(&matcher)
445 .reentrant(false)
446 .visit(self.clone())
447 .map(|matched| matched.make_edit(&matcher, &replacer))
448 .collect()
449 }
450}
451
452#[cfg(test)]
453mod test {
454 use super::*;
455 use crate::language::Tsx;
456 use tree_sitter::Point;
457
458 fn parse(src: &str) -> Result<Tree, TSParseError> {
459 parse_lang(|p| p.parse(src, None), Tsx.get_ts_language())
460 }
461
462 #[test]
463 fn test_tree_sitter() -> Result<(), TSParseError> {
464 let tree = parse("var a = 1234")?;
465 let root_node = tree.root_node();
466 assert_eq!(root_node.kind(), "program");
467 assert_eq!(root_node.start_position().column, 0);
468 assert_eq!(root_node.end_position().column, 12);
469 assert_eq!(
470 root_node.to_sexp(),
471 "(program (variable_declaration (variable_declarator name: (identifier) value: (number))))"
472 );
473 Ok(())
474 }
475
476 #[test]
477 fn test_object_literal() -> Result<(), TSParseError> {
478 let tree = parse("{a: $X}")?;
479 let root_node = tree.root_node();
480 assert_eq!(root_node.to_sexp(), "(program (expression_statement (object (pair key: (property_identifier) value: (identifier)))))");
482 Ok(())
483 }
484
485 #[test]
486 fn test_string() -> Result<(), TSParseError> {
487 let tree = parse("'$A'")?;
488 let root_node = tree.root_node();
489 assert_eq!(
490 root_node.to_sexp(),
491 "(program (expression_statement (string (string_fragment))))"
492 );
493 Ok(())
494 }
495
496 #[test]
497 fn test_row_col() -> Result<(), TSParseError> {
498 let tree = parse("😄")?;
499 let root = tree.root_node();
500 assert_eq!(root.start_position(), Point::new(0, 0));
501 assert_eq!(root.end_position(), Point::new(0, 4));
503 Ok(())
504 }
505
506 #[test]
507 fn test_edit() -> Result<(), TSParseError> {
508 let mut src = "a + b".to_string();
509 let mut tree = parse(&src)?;
510 let _ = perform_edit(
511 &mut tree,
512 &mut src,
513 &Edit {
514 position: 1,
515 deleted_length: 0,
516 inserted_text: " * b".into(),
517 },
518 );
519 let tree2 = parse_lang(|p| p.parse(&src, Some(&tree)), Tsx.get_ts_language())?;
520 assert_eq!(
521 tree.root_node().to_sexp(),
522 "(program (expression_statement (binary_expression left: (identifier) right: (identifier))))"
523 );
524 assert_eq!(tree2.root_node().to_sexp(), "(program (expression_statement (binary_expression left: (binary_expression left: (identifier) right: (identifier)) right: (identifier))))");
525 Ok(())
526 }
527}