1use crate::language::Language;
2use crate::matcher::{FindAllNodes, Matcher, NodeMatch};
3use crate::replacer::Replacer;
4use crate::source::{perform_edit, Content, Edit as E, TSParseError};
5use crate::traversal::{Pre, Visitor};
6use crate::{Doc, StrDoc};
7
8type Edit<D> = E<<D as Doc>::Source>;
9
10use std::borrow::Cow;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub struct Position {
18 line: usize,
20 byte_column: usize,
22 byte_offset: usize,
24}
25
26impl Position {
27 fn new(line: u32, byte_column: u32, byte_offset: u32) -> Self {
28 Self {
29 line: line as usize,
30 byte_column: byte_column as usize,
31 byte_offset: byte_offset as usize,
32 }
33 }
34 pub fn line(&self) -> usize {
35 self.line
36 }
37 pub fn column<D: Doc>(&self, node: &Node<D>) -> usize {
39 let source = node.root.doc.get_source();
40 source.get_char_column(self.byte_column, self.byte_offset)
41 }
42 pub fn ts_point(&self) -> tree_sitter::Point {
44 tree_sitter::Point::new(self.line as u32, self.byte_column as u32)
45 }
46}
47
48#[derive(Clone)]
51pub struct Root<D: Doc> {
52 pub(crate) inner: tree_sitter::Tree,
53 pub(crate) doc: D,
54}
55
56impl<L: Language> Root<StrDoc<L>> {
57 pub fn str(src: &str, lang: L) -> Self {
58 Self::try_new(src, lang).expect("should parse")
59 }
60 pub fn get_text(&self) -> &str {
61 &self.doc.src
62 }
63}
64
65impl<D: Doc> Root<D> {
66 pub fn try_new(src: &str, lang: D::Lang) -> Result<Self, TSParseError> {
67 let doc = D::from_str(src, lang);
68 let inner = doc.parse(None)?;
69 Ok(Self { inner, doc })
70 }
71
72 pub fn new(src: &str, lang: D::Lang) -> Self {
73 Self::try_new(src, lang).expect("should parse")
74 }
75 pub fn try_doc(doc: D) -> Result<Self, TSParseError> {
76 let inner = doc.parse(None)?;
77 Ok(Self { inner, doc })
78 }
79
80 pub fn doc(doc: D) -> Self {
81 Self::try_doc(doc).expect("Parse doc error")
82 }
83
84 pub fn lang(&self) -> &D::Lang {
85 self.doc.get_lang()
86 }
87 pub fn root(&self) -> Node<D> {
89 Node {
90 inner: self.inner.root_node(),
91 root: self,
92 }
93 }
94
95 pub fn do_edit(&mut self, edit: Edit<D>) -> Result<(), TSParseError> {
97 let source = self.doc.get_source_mut();
98 let input_edit = perform_edit(&mut self.inner, source, &edit);
99 self.inner.edit(&input_edit);
100 self.inner = self.doc.parse(Some(&self.inner))?;
101 Ok(())
102 }
103
104 pub fn adopt<'r>(&'r self, inner: tree_sitter::Node<'r>) -> Node<'r, D> {
107 debug_assert!(self.check_lineage(&inner));
108 Node { inner, root: self }
109 }
110
111 fn check_lineage(&self, inner: &tree_sitter::Node<'_>) -> bool {
112 let mut node = inner.clone();
113 while let Some(n) = node.parent() {
114 node = n;
115 }
116 node == self.inner.root_node()
117 }
118
119 #[doc(hidden)]
121 pub unsafe fn readopt<'a: 'b, 'b>(&'a self, node: &mut Node<'b, D>) {
122 debug_assert!(self.check_lineage(&node.inner));
123 node.root = self;
124 }
125
126 pub fn get_injections<F: Fn(&str) -> Option<D::Lang>>(&self, get_lang: F) -> Vec<Root<D>> {
127 let root = self.root();
128 let range = self.lang().extract_injections(root);
129 let roots = range
130 .into_iter()
131 .filter_map(|(lang, ranges)| {
132 let lang = get_lang(&lang)?;
133 let source = self.doc.get_source();
134 let mut parser = tree_sitter::Parser::new().ok()?;
135 parser.set_included_ranges(&ranges).ok()?;
136 parser.set_language(&lang.get_ts_language()).ok()?;
137 let tree = source.parse_tree_sitter(&mut parser, None).ok()?;
138 tree.map(|t| Self {
139 inner: t,
140 doc: self.doc.clone_with_lang(lang),
141 })
142 })
143 .collect();
144 roots
145 }
146}
147
148#[derive(Clone)]
150pub struct Node<'r, D: Doc> {
151 pub(crate) inner: tree_sitter::Node<'r>,
152 pub(crate) root: &'r Root<D>,
153}
154pub type KindId = u16;
155
156struct NodeWalker<'tree, D: Doc> {
157 cursor: tree_sitter::TreeCursor<'tree>,
158 root: &'tree Root<D>,
159 count: usize,
160}
161
162impl<'tree, D: Doc> Iterator for NodeWalker<'tree, D> {
163 type Item = Node<'tree, D>;
164 fn next(&mut self) -> Option<Self::Item> {
165 if self.count == 0 {
166 return None;
167 }
168 let ret = Some(Node {
169 inner: self.cursor.node(),
170 root: self.root,
171 });
172 self.cursor.goto_next_sibling();
173 self.count -= 1;
174 ret
175 }
176}
177
178impl<D: Doc> ExactSizeIterator for NodeWalker<'_, D> {
179 fn len(&self) -> usize {
180 self.count
181 }
182}
183
184impl<'r, D: Doc> Node<'r, D> {
186 pub fn node_id(&self) -> usize {
187 self.inner.id()
188 }
189 pub fn is_leaf(&self) -> bool {
190 self.inner.child_count() == 0
191 }
192 pub fn is_named_leaf(&self) -> bool {
196 self.inner.named_child_count() == 0
197 }
198 pub fn is_error(&self) -> bool {
199 self.inner.is_error()
200 }
201 pub fn kind(&self) -> Cow<str> {
202 self.inner.kind()
203 }
204 pub fn kind_id(&self) -> KindId {
205 self.inner.kind_id()
206 }
207
208 pub fn is_named(&self) -> bool {
209 self.inner.is_named()
210 }
211
212 pub fn get_ts_node(&self) -> tree_sitter::Node<'r> {
214 self.inner.clone()
215 }
216
217 pub fn range(&self) -> std::ops::Range<usize> {
219 (self.inner.start_byte() as usize)..(self.inner.end_byte() as usize)
220 }
221
222 pub fn start_pos(&self) -> Position {
224 let pos = self.inner.start_position();
225 let byte = self.inner.start_byte();
226 Position::new(pos.row(), pos.column(), byte)
227 }
228
229 pub fn end_pos(&self) -> Position {
231 let pos = self.inner.end_position();
232 let byte = self.inner.end_byte();
233 Position::new(pos.row(), pos.column(), byte)
234 }
235
236 pub fn text(&self) -> Cow<'r, str> {
237 let source = self.root.doc.get_source();
238 source.get_text(&self.inner)
239 }
240
241 pub fn to_sexp(&self) -> Cow<'_, str> {
243 self.inner.to_sexp()
244 }
245
246 pub fn lang(&self) -> &'r D::Lang {
247 self.root.lang()
248 }
249}
250
251impl<'r, L: Language> Node<'r, StrDoc<L>> {
253 #[doc(hidden)]
254 pub fn display_context(&self, before: usize, after: usize) -> DisplayContext<'r> {
255 let source = self.root.doc.get_source().as_str();
256 let bytes = source.as_bytes();
257 let start = self.inner.start_byte() as usize;
258 let end = self.inner.end_byte() as usize;
259 let (mut leading, mut trailing) = (start, end);
260 let mut lines_before = before + 1;
261 while leading > 0 {
262 if bytes[leading - 1] == b'\n' {
263 lines_before -= 1;
264 if lines_before == 0 {
265 break;
266 }
267 }
268 leading -= 1;
269 }
270 let mut lines_after = after + 1;
271 trailing = trailing.min(bytes.len());
273 while trailing < bytes.len() {
274 if bytes[trailing] == b'\n' {
275 lines_after -= 1;
276 if lines_after == 0 {
277 break;
278 }
279 }
280 trailing += 1;
281 }
282 let offset = if lines_before == 0 {
284 before
285 } else {
286 before + 1 - lines_before
288 };
289 DisplayContext {
290 matched: self.text(),
291 leading: &source[leading..start],
292 trailing: &source[end..trailing],
293 start_line: self.start_pos().line() - offset,
294 }
295 }
296
297 pub fn root(&self) -> &'r Root<StrDoc<L>> {
298 self.root
299 }
300}
301
302impl<D: Doc> Node<'_, D> {
306 pub fn matches<M: Matcher<D::Lang>>(&self, m: M) -> bool {
307 m.match_node(self.clone()).is_some()
308 }
309
310 pub fn inside<M: Matcher<D::Lang>>(&self, m: M) -> bool {
311 self.ancestors().find_map(|n| m.match_node(n)).is_some()
312 }
313
314 pub fn has<M: Matcher<D::Lang>>(&self, m: M) -> bool {
315 self.dfs().skip(1).find_map(|n| m.match_node(n)).is_some()
316 }
317
318 pub fn precedes<M: Matcher<D::Lang>>(&self, m: M) -> bool {
319 self.next_all().find_map(|n| m.match_node(n)).is_some()
320 }
321
322 pub fn follows<M: Matcher<D::Lang>>(&self, m: M) -> bool {
323 self.prev_all().find_map(|n| m.match_node(n)).is_some()
324 }
325}
326
327pub struct DisplayContext<'r> {
328 pub matched: Cow<'r, str>,
330 pub leading: &'r str,
332 pub trailing: &'r str,
334 pub start_line: usize,
336}
337
338impl<'r, D: Doc> Node<'r, D> {
340 #[must_use]
341 pub fn parent(&self) -> Option<Self> {
342 let inner = self.inner.parent()?;
343 Some(Node {
344 inner,
345 root: self.root,
346 })
347 }
348
349 pub fn children<'s>(&'s self) -> impl ExactSizeIterator<Item = Node<'r, D>> + 's {
350 let mut cursor = self.inner.walk();
351 cursor.goto_first_child();
352 NodeWalker {
353 cursor,
354 root: self.root,
355 count: self.inner.child_count() as usize,
356 }
357 }
358
359 #[must_use]
360 pub fn child(&self, nth: usize) -> Option<Self> {
361 let inner = self.inner.child(nth as u32)?;
363 Some(Node {
364 inner,
365 root: self.root,
366 })
367 }
368
369 pub fn field(&self, name: &str) -> Option<Self> {
370 let inner = self.inner.child_by_field_name(name)?;
371 Some(Node {
372 inner,
373 root: self.root,
374 })
375 }
376
377 pub fn child_by_field_id(&self, field_id: u16) -> Option<Self> {
378 let inner = self.inner.child_by_field_id(field_id)?;
379 Some(Node {
380 inner,
381 root: self.root,
382 })
383 }
384
385 pub fn field_children(&self, name: &str) -> impl Iterator<Item = Node<'r, D>> {
386 let field_id = self
387 .root
388 .lang()
389 .get_ts_language()
390 .field_id_for_name(name)
391 .unwrap_or(0);
392 let root = self.root;
393 let mut cursor = self.inner.walk();
394 cursor.goto_first_child();
395 let mut done = false;
396 std::iter::from_fn(move || {
397 if done {
398 return None;
399 }
400 while cursor.field_id() != Some(field_id) {
401 if !cursor.goto_next_sibling() {
402 return None;
403 }
404 }
405 let inner = cursor.node();
406 if !cursor.goto_next_sibling() {
407 done = true;
408 }
409 Some(Node { inner, root })
410 })
411 }
412
413 pub fn ancestors(&self) -> impl Iterator<Item = Node<'r, D>> + '_ {
417 let mut parent = self.inner.parent();
418 std::iter::from_fn(move || {
419 let inner = parent.clone()?;
420 let ret = Some(Node {
421 inner: inner.clone(),
422 root: self.root,
423 });
424 parent = inner.parent();
425 ret
426 })
427 }
428 #[must_use]
429 pub fn next(&self) -> Option<Self> {
430 let inner = self.inner.next_sibling()?;
431 Some(Node {
432 inner,
433 root: self.root,
434 })
435 }
436
437 #[cfg(not(target_arch = "wasm32"))]
442 pub fn next_all(&self) -> impl Iterator<Item = Node<'r, D>> + '_ {
443 let node = self.parent().unwrap_or_else(|| self.clone());
445 let mut cursor = node.inner.walk();
446 cursor.goto_first_child_for_byte(self.inner.start_byte());
447 std::iter::from_fn(move || {
448 if cursor.goto_next_sibling() {
449 Some(self.root.adopt(cursor.node()))
450 } else {
451 None
452 }
453 })
454 }
455
456 #[cfg(target_arch = "wasm32")]
458 pub fn next_all(&self) -> impl Iterator<Item = Node<'r, D>> + '_ {
459 let mut node = self.clone();
460 std::iter::from_fn(move || {
461 node.next().map(|n| {
462 node = n.clone();
463 n
464 })
465 })
466 }
467
468 #[must_use]
469 pub fn prev(&self) -> Option<Node<'r, D>> {
470 let inner = self.inner.prev_sibling()?;
471 Some(Node {
472 inner,
473 root: self.root,
474 })
475 }
476
477 #[cfg(not(target_arch = "wasm32"))]
478 pub fn prev_all(&self) -> impl Iterator<Item = Node<'r, D>> + '_ {
479 let node = self.parent().unwrap_or_else(|| self.clone());
481 let mut cursor = node.inner.walk();
482 cursor.goto_first_child_for_byte(self.inner.start_byte());
483 std::iter::from_fn(move || {
484 if cursor.goto_previous_sibling() {
485 Some(self.root.adopt(cursor.node()))
486 } else {
487 None
488 }
489 })
490 }
491
492 #[cfg(target_arch = "wasm32")]
494 pub fn prev_all(&self) -> impl Iterator<Item = Node<'r, D>> + '_ {
495 let mut node = self.clone();
496 std::iter::from_fn(move || {
497 node.prev().map(|n| {
498 node = n.clone();
499 n
500 })
501 })
502 }
503
504 pub fn dfs<'s>(&'s self) -> Pre<'r, D> {
505 Pre::new(self)
506 }
507
508 #[must_use]
509 pub fn find<M: Matcher<D::Lang>>(&self, pat: M) -> Option<NodeMatch<'r, D>> {
510 pat.find_node(self.clone())
511 }
512
513 pub fn find_all<M: Matcher<D::Lang>>(&self, pat: M) -> impl Iterator<Item = NodeMatch<'r, D>> {
514 FindAllNodes::new(pat, self.clone())
515 }
516}
517
518impl<D: Doc> Node<'_, D> {
520 pub fn replace<M: Matcher<D::Lang>, R: Replacer<D>>(
521 &self,
522 matcher: M,
523 replacer: R,
524 ) -> Option<Edit<D>> {
525 let matched = matcher.find_node(self.clone())?;
526 let edit = matched.make_edit(&matcher, &replacer);
527 Some(edit)
528 }
529
530 pub fn replace_all<M: Matcher<D::Lang>, R: Replacer<D>>(
531 &self,
532 matcher: M,
533 replacer: R,
534 ) -> Vec<Edit<D>> {
535 Visitor::new(&matcher)
537 .reentrant(false)
538 .visit(self.clone())
539 .map(|matched| matched.make_edit(&matcher, &replacer))
540 .collect()
541 }
542
543 pub fn after(&self) -> Edit<D> {
544 todo!()
545 }
546 pub fn before(&self) -> Edit<D> {
547 todo!()
548 }
549 pub fn append(&self) -> Edit<D> {
550 todo!()
551 }
552 pub fn prepend(&self) -> Edit<D> {
553 todo!()
554 }
555
556 pub fn empty(&self) -> Option<Edit<D>> {
558 let mut children = self.children().peekable();
559 let start = children.peek()?.range().start;
560 let end = children.last()?.range().end;
561 Some(Edit::<D> {
562 position: start,
563 deleted_length: end - start,
564 inserted_text: Vec::new(),
565 })
566 }
567
568 pub fn remove(&self) -> Edit<D> {
570 let range = self.range();
571 Edit::<D> {
572 position: range.start,
573 deleted_length: range.end - range.start,
574 inserted_text: Vec::new(),
575 }
576 }
577}
578
579#[cfg(test)]
580mod test {
581 use crate::language::{Language, Tsx};
582 #[test]
583 fn test_is_leaf() {
584 let root = Tsx.ast_grep("let a = 123");
585 let node = root.root();
586 assert!(!node.is_leaf());
587 }
588
589 #[test]
590 fn test_children() {
591 let root = Tsx.ast_grep("let a = 123");
592 let node = root.root();
593 let children: Vec<_> = node.children().collect();
594 assert_eq!(children.len(), 1);
595 let texts: Vec<_> = children[0]
596 .children()
597 .map(|c| c.text().to_string())
598 .collect();
599 assert_eq!(texts, vec!["let", "a = 123"]);
600 }
601 #[test]
602 fn test_empty() {
603 let root = Tsx.ast_grep("let a = 123");
604 let node = root.root();
605 let edit = node.empty().unwrap();
606 assert_eq!(edit.inserted_text.len(), 0);
607 assert_eq!(edit.deleted_length, 11);
608 assert_eq!(edit.position, 0);
609 }
610
611 #[test]
612 fn test_field_children() {
613 let root = Tsx.ast_grep("let a = 123");
614 let node = root.root().find("let a = $A").unwrap();
615 let children: Vec<_> = node.field_children("kind").collect();
616 assert_eq!(children.len(), 1);
617 assert_eq!(children[0].text(), "let");
618 }
619
620 const MULTI_LINE: &str = "
621if (a) {
622 test(1)
623} else {
624 x
625}
626";
627
628 #[test]
629 fn test_display_context() {
630 let cases = [
632 ["i()", "i()", "", ""],
633 ["i()", "i", "", "()"],
634 [MULTI_LINE, "test", " ", "(1)"],
635 ];
636 for [src, matcher, lead, trail] in cases {
638 let root = Tsx.ast_grep(src);
639 let node = root.root().find(matcher).expect("should match");
640 let display = node.display_context(0, 0);
641 assert_eq!(display.leading, lead);
642 assert_eq!(display.trailing, trail);
643 }
644 }
645
646 #[test]
647 fn test_multi_line_context() {
648 let cases = [
649 ["i()", "i()", "", ""],
650 [MULTI_LINE, "test", "if (a) {\n ", "(1)\n} else {"],
651 ];
652 for [src, matcher, lead, trail] in cases {
654 let root = Tsx.ast_grep(src);
655 let node = root.root().find(matcher).expect("should match");
656 let display = node.display_context(1, 1);
657 assert_eq!(display.leading, lead);
658 assert_eq!(display.trailing, trail);
659 }
660 }
661
662 #[test]
663 fn test_replace_all_nested() {
664 let root = Tsx.ast_grep("Some(Some(1))");
665 let node = root.root();
666 let edits = node.replace_all("Some($A)", "$A");
667 assert_eq!(edits.len(), 1);
668 assert_eq!(edits[0].inserted_text, "Some(1)".as_bytes());
669 }
670
671 #[test]
672 fn test_replace_all_multiple_sorted() {
673 let root = Tsx.ast_grep("Some(Some(1)); Some(2)");
674 let node = root.root();
675 let edits = node.replace_all("Some($A)", "$A");
676 assert_eq!(edits.len(), 2);
678 assert_eq!(edits[0].inserted_text, "Some(1)".as_bytes());
679 assert_eq!(edits[1].inserted_text, "2".as_bytes());
680 }
681
682 #[test]
683 fn test_inside() {
684 let root = Tsx.ast_grep("Some(Some(1)); Some(2)");
685 let root = root.root();
686 let node = root.find("Some(1)").expect("should exist");
687 assert!(node.inside("Some($A)"));
688 }
689 #[test]
690 fn test_has() {
691 let root = Tsx.ast_grep("Some(Some(1)); Some(2)");
692 let root = root.root();
693 let node = root.find("Some($A)").expect("should exist");
694 assert!(node.has("Some(1)"));
695 }
696 #[test]
697 fn precedes() {
698 let root = Tsx.ast_grep("Some(Some(1)); Some(2);");
699 let root = root.root();
700 let node = root.find("Some($A);").expect("should exist");
701 assert!(node.precedes("Some(2);"));
702 }
703 #[test]
704 fn follows() {
705 let root = Tsx.ast_grep("Some(Some(1)); Some(2);");
706 let root = root.root();
707 let node = root.find("Some(2);").expect("should exist");
708 assert!(node.follows("Some(Some(1));"));
709 }
710
711 #[test]
712 fn test_field() {
713 let root = Tsx.ast_grep("class A{}");
714 let root = root.root();
715 let node = root.find("class $C {}").expect("should exist");
716 assert!(node.field("name").is_some());
717 assert!(node.field("none").is_none());
718 }
719 #[test]
720 fn test_child_by_field_id() {
721 let root = Tsx.ast_grep("class A{}");
722 let root = root.root();
723 let node = root.find("class $C {}").expect("should exist");
724 let id = Tsx.get_ts_language().field_id_for_name("name").unwrap();
725 assert!(node.child_by_field_id(id).is_some());
726 assert!(node.child_by_field_id(id + 1).is_none());
727 }
728
729 #[test]
730 fn test_remove() {
731 let root = Tsx.ast_grep("Some(Some(1)); Some(2);");
732 let root = root.root();
733 let node = root.find("Some(2);").expect("should exist");
734 let edit = node.remove();
735 assert_eq!(edit.position, 15);
736 assert_eq!(edit.deleted_length, 8);
737 }
738
739 #[test]
740 fn test_ascii_pos() {
741 let root = Tsx.ast_grep("a");
742 let root = root.root();
743 let node = root.find("$A").expect("should exist");
744 assert_eq!(node.start_pos().line(), 0);
745 assert_eq!(node.start_pos().column(&node), 0);
746 assert_eq!(node.end_pos().line(), 0);
747 assert_eq!(node.end_pos().column(&node), 1);
748 }
749
750 #[test]
751 fn test_unicode_pos() {
752 let root = Tsx.ast_grep("🦀");
753 let root = root.root();
754 let node = root.find("$A").expect("should exist");
755 assert_eq!(node.start_pos().line(), 0);
756 assert_eq!(node.start_pos().column(&node), 0);
757 assert_eq!(node.end_pos().line(), 0);
758 assert_eq!(node.end_pos().column(&node), 1);
759 let root = Tsx.ast_grep("\n 🦀🦀");
760 let root = root.root();
761 let node = root.find("$A").expect("should exist");
762 assert_eq!(node.start_pos().line(), 1);
763 assert_eq!(node.start_pos().column(&node), 2);
764 assert_eq!(node.end_pos().line(), 1);
765 assert_eq!(node.end_pos().column(&node), 4);
766 }
767}