ra_ap_syntax/
syntax_editor.rs

1//! Syntax Tree editor
2//!
3//! Inspired by Roslyn's [`SyntaxEditor`], but is temporarily built upon mutable syntax tree editing.
4//!
5//! [`SyntaxEditor`]: https://github.com/dotnet/roslyn/blob/43b0b05cc4f492fd5de00f6f6717409091df8daa/src/Workspaces/Core/Portable/Editing/SyntaxEditor.cs
6
7use std::{
8    fmt, iter,
9    num::NonZeroU32,
10    ops::RangeInclusive,
11    sync::atomic::{AtomicU32, Ordering},
12};
13
14use rowan::TextRange;
15use rustc_hash::FxHashMap;
16
17use crate::{SyntaxElement, SyntaxNode, SyntaxToken};
18
19mod edit_algo;
20mod edits;
21mod mapping;
22
23pub use edits::Removable;
24pub use mapping::{SyntaxMapping, SyntaxMappingBuilder};
25
26#[derive(Debug)]
27pub struct SyntaxEditor {
28    root: SyntaxNode,
29    changes: Vec<Change>,
30    mappings: SyntaxMapping,
31    annotations: Vec<(SyntaxElement, SyntaxAnnotation)>,
32}
33
34impl SyntaxEditor {
35    /// Creates a syntax editor to start editing from `root`
36    pub fn new(root: SyntaxNode) -> Self {
37        Self { root, changes: vec![], mappings: SyntaxMapping::default(), annotations: vec![] }
38    }
39
40    pub fn add_annotation(&mut self, element: impl Element, annotation: SyntaxAnnotation) {
41        self.annotations.push((element.syntax_element(), annotation))
42    }
43
44    pub fn add_annotation_all(
45        &mut self,
46        elements: Vec<impl Element>,
47        annotation: SyntaxAnnotation,
48    ) {
49        self.annotations
50            .extend(elements.into_iter().map(|e| e.syntax_element()).zip(iter::repeat(annotation)));
51    }
52
53    pub fn merge(&mut self, mut other: SyntaxEditor) {
54        debug_assert!(
55            self.root == other.root || other.root.ancestors().any(|node| node == self.root),
56            "{:?} is not in the same tree as {:?}",
57            other.root,
58            self.root
59        );
60
61        self.changes.append(&mut other.changes);
62        self.mappings.merge(other.mappings);
63        self.annotations.append(&mut other.annotations);
64    }
65
66    pub fn insert(&mut self, position: Position, element: impl Element) {
67        debug_assert!(is_ancestor_or_self(&position.parent(), &self.root));
68        self.changes.push(Change::Insert(position, element.syntax_element()))
69    }
70
71    pub fn insert_all(&mut self, position: Position, elements: Vec<SyntaxElement>) {
72        debug_assert!(is_ancestor_or_self(&position.parent(), &self.root));
73        self.changes.push(Change::InsertAll(position, elements))
74    }
75
76    pub fn delete(&mut self, element: impl Element) {
77        let element = element.syntax_element();
78        debug_assert!(is_ancestor_or_self_of_element(&element, &self.root));
79        debug_assert!(
80            !matches!(&element, SyntaxElement::Node(node) if node == &self.root),
81            "should not delete root node"
82        );
83        self.changes.push(Change::Replace(element.syntax_element(), None));
84    }
85
86    pub fn delete_all(&mut self, range: RangeInclusive<SyntaxElement>) {
87        if range.start() == range.end() {
88            self.delete(range.start());
89            return;
90        }
91
92        debug_assert!(is_ancestor_or_self_of_element(range.start(), &self.root));
93        self.changes.push(Change::ReplaceAll(range, Vec::new()))
94    }
95
96    pub fn replace(&mut self, old: impl Element, new: impl Element) {
97        let old = old.syntax_element();
98        debug_assert!(is_ancestor_or_self_of_element(&old, &self.root));
99        self.changes.push(Change::Replace(old.syntax_element(), Some(new.syntax_element())));
100    }
101
102    pub fn replace_with_many(&mut self, old: impl Element, new: Vec<SyntaxElement>) {
103        let old = old.syntax_element();
104        debug_assert!(is_ancestor_or_self_of_element(&old, &self.root));
105        debug_assert!(
106            !(matches!(&old, SyntaxElement::Node(node) if node == &self.root) && new.len() > 1),
107            "cannot replace root node with many elements"
108        );
109        self.changes.push(Change::ReplaceWithMany(old.syntax_element(), new));
110    }
111
112    pub fn replace_all(&mut self, range: RangeInclusive<SyntaxElement>, new: Vec<SyntaxElement>) {
113        if range.start() == range.end() {
114            self.replace_with_many(range.start(), new);
115            return;
116        }
117
118        debug_assert!(is_ancestor_or_self_of_element(range.start(), &self.root));
119        self.changes.push(Change::ReplaceAll(range, new))
120    }
121
122    pub fn finish(self) -> SyntaxEdit {
123        edit_algo::apply_edits(self)
124    }
125
126    pub fn add_mappings(&mut self, other: SyntaxMapping) {
127        self.mappings.merge(other);
128    }
129}
130
131/// Represents a completed [`SyntaxEditor`] operation.
132pub struct SyntaxEdit {
133    old_root: SyntaxNode,
134    new_root: SyntaxNode,
135    changed_elements: Vec<SyntaxElement>,
136    annotations: FxHashMap<SyntaxAnnotation, Vec<SyntaxElement>>,
137}
138
139impl SyntaxEdit {
140    /// Root of the initial unmodified syntax tree.
141    pub fn old_root(&self) -> &SyntaxNode {
142        &self.old_root
143    }
144
145    /// Root of the modified syntax tree.
146    pub fn new_root(&self) -> &SyntaxNode {
147        &self.new_root
148    }
149
150    /// Which syntax elements in the modified syntax tree were inserted or
151    /// modified as part of the edit.
152    ///
153    /// Note that for syntax nodes, only the upper-most parent of a set of
154    /// changes is included, not any child elements that may have been modified.
155    pub fn changed_elements(&self) -> &[SyntaxElement] {
156        self.changed_elements.as_slice()
157    }
158
159    /// Finds which syntax elements have been annotated with the given
160    /// annotation.
161    ///
162    /// Note that an annotation might not appear in the modified syntax tree if
163    /// the syntax elements that were annotated did not make it into the final
164    /// syntax tree.
165    pub fn find_annotation(&self, annotation: SyntaxAnnotation) -> &[SyntaxElement] {
166        self.annotations.get(&annotation).as_ref().map_or(&[], |it| it.as_slice())
167    }
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
171#[repr(transparent)]
172pub struct SyntaxAnnotation(NonZeroU32);
173
174impl Default for SyntaxAnnotation {
175    fn default() -> Self {
176        static COUNTER: AtomicU32 = AtomicU32::new(1);
177
178        // Only consistency within a thread matters, as SyntaxElements are !Send
179        let id = COUNTER.fetch_add(1, Ordering::Relaxed);
180
181        Self(NonZeroU32::new(id).expect("syntax annotation id overflow"))
182    }
183}
184
185/// Position describing where to insert elements
186#[derive(Debug)]
187pub struct Position {
188    repr: PositionRepr,
189}
190
191impl Position {
192    pub(crate) fn parent(&self) -> SyntaxNode {
193        self.place().0
194    }
195
196    pub(crate) fn place(&self) -> (SyntaxNode, usize) {
197        match &self.repr {
198            PositionRepr::FirstChild(parent) => (parent.clone(), 0),
199            PositionRepr::After(child) => (child.parent().unwrap(), child.index() + 1),
200        }
201    }
202}
203
204#[derive(Debug)]
205enum PositionRepr {
206    FirstChild(SyntaxNode),
207    After(SyntaxElement),
208}
209
210impl Position {
211    pub fn after(elem: impl Element) -> Position {
212        let repr = PositionRepr::After(elem.syntax_element());
213        Position { repr }
214    }
215
216    pub fn before(elem: impl Element) -> Position {
217        let elem = elem.syntax_element();
218        let repr = match elem.prev_sibling_or_token() {
219            Some(it) => PositionRepr::After(it),
220            None => PositionRepr::FirstChild(elem.parent().unwrap()),
221        };
222        Position { repr }
223    }
224
225    pub fn first_child_of(node: &(impl Into<SyntaxNode> + Clone)) -> Position {
226        let repr = PositionRepr::FirstChild(node.clone().into());
227        Position { repr }
228    }
229
230    pub fn last_child_of(node: &(impl Into<SyntaxNode> + Clone)) -> Position {
231        let node = node.clone().into();
232        let repr = match node.last_child_or_token() {
233            Some(it) => PositionRepr::After(it),
234            None => PositionRepr::FirstChild(node),
235        };
236        Position { repr }
237    }
238}
239
240#[derive(Debug)]
241enum Change {
242    /// Inserts a single element at the specified position.
243    Insert(Position, SyntaxElement),
244    /// Inserts many elements in-order at the specified position.
245    InsertAll(Position, Vec<SyntaxElement>),
246    /// Represents both a replace single element and a delete element operation.
247    Replace(SyntaxElement, Option<SyntaxElement>),
248    /// Replaces a single element with many elements.
249    ReplaceWithMany(SyntaxElement, Vec<SyntaxElement>),
250    /// Replaces a range of elements with another list of elements.
251    /// Range will always have start != end.
252    ReplaceAll(RangeInclusive<SyntaxElement>, Vec<SyntaxElement>),
253}
254
255impl Change {
256    fn target_range(&self) -> TextRange {
257        match self {
258            Change::Insert(target, _) | Change::InsertAll(target, _) => match &target.repr {
259                PositionRepr::FirstChild(parent) => TextRange::at(
260                    parent.first_child_or_token().unwrap().text_range().start(),
261                    0.into(),
262                ),
263                PositionRepr::After(child) => TextRange::at(child.text_range().end(), 0.into()),
264            },
265            Change::Replace(target, _) | Change::ReplaceWithMany(target, _) => target.text_range(),
266            Change::ReplaceAll(range, _) => {
267                range.start().text_range().cover(range.end().text_range())
268            }
269        }
270    }
271
272    fn target_parent(&self) -> SyntaxNode {
273        match self {
274            Change::Insert(target, _) | Change::InsertAll(target, _) => target.parent(),
275            Change::Replace(target, _) | Change::ReplaceWithMany(target, _) => match target {
276                SyntaxElement::Node(target) => target.parent().unwrap_or_else(|| target.clone()),
277                SyntaxElement::Token(target) => target.parent().unwrap(),
278            },
279            Change::ReplaceAll(target, _) => target.start().parent().unwrap(),
280        }
281    }
282
283    fn change_kind(&self) -> ChangeKind {
284        match self {
285            Change::Insert(_, _) | Change::InsertAll(_, _) => ChangeKind::Insert,
286            Change::Replace(_, _) | Change::ReplaceWithMany(_, _) => ChangeKind::Replace,
287            Change::ReplaceAll(_, _) => ChangeKind::ReplaceRange,
288        }
289    }
290}
291
292#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
293enum ChangeKind {
294    Insert,
295    ReplaceRange,
296    Replace,
297}
298
299impl fmt::Display for Change {
300    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301        match self {
302            Change::Insert(position, node_or_token) => {
303                let parent = position.parent();
304                let mut parent_str = parent.to_string();
305                let target_range = self.target_range().start() - parent.text_range().start();
306
307                parent_str.insert_str(
308                    target_range.into(),
309                    &format!("\x1b[42m{node_or_token}\x1b[0m\x1b[K"),
310                );
311                f.write_str(&parent_str)
312            }
313            Change::InsertAll(position, vec) => {
314                let parent = position.parent();
315                let mut parent_str = parent.to_string();
316                let target_range = self.target_range().start() - parent.text_range().start();
317                let insertion: String = vec.iter().map(|it| it.to_string()).collect();
318
319                parent_str
320                    .insert_str(target_range.into(), &format!("\x1b[42m{insertion}\x1b[0m\x1b[K"));
321                f.write_str(&parent_str)
322            }
323            Change::Replace(old, new) => {
324                if let Some(new) = new {
325                    write!(f, "\x1b[41m{old}\x1b[42m{new}\x1b[0m\x1b[K")
326                } else {
327                    write!(f, "\x1b[41m{old}\x1b[0m\x1b[K")
328                }
329            }
330            Change::ReplaceWithMany(old, vec) => {
331                let new: String = vec.iter().map(|it| it.to_string()).collect();
332                write!(f, "\x1b[41m{old}\x1b[42m{new}\x1b[0m\x1b[K")
333            }
334            Change::ReplaceAll(range, vec) => {
335                let parent = range.start().parent().unwrap();
336                let parent_str = parent.to_string();
337                let pre_range =
338                    TextRange::new(parent.text_range().start(), range.start().text_range().start());
339                let old_range = TextRange::new(
340                    range.start().text_range().start(),
341                    range.end().text_range().end(),
342                );
343                let post_range =
344                    TextRange::new(range.end().text_range().end(), parent.text_range().end());
345
346                let pre_str = &parent_str[pre_range - parent.text_range().start()];
347                let old_str = &parent_str[old_range - parent.text_range().start()];
348                let post_str = &parent_str[post_range - parent.text_range().start()];
349                let new: String = vec.iter().map(|it| it.to_string()).collect();
350
351                write!(f, "{pre_str}\x1b[41m{old_str}\x1b[42m{new}\x1b[0m\x1b[K{post_str}")
352            }
353        }
354    }
355}
356
357/// Utility trait to allow calling syntax editor functions with references or owned
358/// nodes. Do not use outside of this module.
359pub trait Element {
360    fn syntax_element(self) -> SyntaxElement;
361}
362
363impl<E: Element + Clone> Element for &'_ E {
364    fn syntax_element(self) -> SyntaxElement {
365        self.clone().syntax_element()
366    }
367}
368
369impl Element for SyntaxElement {
370    fn syntax_element(self) -> SyntaxElement {
371        self
372    }
373}
374
375impl Element for SyntaxNode {
376    fn syntax_element(self) -> SyntaxElement {
377        self.into()
378    }
379}
380
381impl Element for SyntaxToken {
382    fn syntax_element(self) -> SyntaxElement {
383        self.into()
384    }
385}
386
387fn is_ancestor_or_self(node: &SyntaxNode, ancestor: &SyntaxNode) -> bool {
388    node == ancestor || node.ancestors().any(|it| &it == ancestor)
389}
390
391fn is_ancestor_or_self_of_element(node: &SyntaxElement, ancestor: &SyntaxNode) -> bool {
392    matches!(node, SyntaxElement::Node(node) if node == ancestor)
393        || node.ancestors().any(|it| &it == ancestor)
394}
395
396#[cfg(test)]
397mod tests {
398    use expect_test::expect;
399
400    use crate::{
401        AstNode,
402        ast::{self, make, syntax_factory::SyntaxFactory},
403    };
404
405    use super::*;
406
407    #[test]
408    fn basic_usage() {
409        let root = make::match_arm(
410            make::wildcard_pat().into(),
411            None,
412            make::expr_tuple([
413                make::expr_bin_op(
414                    make::expr_literal("2").into(),
415                    ast::BinaryOp::ArithOp(ast::ArithOp::Add),
416                    make::expr_literal("2").into(),
417                ),
418                make::expr_literal("true").into(),
419            ])
420            .into(),
421        );
422
423        let to_wrap = root.syntax().descendants().find_map(ast::TupleExpr::cast).unwrap();
424        let to_replace = root.syntax().descendants().find_map(ast::BinExpr::cast).unwrap();
425
426        let mut editor = SyntaxEditor::new(root.syntax().clone());
427        let make = SyntaxFactory::with_mappings();
428
429        let name = make::name("var_name");
430        let name_ref = make::name_ref("var_name").clone_for_update();
431
432        let placeholder_snippet = SyntaxAnnotation::default();
433        editor.add_annotation(name.syntax(), placeholder_snippet);
434        editor.add_annotation(name_ref.syntax(), placeholder_snippet);
435
436        let new_block = make.block_expr(
437            [make
438                .let_stmt(
439                    make.ident_pat(false, false, name.clone()).into(),
440                    None,
441                    Some(to_replace.clone().into()),
442                )
443                .into()],
444            Some(to_wrap.clone().into()),
445        );
446
447        editor.replace(to_replace.syntax(), name_ref.syntax());
448        editor.replace(to_wrap.syntax(), new_block.syntax());
449        editor.add_mappings(make.finish_with_mappings());
450
451        let edit = editor.finish();
452
453        let expect = expect![[r#"
454            _ => {
455                let var_name = 2 + 2;
456                (var_name, true)
457            },"#]];
458        expect.assert_eq(&edit.new_root.to_string());
459
460        assert_eq!(edit.find_annotation(placeholder_snippet).len(), 2);
461        assert!(
462            edit.annotations
463                .iter()
464                .flat_map(|(_, elements)| elements)
465                .all(|element| element.ancestors().any(|it| &it == edit.new_root()))
466        )
467    }
468
469    #[test]
470    fn test_insert_independent() {
471        let root = make::block_expr(
472            [make::let_stmt(
473                make::ext::simple_ident_pat(make::name("second")).into(),
474                None,
475                Some(make::expr_literal("2").into()),
476            )
477            .into()],
478            None,
479        );
480
481        let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap();
482
483        let mut editor = SyntaxEditor::new(root.syntax().clone());
484        let make = SyntaxFactory::without_mappings();
485
486        editor.insert(
487            Position::first_child_of(root.stmt_list().unwrap().syntax()),
488            make.let_stmt(
489                make::ext::simple_ident_pat(make::name("first")).into(),
490                None,
491                Some(make::expr_literal("1").into()),
492            )
493            .syntax(),
494        );
495
496        editor.insert(
497            Position::after(second_let.syntax()),
498            make.let_stmt(
499                make::ext::simple_ident_pat(make::name("third")).into(),
500                None,
501                Some(make::expr_literal("3").into()),
502            )
503            .syntax(),
504        );
505
506        let edit = editor.finish();
507
508        let expect = expect![[r#"
509            let first = 1;{
510                let second = 2;let third = 3;
511            }"#]];
512        expect.assert_eq(&edit.new_root.to_string());
513    }
514
515    #[test]
516    fn test_insert_dependent() {
517        let root = make::block_expr(
518            [],
519            Some(
520                make::block_expr(
521                    [make::let_stmt(
522                        make::ext::simple_ident_pat(make::name("second")).into(),
523                        None,
524                        Some(make::expr_literal("2").into()),
525                    )
526                    .into()],
527                    None,
528                )
529                .into(),
530            ),
531        );
532
533        let inner_block =
534            root.syntax().descendants().flat_map(ast::BlockExpr::cast).nth(1).unwrap();
535        let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap();
536
537        let mut editor = SyntaxEditor::new(root.syntax().clone());
538        let make = SyntaxFactory::with_mappings();
539
540        let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone())));
541
542        let first_let = make.let_stmt(
543            make::ext::simple_ident_pat(make::name("first")).into(),
544            None,
545            Some(make::expr_literal("1").into()),
546        );
547
548        let third_let = make.let_stmt(
549            make::ext::simple_ident_pat(make::name("third")).into(),
550            None,
551            Some(make::expr_literal("3").into()),
552        );
553
554        editor.insert(
555            Position::first_child_of(inner_block.stmt_list().unwrap().syntax()),
556            first_let.syntax(),
557        );
558        editor.insert(Position::after(second_let.syntax()), third_let.syntax());
559        editor.replace(inner_block.syntax(), new_block_expr.syntax());
560        editor.add_mappings(make.finish_with_mappings());
561
562        let edit = editor.finish();
563
564        let expect = expect![[r#"
565            {
566                {
567                let first = 1;{
568                let second = 2;let third = 3;
569            }
570            }
571            }"#]];
572        expect.assert_eq(&edit.new_root.to_string());
573    }
574
575    #[test]
576    fn test_replace_root_with_dependent() {
577        let root = make::block_expr(
578            [make::let_stmt(
579                make::ext::simple_ident_pat(make::name("second")).into(),
580                None,
581                Some(make::expr_literal("2").into()),
582            )
583            .into()],
584            None,
585        );
586
587        let inner_block = root.clone();
588
589        let mut editor = SyntaxEditor::new(root.syntax().clone());
590        let make = SyntaxFactory::with_mappings();
591
592        let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone())));
593
594        let first_let = make.let_stmt(
595            make::ext::simple_ident_pat(make::name("first")).into(),
596            None,
597            Some(make::expr_literal("1").into()),
598        );
599
600        editor.insert(
601            Position::first_child_of(inner_block.stmt_list().unwrap().syntax()),
602            first_let.syntax(),
603        );
604        editor.replace(inner_block.syntax(), new_block_expr.syntax());
605        editor.add_mappings(make.finish_with_mappings());
606
607        let edit = editor.finish();
608
609        let expect = expect![[r#"
610            {
611                let first = 1;{
612                let second = 2;
613            }
614            }"#]];
615        expect.assert_eq(&edit.new_root.to_string());
616    }
617
618    #[test]
619    fn test_replace_token_in_parent() {
620        let parent_fn = make::fn_(
621            None,
622            None,
623            make::name("it"),
624            None,
625            None,
626            make::param_list(None, []),
627            make::block_expr([], Some(make::ext::expr_unit())),
628            Some(make::ret_type(make::ty_unit())),
629            false,
630            false,
631            false,
632            false,
633        );
634
635        let mut editor = SyntaxEditor::new(parent_fn.syntax().clone());
636
637        if let Some(ret_ty) = parent_fn.ret_type() {
638            editor.delete(ret_ty.syntax().clone());
639
640            if let Some(SyntaxElement::Token(token)) = ret_ty.syntax().next_sibling_or_token()
641                && token.kind().is_trivia()
642            {
643                editor.delete(token);
644            }
645        }
646
647        if let Some(tail) = parent_fn.body().unwrap().tail_expr() {
648            editor.delete(tail.syntax().clone());
649        }
650
651        let edit = editor.finish();
652
653        let expect = expect![["fn it() {\n    \n}"]];
654        expect.assert_eq(&edit.new_root.to_string());
655    }
656}