1use std::{
8 fmt, iter,
9 num::NonZeroU32,
10 ops::RangeInclusive,
11 sync::atomic::{AtomicU32, Ordering},
12};
13
14use rowan::TextRange;
15use rustc_hash::FxHashMap;
16
17use crate::{SyntaxElement, SyntaxNode, SyntaxToken};
18
19mod edit_algo;
20mod edits;
21mod mapping;
22
23pub use edits::Removable;
24pub use mapping::{SyntaxMapping, SyntaxMappingBuilder};
25
26#[derive(Debug)]
27pub struct SyntaxEditor {
28 root: SyntaxNode,
29 changes: Vec<Change>,
30 mappings: SyntaxMapping,
31 annotations: Vec<(SyntaxElement, SyntaxAnnotation)>,
32}
33
34impl SyntaxEditor {
35 pub fn new(root: SyntaxNode) -> Self {
37 Self { root, changes: vec![], mappings: SyntaxMapping::default(), annotations: vec![] }
38 }
39
40 pub fn add_annotation(&mut self, element: impl Element, annotation: SyntaxAnnotation) {
41 self.annotations.push((element.syntax_element(), annotation))
42 }
43
44 pub fn add_annotation_all(
45 &mut self,
46 elements: Vec<impl Element>,
47 annotation: SyntaxAnnotation,
48 ) {
49 self.annotations
50 .extend(elements.into_iter().map(|e| e.syntax_element()).zip(iter::repeat(annotation)));
51 }
52
53 pub fn merge(&mut self, mut other: SyntaxEditor) {
54 debug_assert!(
55 self.root == other.root || other.root.ancestors().any(|node| node == self.root),
56 "{:?} is not in the same tree as {:?}",
57 other.root,
58 self.root
59 );
60
61 self.changes.append(&mut other.changes);
62 self.mappings.merge(other.mappings);
63 self.annotations.append(&mut other.annotations);
64 }
65
66 pub fn insert(&mut self, position: Position, element: impl Element) {
67 debug_assert!(is_ancestor_or_self(&position.parent(), &self.root));
68 self.changes.push(Change::Insert(position, element.syntax_element()))
69 }
70
71 pub fn insert_all(&mut self, position: Position, elements: Vec<SyntaxElement>) {
72 debug_assert!(is_ancestor_or_self(&position.parent(), &self.root));
73 self.changes.push(Change::InsertAll(position, elements))
74 }
75
76 pub fn delete(&mut self, element: impl Element) {
77 let element = element.syntax_element();
78 debug_assert!(is_ancestor_or_self_of_element(&element, &self.root));
79 debug_assert!(
80 !matches!(&element, SyntaxElement::Node(node) if node == &self.root),
81 "should not delete root node"
82 );
83 self.changes.push(Change::Replace(element.syntax_element(), None));
84 }
85
86 pub fn delete_all(&mut self, range: RangeInclusive<SyntaxElement>) {
87 if range.start() == range.end() {
88 self.delete(range.start());
89 return;
90 }
91
92 debug_assert!(is_ancestor_or_self_of_element(range.start(), &self.root));
93 self.changes.push(Change::ReplaceAll(range, Vec::new()))
94 }
95
96 pub fn replace(&mut self, old: impl Element, new: impl Element) {
97 let old = old.syntax_element();
98 debug_assert!(is_ancestor_or_self_of_element(&old, &self.root));
99 self.changes.push(Change::Replace(old.syntax_element(), Some(new.syntax_element())));
100 }
101
102 pub fn replace_with_many(&mut self, old: impl Element, new: Vec<SyntaxElement>) {
103 let old = old.syntax_element();
104 debug_assert!(is_ancestor_or_self_of_element(&old, &self.root));
105 debug_assert!(
106 !(matches!(&old, SyntaxElement::Node(node) if node == &self.root) && new.len() > 1),
107 "cannot replace root node with many elements"
108 );
109 self.changes.push(Change::ReplaceWithMany(old.syntax_element(), new));
110 }
111
112 pub fn replace_all(&mut self, range: RangeInclusive<SyntaxElement>, new: Vec<SyntaxElement>) {
113 if range.start() == range.end() {
114 self.replace_with_many(range.start(), new);
115 return;
116 }
117
118 debug_assert!(is_ancestor_or_self_of_element(range.start(), &self.root));
119 self.changes.push(Change::ReplaceAll(range, new))
120 }
121
122 pub fn finish(self) -> SyntaxEdit {
123 edit_algo::apply_edits(self)
124 }
125
126 pub fn add_mappings(&mut self, other: SyntaxMapping) {
127 self.mappings.merge(other);
128 }
129}
130
131pub struct SyntaxEdit {
133 old_root: SyntaxNode,
134 new_root: SyntaxNode,
135 changed_elements: Vec<SyntaxElement>,
136 annotations: FxHashMap<SyntaxAnnotation, Vec<SyntaxElement>>,
137}
138
139impl SyntaxEdit {
140 pub fn old_root(&self) -> &SyntaxNode {
142 &self.old_root
143 }
144
145 pub fn new_root(&self) -> &SyntaxNode {
147 &self.new_root
148 }
149
150 pub fn changed_elements(&self) -> &[SyntaxElement] {
156 self.changed_elements.as_slice()
157 }
158
159 pub fn find_annotation(&self, annotation: SyntaxAnnotation) -> &[SyntaxElement] {
166 self.annotations.get(&annotation).as_ref().map_or(&[], |it| it.as_slice())
167 }
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
171#[repr(transparent)]
172pub struct SyntaxAnnotation(NonZeroU32);
173
174impl Default for SyntaxAnnotation {
175 fn default() -> Self {
176 static COUNTER: AtomicU32 = AtomicU32::new(1);
177
178 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
180
181 Self(NonZeroU32::new(id).expect("syntax annotation id overflow"))
182 }
183}
184
185#[derive(Debug)]
187pub struct Position {
188 repr: PositionRepr,
189}
190
191impl Position {
192 pub(crate) fn parent(&self) -> SyntaxNode {
193 self.place().0
194 }
195
196 pub(crate) fn place(&self) -> (SyntaxNode, usize) {
197 match &self.repr {
198 PositionRepr::FirstChild(parent) => (parent.clone(), 0),
199 PositionRepr::After(child) => (child.parent().unwrap(), child.index() + 1),
200 }
201 }
202}
203
204#[derive(Debug)]
205enum PositionRepr {
206 FirstChild(SyntaxNode),
207 After(SyntaxElement),
208}
209
210impl Position {
211 pub fn after(elem: impl Element) -> Position {
212 let repr = PositionRepr::After(elem.syntax_element());
213 Position { repr }
214 }
215
216 pub fn before(elem: impl Element) -> Position {
217 let elem = elem.syntax_element();
218 let repr = match elem.prev_sibling_or_token() {
219 Some(it) => PositionRepr::After(it),
220 None => PositionRepr::FirstChild(elem.parent().unwrap()),
221 };
222 Position { repr }
223 }
224
225 pub fn first_child_of(node: &(impl Into<SyntaxNode> + Clone)) -> Position {
226 let repr = PositionRepr::FirstChild(node.clone().into());
227 Position { repr }
228 }
229
230 pub fn last_child_of(node: &(impl Into<SyntaxNode> + Clone)) -> Position {
231 let node = node.clone().into();
232 let repr = match node.last_child_or_token() {
233 Some(it) => PositionRepr::After(it),
234 None => PositionRepr::FirstChild(node),
235 };
236 Position { repr }
237 }
238}
239
240#[derive(Debug)]
241enum Change {
242 Insert(Position, SyntaxElement),
244 InsertAll(Position, Vec<SyntaxElement>),
246 Replace(SyntaxElement, Option<SyntaxElement>),
248 ReplaceWithMany(SyntaxElement, Vec<SyntaxElement>),
250 ReplaceAll(RangeInclusive<SyntaxElement>, Vec<SyntaxElement>),
253}
254
255impl Change {
256 fn target_range(&self) -> TextRange {
257 match self {
258 Change::Insert(target, _) | Change::InsertAll(target, _) => match &target.repr {
259 PositionRepr::FirstChild(parent) => TextRange::at(
260 parent.first_child_or_token().unwrap().text_range().start(),
261 0.into(),
262 ),
263 PositionRepr::After(child) => TextRange::at(child.text_range().end(), 0.into()),
264 },
265 Change::Replace(target, _) | Change::ReplaceWithMany(target, _) => target.text_range(),
266 Change::ReplaceAll(range, _) => {
267 range.start().text_range().cover(range.end().text_range())
268 }
269 }
270 }
271
272 fn target_parent(&self) -> SyntaxNode {
273 match self {
274 Change::Insert(target, _) | Change::InsertAll(target, _) => target.parent(),
275 Change::Replace(target, _) | Change::ReplaceWithMany(target, _) => match target {
276 SyntaxElement::Node(target) => target.parent().unwrap_or_else(|| target.clone()),
277 SyntaxElement::Token(target) => target.parent().unwrap(),
278 },
279 Change::ReplaceAll(target, _) => target.start().parent().unwrap(),
280 }
281 }
282
283 fn change_kind(&self) -> ChangeKind {
284 match self {
285 Change::Insert(_, _) | Change::InsertAll(_, _) => ChangeKind::Insert,
286 Change::Replace(_, _) | Change::ReplaceWithMany(_, _) => ChangeKind::Replace,
287 Change::ReplaceAll(_, _) => ChangeKind::ReplaceRange,
288 }
289 }
290}
291
292#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
293enum ChangeKind {
294 Insert,
295 ReplaceRange,
296 Replace,
297}
298
299impl fmt::Display for Change {
300 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301 match self {
302 Change::Insert(position, node_or_token) => {
303 let parent = position.parent();
304 let mut parent_str = parent.to_string();
305 let target_range = self.target_range().start() - parent.text_range().start();
306
307 parent_str.insert_str(
308 target_range.into(),
309 &format!("\x1b[42m{node_or_token}\x1b[0m\x1b[K"),
310 );
311 f.write_str(&parent_str)
312 }
313 Change::InsertAll(position, vec) => {
314 let parent = position.parent();
315 let mut parent_str = parent.to_string();
316 let target_range = self.target_range().start() - parent.text_range().start();
317 let insertion: String = vec.iter().map(|it| it.to_string()).collect();
318
319 parent_str
320 .insert_str(target_range.into(), &format!("\x1b[42m{insertion}\x1b[0m\x1b[K"));
321 f.write_str(&parent_str)
322 }
323 Change::Replace(old, new) => {
324 if let Some(new) = new {
325 write!(f, "\x1b[41m{old}\x1b[42m{new}\x1b[0m\x1b[K")
326 } else {
327 write!(f, "\x1b[41m{old}\x1b[0m\x1b[K")
328 }
329 }
330 Change::ReplaceWithMany(old, vec) => {
331 let new: String = vec.iter().map(|it| it.to_string()).collect();
332 write!(f, "\x1b[41m{old}\x1b[42m{new}\x1b[0m\x1b[K")
333 }
334 Change::ReplaceAll(range, vec) => {
335 let parent = range.start().parent().unwrap();
336 let parent_str = parent.to_string();
337 let pre_range =
338 TextRange::new(parent.text_range().start(), range.start().text_range().start());
339 let old_range = TextRange::new(
340 range.start().text_range().start(),
341 range.end().text_range().end(),
342 );
343 let post_range =
344 TextRange::new(range.end().text_range().end(), parent.text_range().end());
345
346 let pre_str = &parent_str[pre_range - parent.text_range().start()];
347 let old_str = &parent_str[old_range - parent.text_range().start()];
348 let post_str = &parent_str[post_range - parent.text_range().start()];
349 let new: String = vec.iter().map(|it| it.to_string()).collect();
350
351 write!(f, "{pre_str}\x1b[41m{old_str}\x1b[42m{new}\x1b[0m\x1b[K{post_str}")
352 }
353 }
354 }
355}
356
357pub trait Element {
360 fn syntax_element(self) -> SyntaxElement;
361}
362
363impl<E: Element + Clone> Element for &'_ E {
364 fn syntax_element(self) -> SyntaxElement {
365 self.clone().syntax_element()
366 }
367}
368
369impl Element for SyntaxElement {
370 fn syntax_element(self) -> SyntaxElement {
371 self
372 }
373}
374
375impl Element for SyntaxNode {
376 fn syntax_element(self) -> SyntaxElement {
377 self.into()
378 }
379}
380
381impl Element for SyntaxToken {
382 fn syntax_element(self) -> SyntaxElement {
383 self.into()
384 }
385}
386
387fn is_ancestor_or_self(node: &SyntaxNode, ancestor: &SyntaxNode) -> bool {
388 node == ancestor || node.ancestors().any(|it| &it == ancestor)
389}
390
391fn is_ancestor_or_self_of_element(node: &SyntaxElement, ancestor: &SyntaxNode) -> bool {
392 matches!(node, SyntaxElement::Node(node) if node == ancestor)
393 || node.ancestors().any(|it| &it == ancestor)
394}
395
396#[cfg(test)]
397mod tests {
398 use expect_test::expect;
399
400 use crate::{
401 AstNode,
402 ast::{self, make, syntax_factory::SyntaxFactory},
403 };
404
405 use super::*;
406
407 #[test]
408 fn basic_usage() {
409 let root = make::match_arm(
410 make::wildcard_pat().into(),
411 None,
412 make::expr_tuple([
413 make::expr_bin_op(
414 make::expr_literal("2").into(),
415 ast::BinaryOp::ArithOp(ast::ArithOp::Add),
416 make::expr_literal("2").into(),
417 ),
418 make::expr_literal("true").into(),
419 ])
420 .into(),
421 );
422
423 let to_wrap = root.syntax().descendants().find_map(ast::TupleExpr::cast).unwrap();
424 let to_replace = root.syntax().descendants().find_map(ast::BinExpr::cast).unwrap();
425
426 let mut editor = SyntaxEditor::new(root.syntax().clone());
427 let make = SyntaxFactory::with_mappings();
428
429 let name = make::name("var_name");
430 let name_ref = make::name_ref("var_name").clone_for_update();
431
432 let placeholder_snippet = SyntaxAnnotation::default();
433 editor.add_annotation(name.syntax(), placeholder_snippet);
434 editor.add_annotation(name_ref.syntax(), placeholder_snippet);
435
436 let new_block = make.block_expr(
437 [make
438 .let_stmt(
439 make.ident_pat(false, false, name.clone()).into(),
440 None,
441 Some(to_replace.clone().into()),
442 )
443 .into()],
444 Some(to_wrap.clone().into()),
445 );
446
447 editor.replace(to_replace.syntax(), name_ref.syntax());
448 editor.replace(to_wrap.syntax(), new_block.syntax());
449 editor.add_mappings(make.finish_with_mappings());
450
451 let edit = editor.finish();
452
453 let expect = expect![[r#"
454 _ => {
455 let var_name = 2 + 2;
456 (var_name, true)
457 },"#]];
458 expect.assert_eq(&edit.new_root.to_string());
459
460 assert_eq!(edit.find_annotation(placeholder_snippet).len(), 2);
461 assert!(
462 edit.annotations
463 .iter()
464 .flat_map(|(_, elements)| elements)
465 .all(|element| element.ancestors().any(|it| &it == edit.new_root()))
466 )
467 }
468
469 #[test]
470 fn test_insert_independent() {
471 let root = make::block_expr(
472 [make::let_stmt(
473 make::ext::simple_ident_pat(make::name("second")).into(),
474 None,
475 Some(make::expr_literal("2").into()),
476 )
477 .into()],
478 None,
479 );
480
481 let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap();
482
483 let mut editor = SyntaxEditor::new(root.syntax().clone());
484 let make = SyntaxFactory::without_mappings();
485
486 editor.insert(
487 Position::first_child_of(root.stmt_list().unwrap().syntax()),
488 make.let_stmt(
489 make::ext::simple_ident_pat(make::name("first")).into(),
490 None,
491 Some(make::expr_literal("1").into()),
492 )
493 .syntax(),
494 );
495
496 editor.insert(
497 Position::after(second_let.syntax()),
498 make.let_stmt(
499 make::ext::simple_ident_pat(make::name("third")).into(),
500 None,
501 Some(make::expr_literal("3").into()),
502 )
503 .syntax(),
504 );
505
506 let edit = editor.finish();
507
508 let expect = expect![[r#"
509 let first = 1;{
510 let second = 2;let third = 3;
511 }"#]];
512 expect.assert_eq(&edit.new_root.to_string());
513 }
514
515 #[test]
516 fn test_insert_dependent() {
517 let root = make::block_expr(
518 [],
519 Some(
520 make::block_expr(
521 [make::let_stmt(
522 make::ext::simple_ident_pat(make::name("second")).into(),
523 None,
524 Some(make::expr_literal("2").into()),
525 )
526 .into()],
527 None,
528 )
529 .into(),
530 ),
531 );
532
533 let inner_block =
534 root.syntax().descendants().flat_map(ast::BlockExpr::cast).nth(1).unwrap();
535 let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap();
536
537 let mut editor = SyntaxEditor::new(root.syntax().clone());
538 let make = SyntaxFactory::with_mappings();
539
540 let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone())));
541
542 let first_let = make.let_stmt(
543 make::ext::simple_ident_pat(make::name("first")).into(),
544 None,
545 Some(make::expr_literal("1").into()),
546 );
547
548 let third_let = make.let_stmt(
549 make::ext::simple_ident_pat(make::name("third")).into(),
550 None,
551 Some(make::expr_literal("3").into()),
552 );
553
554 editor.insert(
555 Position::first_child_of(inner_block.stmt_list().unwrap().syntax()),
556 first_let.syntax(),
557 );
558 editor.insert(Position::after(second_let.syntax()), third_let.syntax());
559 editor.replace(inner_block.syntax(), new_block_expr.syntax());
560 editor.add_mappings(make.finish_with_mappings());
561
562 let edit = editor.finish();
563
564 let expect = expect![[r#"
565 {
566 {
567 let first = 1;{
568 let second = 2;let third = 3;
569 }
570 }
571 }"#]];
572 expect.assert_eq(&edit.new_root.to_string());
573 }
574
575 #[test]
576 fn test_replace_root_with_dependent() {
577 let root = make::block_expr(
578 [make::let_stmt(
579 make::ext::simple_ident_pat(make::name("second")).into(),
580 None,
581 Some(make::expr_literal("2").into()),
582 )
583 .into()],
584 None,
585 );
586
587 let inner_block = root.clone();
588
589 let mut editor = SyntaxEditor::new(root.syntax().clone());
590 let make = SyntaxFactory::with_mappings();
591
592 let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone())));
593
594 let first_let = make.let_stmt(
595 make::ext::simple_ident_pat(make::name("first")).into(),
596 None,
597 Some(make::expr_literal("1").into()),
598 );
599
600 editor.insert(
601 Position::first_child_of(inner_block.stmt_list().unwrap().syntax()),
602 first_let.syntax(),
603 );
604 editor.replace(inner_block.syntax(), new_block_expr.syntax());
605 editor.add_mappings(make.finish_with_mappings());
606
607 let edit = editor.finish();
608
609 let expect = expect![[r#"
610 {
611 let first = 1;{
612 let second = 2;
613 }
614 }"#]];
615 expect.assert_eq(&edit.new_root.to_string());
616 }
617
618 #[test]
619 fn test_replace_token_in_parent() {
620 let parent_fn = make::fn_(
621 None,
622 None,
623 make::name("it"),
624 None,
625 None,
626 make::param_list(None, []),
627 make::block_expr([], Some(make::ext::expr_unit())),
628 Some(make::ret_type(make::ty_unit())),
629 false,
630 false,
631 false,
632 false,
633 );
634
635 let mut editor = SyntaxEditor::new(parent_fn.syntax().clone());
636
637 if let Some(ret_ty) = parent_fn.ret_type() {
638 editor.delete(ret_ty.syntax().clone());
639
640 if let Some(SyntaxElement::Token(token)) = ret_ty.syntax().next_sibling_or_token()
641 && token.kind().is_trivia()
642 {
643 editor.delete(token);
644 }
645 }
646
647 if let Some(tail) = parent_fn.body().unwrap().tail_expr() {
648 editor.delete(tail.syntax().clone());
649 }
650
651 let edit = editor.finish();
652
653 let expect = expect![["fn it() {\n \n}"]];
654 expect.assert_eq(&edit.new_root.to_string());
655 }
656}