ra_ap_syntax/
syntax_editor.rs

1//! Syntax Tree editor
2//!
3//! Inspired by Roslyn's [`SyntaxEditor`], but is temporarily built upon mutable syntax tree editing.
4//!
5//! [`SyntaxEditor`]: https://github.com/dotnet/roslyn/blob/43b0b05cc4f492fd5de00f6f6717409091df8daa/src/Workspaces/Core/Portable/Editing/SyntaxEditor.cs
6
7use std::{
8    fmt,
9    num::NonZeroU32,
10    ops::RangeInclusive,
11    sync::atomic::{AtomicU32, Ordering},
12};
13
14use rowan::TextRange;
15use rustc_hash::FxHashMap;
16
17use crate::{SyntaxElement, SyntaxNode, SyntaxToken};
18
19mod edit_algo;
20mod edits;
21mod mapping;
22
23pub use edits::Removable;
24pub use mapping::{SyntaxMapping, SyntaxMappingBuilder};
25
26#[derive(Debug)]
27pub struct SyntaxEditor {
28    root: SyntaxNode,
29    changes: Vec<Change>,
30    mappings: SyntaxMapping,
31    annotations: Vec<(SyntaxElement, SyntaxAnnotation)>,
32}
33
34impl SyntaxEditor {
35    /// Creates a syntax editor to start editing from `root`
36    pub fn new(root: SyntaxNode) -> Self {
37        Self { root, changes: vec![], mappings: SyntaxMapping::default(), annotations: vec![] }
38    }
39
40    pub fn add_annotation(&mut self, element: impl Element, annotation: SyntaxAnnotation) {
41        self.annotations.push((element.syntax_element(), annotation))
42    }
43
44    pub fn merge(&mut self, mut other: SyntaxEditor) {
45        debug_assert!(
46            self.root == other.root || other.root.ancestors().any(|node| node == self.root),
47            "{:?} is not in the same tree as {:?}",
48            other.root,
49            self.root
50        );
51
52        self.changes.append(&mut other.changes);
53        self.mappings.merge(other.mappings);
54        self.annotations.append(&mut other.annotations);
55    }
56
57    pub fn insert(&mut self, position: Position, element: impl Element) {
58        debug_assert!(is_ancestor_or_self(&position.parent(), &self.root));
59        self.changes.push(Change::Insert(position, element.syntax_element()))
60    }
61
62    pub fn insert_all(&mut self, position: Position, elements: Vec<SyntaxElement>) {
63        debug_assert!(is_ancestor_or_self(&position.parent(), &self.root));
64        self.changes.push(Change::InsertAll(position, elements))
65    }
66
67    pub fn delete(&mut self, element: impl Element) {
68        let element = element.syntax_element();
69        debug_assert!(is_ancestor_or_self_of_element(&element, &self.root));
70        debug_assert!(
71            !matches!(&element, SyntaxElement::Node(node) if node == &self.root),
72            "should not delete root node"
73        );
74        self.changes.push(Change::Replace(element.syntax_element(), None));
75    }
76
77    pub fn replace(&mut self, old: impl Element, new: impl Element) {
78        let old = old.syntax_element();
79        debug_assert!(is_ancestor_or_self_of_element(&old, &self.root));
80        self.changes.push(Change::Replace(old.syntax_element(), Some(new.syntax_element())));
81    }
82
83    pub fn replace_with_many(&mut self, old: impl Element, new: Vec<SyntaxElement>) {
84        let old = old.syntax_element();
85        debug_assert!(is_ancestor_or_self_of_element(&old, &self.root));
86        debug_assert!(
87            !(matches!(&old, SyntaxElement::Node(node) if node == &self.root) && new.len() > 1),
88            "cannot replace root node with many elements"
89        );
90        self.changes.push(Change::ReplaceWithMany(old.syntax_element(), new));
91    }
92
93    pub fn replace_all(&mut self, range: RangeInclusive<SyntaxElement>, new: Vec<SyntaxElement>) {
94        if range.start() == range.end() {
95            self.replace_with_many(range.start(), new);
96            return;
97        }
98
99        debug_assert!(is_ancestor_or_self_of_element(range.start(), &self.root));
100        self.changes.push(Change::ReplaceAll(range, new))
101    }
102
103    pub fn finish(self) -> SyntaxEdit {
104        edit_algo::apply_edits(self)
105    }
106
107    pub fn add_mappings(&mut self, other: SyntaxMapping) {
108        self.mappings.merge(other);
109    }
110}
111
112/// Represents a completed [`SyntaxEditor`] operation.
113pub struct SyntaxEdit {
114    old_root: SyntaxNode,
115    new_root: SyntaxNode,
116    changed_elements: Vec<SyntaxElement>,
117    annotations: FxHashMap<SyntaxAnnotation, Vec<SyntaxElement>>,
118}
119
120impl SyntaxEdit {
121    /// Root of the initial unmodified syntax tree.
122    pub fn old_root(&self) -> &SyntaxNode {
123        &self.old_root
124    }
125
126    /// Root of the modified syntax tree.
127    pub fn new_root(&self) -> &SyntaxNode {
128        &self.new_root
129    }
130
131    /// Which syntax elements in the modified syntax tree were inserted or
132    /// modified as part of the edit.
133    ///
134    /// Note that for syntax nodes, only the upper-most parent of a set of
135    /// changes is included, not any child elements that may have been modified.
136    pub fn changed_elements(&self) -> &[SyntaxElement] {
137        self.changed_elements.as_slice()
138    }
139
140    /// Finds which syntax elements have been annotated with the given
141    /// annotation.
142    ///
143    /// Note that an annotation might not appear in the modified syntax tree if
144    /// the syntax elements that were annotated did not make it into the final
145    /// syntax tree.
146    pub fn find_annotation(&self, annotation: SyntaxAnnotation) -> &[SyntaxElement] {
147        self.annotations.get(&annotation).as_ref().map_or(&[], |it| it.as_slice())
148    }
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
152#[repr(transparent)]
153pub struct SyntaxAnnotation(NonZeroU32);
154
155impl Default for SyntaxAnnotation {
156    fn default() -> Self {
157        static COUNTER: AtomicU32 = AtomicU32::new(1);
158
159        // Only consistency within a thread matters, as SyntaxElements are !Send
160        let id = COUNTER.fetch_add(1, Ordering::Relaxed);
161
162        Self(NonZeroU32::new(id).expect("syntax annotation id overflow"))
163    }
164}
165
166/// Position describing where to insert elements
167#[derive(Debug)]
168pub struct Position {
169    repr: PositionRepr,
170}
171
172impl Position {
173    pub(crate) fn parent(&self) -> SyntaxNode {
174        self.place().0
175    }
176
177    pub(crate) fn place(&self) -> (SyntaxNode, usize) {
178        match &self.repr {
179            PositionRepr::FirstChild(parent) => (parent.clone(), 0),
180            PositionRepr::After(child) => (child.parent().unwrap(), child.index() + 1),
181        }
182    }
183}
184
185#[derive(Debug)]
186enum PositionRepr {
187    FirstChild(SyntaxNode),
188    After(SyntaxElement),
189}
190
191impl Position {
192    pub fn after(elem: impl Element) -> Position {
193        let repr = PositionRepr::After(elem.syntax_element());
194        Position { repr }
195    }
196
197    pub fn before(elem: impl Element) -> Position {
198        let elem = elem.syntax_element();
199        let repr = match elem.prev_sibling_or_token() {
200            Some(it) => PositionRepr::After(it),
201            None => PositionRepr::FirstChild(elem.parent().unwrap()),
202        };
203        Position { repr }
204    }
205
206    pub fn first_child_of(node: &(impl Into<SyntaxNode> + Clone)) -> Position {
207        let repr = PositionRepr::FirstChild(node.clone().into());
208        Position { repr }
209    }
210
211    pub fn last_child_of(node: &(impl Into<SyntaxNode> + Clone)) -> Position {
212        let node = node.clone().into();
213        let repr = match node.last_child_or_token() {
214            Some(it) => PositionRepr::After(it),
215            None => PositionRepr::FirstChild(node),
216        };
217        Position { repr }
218    }
219}
220
221#[derive(Debug)]
222enum Change {
223    /// Inserts a single element at the specified position.
224    Insert(Position, SyntaxElement),
225    /// Inserts many elements in-order at the specified position.
226    InsertAll(Position, Vec<SyntaxElement>),
227    /// Represents both a replace single element and a delete element operation.
228    Replace(SyntaxElement, Option<SyntaxElement>),
229    /// Replaces a single element with many elements.
230    ReplaceWithMany(SyntaxElement, Vec<SyntaxElement>),
231    /// Replaces a range of elements with another list of elements.
232    /// Range will always have start != end.
233    ReplaceAll(RangeInclusive<SyntaxElement>, Vec<SyntaxElement>),
234}
235
236impl Change {
237    fn target_range(&self) -> TextRange {
238        match self {
239            Change::Insert(target, _) | Change::InsertAll(target, _) => match &target.repr {
240                PositionRepr::FirstChild(parent) => TextRange::at(
241                    parent.first_child_or_token().unwrap().text_range().start(),
242                    0.into(),
243                ),
244                PositionRepr::After(child) => TextRange::at(child.text_range().end(), 0.into()),
245            },
246            Change::Replace(target, _) | Change::ReplaceWithMany(target, _) => target.text_range(),
247            Change::ReplaceAll(range, _) => {
248                range.start().text_range().cover(range.end().text_range())
249            }
250        }
251    }
252
253    fn target_parent(&self) -> SyntaxNode {
254        match self {
255            Change::Insert(target, _) | Change::InsertAll(target, _) => target.parent(),
256            Change::Replace(target, _) | Change::ReplaceWithMany(target, _) => match target {
257                SyntaxElement::Node(target) => target.parent().unwrap_or_else(|| target.clone()),
258                SyntaxElement::Token(target) => target.parent().unwrap(),
259            },
260            Change::ReplaceAll(target, _) => target.start().parent().unwrap(),
261        }
262    }
263
264    fn change_kind(&self) -> ChangeKind {
265        match self {
266            Change::Insert(_, _) | Change::InsertAll(_, _) => ChangeKind::Insert,
267            Change::Replace(_, _) | Change::ReplaceWithMany(_, _) => ChangeKind::Replace,
268            Change::ReplaceAll(_, _) => ChangeKind::ReplaceRange,
269        }
270    }
271}
272
273#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
274enum ChangeKind {
275    Insert,
276    ReplaceRange,
277    Replace,
278}
279
280impl fmt::Display for Change {
281    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282        match self {
283            Change::Insert(position, node_or_token) => {
284                let parent = position.parent();
285                let mut parent_str = parent.to_string();
286                let target_range = self.target_range().start() - parent.text_range().start();
287
288                parent_str.insert_str(
289                    target_range.into(),
290                    &format!("\x1b[42m{node_or_token}\x1b[0m\x1b[K"),
291                );
292                f.write_str(&parent_str)
293            }
294            Change::InsertAll(position, vec) => {
295                let parent = position.parent();
296                let mut parent_str = parent.to_string();
297                let target_range = self.target_range().start() - parent.text_range().start();
298                let insertion: String = vec.iter().map(|it| it.to_string()).collect();
299
300                parent_str
301                    .insert_str(target_range.into(), &format!("\x1b[42m{insertion}\x1b[0m\x1b[K"));
302                f.write_str(&parent_str)
303            }
304            Change::Replace(old, new) => {
305                if let Some(new) = new {
306                    write!(f, "\x1b[41m{old}\x1b[42m{new}\x1b[0m\x1b[K")
307                } else {
308                    write!(f, "\x1b[41m{old}\x1b[0m\x1b[K")
309                }
310            }
311            Change::ReplaceWithMany(old, vec) => {
312                let new: String = vec.iter().map(|it| it.to_string()).collect();
313                write!(f, "\x1b[41m{old}\x1b[42m{new}\x1b[0m\x1b[K")
314            }
315            Change::ReplaceAll(range, vec) => {
316                let parent = range.start().parent().unwrap();
317                let parent_str = parent.to_string();
318                let pre_range =
319                    TextRange::new(parent.text_range().start(), range.start().text_range().start());
320                let old_range = TextRange::new(
321                    range.start().text_range().start(),
322                    range.end().text_range().end(),
323                );
324                let post_range =
325                    TextRange::new(range.end().text_range().end(), parent.text_range().end());
326
327                let pre_str = &parent_str[pre_range - parent.text_range().start()];
328                let old_str = &parent_str[old_range - parent.text_range().start()];
329                let post_str = &parent_str[post_range - parent.text_range().start()];
330                let new: String = vec.iter().map(|it| it.to_string()).collect();
331
332                write!(f, "{pre_str}\x1b[41m{old_str}\x1b[42m{new}\x1b[0m\x1b[K{post_str}")
333            }
334        }
335    }
336}
337
338/// Utility trait to allow calling syntax editor functions with references or owned
339/// nodes. Do not use outside of this module.
340pub trait Element {
341    fn syntax_element(self) -> SyntaxElement;
342}
343
344impl<E: Element + Clone> Element for &'_ E {
345    fn syntax_element(self) -> SyntaxElement {
346        self.clone().syntax_element()
347    }
348}
349
350impl Element for SyntaxElement {
351    fn syntax_element(self) -> SyntaxElement {
352        self
353    }
354}
355
356impl Element for SyntaxNode {
357    fn syntax_element(self) -> SyntaxElement {
358        self.into()
359    }
360}
361
362impl Element for SyntaxToken {
363    fn syntax_element(self) -> SyntaxElement {
364        self.into()
365    }
366}
367
368fn is_ancestor_or_self(node: &SyntaxNode, ancestor: &SyntaxNode) -> bool {
369    node == ancestor || node.ancestors().any(|it| &it == ancestor)
370}
371
372fn is_ancestor_or_self_of_element(node: &SyntaxElement, ancestor: &SyntaxNode) -> bool {
373    matches!(node, SyntaxElement::Node(node) if node == ancestor)
374        || node.ancestors().any(|it| &it == ancestor)
375}
376
377#[cfg(test)]
378mod tests {
379    use expect_test::expect;
380
381    use crate::{
382        AstNode,
383        ast::{self, make, syntax_factory::SyntaxFactory},
384    };
385
386    use super::*;
387
388    #[test]
389    fn basic_usage() {
390        let root = make::match_arm(
391            make::wildcard_pat().into(),
392            None,
393            make::expr_tuple([
394                make::expr_bin_op(
395                    make::expr_literal("2").into(),
396                    ast::BinaryOp::ArithOp(ast::ArithOp::Add),
397                    make::expr_literal("2").into(),
398                ),
399                make::expr_literal("true").into(),
400            ])
401            .into(),
402        );
403
404        let to_wrap = root.syntax().descendants().find_map(ast::TupleExpr::cast).unwrap();
405        let to_replace = root.syntax().descendants().find_map(ast::BinExpr::cast).unwrap();
406
407        let mut editor = SyntaxEditor::new(root.syntax().clone());
408        let make = SyntaxFactory::with_mappings();
409
410        let name = make::name("var_name");
411        let name_ref = make::name_ref("var_name").clone_for_update();
412
413        let placeholder_snippet = SyntaxAnnotation::default();
414        editor.add_annotation(name.syntax(), placeholder_snippet);
415        editor.add_annotation(name_ref.syntax(), placeholder_snippet);
416
417        let new_block = make.block_expr(
418            [make
419                .let_stmt(
420                    make.ident_pat(false, false, name.clone()).into(),
421                    None,
422                    Some(to_replace.clone().into()),
423                )
424                .into()],
425            Some(to_wrap.clone().into()),
426        );
427
428        editor.replace(to_replace.syntax(), name_ref.syntax());
429        editor.replace(to_wrap.syntax(), new_block.syntax());
430        editor.add_mappings(make.finish_with_mappings());
431
432        let edit = editor.finish();
433
434        let expect = expect![[r#"
435            _ => {
436                let var_name = 2 + 2;
437                (var_name, true)
438            },"#]];
439        expect.assert_eq(&edit.new_root.to_string());
440
441        assert_eq!(edit.find_annotation(placeholder_snippet).len(), 2);
442        assert!(
443            edit.annotations
444                .iter()
445                .flat_map(|(_, elements)| elements)
446                .all(|element| element.ancestors().any(|it| &it == edit.new_root()))
447        )
448    }
449
450    #[test]
451    fn test_insert_independent() {
452        let root = make::block_expr(
453            [make::let_stmt(
454                make::ext::simple_ident_pat(make::name("second")).into(),
455                None,
456                Some(make::expr_literal("2").into()),
457            )
458            .into()],
459            None,
460        );
461
462        let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap();
463
464        let mut editor = SyntaxEditor::new(root.syntax().clone());
465        let make = SyntaxFactory::without_mappings();
466
467        editor.insert(
468            Position::first_child_of(root.stmt_list().unwrap().syntax()),
469            make.let_stmt(
470                make::ext::simple_ident_pat(make::name("first")).into(),
471                None,
472                Some(make::expr_literal("1").into()),
473            )
474            .syntax(),
475        );
476
477        editor.insert(
478            Position::after(second_let.syntax()),
479            make.let_stmt(
480                make::ext::simple_ident_pat(make::name("third")).into(),
481                None,
482                Some(make::expr_literal("3").into()),
483            )
484            .syntax(),
485        );
486
487        let edit = editor.finish();
488
489        let expect = expect![[r#"
490            let first = 1;{
491                let second = 2;let third = 3;
492            }"#]];
493        expect.assert_eq(&edit.new_root.to_string());
494    }
495
496    #[test]
497    fn test_insert_dependent() {
498        let root = make::block_expr(
499            [],
500            Some(
501                make::block_expr(
502                    [make::let_stmt(
503                        make::ext::simple_ident_pat(make::name("second")).into(),
504                        None,
505                        Some(make::expr_literal("2").into()),
506                    )
507                    .into()],
508                    None,
509                )
510                .into(),
511            ),
512        );
513
514        let inner_block =
515            root.syntax().descendants().flat_map(ast::BlockExpr::cast).nth(1).unwrap();
516        let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap();
517
518        let mut editor = SyntaxEditor::new(root.syntax().clone());
519        let make = SyntaxFactory::with_mappings();
520
521        let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone())));
522
523        let first_let = make.let_stmt(
524            make::ext::simple_ident_pat(make::name("first")).into(),
525            None,
526            Some(make::expr_literal("1").into()),
527        );
528
529        let third_let = make.let_stmt(
530            make::ext::simple_ident_pat(make::name("third")).into(),
531            None,
532            Some(make::expr_literal("3").into()),
533        );
534
535        editor.insert(
536            Position::first_child_of(inner_block.stmt_list().unwrap().syntax()),
537            first_let.syntax(),
538        );
539        editor.insert(Position::after(second_let.syntax()), third_let.syntax());
540        editor.replace(inner_block.syntax(), new_block_expr.syntax());
541        editor.add_mappings(make.finish_with_mappings());
542
543        let edit = editor.finish();
544
545        let expect = expect![[r#"
546            {
547                {
548                let first = 1;{
549                let second = 2;let third = 3;
550            }
551            }
552            }"#]];
553        expect.assert_eq(&edit.new_root.to_string());
554    }
555
556    #[test]
557    fn test_replace_root_with_dependent() {
558        let root = make::block_expr(
559            [make::let_stmt(
560                make::ext::simple_ident_pat(make::name("second")).into(),
561                None,
562                Some(make::expr_literal("2").into()),
563            )
564            .into()],
565            None,
566        );
567
568        let inner_block = root.clone();
569
570        let mut editor = SyntaxEditor::new(root.syntax().clone());
571        let make = SyntaxFactory::with_mappings();
572
573        let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone())));
574
575        let first_let = make.let_stmt(
576            make::ext::simple_ident_pat(make::name("first")).into(),
577            None,
578            Some(make::expr_literal("1").into()),
579        );
580
581        editor.insert(
582            Position::first_child_of(inner_block.stmt_list().unwrap().syntax()),
583            first_let.syntax(),
584        );
585        editor.replace(inner_block.syntax(), new_block_expr.syntax());
586        editor.add_mappings(make.finish_with_mappings());
587
588        let edit = editor.finish();
589
590        let expect = expect![[r#"
591            {
592                let first = 1;{
593                let second = 2;
594            }
595            }"#]];
596        expect.assert_eq(&edit.new_root.to_string());
597    }
598
599    #[test]
600    fn test_replace_token_in_parent() {
601        let parent_fn = make::fn_(
602            None,
603            make::name("it"),
604            None,
605            None,
606            make::param_list(None, []),
607            make::block_expr([], Some(make::ext::expr_unit())),
608            Some(make::ret_type(make::ty_unit())),
609            false,
610            false,
611            false,
612            false,
613        );
614
615        let mut editor = SyntaxEditor::new(parent_fn.syntax().clone());
616
617        if let Some(ret_ty) = parent_fn.ret_type() {
618            editor.delete(ret_ty.syntax().clone());
619
620            if let Some(SyntaxElement::Token(token)) = ret_ty.syntax().next_sibling_or_token() {
621                if token.kind().is_trivia() {
622                    editor.delete(token);
623                }
624            }
625        }
626
627        if let Some(tail) = parent_fn.body().unwrap().tail_expr() {
628            editor.delete(tail.syntax().clone());
629        }
630
631        let edit = editor.finish();
632
633        let expect = expect![["fn it() {\n    \n}"]];
634        expect.assert_eq(&edit.new_root.to_string());
635    }
636}