1use std::{
8 fmt,
9 num::NonZeroU32,
10 ops::RangeInclusive,
11 sync::atomic::{AtomicU32, Ordering},
12};
13
14use rowan::TextRange;
15use rustc_hash::FxHashMap;
16
17use crate::{SyntaxElement, SyntaxNode, SyntaxToken};
18
19mod edit_algo;
20mod edits;
21mod mapping;
22
23pub use edits::Removable;
24pub use mapping::{SyntaxMapping, SyntaxMappingBuilder};
25
26#[derive(Debug)]
27pub struct SyntaxEditor {
28 root: SyntaxNode,
29 changes: Vec<Change>,
30 mappings: SyntaxMapping,
31 annotations: Vec<(SyntaxElement, SyntaxAnnotation)>,
32}
33
34impl SyntaxEditor {
35 pub fn new(root: SyntaxNode) -> Self {
37 Self { root, changes: vec![], mappings: SyntaxMapping::default(), annotations: vec![] }
38 }
39
40 pub fn add_annotation(&mut self, element: impl Element, annotation: SyntaxAnnotation) {
41 self.annotations.push((element.syntax_element(), annotation))
42 }
43
44 pub fn merge(&mut self, mut other: SyntaxEditor) {
45 debug_assert!(
46 self.root == other.root || other.root.ancestors().any(|node| node == self.root),
47 "{:?} is not in the same tree as {:?}",
48 other.root,
49 self.root
50 );
51
52 self.changes.append(&mut other.changes);
53 self.mappings.merge(other.mappings);
54 self.annotations.append(&mut other.annotations);
55 }
56
57 pub fn insert(&mut self, position: Position, element: impl Element) {
58 debug_assert!(is_ancestor_or_self(&position.parent(), &self.root));
59 self.changes.push(Change::Insert(position, element.syntax_element()))
60 }
61
62 pub fn insert_all(&mut self, position: Position, elements: Vec<SyntaxElement>) {
63 debug_assert!(is_ancestor_or_self(&position.parent(), &self.root));
64 self.changes.push(Change::InsertAll(position, elements))
65 }
66
67 pub fn delete(&mut self, element: impl Element) {
68 let element = element.syntax_element();
69 debug_assert!(is_ancestor_or_self_of_element(&element, &self.root));
70 debug_assert!(
71 !matches!(&element, SyntaxElement::Node(node) if node == &self.root),
72 "should not delete root node"
73 );
74 self.changes.push(Change::Replace(element.syntax_element(), None));
75 }
76
77 pub fn replace(&mut self, old: impl Element, new: impl Element) {
78 let old = old.syntax_element();
79 debug_assert!(is_ancestor_or_self_of_element(&old, &self.root));
80 self.changes.push(Change::Replace(old.syntax_element(), Some(new.syntax_element())));
81 }
82
83 pub fn replace_with_many(&mut self, old: impl Element, new: Vec<SyntaxElement>) {
84 let old = old.syntax_element();
85 debug_assert!(is_ancestor_or_self_of_element(&old, &self.root));
86 debug_assert!(
87 !(matches!(&old, SyntaxElement::Node(node) if node == &self.root) && new.len() > 1),
88 "cannot replace root node with many elements"
89 );
90 self.changes.push(Change::ReplaceWithMany(old.syntax_element(), new));
91 }
92
93 pub fn replace_all(&mut self, range: RangeInclusive<SyntaxElement>, new: Vec<SyntaxElement>) {
94 if range.start() == range.end() {
95 self.replace_with_many(range.start(), new);
96 return;
97 }
98
99 debug_assert!(is_ancestor_or_self_of_element(range.start(), &self.root));
100 self.changes.push(Change::ReplaceAll(range, new))
101 }
102
103 pub fn finish(self) -> SyntaxEdit {
104 edit_algo::apply_edits(self)
105 }
106
107 pub fn add_mappings(&mut self, other: SyntaxMapping) {
108 self.mappings.merge(other);
109 }
110}
111
112pub struct SyntaxEdit {
114 old_root: SyntaxNode,
115 new_root: SyntaxNode,
116 changed_elements: Vec<SyntaxElement>,
117 annotations: FxHashMap<SyntaxAnnotation, Vec<SyntaxElement>>,
118}
119
120impl SyntaxEdit {
121 pub fn old_root(&self) -> &SyntaxNode {
123 &self.old_root
124 }
125
126 pub fn new_root(&self) -> &SyntaxNode {
128 &self.new_root
129 }
130
131 pub fn changed_elements(&self) -> &[SyntaxElement] {
137 self.changed_elements.as_slice()
138 }
139
140 pub fn find_annotation(&self, annotation: SyntaxAnnotation) -> &[SyntaxElement] {
147 self.annotations.get(&annotation).as_ref().map_or(&[], |it| it.as_slice())
148 }
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
152#[repr(transparent)]
153pub struct SyntaxAnnotation(NonZeroU32);
154
155impl Default for SyntaxAnnotation {
156 fn default() -> Self {
157 static COUNTER: AtomicU32 = AtomicU32::new(1);
158
159 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
161
162 Self(NonZeroU32::new(id).expect("syntax annotation id overflow"))
163 }
164}
165
166#[derive(Debug)]
168pub struct Position {
169 repr: PositionRepr,
170}
171
172impl Position {
173 pub(crate) fn parent(&self) -> SyntaxNode {
174 self.place().0
175 }
176
177 pub(crate) fn place(&self) -> (SyntaxNode, usize) {
178 match &self.repr {
179 PositionRepr::FirstChild(parent) => (parent.clone(), 0),
180 PositionRepr::After(child) => (child.parent().unwrap(), child.index() + 1),
181 }
182 }
183}
184
185#[derive(Debug)]
186enum PositionRepr {
187 FirstChild(SyntaxNode),
188 After(SyntaxElement),
189}
190
191impl Position {
192 pub fn after(elem: impl Element) -> Position {
193 let repr = PositionRepr::After(elem.syntax_element());
194 Position { repr }
195 }
196
197 pub fn before(elem: impl Element) -> Position {
198 let elem = elem.syntax_element();
199 let repr = match elem.prev_sibling_or_token() {
200 Some(it) => PositionRepr::After(it),
201 None => PositionRepr::FirstChild(elem.parent().unwrap()),
202 };
203 Position { repr }
204 }
205
206 pub fn first_child_of(node: &(impl Into<SyntaxNode> + Clone)) -> Position {
207 let repr = PositionRepr::FirstChild(node.clone().into());
208 Position { repr }
209 }
210
211 pub fn last_child_of(node: &(impl Into<SyntaxNode> + Clone)) -> Position {
212 let node = node.clone().into();
213 let repr = match node.last_child_or_token() {
214 Some(it) => PositionRepr::After(it),
215 None => PositionRepr::FirstChild(node),
216 };
217 Position { repr }
218 }
219}
220
221#[derive(Debug)]
222enum Change {
223 Insert(Position, SyntaxElement),
225 InsertAll(Position, Vec<SyntaxElement>),
227 Replace(SyntaxElement, Option<SyntaxElement>),
229 ReplaceWithMany(SyntaxElement, Vec<SyntaxElement>),
231 ReplaceAll(RangeInclusive<SyntaxElement>, Vec<SyntaxElement>),
234}
235
236impl Change {
237 fn target_range(&self) -> TextRange {
238 match self {
239 Change::Insert(target, _) | Change::InsertAll(target, _) => match &target.repr {
240 PositionRepr::FirstChild(parent) => TextRange::at(
241 parent.first_child_or_token().unwrap().text_range().start(),
242 0.into(),
243 ),
244 PositionRepr::After(child) => TextRange::at(child.text_range().end(), 0.into()),
245 },
246 Change::Replace(target, _) | Change::ReplaceWithMany(target, _) => target.text_range(),
247 Change::ReplaceAll(range, _) => {
248 range.start().text_range().cover(range.end().text_range())
249 }
250 }
251 }
252
253 fn target_parent(&self) -> SyntaxNode {
254 match self {
255 Change::Insert(target, _) | Change::InsertAll(target, _) => target.parent(),
256 Change::Replace(target, _) | Change::ReplaceWithMany(target, _) => match target {
257 SyntaxElement::Node(target) => target.parent().unwrap_or_else(|| target.clone()),
258 SyntaxElement::Token(target) => target.parent().unwrap(),
259 },
260 Change::ReplaceAll(target, _) => target.start().parent().unwrap(),
261 }
262 }
263
264 fn change_kind(&self) -> ChangeKind {
265 match self {
266 Change::Insert(_, _) | Change::InsertAll(_, _) => ChangeKind::Insert,
267 Change::Replace(_, _) | Change::ReplaceWithMany(_, _) => ChangeKind::Replace,
268 Change::ReplaceAll(_, _) => ChangeKind::ReplaceRange,
269 }
270 }
271}
272
273#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
274enum ChangeKind {
275 Insert,
276 ReplaceRange,
277 Replace,
278}
279
280impl fmt::Display for Change {
281 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282 match self {
283 Change::Insert(position, node_or_token) => {
284 let parent = position.parent();
285 let mut parent_str = parent.to_string();
286 let target_range = self.target_range().start() - parent.text_range().start();
287
288 parent_str.insert_str(
289 target_range.into(),
290 &format!("\x1b[42m{node_or_token}\x1b[0m\x1b[K"),
291 );
292 f.write_str(&parent_str)
293 }
294 Change::InsertAll(position, vec) => {
295 let parent = position.parent();
296 let mut parent_str = parent.to_string();
297 let target_range = self.target_range().start() - parent.text_range().start();
298 let insertion: String = vec.iter().map(|it| it.to_string()).collect();
299
300 parent_str
301 .insert_str(target_range.into(), &format!("\x1b[42m{insertion}\x1b[0m\x1b[K"));
302 f.write_str(&parent_str)
303 }
304 Change::Replace(old, new) => {
305 if let Some(new) = new {
306 write!(f, "\x1b[41m{old}\x1b[42m{new}\x1b[0m\x1b[K")
307 } else {
308 write!(f, "\x1b[41m{old}\x1b[0m\x1b[K")
309 }
310 }
311 Change::ReplaceWithMany(old, vec) => {
312 let new: String = vec.iter().map(|it| it.to_string()).collect();
313 write!(f, "\x1b[41m{old}\x1b[42m{new}\x1b[0m\x1b[K")
314 }
315 Change::ReplaceAll(range, vec) => {
316 let parent = range.start().parent().unwrap();
317 let parent_str = parent.to_string();
318 let pre_range =
319 TextRange::new(parent.text_range().start(), range.start().text_range().start());
320 let old_range = TextRange::new(
321 range.start().text_range().start(),
322 range.end().text_range().end(),
323 );
324 let post_range =
325 TextRange::new(range.end().text_range().end(), parent.text_range().end());
326
327 let pre_str = &parent_str[pre_range - parent.text_range().start()];
328 let old_str = &parent_str[old_range - parent.text_range().start()];
329 let post_str = &parent_str[post_range - parent.text_range().start()];
330 let new: String = vec.iter().map(|it| it.to_string()).collect();
331
332 write!(f, "{pre_str}\x1b[41m{old_str}\x1b[42m{new}\x1b[0m\x1b[K{post_str}")
333 }
334 }
335 }
336}
337
338pub trait Element {
341 fn syntax_element(self) -> SyntaxElement;
342}
343
344impl<E: Element + Clone> Element for &'_ E {
345 fn syntax_element(self) -> SyntaxElement {
346 self.clone().syntax_element()
347 }
348}
349
350impl Element for SyntaxElement {
351 fn syntax_element(self) -> SyntaxElement {
352 self
353 }
354}
355
356impl Element for SyntaxNode {
357 fn syntax_element(self) -> SyntaxElement {
358 self.into()
359 }
360}
361
362impl Element for SyntaxToken {
363 fn syntax_element(self) -> SyntaxElement {
364 self.into()
365 }
366}
367
368fn is_ancestor_or_self(node: &SyntaxNode, ancestor: &SyntaxNode) -> bool {
369 node == ancestor || node.ancestors().any(|it| &it == ancestor)
370}
371
372fn is_ancestor_or_self_of_element(node: &SyntaxElement, ancestor: &SyntaxNode) -> bool {
373 matches!(node, SyntaxElement::Node(node) if node == ancestor)
374 || node.ancestors().any(|it| &it == ancestor)
375}
376
377#[cfg(test)]
378mod tests {
379 use expect_test::expect;
380
381 use crate::{
382 AstNode,
383 ast::{self, make, syntax_factory::SyntaxFactory},
384 };
385
386 use super::*;
387
388 #[test]
389 fn basic_usage() {
390 let root = make::match_arm(
391 make::wildcard_pat().into(),
392 None,
393 make::expr_tuple([
394 make::expr_bin_op(
395 make::expr_literal("2").into(),
396 ast::BinaryOp::ArithOp(ast::ArithOp::Add),
397 make::expr_literal("2").into(),
398 ),
399 make::expr_literal("true").into(),
400 ])
401 .into(),
402 );
403
404 let to_wrap = root.syntax().descendants().find_map(ast::TupleExpr::cast).unwrap();
405 let to_replace = root.syntax().descendants().find_map(ast::BinExpr::cast).unwrap();
406
407 let mut editor = SyntaxEditor::new(root.syntax().clone());
408 let make = SyntaxFactory::with_mappings();
409
410 let name = make::name("var_name");
411 let name_ref = make::name_ref("var_name").clone_for_update();
412
413 let placeholder_snippet = SyntaxAnnotation::default();
414 editor.add_annotation(name.syntax(), placeholder_snippet);
415 editor.add_annotation(name_ref.syntax(), placeholder_snippet);
416
417 let new_block = make.block_expr(
418 [make
419 .let_stmt(
420 make.ident_pat(false, false, name.clone()).into(),
421 None,
422 Some(to_replace.clone().into()),
423 )
424 .into()],
425 Some(to_wrap.clone().into()),
426 );
427
428 editor.replace(to_replace.syntax(), name_ref.syntax());
429 editor.replace(to_wrap.syntax(), new_block.syntax());
430 editor.add_mappings(make.finish_with_mappings());
431
432 let edit = editor.finish();
433
434 let expect = expect![[r#"
435 _ => {
436 let var_name = 2 + 2;
437 (var_name, true)
438 },"#]];
439 expect.assert_eq(&edit.new_root.to_string());
440
441 assert_eq!(edit.find_annotation(placeholder_snippet).len(), 2);
442 assert!(
443 edit.annotations
444 .iter()
445 .flat_map(|(_, elements)| elements)
446 .all(|element| element.ancestors().any(|it| &it == edit.new_root()))
447 )
448 }
449
450 #[test]
451 fn test_insert_independent() {
452 let root = make::block_expr(
453 [make::let_stmt(
454 make::ext::simple_ident_pat(make::name("second")).into(),
455 None,
456 Some(make::expr_literal("2").into()),
457 )
458 .into()],
459 None,
460 );
461
462 let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap();
463
464 let mut editor = SyntaxEditor::new(root.syntax().clone());
465 let make = SyntaxFactory::without_mappings();
466
467 editor.insert(
468 Position::first_child_of(root.stmt_list().unwrap().syntax()),
469 make.let_stmt(
470 make::ext::simple_ident_pat(make::name("first")).into(),
471 None,
472 Some(make::expr_literal("1").into()),
473 )
474 .syntax(),
475 );
476
477 editor.insert(
478 Position::after(second_let.syntax()),
479 make.let_stmt(
480 make::ext::simple_ident_pat(make::name("third")).into(),
481 None,
482 Some(make::expr_literal("3").into()),
483 )
484 .syntax(),
485 );
486
487 let edit = editor.finish();
488
489 let expect = expect![[r#"
490 let first = 1;{
491 let second = 2;let third = 3;
492 }"#]];
493 expect.assert_eq(&edit.new_root.to_string());
494 }
495
496 #[test]
497 fn test_insert_dependent() {
498 let root = make::block_expr(
499 [],
500 Some(
501 make::block_expr(
502 [make::let_stmt(
503 make::ext::simple_ident_pat(make::name("second")).into(),
504 None,
505 Some(make::expr_literal("2").into()),
506 )
507 .into()],
508 None,
509 )
510 .into(),
511 ),
512 );
513
514 let inner_block =
515 root.syntax().descendants().flat_map(ast::BlockExpr::cast).nth(1).unwrap();
516 let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap();
517
518 let mut editor = SyntaxEditor::new(root.syntax().clone());
519 let make = SyntaxFactory::with_mappings();
520
521 let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone())));
522
523 let first_let = make.let_stmt(
524 make::ext::simple_ident_pat(make::name("first")).into(),
525 None,
526 Some(make::expr_literal("1").into()),
527 );
528
529 let third_let = make.let_stmt(
530 make::ext::simple_ident_pat(make::name("third")).into(),
531 None,
532 Some(make::expr_literal("3").into()),
533 );
534
535 editor.insert(
536 Position::first_child_of(inner_block.stmt_list().unwrap().syntax()),
537 first_let.syntax(),
538 );
539 editor.insert(Position::after(second_let.syntax()), third_let.syntax());
540 editor.replace(inner_block.syntax(), new_block_expr.syntax());
541 editor.add_mappings(make.finish_with_mappings());
542
543 let edit = editor.finish();
544
545 let expect = expect![[r#"
546 {
547 {
548 let first = 1;{
549 let second = 2;let third = 3;
550 }
551 }
552 }"#]];
553 expect.assert_eq(&edit.new_root.to_string());
554 }
555
556 #[test]
557 fn test_replace_root_with_dependent() {
558 let root = make::block_expr(
559 [make::let_stmt(
560 make::ext::simple_ident_pat(make::name("second")).into(),
561 None,
562 Some(make::expr_literal("2").into()),
563 )
564 .into()],
565 None,
566 );
567
568 let inner_block = root.clone();
569
570 let mut editor = SyntaxEditor::new(root.syntax().clone());
571 let make = SyntaxFactory::with_mappings();
572
573 let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone())));
574
575 let first_let = make.let_stmt(
576 make::ext::simple_ident_pat(make::name("first")).into(),
577 None,
578 Some(make::expr_literal("1").into()),
579 );
580
581 editor.insert(
582 Position::first_child_of(inner_block.stmt_list().unwrap().syntax()),
583 first_let.syntax(),
584 );
585 editor.replace(inner_block.syntax(), new_block_expr.syntax());
586 editor.add_mappings(make.finish_with_mappings());
587
588 let edit = editor.finish();
589
590 let expect = expect![[r#"
591 {
592 let first = 1;{
593 let second = 2;
594 }
595 }"#]];
596 expect.assert_eq(&edit.new_root.to_string());
597 }
598
599 #[test]
600 fn test_replace_token_in_parent() {
601 let parent_fn = make::fn_(
602 None,
603 make::name("it"),
604 None,
605 None,
606 make::param_list(None, []),
607 make::block_expr([], Some(make::ext::expr_unit())),
608 Some(make::ret_type(make::ty_unit())),
609 false,
610 false,
611 false,
612 false,
613 );
614
615 let mut editor = SyntaxEditor::new(parent_fn.syntax().clone());
616
617 if let Some(ret_ty) = parent_fn.ret_type() {
618 editor.delete(ret_ty.syntax().clone());
619
620 if let Some(SyntaxElement::Token(token)) = ret_ty.syntax().next_sibling_or_token() {
621 if token.kind().is_trivia() {
622 editor.delete(token);
623 }
624 }
625 }
626
627 if let Some(tail) = parent_fn.body().unwrap().tail_expr() {
628 editor.delete(tail.syntax().clone());
629 }
630
631 let edit = editor.finish();
632
633 let expect = expect![["fn it() {\n \n}"]];
634 expect.assert_eq(&edit.new_root.to_string());
635 }
636}