1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::generator::{Generator, GeneratorConfig};
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap, HashSet};
13
14#[derive(Debug, Clone, PartialEq)]
16pub enum Edit {
17 Insert { expression: Expression },
19 Remove { expression: Expression },
21 Move {
23 source: Expression,
24 target: Expression,
25 },
26 Update {
28 source: Expression,
29 target: Expression,
30 },
31 Keep {
33 source: Expression,
34 target: Expression,
35 },
36}
37
38impl Edit {
39 pub fn is_change(&self) -> bool {
41 !matches!(self, Edit::Keep { .. })
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct DiffConfig {
48 pub f: f64,
50 pub t: f64,
52 pub dialect: Option<DialectType>,
54}
55
56impl Default for DiffConfig {
57 fn default() -> Self {
58 Self {
59 f: 0.6,
60 t: 0.6,
61 dialect: None,
62 }
63 }
64}
65
66pub fn diff(source: &Expression, target: &Expression, delta_only: bool) -> Vec<Edit> {
87 let config = DiffConfig::default();
88 diff_with_config(source, target, delta_only, &config)
89}
90
91pub fn diff_with_config(
93 source: &Expression,
94 target: &Expression,
95 delta_only: bool,
96 config: &DiffConfig,
97) -> Vec<Edit> {
98 let mut distiller = ChangeDistiller::new(config.clone());
99 distiller.diff(source, target, delta_only)
100}
101
102pub fn has_changes(edits: &[Edit]) -> bool {
104 edits.iter().any(|e| e.is_change())
105}
106
107pub fn changes_only(edits: Vec<Edit>) -> Vec<Edit> {
109 edits.into_iter().filter(|e| e.is_change()).collect()
110}
111
112struct IndexedTree {
118 nodes: Vec<Expression>,
119 parents: Vec<Option<usize>>,
120 children_indices: Vec<Vec<usize>>,
121}
122
123impl IndexedTree {
124 fn empty() -> Self {
125 Self {
126 nodes: Vec::new(),
127 parents: Vec::new(),
128 children_indices: Vec::new(),
129 }
130 }
131
132 fn build(root: &Expression) -> Self {
133 let mut tree = Self::empty();
134 tree.add_expr(root, None);
135 tree
136 }
137
138 fn add_expr(&mut self, expr: &Expression, parent_idx: Option<usize>) {
139 if matches!(expr, Expression::Identifier(_)) {
141 return;
142 }
143
144 let idx = self.nodes.len();
145 self.nodes.push(expr.clone());
146 self.parents.push(parent_idx);
147 self.children_indices.push(Vec::new());
148
149 if let Some(p) = parent_idx {
150 self.children_indices[p].push(idx);
151 }
152
153 self.add_children(expr, idx);
154 }
155
156 fn add_children(&mut self, expr: &Expression, parent_idx: usize) {
157 match expr {
158 Expression::Select(select) => {
159 if let Some(with) = &select.with {
160 for cte in &with.ctes {
161 self.add_expr(&Expression::Cte(Box::new(cte.clone())), Some(parent_idx));
162 }
163 }
164 for e in &select.expressions {
165 self.add_expr(e, Some(parent_idx));
166 }
167 if let Some(from) = &select.from {
168 for e in &from.expressions {
169 self.add_expr(e, Some(parent_idx));
170 }
171 }
172 for join in &select.joins {
173 self.add_expr(
174 &Expression::Join(Box::new(join.clone())),
175 Some(parent_idx),
176 );
177 }
178 if let Some(w) = &select.where_clause {
179 self.add_expr(&w.this, Some(parent_idx));
180 }
181 if let Some(gb) = &select.group_by {
182 for e in &gb.expressions {
183 self.add_expr(e, Some(parent_idx));
184 }
185 }
186 if let Some(h) = &select.having {
187 self.add_expr(&h.this, Some(parent_idx));
188 }
189 if let Some(ob) = &select.order_by {
190 for o in &ob.expressions {
191 self.add_expr(
192 &Expression::Ordered(Box::new(o.clone())),
193 Some(parent_idx),
194 );
195 }
196 }
197 if let Some(limit) = &select.limit {
198 self.add_expr(&limit.this, Some(parent_idx));
199 }
200 if let Some(offset) = &select.offset {
201 self.add_expr(&offset.this, Some(parent_idx));
202 }
203 }
204 Expression::Alias(alias) => {
205 self.add_expr(&alias.this, Some(parent_idx));
206 }
207 Expression::And(op)
208 | Expression::Or(op)
209 | Expression::Eq(op)
210 | Expression::Neq(op)
211 | Expression::Lt(op)
212 | Expression::Lte(op)
213 | Expression::Gt(op)
214 | Expression::Gte(op)
215 | Expression::Add(op)
216 | Expression::Sub(op)
217 | Expression::Mul(op)
218 | Expression::Div(op)
219 | Expression::Mod(op)
220 | Expression::BitwiseAnd(op)
221 | Expression::BitwiseOr(op)
222 | Expression::BitwiseXor(op)
223 | Expression::Concat(op) => {
224 self.add_expr(&op.left, Some(parent_idx));
225 self.add_expr(&op.right, Some(parent_idx));
226 }
227 Expression::Like(op) | Expression::ILike(op) => {
228 self.add_expr(&op.left, Some(parent_idx));
229 self.add_expr(&op.right, Some(parent_idx));
230 }
231 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
232 self.add_expr(&u.this, Some(parent_idx));
233 }
234 Expression::Function(func) => {
235 for arg in &func.args {
236 self.add_expr(arg, Some(parent_idx));
237 }
238 }
239 Expression::AggregateFunction(func) => {
240 for arg in &func.args {
241 self.add_expr(arg, Some(parent_idx));
242 }
243 }
244 Expression::Join(j) => {
245 self.add_expr(&j.this, Some(parent_idx));
246 if let Some(on) = &j.on {
247 self.add_expr(on, Some(parent_idx));
248 }
249 }
250 Expression::Anonymous(a) => {
251 for arg in &a.expressions {
252 self.add_expr(arg, Some(parent_idx));
253 }
254 }
255 Expression::WindowFunction(wf) => {
256 self.add_expr(&wf.this, Some(parent_idx));
257 }
258 Expression::Cast(cast) => {
259 self.add_expr(&cast.this, Some(parent_idx));
260 }
261 Expression::Subquery(sq) => {
262 self.add_expr(&sq.this, Some(parent_idx));
263 }
264 Expression::Paren(p) => {
265 self.add_expr(&p.this, Some(parent_idx));
266 }
267 Expression::Union(u) => {
268 self.add_expr(&u.left, Some(parent_idx));
269 self.add_expr(&u.right, Some(parent_idx));
270 }
271 Expression::Intersect(i) => {
272 self.add_expr(&i.left, Some(parent_idx));
273 self.add_expr(&i.right, Some(parent_idx));
274 }
275 Expression::Except(e) => {
276 self.add_expr(&e.left, Some(parent_idx));
277 self.add_expr(&e.right, Some(parent_idx));
278 }
279 Expression::Cte(cte) => {
280 self.add_expr(&cte.this, Some(parent_idx));
281 }
282 Expression::Case(c) => {
283 if let Some(operand) = &c.operand {
284 self.add_expr(operand, Some(parent_idx));
285 }
286 for (when, then) in &c.whens {
287 self.add_expr(when, Some(parent_idx));
288 self.add_expr(then, Some(parent_idx));
289 }
290 if let Some(else_) = &c.else_ {
291 self.add_expr(else_, Some(parent_idx));
292 }
293 }
294 Expression::In(i) => {
295 self.add_expr(&i.this, Some(parent_idx));
296 for e in &i.expressions {
297 self.add_expr(e, Some(parent_idx));
298 }
299 if let Some(q) = &i.query {
300 self.add_expr(q, Some(parent_idx));
301 }
302 }
303 Expression::Between(b) => {
304 self.add_expr(&b.this, Some(parent_idx));
305 self.add_expr(&b.low, Some(parent_idx));
306 self.add_expr(&b.high, Some(parent_idx));
307 }
308 Expression::IsNull(i) => {
309 self.add_expr(&i.this, Some(parent_idx));
310 }
311 Expression::Exists(e) => {
312 self.add_expr(&e.this, Some(parent_idx));
313 }
314 Expression::Ordered(o) => {
315 self.add_expr(&o.this, Some(parent_idx));
316 }
317 Expression::Lambda(l) => {
318 self.add_expr(&l.body, Some(parent_idx));
319 }
320 Expression::Coalesce(c) => {
321 for e in &c.expressions {
322 self.add_expr(e, Some(parent_idx));
323 }
324 }
325 Expression::Tuple(t) => {
326 for e in &t.expressions {
327 self.add_expr(e, Some(parent_idx));
328 }
329 }
330 Expression::Array(a) => {
331 for e in &a.expressions {
332 self.add_expr(e, Some(parent_idx));
333 }
334 }
335 Expression::Literal(_)
337 | Expression::Boolean(_)
338 | Expression::Null(_)
339 | Expression::Column(_)
340 | Expression::Table(_)
341 | Expression::Star(_)
342 | Expression::DataType(_)
343 | Expression::CurrentDate(_)
344 | Expression::CurrentTime(_)
345 | Expression::CurrentTimestamp(_) => {}
346 other => {
348 use crate::traversal::ExpressionWalk;
349 for child in other.children() {
350 if !matches!(child, Expression::Identifier(_)) {
351 self.add_expr(child, Some(parent_idx));
352 }
353 }
354 }
355 }
356 }
357
358 fn is_leaf(&self, idx: usize) -> bool {
359 self.children_indices[idx].is_empty()
360 }
361
362 fn leaf_indices(&self) -> Vec<usize> {
363 (0..self.nodes.len())
364 .filter(|&i| self.is_leaf(i))
365 .collect()
366 }
367
368 fn leaf_descendants(&self, idx: usize) -> Vec<usize> {
370 let mut result = Vec::new();
371 let mut stack = vec![idx];
372 while let Some(i) = stack.pop() {
373 if self.is_leaf(i) {
374 result.push(i);
375 }
376 for &child in &self.children_indices[i] {
377 stack.push(child);
378 }
379 }
380 result
381 }
382}
383
384fn dice_coefficient(a: &str, b: &str) -> f64 {
390 if a.len() < 2 || b.len() < 2 {
392 return if a == b { 1.0 } else { 0.0 };
393 }
394 let a_bigrams = bigram_histo(a);
395 let b_bigrams = bigram_histo(b);
396 let common: usize = a_bigrams
397 .iter()
398 .map(|(k, v)| v.min(b_bigrams.get(k).unwrap_or(&0)))
399 .sum();
400 let total: usize = a_bigrams.values().sum::<usize>() + b_bigrams.values().sum::<usize>();
401 if total == 0 {
402 1.0
403 } else {
404 2.0 * common as f64 / total as f64
405 }
406}
407
408fn bigram_histo(s: &str) -> HashMap<(char, char), usize> {
410 let chars: Vec<char> = s.chars().collect();
411 let mut map = HashMap::new();
412 for w in chars.windows(2) {
413 *map.entry((w[0], w[1])).or_insert(0) += 1;
414 }
415 map
416}
417
418fn node_sql(expr: &Expression, dialect: Option<DialectType>) -> String {
420 match dialect {
421 Some(d) => {
422 let config = GeneratorConfig {
423 dialect: Some(d),
424 ..GeneratorConfig::default()
425 };
426 let mut gen = Generator::with_config(config);
427 gen.generate(expr).unwrap_or_default()
428 }
429 None => Generator::sql(expr).unwrap_or_default(),
430 }
431}
432
433fn is_same_type(a: &Expression, b: &Expression) -> bool {
438 if std::mem::discriminant(a) != std::mem::discriminant(b) {
439 return false;
440 }
441 match (a, b) {
442 (Expression::Join(ja), Expression::Join(jb)) => ja.kind == jb.kind,
443 (Expression::Anonymous(aa), Expression::Anonymous(ab)) => {
444 Generator::sql(&aa.this).unwrap_or_default()
445 == Generator::sql(&ab.this).unwrap_or_default()
446 }
447 _ => true,
448 }
449}
450
451fn parent_similarity_score(
453 src_idx: usize,
454 tgt_idx: usize,
455 src_tree: &IndexedTree,
456 tgt_tree: &IndexedTree,
457 matchings: &HashMap<usize, usize>,
458) -> usize {
459 let mut score = 0;
460 let mut s = src_tree.parents[src_idx];
461 let mut t = tgt_tree.parents[tgt_idx];
462 while let (Some(sp), Some(tp)) = (s, t) {
463 if matchings.get(&sp) == Some(&tp) {
464 score += 1;
465 s = src_tree.parents[sp];
466 t = tgt_tree.parents[tp];
467 } else {
468 break;
469 }
470 }
471 score
472}
473
474fn is_updatable(expr: &Expression) -> bool {
478 matches!(
479 expr,
480 Expression::Alias(_)
481 | Expression::Boolean(_)
482 | Expression::Column(_)
483 | Expression::DataType(_)
484 | Expression::Lambda(_)
485 | Expression::Literal(_)
486 | Expression::Table(_)
487 | Expression::WindowFunction(_)
488 )
489}
490
491fn has_non_expression_leaf_change(a: &Expression, b: &Expression) -> bool {
495 match (a, b) {
496 (Expression::Union(ua), Expression::Union(ub)) => {
497 ua.all != ub.all || ua.distinct != ub.distinct
498 }
499 (Expression::Intersect(ia), Expression::Intersect(ib)) => {
500 ia.all != ib.all || ia.distinct != ib.distinct
501 }
502 (Expression::Except(ea), Expression::Except(eb)) => {
503 ea.all != eb.all || ea.distinct != eb.distinct
504 }
505 (Expression::Ordered(oa), Expression::Ordered(ob)) => {
506 oa.desc != ob.desc || oa.nulls_first != ob.nulls_first
507 }
508 (Expression::Join(ja), Expression::Join(jb)) => ja.kind != jb.kind,
509 _ => false,
510 }
511}
512
513fn lcs<T, F>(a: &[T], b: &[T], eq_fn: F) -> Vec<(usize, usize)>
515where
516 F: Fn(&T, &T) -> bool,
517{
518 let m = a.len();
519 let n = b.len();
520 let mut dp = vec![vec![0usize; n + 1]; m + 1];
521 for i in 1..=m {
522 for j in 1..=n {
523 if eq_fn(&a[i - 1], &b[j - 1]) {
524 dp[i][j] = dp[i - 1][j - 1] + 1;
525 } else {
526 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
527 }
528 }
529 }
530 let mut result = Vec::new();
531 let mut i = m;
532 let mut j = n;
533 while i > 0 && j > 0 {
534 if eq_fn(&a[i - 1], &b[j - 1]) {
535 result.push((i - 1, j - 1));
536 i -= 1;
537 j -= 1;
538 } else if dp[i - 1][j] > dp[i][j - 1] {
539 i -= 1;
540 } else {
541 j -= 1;
542 }
543 }
544 result.reverse();
545 result
546}
547
548#[derive(PartialEq)]
553struct MatchCandidate {
554 score: f64,
555 parent_sim: usize,
556 counter: usize, src_idx: usize,
558 tgt_idx: usize,
559}
560
561impl Eq for MatchCandidate {}
562
563impl PartialOrd for MatchCandidate {
564 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
565 Some(self.cmp(other))
566 }
567}
568
569impl Ord for MatchCandidate {
570 fn cmp(&self, other: &Self) -> Ordering {
571 self.score
572 .partial_cmp(&other.score)
573 .unwrap_or(Ordering::Equal)
574 .then_with(|| self.parent_sim.cmp(&other.parent_sim))
575 .then_with(|| self.counter.cmp(&other.counter))
576 }
577}
578
579struct ChangeDistiller {
584 config: DiffConfig,
585 src_tree: IndexedTree,
586 tgt_tree: IndexedTree,
587 matchings: HashMap<usize, usize>, }
589
590impl ChangeDistiller {
591 fn new(config: DiffConfig) -> Self {
592 Self {
593 config,
594 src_tree: IndexedTree::empty(),
595 tgt_tree: IndexedTree::empty(),
596 matchings: HashMap::new(),
597 }
598 }
599
600 fn diff(
601 &mut self,
602 source: &Expression,
603 target: &Expression,
604 delta_only: bool,
605 ) -> Vec<Edit> {
606 self.src_tree = IndexedTree::build(source);
607 self.tgt_tree = IndexedTree::build(target);
608
609 self.match_leaves();
611
612 self.match_internal_nodes();
614
615 self.generate_edits(delta_only)
617 }
618
619 fn match_leaves(&mut self) {
622 let src_leaves = self.src_tree.leaf_indices();
623 let tgt_leaves = self.tgt_tree.leaf_indices();
624
625 let src_sql: Vec<String> = src_leaves
627 .iter()
628 .map(|&i| node_sql(&self.src_tree.nodes[i], self.config.dialect))
629 .collect();
630 let tgt_sql: Vec<String> = tgt_leaves
631 .iter()
632 .map(|&i| node_sql(&self.tgt_tree.nodes[i], self.config.dialect))
633 .collect();
634
635 let mut heap = BinaryHeap::new();
636 let mut counter = 0usize;
637
638 for (si_pos, &si) in src_leaves.iter().enumerate() {
639 for (ti_pos, &ti) in tgt_leaves.iter().enumerate() {
640 if !is_same_type(&self.src_tree.nodes[si], &self.tgt_tree.nodes[ti]) {
641 continue;
642 }
643 let score = dice_coefficient(&src_sql[si_pos], &tgt_sql[ti_pos]);
644 if score >= self.config.t {
645 let parent_sim = parent_similarity_score(
646 si,
647 ti,
648 &self.src_tree,
649 &self.tgt_tree,
650 &self.matchings,
651 );
652 heap.push(MatchCandidate {
653 score,
654 parent_sim,
655 counter,
656 src_idx: si,
657 tgt_idx: ti,
658 });
659 counter += 1;
660 }
661 }
662 }
663
664 let mut matched_src: HashSet<usize> = HashSet::new();
665 let mut matched_tgt: HashSet<usize> = HashSet::new();
666
667 while let Some(m) = heap.pop() {
668 if matched_src.contains(&m.src_idx) || matched_tgt.contains(&m.tgt_idx) {
669 continue;
670 }
671 self.matchings.insert(m.src_idx, m.tgt_idx);
672 matched_src.insert(m.src_idx);
673 matched_tgt.insert(m.tgt_idx);
674 }
675 }
676
677 fn match_internal_nodes(&mut self) {
680 let src_internal: Vec<usize> = (0..self.src_tree.nodes.len())
683 .rev()
684 .filter(|&i| !self.src_tree.is_leaf(i) && !self.matchings.contains_key(&i))
685 .collect();
686
687 let tgt_internal: Vec<usize> = (0..self.tgt_tree.nodes.len())
688 .rev()
689 .filter(|&i| !self.tgt_tree.is_leaf(i))
690 .collect();
691
692 let mut matched_tgt: HashSet<usize> = self.matchings.values().cloned().collect();
693
694 let mut heap = BinaryHeap::new();
695 let mut counter = 0usize;
696
697 for &si in &src_internal {
698 let src_leaves: HashSet<usize> =
699 self.src_tree.leaf_descendants(si).into_iter().collect();
700 let src_sql = node_sql(&self.src_tree.nodes[si], self.config.dialect);
701
702 for &ti in &tgt_internal {
703 if matched_tgt.contains(&ti) {
704 continue;
705 }
706 if !is_same_type(&self.src_tree.nodes[si], &self.tgt_tree.nodes[ti]) {
707 continue;
708 }
709
710 let tgt_leaves: HashSet<usize> =
711 self.tgt_tree.leaf_descendants(ti).into_iter().collect();
712
713 let common = src_leaves
715 .iter()
716 .filter(|&&sl| {
717 self.matchings
718 .get(&sl)
719 .map_or(false, |&tl| tgt_leaves.contains(&tl))
720 })
721 .count();
722
723 let max_leaves = src_leaves.len().max(tgt_leaves.len());
724 if max_leaves == 0 {
725 continue;
726 }
727
728 let leaf_sim = common as f64 / max_leaves as f64;
729
730 let t = if src_leaves.len().min(tgt_leaves.len()) <= 4 {
732 0.4
733 } else {
734 self.config.t
735 };
736
737 let tgt_sql = node_sql(&self.tgt_tree.nodes[ti], self.config.dialect);
738 let dice = dice_coefficient(&src_sql, &tgt_sql);
739
740 if leaf_sim >= 0.8 || (leaf_sim >= t && dice >= self.config.f) {
741 heap.push(MatchCandidate {
742 score: leaf_sim,
743 parent_sim: parent_similarity_score(
744 si,
745 ti,
746 &self.src_tree,
747 &self.tgt_tree,
748 &self.matchings,
749 ),
750 counter,
751 src_idx: si,
752 tgt_idx: ti,
753 });
754 counter += 1;
755 }
756 }
757 }
758
759 while let Some(m) = heap.pop() {
760 if self.matchings.contains_key(&m.src_idx) || matched_tgt.contains(&m.tgt_idx) {
761 continue;
762 }
763 self.matchings.insert(m.src_idx, m.tgt_idx);
764 matched_tgt.insert(m.tgt_idx);
765 }
766 }
767
768 fn generate_edits(&self, delta_only: bool) -> Vec<Edit> {
771 let mut edits = Vec::new();
772 let matched_tgt: HashSet<usize> = self.matchings.values().cloned().collect();
773
774 let reverse_matchings: HashMap<usize, usize> = self
776 .matchings
777 .iter()
778 .map(|(&s, &t)| (t, s))
779 .collect();
780
781 let mut moved_src: HashSet<usize> = HashSet::new();
783
784 for (&src_parent, &tgt_parent) in &self.matchings {
785 if self.src_tree.is_leaf(src_parent) {
786 continue;
787 }
788
789 let src_children = &self.src_tree.children_indices[src_parent];
790 let tgt_children = &self.tgt_tree.children_indices[tgt_parent];
791
792 if src_children.is_empty() || tgt_children.is_empty() {
793 continue;
794 }
795
796 let src_seq: Vec<usize> = src_children
798 .iter()
799 .filter_map(|&sc| self.matchings.get(&sc).cloned())
800 .collect();
801
802 let tgt_seq: Vec<usize> = tgt_children
804 .iter()
805 .filter(|&&tc| reverse_matchings.contains_key(&tc))
806 .cloned()
807 .collect();
808
809 let lcs_pairs = lcs(&src_seq, &tgt_seq, |a, b| a == b);
810 let lcs_tgt_set: HashSet<usize> = lcs_pairs.iter().map(|&(i, _)| src_seq[i]).collect();
811
812 for &sc in src_children {
814 if let Some(&tc) = self.matchings.get(&sc) {
815 if !lcs_tgt_set.contains(&tc) {
816 moved_src.insert(sc);
817 }
818 }
819 }
820 }
821
822 for i in 0..self.src_tree.nodes.len() {
824 if !self.matchings.contains_key(&i) {
825 edits.push(Edit::Remove {
826 expression: self.src_tree.nodes[i].clone(),
827 });
828 }
829 }
830
831 for i in 0..self.tgt_tree.nodes.len() {
833 if !matched_tgt.contains(&i) {
834 edits.push(Edit::Insert {
835 expression: self.tgt_tree.nodes[i].clone(),
836 });
837 }
838 }
839
840 for (&src_idx, &tgt_idx) in &self.matchings {
842 let src_node = &self.src_tree.nodes[src_idx];
843 let tgt_node = &self.tgt_tree.nodes[tgt_idx];
844
845 let src_sql = node_sql(src_node, self.config.dialect);
846 let tgt_sql = node_sql(tgt_node, self.config.dialect);
847
848 if is_updatable(src_node) && src_sql != tgt_sql {
849 edits.push(Edit::Update {
850 source: src_node.clone(),
851 target: tgt_node.clone(),
852 });
853 } else if has_non_expression_leaf_change(src_node, tgt_node) {
854 edits.push(Edit::Update {
855 source: src_node.clone(),
856 target: tgt_node.clone(),
857 });
858 } else if moved_src.contains(&src_idx) {
859 edits.push(Edit::Move {
860 source: src_node.clone(),
861 target: tgt_node.clone(),
862 });
863 } else if !delta_only {
864 edits.push(Edit::Keep {
865 source: src_node.clone(),
866 target: tgt_node.clone(),
867 });
868 }
869 }
870
871 edits
872 }
873}
874
875#[cfg(test)]
880mod tests {
881 use super::*;
882 use crate::dialects::{Dialect, DialectType};
883
884 fn parse(sql: &str) -> Expression {
885 let dialect = Dialect::get(DialectType::Generic);
886 let ast = dialect.parse(sql).unwrap();
887 ast.into_iter().next().unwrap()
888 }
889
890 fn count_edits(edits: &[Edit]) -> (usize, usize, usize, usize, usize) {
891 let mut insert = 0;
892 let mut remove = 0;
893 let mut r#move = 0;
894 let mut update = 0;
895 let mut keep = 0;
896 for e in edits {
897 match e {
898 Edit::Insert { .. } => insert += 1,
899 Edit::Remove { .. } => remove += 1,
900 Edit::Move { .. } => r#move += 1,
901 Edit::Update { .. } => update += 1,
902 Edit::Keep { .. } => keep += 1,
903 }
904 }
905 (insert, remove, r#move, update, keep)
906 }
907
908 #[test]
909 fn test_diff_identical() {
910 let source = parse("SELECT a FROM t");
911 let target = parse("SELECT a FROM t");
912
913 let edits = diff(&source, &target, false);
914
915 assert!(
917 edits.iter().all(|e| matches!(e, Edit::Keep { .. })),
918 "Expected only Keep edits, got: {:?}",
919 count_edits(&edits)
920 );
921 }
922
923 #[test]
924 fn test_diff_simple_change() {
925 let source = parse("SELECT a FROM t");
926 let target = parse("SELECT b FROM t");
927
928 let edits = diff(&source, &target, true);
929
930 assert!(!edits.is_empty());
933 assert!(has_changes(&edits));
934 let (ins, rem, _, _, _) = count_edits(&edits);
935 assert!(ins > 0 && rem > 0, "Expected Insert+Remove, got ins={ins} rem={rem}");
936 }
937
938 #[test]
939 fn test_diff_similar_column_update() {
940 let source = parse("SELECT col_a FROM t");
941 let target = parse("SELECT col_b FROM t");
942
943 let edits = diff(&source, &target, true);
944
945 assert!(has_changes(&edits));
947 assert!(
948 edits.iter().any(|e| matches!(e, Edit::Update { .. })),
949 "Expected Update for similar column name change"
950 );
951 }
952
953 #[test]
954 fn test_operator_change() {
955 let source = parse("SELECT a + b FROM t");
956 let target = parse("SELECT a - b FROM t");
957
958 let edits = diff(&source, &target, true);
959
960 assert!(!edits.is_empty());
963 let (ins, rem, _, _, _) = count_edits(&edits);
964 assert!(
965 ins > 0 && rem > 0,
966 "Expected Insert and Remove for operator change, got ins={ins} rem={rem}"
967 );
968 }
969
970 #[test]
971 fn test_column_added() {
972 let source = parse("SELECT a, b FROM t");
973 let target = parse("SELECT a, b, c FROM t");
974
975 let edits = diff(&source, &target, true);
976
977 assert!(
979 edits.iter().any(|e| matches!(e, Edit::Insert { .. })),
980 "Expected at least one Insert edit for added column"
981 );
982 }
983
984 #[test]
985 fn test_column_removed() {
986 let source = parse("SELECT a, b, c FROM t");
987 let target = parse("SELECT a, c FROM t");
988
989 let edits = diff(&source, &target, true);
990
991 assert!(
993 edits.iter().any(|e| matches!(e, Edit::Remove { .. })),
994 "Expected at least one Remove edit for removed column"
995 );
996 }
997
998 #[test]
999 fn test_table_updated() {
1000 let source = parse("SELECT a FROM table_one");
1001 let target = parse("SELECT a FROM table_two");
1002
1003 let edits = diff(&source, &target, true);
1004
1005 assert!(!edits.is_empty());
1007 assert!(has_changes(&edits));
1008 assert!(
1009 edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1010 "Expected Update for table name change"
1011 );
1012 }
1013
1014 #[test]
1015 fn test_lambda() {
1016 let source = parse("SELECT TRANSFORM(arr, a -> a + 1) FROM t");
1017 let target = parse("SELECT TRANSFORM(arr, b -> b + 1) FROM t");
1018
1019 let edits = diff(&source, &target, true);
1020
1021 assert!(has_changes(&edits));
1023 }
1024
1025 #[test]
1026 fn test_node_position_changed() {
1027 let source = parse("SELECT a, b, c FROM t");
1028 let target = parse("SELECT c, a, b FROM t");
1029
1030 let edits = diff(&source, &target, false);
1031
1032 let (_, _, moves, _, _) = count_edits(&edits);
1034 assert!(moves > 0, "Expected at least one Move for reordered columns");
1035 }
1036
1037 #[test]
1038 fn test_cte_changes() {
1039 let source = parse("WITH cte AS (SELECT a FROM t WHERE a > 1000) SELECT * FROM cte");
1040 let target = parse("WITH cte AS (SELECT a FROM t WHERE a > 2000) SELECT * FROM cte");
1041
1042 let edits = diff(&source, &target, true);
1043
1044 assert!(has_changes(&edits));
1046 assert!(
1047 edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1048 "Expected Update for literal change in CTE"
1049 );
1050 }
1051
1052 #[test]
1053 fn test_join_changes() {
1054 let source = parse("SELECT a FROM t LEFT JOIN s ON t.id = s.id");
1055 let target = parse("SELECT a FROM t RIGHT JOIN s ON t.id = s.id");
1056
1057 let edits = diff(&source, &target, true);
1058
1059 assert!(has_changes(&edits));
1062 let (ins, rem, _, _, _) = count_edits(&edits);
1063 assert!(
1064 ins > 0 && rem > 0,
1065 "Expected Insert+Remove for join kind change, got ins={ins} rem={rem}"
1066 );
1067 }
1068
1069 #[test]
1070 fn test_window_functions() {
1071 let source = parse("SELECT ROW_NUMBER() OVER (ORDER BY a) FROM t");
1072 let target = parse("SELECT RANK() OVER (ORDER BY a) FROM t");
1073
1074 let edits = diff(&source, &target, true);
1075
1076 assert!(has_changes(&edits));
1078 }
1079
1080 #[test]
1081 fn test_non_expression_leaf_delta() {
1082 let source = parse("SELECT a FROM t UNION SELECT b FROM s");
1083 let target = parse("SELECT a FROM t UNION ALL SELECT b FROM s");
1084
1085 let edits = diff(&source, &target, true);
1086
1087 assert!(has_changes(&edits));
1089 assert!(
1090 edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1091 "Expected Update for UNION → UNION ALL"
1092 );
1093 }
1094
1095 #[test]
1096 fn test_is_leaf() {
1097 let tree = IndexedTree::build(&parse("SELECT a, 1 FROM t"));
1098 assert!(!tree.is_leaf(0));
1100 let leaves = tree.leaf_indices();
1102 assert!(!leaves.is_empty());
1103 for &l in &leaves {
1105 assert!(tree.children_indices[l].is_empty());
1106 }
1107 }
1108
1109 #[test]
1110 fn test_same_type_special_cases() {
1111 let a = Expression::Literal(crate::expressions::Literal::Number("1".to_string()));
1113 let b = Expression::Literal(crate::expressions::Literal::String("abc".to_string()));
1114 assert!(is_same_type(&a, &b));
1115
1116 let c = Expression::Null(crate::expressions::Null);
1118 assert!(!is_same_type(&a, &c));
1119
1120 let join_left = Expression::Join(Box::new(crate::expressions::Join {
1122 this: Expression::Table(crate::expressions::TableRef::new("t")),
1123 on: None,
1124 using: vec![],
1125 kind: crate::expressions::JoinKind::Left,
1126 use_inner_keyword: false,
1127 use_outer_keyword: false,
1128 deferred_condition: false,
1129 join_hint: None,
1130 match_condition: None,
1131 pivots: vec![],
1132 }));
1133 let join_right = Expression::Join(Box::new(crate::expressions::Join {
1134 this: Expression::Table(crate::expressions::TableRef::new("t")),
1135 on: None,
1136 using: vec![],
1137 kind: crate::expressions::JoinKind::Right,
1138 use_inner_keyword: false,
1139 use_outer_keyword: false,
1140 deferred_condition: false,
1141 join_hint: None,
1142 match_condition: None,
1143 pivots: vec![],
1144 }));
1145 assert!(!is_same_type(&join_left, &join_right));
1146 }
1147
1148 #[test]
1149 fn test_comments_excluded() {
1150 let source = parse("SELECT a FROM t");
1152 let target = parse("SELECT a FROM t");
1153
1154 let edits = diff(&source, &target, true);
1155
1156 assert!(edits.is_empty() || !has_changes(&edits));
1158 }
1159}