1use std::collections::HashMap;
23
24use crate::ast::*;
25
26#[derive(Debug, Clone, PartialEq)]
28pub enum ChangeAction {
29 Remove(AstNode),
31 Insert(AstNode),
33 Keep(AstNode, AstNode),
35 Move(AstNode, AstNode),
37 Update(AstNode, AstNode),
39}
40
41#[derive(Debug, Clone, PartialEq)]
44pub enum AstNode {
45 Statement(Box<Statement>),
46 Expr(Expr),
47 SelectItem(SelectItem),
48 JoinClause(JoinClause),
49 OrderByItem(OrderByItem),
50 Cte(Box<Cte>),
51 ColumnDef(ColumnDef),
52 TableConstraint(TableConstraint),
53}
54
55impl std::fmt::Display for AstNode {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 AstNode::Statement(s) => write!(f, "{s:?}"),
59 AstNode::Expr(e) => write!(f, "{e:?}"),
60 AstNode::SelectItem(si) => write!(f, "{si:?}"),
61 AstNode::JoinClause(j) => write!(f, "{j:?}"),
62 AstNode::OrderByItem(o) => write!(f, "{o:?}"),
63 AstNode::Cte(c) => write!(f, "{c:?}"),
64 AstNode::ColumnDef(cd) => write!(f, "{cd:?}"),
65 AstNode::TableConstraint(tc) => write!(f, "{tc:?}"),
66 }
67 }
68}
69
70#[must_use]
75pub fn diff(source: &Statement, target: &Statement) -> Vec<ChangeAction> {
76 let mut differ = AstDiffer::new();
77 differ.diff_statements(source, target);
78 differ.changes
79}
80
81struct AstDiffer {
83 changes: Vec<ChangeAction>,
84}
85
86impl AstDiffer {
87 fn new() -> Self {
88 Self {
89 changes: Vec::new(),
90 }
91 }
92
93 fn diff_statements(&mut self, source: &Statement, target: &Statement) {
94 use Statement::*;
95
96 match (source, target) {
97 (Select(s), Select(t)) => self.diff_select(s, t),
98 (Insert(s), Insert(t)) => self.diff_insert(s, t),
99 (Update(s), Update(t)) => self.diff_update(s, t),
100 (Delete(s), Delete(t)) => self.diff_delete(s, t),
101 (CreateTable(s), CreateTable(t)) => self.diff_create_table(s, t),
102 (DropTable(s), DropTable(t)) => self.diff_drop_table(s, t),
103 (SetOperation(s), SetOperation(t)) => self.diff_set_operation(s, t),
104 (AlterTable(s), AlterTable(t)) => self.diff_alter_table(s, t),
105 (CreateView(s), CreateView(t)) => self.diff_create_view(s, t),
106 (Expression(s), Expression(t)) => self.diff_exprs(s, t),
107 _ => {
108 self.changes
110 .push(ChangeAction::Remove(AstNode::Statement(Box::new(
111 source.clone(),
112 ))));
113 self.changes
114 .push(ChangeAction::Insert(AstNode::Statement(Box::new(
115 target.clone(),
116 ))));
117 }
118 }
119 }
120
121 fn diff_select(&mut self, source: &SelectStatement, target: &SelectStatement) {
124 self.diff_ctes(&source.ctes, &target.ctes);
126
127 if source.distinct != target.distinct {
129 if target.distinct {
130 self.changes
131 .push(ChangeAction::Insert(AstNode::Expr(Expr::Column {
132 table: None,
133 name: "DISTINCT".to_string(),
134 quote_style: QuoteStyle::None,
135 table_quote_style: QuoteStyle::None,
136 })));
137 } else {
138 self.changes
139 .push(ChangeAction::Remove(AstNode::Expr(Expr::Column {
140 table: None,
141 name: "DISTINCT".to_string(),
142 quote_style: QuoteStyle::None,
143 table_quote_style: QuoteStyle::None,
144 })));
145 }
146 }
147
148 self.diff_select_items(&source.columns, &target.columns);
150
151 match (&source.from, &target.from) {
153 (Some(sf), Some(tf)) => self.diff_table_sources(&sf.source, &tf.source),
154 (None, Some(tf)) => self.insert_table_source(&tf.source),
155 (Some(sf), None) => self.remove_table_source(&sf.source),
156 (None, None) => {}
157 }
158
159 self.diff_joins(&source.joins, &target.joins);
161
162 self.diff_optional_exprs(&source.where_clause, &target.where_clause);
164
165 self.diff_expr_lists(&source.group_by, &target.group_by);
167
168 self.diff_optional_exprs(&source.having, &target.having);
170
171 self.diff_order_by(&source.order_by, &target.order_by);
173
174 self.diff_optional_exprs(&source.limit, &target.limit);
176
177 self.diff_optional_exprs(&source.offset, &target.offset);
179
180 self.diff_optional_exprs(&source.qualify, &target.qualify);
182 }
183
184 fn diff_insert(&mut self, source: &InsertStatement, target: &InsertStatement) {
187 if source.table != target.table {
188 self.changes.push(ChangeAction::Update(
189 AstNode::Expr(table_ref_to_expr(&source.table)),
190 AstNode::Expr(table_ref_to_expr(&target.table)),
191 ));
192 }
193
194 self.diff_string_lists(&source.columns, &target.columns);
196
197 match (&source.source, &target.source) {
199 (InsertSource::Values(sv), InsertSource::Values(tv)) => {
200 for (i, (sr, tr)) in sv.iter().zip(tv.iter()).enumerate() {
201 self.diff_expr_lists(sr, tr);
202 let _ = i;
203 }
204 for extra in sv.iter().skip(tv.len()) {
205 for e in extra {
206 self.changes
207 .push(ChangeAction::Remove(AstNode::Expr(e.clone())));
208 }
209 }
210 for extra in tv.iter().skip(sv.len()) {
211 for e in extra {
212 self.changes
213 .push(ChangeAction::Insert(AstNode::Expr(e.clone())));
214 }
215 }
216 }
217 (InsertSource::Query(sq), InsertSource::Query(tq)) => {
218 self.diff_statements(sq, tq);
219 }
220 _ => {
221 self.changes
222 .push(ChangeAction::Remove(AstNode::Statement(Box::new(
223 Statement::Insert(source.clone()),
224 ))));
225 self.changes
226 .push(ChangeAction::Insert(AstNode::Statement(Box::new(
227 Statement::Insert(target.clone()),
228 ))));
229 }
230 }
231 }
232
233 fn diff_update(&mut self, source: &UpdateStatement, target: &UpdateStatement) {
236 if source.table != target.table {
237 self.changes.push(ChangeAction::Update(
238 AstNode::Expr(table_ref_to_expr(&source.table)),
239 AstNode::Expr(table_ref_to_expr(&target.table)),
240 ));
241 }
242
243 let source_map: HashMap<&str, &Expr> = source
245 .assignments
246 .iter()
247 .map(|(k, v)| (k.as_str(), v))
248 .collect();
249 let target_map: HashMap<&str, &Expr> = target
250 .assignments
251 .iter()
252 .map(|(k, v)| (k.as_str(), v))
253 .collect();
254
255 for (col, src_val) in &source_map {
256 if let Some(tgt_val) = target_map.get(col) {
257 self.diff_exprs(src_val, tgt_val);
258 } else {
259 self.changes
260 .push(ChangeAction::Remove(AstNode::Expr((*src_val).clone())));
261 }
262 }
263 for (col, tgt_val) in &target_map {
264 if !source_map.contains_key(col) {
265 self.changes
266 .push(ChangeAction::Insert(AstNode::Expr((*tgt_val).clone())));
267 }
268 }
269
270 self.diff_optional_exprs(&source.where_clause, &target.where_clause);
271 }
272
273 fn diff_delete(&mut self, source: &DeleteStatement, target: &DeleteStatement) {
276 if source.table != target.table {
277 self.changes.push(ChangeAction::Update(
278 AstNode::Expr(table_ref_to_expr(&source.table)),
279 AstNode::Expr(table_ref_to_expr(&target.table)),
280 ));
281 }
282 self.diff_optional_exprs(&source.where_clause, &target.where_clause);
283 }
284
285 fn diff_create_table(&mut self, source: &CreateTableStatement, target: &CreateTableStatement) {
288 if source.table != target.table {
289 self.changes.push(ChangeAction::Update(
290 AstNode::Expr(table_ref_to_expr(&source.table)),
291 AstNode::Expr(table_ref_to_expr(&target.table)),
292 ));
293 }
294
295 let source_cols: HashMap<&str, &ColumnDef> = source
297 .columns
298 .iter()
299 .map(|c| (c.name.as_str(), c))
300 .collect();
301 let target_cols: HashMap<&str, &ColumnDef> = target
302 .columns
303 .iter()
304 .map(|c| (c.name.as_str(), c))
305 .collect();
306
307 for (name, src_col) in &source_cols {
308 if let Some(tgt_col) = target_cols.get(name) {
309 if src_col != tgt_col {
310 self.changes.push(ChangeAction::Update(
311 AstNode::ColumnDef((*src_col).clone()),
312 AstNode::ColumnDef((*tgt_col).clone()),
313 ));
314 } else {
315 self.changes.push(ChangeAction::Keep(
316 AstNode::ColumnDef((*src_col).clone()),
317 AstNode::ColumnDef((*tgt_col).clone()),
318 ));
319 }
320 } else {
321 self.changes
322 .push(ChangeAction::Remove(AstNode::ColumnDef((*src_col).clone())));
323 }
324 }
325 for (name, tgt_col) in &target_cols {
326 if !source_cols.contains_key(name) {
327 self.changes
328 .push(ChangeAction::Insert(AstNode::ColumnDef((*tgt_col).clone())));
329 }
330 }
331
332 self.diff_constraints(&source.constraints, &target.constraints);
334 }
335
336 fn diff_drop_table(&mut self, source: &DropTableStatement, target: &DropTableStatement) {
339 if source != target {
340 self.changes.push(ChangeAction::Update(
341 AstNode::Statement(Box::new(Statement::DropTable(source.clone()))),
342 AstNode::Statement(Box::new(Statement::DropTable(target.clone()))),
343 ));
344 } else {
345 self.changes.push(ChangeAction::Keep(
346 AstNode::Statement(Box::new(Statement::DropTable(source.clone()))),
347 AstNode::Statement(Box::new(Statement::DropTable(target.clone()))),
348 ));
349 }
350 }
351
352 fn diff_set_operation(
355 &mut self,
356 source: &SetOperationStatement,
357 target: &SetOperationStatement,
358 ) {
359 if source.op != target.op || source.all != target.all {
360 self.changes.push(ChangeAction::Update(
361 AstNode::Statement(Box::new(Statement::SetOperation(source.clone()))),
362 AstNode::Statement(Box::new(Statement::SetOperation(target.clone()))),
363 ));
364 return;
365 }
366 self.diff_statements(&source.left, &target.left);
367 self.diff_statements(&source.right, &target.right);
368 self.diff_order_by(&source.order_by, &target.order_by);
369 self.diff_optional_exprs(&source.limit, &target.limit);
370 self.diff_optional_exprs(&source.offset, &target.offset);
371 }
372
373 fn diff_alter_table(&mut self, source: &AlterTableStatement, target: &AlterTableStatement) {
376 if source.table != target.table {
377 self.changes.push(ChangeAction::Update(
378 AstNode::Expr(table_ref_to_expr(&source.table)),
379 AstNode::Expr(table_ref_to_expr(&target.table)),
380 ));
381 }
382 if source.actions != target.actions {
384 self.changes.push(ChangeAction::Update(
385 AstNode::Statement(Box::new(Statement::AlterTable(source.clone()))),
386 AstNode::Statement(Box::new(Statement::AlterTable(target.clone()))),
387 ));
388 }
389 }
390
391 fn diff_create_view(&mut self, source: &CreateViewStatement, target: &CreateViewStatement) {
394 if source.name != target.name {
395 self.changes.push(ChangeAction::Update(
396 AstNode::Expr(table_ref_to_expr(&source.name)),
397 AstNode::Expr(table_ref_to_expr(&target.name)),
398 ));
399 }
400 self.diff_statements(&source.query, &target.query);
401 }
402
403 fn diff_exprs(&mut self, source: &Expr, target: &Expr) {
406 if source == target {
407 self.changes.push(ChangeAction::Keep(
408 AstNode::Expr(source.clone()),
409 AstNode::Expr(target.clone()),
410 ));
411 return;
412 }
413
414 match (source, target) {
416 (
417 Expr::BinaryOp {
418 left: sl,
419 op: sop,
420 right: sr,
421 },
422 Expr::BinaryOp {
423 left: tl,
424 op: top,
425 right: tr,
426 },
427 ) => {
428 if sop != top {
429 self.changes.push(ChangeAction::Update(
430 AstNode::Expr(source.clone()),
431 AstNode::Expr(target.clone()),
432 ));
433 } else {
434 self.diff_exprs(sl, tl);
435 self.diff_exprs(sr, tr);
436 }
437 }
438 (Expr::UnaryOp { op: sop, expr: se }, Expr::UnaryOp { op: top, expr: te }) => {
439 if sop != top {
440 self.changes.push(ChangeAction::Update(
441 AstNode::Expr(source.clone()),
442 AstNode::Expr(target.clone()),
443 ));
444 } else {
445 self.diff_exprs(se, te);
446 }
447 }
448 (
449 Expr::Function {
450 name: sn,
451 args: sa,
452 distinct: sd,
453 ..
454 },
455 Expr::Function {
456 name: tn,
457 args: ta,
458 distinct: td,
459 ..
460 },
461 ) => {
462 if sn != tn || sd != td {
463 self.changes.push(ChangeAction::Update(
464 AstNode::Expr(source.clone()),
465 AstNode::Expr(target.clone()),
466 ));
467 } else {
468 self.diff_expr_lists(sa, ta);
469 }
470 }
471 (
472 Expr::Cast {
473 expr: se,
474 data_type: sd,
475 },
476 Expr::Cast {
477 expr: te,
478 data_type: td,
479 },
480 ) => {
481 if sd != td {
482 self.changes.push(ChangeAction::Update(
483 AstNode::Expr(source.clone()),
484 AstNode::Expr(target.clone()),
485 ));
486 } else {
487 self.diff_exprs(se, te);
488 }
489 }
490 (
491 Expr::Case {
492 operand: so,
493 when_clauses: sw,
494 else_clause: se,
495 },
496 Expr::Case {
497 operand: to,
498 when_clauses: tw,
499 else_clause: te,
500 },
501 ) => {
502 self.diff_optional_boxed_exprs(so, to);
503 for (i, ((sc, sr), (tc, tr))) in sw.iter().zip(tw.iter()).enumerate() {
505 self.diff_exprs(sc, tc);
506 self.diff_exprs(sr, tr);
507 let _ = i;
508 }
509 for (sc, sr) in sw.iter().skip(tw.len()) {
510 self.changes
511 .push(ChangeAction::Remove(AstNode::Expr(sc.clone())));
512 self.changes
513 .push(ChangeAction::Remove(AstNode::Expr(sr.clone())));
514 }
515 for (tc, tr) in tw.iter().skip(sw.len()) {
516 self.changes
517 .push(ChangeAction::Insert(AstNode::Expr(tc.clone())));
518 self.changes
519 .push(ChangeAction::Insert(AstNode::Expr(tr.clone())));
520 }
521 self.diff_optional_boxed_exprs(se, te);
522 }
523 (Expr::Nested(se), Expr::Nested(te)) => self.diff_exprs(se, te),
524 (
525 Expr::Between {
526 expr: se,
527 low: sl,
528 high: sh,
529 negated: sn,
530 },
531 Expr::Between {
532 expr: te,
533 low: tl,
534 high: th,
535 negated: tn,
536 },
537 ) => {
538 if sn != tn {
539 self.changes.push(ChangeAction::Update(
540 AstNode::Expr(source.clone()),
541 AstNode::Expr(target.clone()),
542 ));
543 } else {
544 self.diff_exprs(se, te);
545 self.diff_exprs(sl, tl);
546 self.diff_exprs(sh, th);
547 }
548 }
549 (
550 Expr::InList {
551 expr: se,
552 list: sl,
553 negated: sn,
554 },
555 Expr::InList {
556 expr: te,
557 list: tl,
558 negated: tn,
559 },
560 ) => {
561 if sn != tn {
562 self.changes.push(ChangeAction::Update(
563 AstNode::Expr(source.clone()),
564 AstNode::Expr(target.clone()),
565 ));
566 } else {
567 self.diff_exprs(se, te);
568 self.diff_expr_lists(sl, tl);
569 }
570 }
571 (
572 Expr::InSubquery {
573 expr: se,
574 subquery: ss,
575 negated: sn,
576 },
577 Expr::InSubquery {
578 expr: te,
579 subquery: ts,
580 negated: tn,
581 },
582 ) => {
583 if sn != tn {
584 self.changes.push(ChangeAction::Update(
585 AstNode::Expr(source.clone()),
586 AstNode::Expr(target.clone()),
587 ));
588 } else {
589 self.diff_exprs(se, te);
590 self.diff_statements(ss, ts);
591 }
592 }
593 (
594 Expr::IsNull {
595 expr: se,
596 negated: sn,
597 },
598 Expr::IsNull {
599 expr: te,
600 negated: tn,
601 },
602 ) => {
603 if sn != tn {
604 self.changes.push(ChangeAction::Update(
605 AstNode::Expr(source.clone()),
606 AstNode::Expr(target.clone()),
607 ));
608 } else {
609 self.diff_exprs(se, te);
610 }
611 }
612 (
613 Expr::Like {
614 expr: se,
615 pattern: sp,
616 negated: sn,
617 ..
618 },
619 Expr::Like {
620 expr: te,
621 pattern: tp,
622 negated: tn,
623 ..
624 },
625 )
626 | (
627 Expr::ILike {
628 expr: se,
629 pattern: sp,
630 negated: sn,
631 ..
632 },
633 Expr::ILike {
634 expr: te,
635 pattern: tp,
636 negated: tn,
637 ..
638 },
639 ) => {
640 if sn != tn {
641 self.changes.push(ChangeAction::Update(
642 AstNode::Expr(source.clone()),
643 AstNode::Expr(target.clone()),
644 ));
645 } else {
646 self.diff_exprs(se, te);
647 self.diff_exprs(sp, tp);
648 }
649 }
650 (Expr::Subquery(ss), Expr::Subquery(ts)) => self.diff_statements(ss, ts),
651 (
652 Expr::Exists {
653 subquery: ss,
654 negated: sn,
655 },
656 Expr::Exists {
657 subquery: ts,
658 negated: tn,
659 },
660 ) => {
661 if sn != tn {
662 self.changes.push(ChangeAction::Update(
663 AstNode::Expr(source.clone()),
664 AstNode::Expr(target.clone()),
665 ));
666 } else {
667 self.diff_statements(ss, ts);
668 }
669 }
670 (Expr::Alias { expr: se, name: sn }, Expr::Alias { expr: te, name: tn }) => {
671 if sn != tn {
672 self.changes.push(ChangeAction::Update(
673 AstNode::Expr(source.clone()),
674 AstNode::Expr(target.clone()),
675 ));
676 } else {
677 self.diff_exprs(se, te);
678 }
679 }
680 (Expr::Coalesce(sa), Expr::Coalesce(ta)) => self.diff_expr_lists(sa, ta),
681 (Expr::ArrayLiteral(sa), Expr::ArrayLiteral(ta)) => self.diff_expr_lists(sa, ta),
682 (Expr::Tuple(sa), Expr::Tuple(ta)) => self.diff_expr_lists(sa, ta),
683 (Expr::TypedFunction { func: sf, .. }, Expr::TypedFunction { func: tf, .. }) => {
684 if std::mem::discriminant(sf) == std::mem::discriminant(tf) && source == target {
685 self.changes.push(ChangeAction::Keep(
686 AstNode::Expr(source.clone()),
687 AstNode::Expr(target.clone()),
688 ));
689 } else {
690 self.changes.push(ChangeAction::Update(
691 AstNode::Expr(source.clone()),
692 AstNode::Expr(target.clone()),
693 ));
694 }
695 }
696 _ => {
698 self.changes.push(ChangeAction::Update(
699 AstNode::Expr(source.clone()),
700 AstNode::Expr(target.clone()),
701 ));
702 }
703 }
704 }
705
706 fn diff_expr_lists(&mut self, source: &[Expr], target: &[Expr]) {
708 let lcs = compute_lcs(source, target);
710 let mut si = 0;
711 let mut ti = 0;
712 let mut li = 0;
713
714 while si < source.len() || ti < target.len() {
715 if li < lcs.len() {
716 let (lcs_si, lcs_ti) = lcs[li];
717
718 while si < lcs_si {
720 self.changes
721 .push(ChangeAction::Remove(AstNode::Expr(source[si].clone())));
722 si += 1;
723 }
724 while ti < lcs_ti {
726 self.changes
727 .push(ChangeAction::Insert(AstNode::Expr(target[ti].clone())));
728 ti += 1;
729 }
730 self.diff_exprs(&source[si], &target[ti]);
732 si += 1;
733 ti += 1;
734 li += 1;
735 } else {
736 while si < source.len() {
738 self.changes
739 .push(ChangeAction::Remove(AstNode::Expr(source[si].clone())));
740 si += 1;
741 }
742 while ti < target.len() {
744 self.changes
745 .push(ChangeAction::Insert(AstNode::Expr(target[ti].clone())));
746 ti += 1;
747 }
748 }
749 }
750 }
751
752 fn diff_select_items(&mut self, source: &[SelectItem], target: &[SelectItem]) {
753 let min_len = source.len().min(target.len());
754 for i in 0..min_len {
755 if source[i] == target[i] {
756 self.changes.push(ChangeAction::Keep(
757 AstNode::SelectItem(source[i].clone()),
758 AstNode::SelectItem(target[i].clone()),
759 ));
760 } else {
761 match (&source[i], &target[i]) {
762 (
763 SelectItem::Expr {
764 expr: se,
765 alias: sa,
766 ..
767 },
768 SelectItem::Expr {
769 expr: te,
770 alias: ta,
771 ..
772 },
773 ) => {
774 if sa != ta {
775 self.changes.push(ChangeAction::Update(
776 AstNode::SelectItem(source[i].clone()),
777 AstNode::SelectItem(target[i].clone()),
778 ));
779 } else {
780 self.diff_exprs(se, te);
781 }
782 }
783 _ => {
784 self.changes.push(ChangeAction::Update(
785 AstNode::SelectItem(source[i].clone()),
786 AstNode::SelectItem(target[i].clone()),
787 ));
788 }
789 }
790 }
791 }
792 for item in source.iter().skip(min_len) {
793 self.changes
794 .push(ChangeAction::Remove(AstNode::SelectItem(item.clone())));
795 }
796 for item in target.iter().skip(min_len) {
797 self.changes
798 .push(ChangeAction::Insert(AstNode::SelectItem(item.clone())));
799 }
800 }
801
802 fn diff_optional_exprs(&mut self, source: &Option<Expr>, target: &Option<Expr>) {
803 match (source, target) {
804 (Some(s), Some(t)) => self.diff_exprs(s, t),
805 (None, Some(t)) => self
806 .changes
807 .push(ChangeAction::Insert(AstNode::Expr(t.clone()))),
808 (Some(s), None) => self
809 .changes
810 .push(ChangeAction::Remove(AstNode::Expr(s.clone()))),
811 (None, None) => {}
812 }
813 }
814
815 fn diff_optional_boxed_exprs(
816 &mut self,
817 source: &Option<Box<Expr>>,
818 target: &Option<Box<Expr>>,
819 ) {
820 match (source, target) {
821 (Some(s), Some(t)) => self.diff_exprs(s, t),
822 (None, Some(t)) => self
823 .changes
824 .push(ChangeAction::Insert(AstNode::Expr((**t).clone()))),
825 (Some(s), None) => self
826 .changes
827 .push(ChangeAction::Remove(AstNode::Expr((**s).clone()))),
828 (None, None) => {}
829 }
830 }
831
832 fn diff_ctes(&mut self, source: &[Cte], target: &[Cte]) {
833 let source_map: HashMap<&str, &Cte> = source.iter().map(|c| (c.name.as_str(), c)).collect();
835 let target_map: HashMap<&str, &Cte> = target.iter().map(|c| (c.name.as_str(), c)).collect();
836
837 for (name, sc) in &source_map {
838 if let Some(tc) = target_map.get(name) {
839 if sc == tc {
840 self.changes.push(ChangeAction::Keep(
841 AstNode::Cte(Box::new((*sc).clone())),
842 AstNode::Cte(Box::new((*tc).clone())),
843 ));
844 } else {
845 self.diff_statements(&sc.query, &tc.query);
846 }
847 } else {
848 self.changes
849 .push(ChangeAction::Remove(AstNode::Cte(Box::new((*sc).clone()))));
850 }
851 }
852 for (name, tc) in &target_map {
853 if !source_map.contains_key(name) {
854 self.changes
855 .push(ChangeAction::Insert(AstNode::Cte(Box::new((*tc).clone()))));
856 }
857 }
858 }
859
860 fn diff_joins(&mut self, source: &[JoinClause], target: &[JoinClause]) {
861 let min_len = source.len().min(target.len());
862 for i in 0..min_len {
863 if source[i] == target[i] {
864 self.changes.push(ChangeAction::Keep(
865 AstNode::JoinClause(source[i].clone()),
866 AstNode::JoinClause(target[i].clone()),
867 ));
868 } else if source[i].join_type == target[i].join_type {
869 self.diff_table_sources(&source[i].table, &target[i].table);
871 self.diff_optional_exprs(&source[i].on, &target[i].on);
872 } else {
873 self.changes.push(ChangeAction::Update(
874 AstNode::JoinClause(source[i].clone()),
875 AstNode::JoinClause(target[i].clone()),
876 ));
877 }
878 }
879 for item in source.iter().skip(min_len) {
880 self.changes
881 .push(ChangeAction::Remove(AstNode::JoinClause(item.clone())));
882 }
883 for item in target.iter().skip(min_len) {
884 self.changes
885 .push(ChangeAction::Insert(AstNode::JoinClause(item.clone())));
886 }
887 }
888
889 fn diff_order_by(&mut self, source: &[OrderByItem], target: &[OrderByItem]) {
890 let min_len = source.len().min(target.len());
891 for i in 0..min_len {
892 if source[i] == target[i] {
893 self.changes.push(ChangeAction::Keep(
894 AstNode::OrderByItem(source[i].clone()),
895 AstNode::OrderByItem(target[i].clone()),
896 ));
897 } else if source[i].ascending == target[i].ascending
898 && source[i].nulls_first == target[i].nulls_first
899 {
900 self.diff_exprs(&source[i].expr, &target[i].expr);
901 } else {
902 self.changes.push(ChangeAction::Update(
903 AstNode::OrderByItem(source[i].clone()),
904 AstNode::OrderByItem(target[i].clone()),
905 ));
906 }
907 }
908 for item in source.iter().skip(min_len) {
909 self.changes
910 .push(ChangeAction::Remove(AstNode::OrderByItem(item.clone())));
911 }
912 for item in target.iter().skip(min_len) {
913 self.changes
914 .push(ChangeAction::Insert(AstNode::OrderByItem(item.clone())));
915 }
916 }
917
918 fn diff_table_sources(&mut self, source: &TableSource, target: &TableSource) {
919 if source == target {
920 return;
921 }
922 match (source, target) {
923 (TableSource::Table(st), TableSource::Table(tt)) => {
924 if st != tt {
925 self.changes.push(ChangeAction::Update(
926 AstNode::Expr(table_ref_to_expr(st)),
927 AstNode::Expr(table_ref_to_expr(tt)),
928 ));
929 }
930 }
931 (TableSource::Subquery { query: sq, .. }, TableSource::Subquery { query: tq, .. }) => {
932 self.diff_statements(sq, tq);
933 }
934 _ => {
935 self.remove_table_source(source);
937 self.insert_table_source(target);
938 }
939 }
940 }
941
942 fn insert_table_source(&mut self, source: &TableSource) {
943 match source {
944 TableSource::Table(t) => {
945 self.changes
946 .push(ChangeAction::Insert(AstNode::Expr(table_ref_to_expr(t))));
947 }
948 TableSource::Subquery { query, .. } => {
949 self.changes
950 .push(ChangeAction::Insert(AstNode::Statement(Box::new(
951 (**query).clone(),
952 ))));
953 }
954 other => {
955 self.changes
956 .push(ChangeAction::Insert(AstNode::Expr(Expr::StringLiteral(
957 format!("{other:?}"),
958 ))));
959 }
960 }
961 }
962
963 fn remove_table_source(&mut self, source: &TableSource) {
964 match source {
965 TableSource::Table(t) => {
966 self.changes
967 .push(ChangeAction::Remove(AstNode::Expr(table_ref_to_expr(t))));
968 }
969 TableSource::Subquery { query, .. } => {
970 self.changes
971 .push(ChangeAction::Remove(AstNode::Statement(Box::new(
972 (**query).clone(),
973 ))));
974 }
975 other => {
976 self.changes
977 .push(ChangeAction::Remove(AstNode::Expr(Expr::StringLiteral(
978 format!("{other:?}"),
979 ))));
980 }
981 }
982 }
983
984 fn diff_constraints(&mut self, source: &[TableConstraint], target: &[TableConstraint]) {
985 let min_len = source.len().min(target.len());
987 for i in 0..min_len {
988 if source[i] == target[i] {
989 self.changes.push(ChangeAction::Keep(
990 AstNode::TableConstraint(source[i].clone()),
991 AstNode::TableConstraint(target[i].clone()),
992 ));
993 } else {
994 self.changes.push(ChangeAction::Update(
995 AstNode::TableConstraint(source[i].clone()),
996 AstNode::TableConstraint(target[i].clone()),
997 ));
998 }
999 }
1000 for item in source.iter().skip(min_len) {
1001 self.changes
1002 .push(ChangeAction::Remove(AstNode::TableConstraint(item.clone())));
1003 }
1004 for item in target.iter().skip(min_len) {
1005 self.changes
1006 .push(ChangeAction::Insert(AstNode::TableConstraint(item.clone())));
1007 }
1008 }
1009
1010 fn diff_string_lists(&mut self, source: &[String], target: &[String]) {
1011 for s in source {
1012 if !target.contains(s) {
1013 self.changes
1014 .push(ChangeAction::Remove(AstNode::Expr(Expr::Column {
1015 table: None,
1016 name: s.clone(),
1017 quote_style: QuoteStyle::None,
1018 table_quote_style: QuoteStyle::None,
1019 })));
1020 }
1021 }
1022 for t in target {
1023 if !source.contains(t) {
1024 self.changes
1025 .push(ChangeAction::Insert(AstNode::Expr(Expr::Column {
1026 table: None,
1027 name: t.clone(),
1028 quote_style: QuoteStyle::None,
1029 table_quote_style: QuoteStyle::None,
1030 })));
1031 }
1032 }
1033 }
1034}
1035
1036fn compute_lcs(source: &[Expr], target: &[Expr]) -> Vec<(usize, usize)> {
1043 let m = source.len();
1044 let n = target.len();
1045 if m == 0 || n == 0 {
1046 return Vec::new();
1047 }
1048
1049 let mut dp = vec![vec![0u32; n + 1]; m + 1];
1051 for i in 1..=m {
1052 for j in 1..=n {
1053 if source[i - 1] == target[j - 1] {
1054 dp[i][j] = dp[i - 1][j - 1] + 1;
1055 } else {
1056 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
1057 }
1058 }
1059 }
1060
1061 let mut result = Vec::new();
1063 let mut i = m;
1064 let mut j = n;
1065 while i > 0 && j > 0 {
1066 if source[i - 1] == target[j - 1] {
1067 result.push((i - 1, j - 1));
1068 i -= 1;
1069 j -= 1;
1070 } else if dp[i - 1][j] >= dp[i][j - 1] {
1071 i -= 1;
1072 } else {
1073 j -= 1;
1074 }
1075 }
1076 result.reverse();
1077 result
1078}
1079
1080fn table_ref_to_expr(table: &TableRef) -> Expr {
1082 let full_name = match (&table.catalog, &table.schema) {
1083 (Some(c), Some(s)) => format!("{c}.{s}.{}", table.name),
1084 (None, Some(s)) => format!("{s}.{}", table.name),
1085 _ => table.name.clone(),
1086 };
1087 Expr::Column {
1088 table: table.schema.clone(),
1089 name: full_name,
1090 quote_style: table.name_quote_style,
1091 table_quote_style: QuoteStyle::None,
1092 }
1093}
1094
1095pub fn diff_sql(
1106 source_sql: &str,
1107 target_sql: &str,
1108 dialect: crate::dialects::Dialect,
1109) -> crate::errors::Result<Vec<ChangeAction>> {
1110 let source = crate::parser::parse(source_sql, dialect)?;
1111 let target = crate::parser::parse(target_sql, dialect)?;
1112 Ok(diff(&source, &target))
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117 use super::*;
1118 use crate::dialects::Dialect;
1119 use crate::parser::parse;
1120
1121 fn count_by_action(changes: &[ChangeAction]) -> (usize, usize, usize, usize, usize) {
1122 let mut keeps = 0;
1123 let mut inserts = 0;
1124 let mut removes = 0;
1125 let mut updates = 0;
1126 let mut moves = 0;
1127 for c in changes {
1128 match c {
1129 ChangeAction::Keep(..) => keeps += 1,
1130 ChangeAction::Insert(..) => inserts += 1,
1131 ChangeAction::Remove(..) => removes += 1,
1132 ChangeAction::Update(..) => updates += 1,
1133 ChangeAction::Move(..) => moves += 1,
1134 }
1135 }
1136 (keeps, inserts, removes, updates, moves)
1137 }
1138
1139 #[test]
1140 fn test_identical_queries_are_all_keep() {
1141 let sql = "SELECT a, b FROM t WHERE a > 1";
1142 let source = parse(sql, Dialect::Ansi).unwrap();
1143 let target = parse(sql, Dialect::Ansi).unwrap();
1144 let changes = diff(&source, &target);
1145 let (keeps, inserts, removes, updates, _moves) = count_by_action(&changes);
1146 assert!(keeps > 0, "should have keep actions");
1147 assert_eq!(inserts, 0, "no inserts for identical queries");
1148 assert_eq!(removes, 0, "no removes for identical queries");
1149 assert_eq!(updates, 0, "no updates for identical queries");
1150 }
1151
1152 #[test]
1153 fn test_column_added() {
1154 let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1155 let target = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1156 let changes = diff(&source, &target);
1157 let (keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1158 assert!(keeps > 0);
1159 assert!(inserts > 0, "should have insert for new column b");
1160 assert_eq!(removes, 0);
1161 }
1162
1163 #[test]
1164 fn test_column_removed() {
1165 let source = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1166 let target = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1167 let changes = diff(&source, &target);
1168 let (keeps, _inserts, removes, _updates, _moves) = count_by_action(&changes);
1169 assert!(keeps > 0);
1170 assert!(removes > 0, "should have remove for column b");
1171 }
1172
1173 #[test]
1174 fn test_column_changed() {
1175 let source = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1176 let target = parse("SELECT a, c FROM t", Dialect::Ansi).unwrap();
1177 let changes = diff(&source, &target);
1178 let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1179 assert!(updates > 0, "should have update for b -> c");
1180 }
1181
1182 #[test]
1183 fn test_where_clause_added() {
1184 let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1185 let target = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1186 let changes = diff(&source, &target);
1187 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1188 assert!(inserts > 0, "should have insert for WHERE clause");
1189 }
1190
1191 #[test]
1192 fn test_where_clause_removed() {
1193 let source = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1194 let target = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1195 let changes = diff(&source, &target);
1196 let (_keeps, _inserts, removes, _updates, _moves) = count_by_action(&changes);
1197 assert!(removes > 0, "should have remove for WHERE clause");
1198 }
1199
1200 #[test]
1201 fn test_where_clause_updated() {
1202 let source = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1203 let target = parse("SELECT a FROM t WHERE a > 2", Dialect::Ansi).unwrap();
1204 let changes = diff(&source, &target);
1205 let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1206 assert!(updates > 0, "should have update for WHERE value change");
1207 }
1208
1209 #[test]
1210 fn test_table_changed() {
1211 let source = parse("SELECT a FROM t1", Dialect::Ansi).unwrap();
1212 let target = parse("SELECT a FROM t2", Dialect::Ansi).unwrap();
1213 let changes = diff(&source, &target);
1214 let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1215 assert!(updates > 0, "should have update for table change");
1216 }
1217
1218 #[test]
1219 fn test_join_added() {
1220 let source = parse("SELECT a FROM t1", Dialect::Ansi).unwrap();
1221 let target = parse("SELECT a FROM t1 JOIN t2 ON t1.id = t2.id", Dialect::Ansi).unwrap();
1222 let changes = diff(&source, &target);
1223 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1224 assert!(inserts > 0, "should have insert for JOIN");
1225 }
1226
1227 #[test]
1228 fn test_order_by_changed() {
1229 let source = parse("SELECT a FROM t ORDER BY a ASC", Dialect::Ansi).unwrap();
1230 let target = parse("SELECT a FROM t ORDER BY a DESC", Dialect::Ansi).unwrap();
1231 let changes = diff(&source, &target);
1232 let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1233 assert!(updates > 0, "should have update for ORDER BY direction");
1234 }
1235
1236 #[test]
1237 fn test_complex_nested_query() {
1238 let source = parse(
1239 "SELECT a, b FROM t1 WHERE a IN (SELECT x FROM t2 WHERE x > 0)",
1240 Dialect::Ansi,
1241 )
1242 .unwrap();
1243 let target = parse(
1244 "SELECT a, c FROM t1 WHERE a IN (SELECT x FROM t2 WHERE x > 5)",
1245 Dialect::Ansi,
1246 )
1247 .unwrap();
1248 let changes = diff(&source, &target);
1249 let (keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1250 assert!(keeps > 0, "unchanged parts should be kept");
1251 assert!(updates > 0, "changed parts should be updated (b->c, 0->5)");
1252 }
1253
1254 #[test]
1255 fn test_different_statement_types() {
1256 let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1257 let target = parse("CREATE TABLE t (a INT)", Dialect::Ansi).unwrap();
1258 let changes = diff(&source, &target);
1259 let (_keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1260 assert!(removes > 0, "source should be removed");
1261 assert!(inserts > 0, "target should be inserted");
1262 }
1263
1264 #[test]
1265 fn test_cte_added() {
1266 let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1267 let target = parse("WITH cte AS (SELECT 1 AS x) SELECT a FROM t", Dialect::Ansi).unwrap();
1268 let changes = diff(&source, &target);
1269 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1270 assert!(inserts > 0, "should have insert for CTE");
1271 }
1272
1273 #[test]
1274 fn test_limit_changed() {
1275 let source = parse("SELECT a FROM t LIMIT 10", Dialect::Ansi).unwrap();
1276 let target = parse("SELECT a FROM t LIMIT 20", Dialect::Ansi).unwrap();
1277 let changes = diff(&source, &target);
1278 let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1279 assert!(updates > 0, "should have update for LIMIT change");
1280 }
1281
1282 #[test]
1283 fn test_group_by_added() {
1284 let source = parse("SELECT a, COUNT(*) FROM t", Dialect::Ansi).unwrap();
1285 let target = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
1286 let changes = diff(&source, &target);
1287 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1288 assert!(inserts > 0, "should have insert for GROUP BY");
1289 }
1290
1291 #[test]
1292 fn test_diff_sql_convenience() {
1293 let changes = diff_sql("SELECT a FROM t", "SELECT a, b FROM t", Dialect::Ansi).unwrap();
1294 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1295 assert!(inserts > 0);
1296 }
1297
1298 #[test]
1299 fn test_having_added() {
1300 let source = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
1301 let target = parse(
1302 "SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1",
1303 Dialect::Ansi,
1304 )
1305 .unwrap();
1306 let changes = diff(&source, &target);
1307 let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1308 assert!(inserts > 0, "should have insert for HAVING");
1309 }
1310
1311 #[test]
1312 fn test_create_table_column_diff() {
1313 let source = parse("CREATE TABLE t (a INT, b TEXT)", Dialect::Ansi).unwrap();
1314 let target = parse("CREATE TABLE t (a INT, c TEXT)", Dialect::Ansi).unwrap();
1315 let changes = diff(&source, &target);
1316 let (_keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1317 assert!(removes > 0, "should remove column b");
1318 assert!(inserts > 0, "should insert column c");
1319 }
1320
1321 #[test]
1322 fn test_union_diff() {
1323 let source = parse("SELECT a FROM t1 UNION SELECT b FROM t2", Dialect::Ansi).unwrap();
1324 let target = parse("SELECT a FROM t1 UNION SELECT c FROM t2", Dialect::Ansi).unwrap();
1325 let changes = diff(&source, &target);
1326 let (keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1327 assert!(keeps > 0);
1328 assert!(updates > 0, "should have update for b -> c");
1329 }
1330}