ast_grep_core/tree_sitter/
mod.rs1pub mod traversal;
2
3use crate::node::Root;
4use crate::replacer::Replacer;
5use crate::source::{Content, Doc, Edit, SgNode};
6use crate::{node::KindId, Language, Position};
7use crate::{AstGrep, Matcher};
8use std::borrow::Cow;
9use std::collections::HashMap;
10use thiserror::Error;
11pub use traversal::{TsPre, Visitor};
12pub use tree_sitter::Language as TSLanguage;
13use tree_sitter::{InputEdit, LanguageError, Node, Parser, ParserError, Point, Tree};
14pub use tree_sitter::{Point as TSPoint, Range as TSRange};
15
16#[derive(Debug, Error)]
18pub enum TSParseError {
19 #[error("web-tree-sitter parser is not available")]
20 Parse(#[from] ParserError),
21 #[error("incompatible `Language` is assigned to a `Parser`.")]
22 Language(#[from] LanguageError),
23 #[error("general error when tree-sitter fails to parse.")]
29 TreeUnavailable,
30}
31
32#[inline]
33fn parse_lang(
34 parse_fn: impl Fn(&mut Parser) -> Result<Option<Tree>, ParserError>,
35 ts_lang: TSLanguage,
36) -> Result<Tree, TSParseError> {
37 let mut parser = Parser::new()?;
38 parser.set_language(&ts_lang)?;
39 if let Some(tree) = parse_fn(&mut parser)? {
40 Ok(tree)
41 } else {
42 Err(TSParseError::TreeUnavailable)
43 }
44}
45
46#[derive(Clone)]
47pub struct StrDoc<L: LanguageExt> {
48 pub src: String,
49 pub lang: L,
50 pub tree: Tree,
51}
52
53impl<L: LanguageExt> StrDoc<L> {
54 pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
55 let src = src.to_string();
56 let ts_lang = lang.get_ts_language();
57 let tree = parse_lang(|p| p.parse(src.as_bytes(), None), ts_lang).map_err(|e| e.to_string())?;
58 Ok(Self { src, lang, tree })
59 }
60 pub fn new(src: &str, lang: L) -> Self {
61 Self::try_new(src, lang).expect("Parser tree error")
62 }
63 fn parse(&self, old_tree: Option<&Tree>) -> Result<Tree, TSParseError> {
64 let source = self.get_source();
65 let lang = self.get_lang().get_ts_language();
66 parse_lang(|p| p.parse(source.as_bytes(), old_tree), lang)
67 }
68}
69
70impl<L: LanguageExt> Doc for StrDoc<L> {
71 type Source = String;
72 type Lang = L;
73 type Node<'r> = Node<'r>;
74 fn get_lang(&self) -> &Self::Lang {
75 &self.lang
76 }
77 fn get_source(&self) -> &Self::Source {
78 &self.src
79 }
80 fn do_edit(&mut self, edit: &Edit<Self::Source>) -> Result<(), String> {
81 let source = &mut self.src;
82 perform_edit(&mut self.tree, source, edit);
83 self.tree = self.parse(Some(&self.tree)).map_err(|e| e.to_string())?;
84 Ok(())
85 }
86 fn root_node(&self) -> Node<'_> {
87 self.tree.root_node()
88 }
89 fn get_node_text<'a>(&'a self, node: &Self::Node<'a>) -> Cow<'a, str> {
90 node
91 .utf8_text(self.src.as_bytes())
92 .expect("invalid source text encoding")
93 }
94}
95
96struct NodeWalker<'tree> {
97 cursor: tree_sitter::TreeCursor<'tree>,
98 count: usize,
99}
100
101impl<'tree> Iterator for NodeWalker<'tree> {
102 type Item = Node<'tree>;
103 fn next(&mut self) -> Option<Self::Item> {
104 if self.count == 0 {
105 return None;
106 }
107 let ret = Some(self.cursor.node());
108 self.cursor.goto_next_sibling();
109 self.count -= 1;
110 ret
111 }
112}
113
114impl ExactSizeIterator for NodeWalker<'_> {
115 fn len(&self) -> usize {
116 self.count
117 }
118}
119
120impl<'r> SgNode<'r> for Node<'r> {
121 fn parent(&self) -> Option<Self> {
122 Node::parent(self)
123 }
124 fn ancestors(&self, root: Self) -> impl Iterator<Item = Self> {
125 let mut ancestor = Some(root);
126 let self_id = self.id();
127 std::iter::from_fn(move || {
128 let inner = ancestor.take()?;
129 if inner.id() == self_id {
130 return None;
131 }
132 ancestor = inner.child_with_descendant(self.clone());
133 Some(inner)
134 })
135 .collect::<Vec<_>>()
137 .into_iter()
138 .rev()
139 }
140 fn dfs(&self) -> impl Iterator<Item = Self> {
141 TsPre::new(self)
142 }
143 fn child(&self, nth: usize) -> Option<Self> {
144 Node::child(self, nth as u32)
146 }
147 fn children(&self) -> impl ExactSizeIterator<Item = Self> {
148 let mut cursor = self.walk();
149 cursor.goto_first_child();
150 NodeWalker {
151 cursor,
152 count: self.child_count() as usize,
153 }
154 }
155 fn child_by_field_id(&self, field_id: u16) -> Option<Self> {
156 Node::child_by_field_id(self, field_id)
157 }
158 fn next(&self) -> Option<Self> {
159 self.next_sibling()
160 }
161 fn prev(&self) -> Option<Self> {
162 self.prev_sibling()
163 }
164 fn next_all(&self) -> impl Iterator<Item = Self> {
165 let node = self.parent().unwrap_or_else(|| self.clone());
167 let mut cursor = node.walk();
168 cursor.goto_first_child_for_byte(self.start_byte());
169 std::iter::from_fn(move || {
170 if cursor.goto_next_sibling() {
171 Some(cursor.node())
172 } else {
173 None
174 }
175 })
176 }
177 fn prev_all(&self) -> impl Iterator<Item = Self> {
178 let node = self.parent().unwrap_or_else(|| self.clone());
180 let mut cursor = node.walk();
181 cursor.goto_first_child_for_byte(self.start_byte());
182 std::iter::from_fn(move || {
183 if cursor.goto_previous_sibling() {
184 Some(cursor.node())
185 } else {
186 None
187 }
188 })
189 }
190 fn is_named(&self) -> bool {
191 Node::is_named(self)
192 }
193 fn is_named_leaf(&self) -> bool {
196 self.named_child_count() == 0
197 }
198 fn is_leaf(&self) -> bool {
199 self.child_count() == 0
200 }
201 fn kind(&self) -> Cow<str> {
202 Node::kind(self)
203 }
204 fn kind_id(&self) -> KindId {
205 Node::kind_id(self)
206 }
207 fn node_id(&self) -> usize {
208 self.id()
209 }
210 fn range(&self) -> std::ops::Range<usize> {
211 (self.start_byte() as usize)..(self.end_byte() as usize)
212 }
213 fn start_pos(&self) -> Position {
214 let pos = self.start_position();
215 let byte = self.start_byte() as usize;
216 Position::new(pos.row() as usize, pos.column() as usize, byte)
217 }
218 fn end_pos(&self) -> Position {
219 let pos = self.end_position();
220 let byte = self.end_byte() as usize;
221 Position::new(pos.row() as usize, pos.column() as usize, byte)
222 }
223 fn is_missing(&self) -> bool {
225 Node::is_missing(self)
226 }
227 fn is_error(&self) -> bool {
228 Node::is_error(self)
229 }
230
231 fn field(&self, name: &str) -> Option<Self> {
232 self.child_by_field_name(name)
233 }
234 fn field_children(&self, field_id: Option<u16>) -> impl Iterator<Item = Self> {
235 let mut cursor = self.walk();
236 cursor.goto_first_child();
237 let mut done = field_id.is_none();
239
240 std::iter::from_fn(move || {
241 if done {
242 return None;
243 }
244 while cursor.field_id() != field_id {
245 if !cursor.goto_next_sibling() {
246 return None;
247 }
248 }
249 let ret = cursor.node();
250 if !cursor.goto_next_sibling() {
251 done = true;
252 }
253 Some(ret)
254 })
255 }
256}
257
258pub fn perform_edit<S: ContentExt>(tree: &mut Tree, input: &mut S, edit: &Edit<S>) -> InputEdit {
259 let edit = input.accept_edit(edit);
260 tree.edit(&edit);
261 edit
262}
263
264pub trait LanguageExt: Language {
266 fn ast_grep<S: AsRef<str>>(&self, source: S) -> AstGrep<StrDoc<Self>> {
268 AstGrep::new(source, self.clone())
269 }
270
271 fn get_ts_language(&self) -> TSLanguage;
273
274 fn injectable_languages(&self) -> Option<&'static [&'static str]> {
275 None
276 }
277
278 fn extract_injections<L: LanguageExt>(
284 &self,
285 _root: crate::Node<StrDoc<L>>,
286 ) -> HashMap<String, Vec<TSRange>> {
287 HashMap::new()
288 }
289}
290
291fn position_for_offset(input: &[u8], offset: usize) -> Point {
292 debug_assert!(offset <= input.len());
293 let (mut row, mut col) = (0, 0);
294 for c in &input[0..offset] {
295 if *c as char == '\n' {
296 row += 1;
297 col = 0;
298 } else {
299 col += 1;
300 }
301 }
302 Point::new(row, col)
303}
304
305impl<L: LanguageExt> AstGrep<StrDoc<L>> {
306 pub fn new<S: AsRef<str>>(src: S, lang: L) -> Self {
307 Root::str(src.as_ref(), lang)
308 }
309
310 pub fn source(&self) -> &str {
311 self.doc.get_source().as_str()
312 }
313
314 pub fn generate(self) -> String {
315 self.doc.src
316 }
317}
318
319pub trait ContentExt: Content {
320 fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit;
321}
322impl ContentExt for String {
323 fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit {
324 let start_byte = edit.position;
325 let old_end_byte = edit.position + edit.deleted_length;
326 let new_end_byte = edit.position + edit.inserted_text.len();
327 let input = unsafe { self.as_mut_vec() };
328 let start_position = position_for_offset(input, start_byte);
329 let old_end_position = position_for_offset(input, old_end_byte);
330 input.splice(start_byte..old_end_byte, edit.inserted_text.clone());
331 let new_end_position = position_for_offset(input, new_end_byte);
332 InputEdit::new(
333 start_byte as u32,
334 old_end_byte as u32,
335 new_end_byte as u32,
336 &start_position,
337 &old_end_position,
338 &new_end_position,
339 )
340 }
341}
342
343impl<L: LanguageExt> Root<StrDoc<L>> {
344 pub fn str(src: &str, lang: L) -> Self {
345 Self::try_new(src, lang).expect("should parse")
346 }
347 pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
348 let doc = StrDoc::try_new(src, lang)?;
349 Ok(Self { doc })
350 }
351 pub fn get_text(&self) -> &str {
352 &self.doc.src
353 }
354
355 pub fn get_injections<F: Fn(&str) -> Option<L>>(&self, get_lang: F) -> Vec<Self> {
356 let root = self.root();
357 let range = self.lang().extract_injections(root);
358 let roots = range
359 .into_iter()
360 .filter_map(|(lang, ranges)| {
361 let lang = get_lang(&lang)?;
362 let source = self.doc.get_source();
363 let mut parser = tree_sitter::Parser::new().ok()?;
364 parser.set_included_ranges(&ranges).ok()?;
365 parser.set_language(&lang.get_ts_language()).ok()?;
366 let tree = parser.parse(source, None).ok()?;
367 tree.map(|t| Self {
368 doc: StrDoc {
369 src: self.doc.src.clone(),
370 lang,
371 tree: t,
372 },
373 })
374 })
375 .collect();
376 roots
377 }
378}
379
380pub struct DisplayContext<'r> {
381 pub matched: Cow<'r, str>,
383 pub leading: &'r str,
385 pub trailing: &'r str,
387 pub start_line: usize,
389}
390
391impl<'r, L: LanguageExt> crate::Node<'r, StrDoc<L>> {
393 #[doc(hidden)]
394 pub fn display_context(&self, before: usize, after: usize) -> DisplayContext<'r> {
395 let source = self.root.doc.get_source().as_str();
396 let bytes = source.as_bytes();
397 let start = self.inner.start_byte() as usize;
398 let end = self.inner.end_byte() as usize;
399 let (mut leading, mut trailing) = (start, end);
400 let mut lines_before = before + 1;
401 while leading > 0 {
402 if bytes[leading - 1] == b'\n' {
403 lines_before -= 1;
404 if lines_before == 0 {
405 break;
406 }
407 }
408 leading -= 1;
409 }
410 let mut lines_after = after + 1;
411 trailing = trailing.min(bytes.len());
413 while trailing < bytes.len() {
414 if bytes[trailing] == b'\n' {
415 lines_after -= 1;
416 if lines_after == 0 {
417 break;
418 }
419 }
420 trailing += 1;
421 }
422 let offset = if lines_before == 0 {
424 before
425 } else {
426 before + 1 - lines_before
428 };
429 DisplayContext {
430 matched: self.text(),
431 leading: &source[leading..start],
432 trailing: &source[end..trailing],
433 start_line: self.start_pos().line() - offset,
434 }
435 }
436
437 pub fn replace_all<M: Matcher, R: Replacer<StrDoc<L>>>(
438 &self,
439 matcher: M,
440 replacer: R,
441 ) -> Vec<Edit<String>> {
442 Visitor::new(&matcher)
444 .reentrant(false)
445 .visit(self.clone())
446 .map(|matched| matched.make_edit(&matcher, &replacer))
447 .collect()
448 }
449}
450
451#[cfg(test)]
452mod test {
453 use super::*;
454 use crate::language::Tsx;
455 use tree_sitter::Point;
456
457 fn parse(src: &str) -> Result<Tree, TSParseError> {
458 parse_lang(|p| p.parse(src, None), Tsx.get_ts_language())
459 }
460
461 #[test]
462 fn test_tree_sitter() -> Result<(), TSParseError> {
463 let tree = parse("var a = 1234")?;
464 let root_node = tree.root_node();
465 assert_eq!(root_node.kind(), "program");
466 assert_eq!(root_node.start_position().column(), 0);
467 assert_eq!(root_node.end_position().column(), 12);
468 assert_eq!(
469 root_node.to_sexp(),
470 "(program (variable_declaration (variable_declarator name: (identifier) value: (number))))"
471 );
472 Ok(())
473 }
474
475 #[test]
476 fn test_object_literal() -> Result<(), TSParseError> {
477 let tree = parse("{a: $X}")?;
478 let root_node = tree.root_node();
479 assert_eq!(root_node.to_sexp(), "(program (expression_statement (object (pair key: (property_identifier) value: (identifier)))))");
481 Ok(())
482 }
483
484 #[test]
485 fn test_string() -> Result<(), TSParseError> {
486 let tree = parse("'$A'")?;
487 let root_node = tree.root_node();
488 assert_eq!(
489 root_node.to_sexp(),
490 "(program (expression_statement (string (string_fragment))))"
491 );
492 Ok(())
493 }
494
495 #[test]
496 fn test_row_col() -> Result<(), TSParseError> {
497 let tree = parse("😄")?;
498 let root = tree.root_node();
499 assert_eq!(root.start_position(), Point::new(0, 0));
500 assert_eq!(root.end_position(), Point::new(0, 4));
502 Ok(())
503 }
504
505 #[test]
506 fn test_edit() -> Result<(), TSParseError> {
507 let mut src = "a + b".to_string();
508 let mut tree = parse(&src)?;
509 let _ = perform_edit(
510 &mut tree,
511 &mut src,
512 &Edit {
513 position: 1,
514 deleted_length: 0,
515 inserted_text: " * b".into(),
516 },
517 );
518 let tree2 = parse_lang(|p| p.parse(&src, Some(&tree)), Tsx.get_ts_language())?;
519 assert_eq!(
520 tree.root_node().to_sexp(),
521 "(program (expression_statement (binary_expression left: (identifier) right: (identifier))))"
522 );
523 assert_eq!(tree2.root_node().to_sexp(), "(program (expression_statement (binary_expression left: (binary_expression left: (identifier) right: (identifier)) right: (identifier))))");
524 Ok(())
525 }
526}