1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::generator::{Generator, GeneratorConfig};
11use serde::{Deserialize, Serialize};
12use std::cmp::Ordering;
13use std::collections::{BinaryHeap, HashMap, HashSet};
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17#[serde(tag = "type", rename_all = "snake_case")]
18pub enum Edit {
19 Insert { expression: Expression },
21 Remove { expression: Expression },
23 Move {
25 source: Expression,
26 target: Expression,
27 },
28 Update {
30 source: Expression,
31 target: Expression,
32 },
33 Keep {
35 source: Expression,
36 target: Expression,
37 },
38}
39
40impl Edit {
41 pub fn is_change(&self) -> bool {
43 !matches!(self, Edit::Keep { .. })
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct DiffConfig {
50 pub f: f64,
52 pub t: f64,
54 pub dialect: Option<DialectType>,
56}
57
58impl Default for DiffConfig {
59 fn default() -> Self {
60 Self {
61 f: 0.6,
62 t: 0.6,
63 dialect: None,
64 }
65 }
66}
67
68pub fn diff(source: &Expression, target: &Expression, delta_only: bool) -> Vec<Edit> {
89 let config = DiffConfig::default();
90 diff_with_config(source, target, delta_only, &config)
91}
92
93pub fn diff_with_config(
95 source: &Expression,
96 target: &Expression,
97 delta_only: bool,
98 config: &DiffConfig,
99) -> Vec<Edit> {
100 let mut distiller = ChangeDistiller::new(config.clone());
101 distiller.diff(source, target, delta_only)
102}
103
104pub fn has_changes(edits: &[Edit]) -> bool {
106 edits.iter().any(|e| e.is_change())
107}
108
109pub fn changes_only(edits: Vec<Edit>) -> Vec<Edit> {
111 edits.into_iter().filter(|e| e.is_change()).collect()
112}
113
114struct IndexedTree {
120 nodes: Vec<Expression>,
121 parents: Vec<Option<usize>>,
122 children_indices: Vec<Vec<usize>>,
123}
124
125impl IndexedTree {
126 fn empty() -> Self {
127 Self {
128 nodes: Vec::new(),
129 parents: Vec::new(),
130 children_indices: Vec::new(),
131 }
132 }
133
134 fn build(root: &Expression) -> Self {
135 let mut tree = Self::empty();
136 tree.add_expr(root, None);
137 tree
138 }
139
140 fn add_expr(&mut self, expr: &Expression, parent_idx: Option<usize>) {
141 if matches!(expr, Expression::Identifier(_)) {
143 return;
144 }
145
146 let idx = self.nodes.len();
147 self.nodes.push(expr.clone());
148 self.parents.push(parent_idx);
149 self.children_indices.push(Vec::new());
150
151 if let Some(p) = parent_idx {
152 self.children_indices[p].push(idx);
153 }
154
155 self.add_children(expr, idx);
156 }
157
158 fn add_children(&mut self, expr: &Expression, parent_idx: usize) {
159 match expr {
160 Expression::Select(select) => {
161 if let Some(with) = &select.with {
162 for cte in &with.ctes {
163 self.add_expr(&Expression::Cte(Box::new(cte.clone())), Some(parent_idx));
164 }
165 }
166 for e in &select.expressions {
167 self.add_expr(e, Some(parent_idx));
168 }
169 if let Some(from) = &select.from {
170 for e in &from.expressions {
171 self.add_expr(e, Some(parent_idx));
172 }
173 }
174 for join in &select.joins {
175 self.add_expr(&Expression::Join(Box::new(join.clone())), Some(parent_idx));
176 }
177 if let Some(w) = &select.where_clause {
178 self.add_expr(&w.this, Some(parent_idx));
179 }
180 if let Some(gb) = &select.group_by {
181 for e in &gb.expressions {
182 self.add_expr(e, Some(parent_idx));
183 }
184 }
185 if let Some(h) = &select.having {
186 self.add_expr(&h.this, Some(parent_idx));
187 }
188 if let Some(ob) = &select.order_by {
189 for o in &ob.expressions {
190 self.add_expr(&Expression::Ordered(Box::new(o.clone())), Some(parent_idx));
191 }
192 }
193 if let Some(limit) = &select.limit {
194 self.add_expr(&limit.this, Some(parent_idx));
195 }
196 if let Some(offset) = &select.offset {
197 self.add_expr(&offset.this, Some(parent_idx));
198 }
199 }
200 Expression::Alias(alias) => {
201 self.add_expr(&alias.this, Some(parent_idx));
202 }
203 Expression::And(op)
204 | Expression::Or(op)
205 | Expression::Eq(op)
206 | Expression::Neq(op)
207 | Expression::Lt(op)
208 | Expression::Lte(op)
209 | Expression::Gt(op)
210 | Expression::Gte(op)
211 | Expression::Add(op)
212 | Expression::Sub(op)
213 | Expression::Mul(op)
214 | Expression::Div(op)
215 | Expression::Mod(op)
216 | Expression::BitwiseAnd(op)
217 | Expression::BitwiseOr(op)
218 | Expression::BitwiseXor(op)
219 | Expression::Concat(op) => {
220 self.add_expr(&op.left, Some(parent_idx));
221 self.add_expr(&op.right, Some(parent_idx));
222 }
223 Expression::Like(op) | Expression::ILike(op) => {
224 self.add_expr(&op.left, Some(parent_idx));
225 self.add_expr(&op.right, Some(parent_idx));
226 }
227 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
228 self.add_expr(&u.this, Some(parent_idx));
229 }
230 Expression::Function(func) => {
231 for arg in &func.args {
232 self.add_expr(arg, Some(parent_idx));
233 }
234 }
235 Expression::AggregateFunction(func) => {
236 for arg in &func.args {
237 self.add_expr(arg, Some(parent_idx));
238 }
239 }
240 Expression::Join(j) => {
241 self.add_expr(&j.this, Some(parent_idx));
242 if let Some(on) = &j.on {
243 self.add_expr(on, Some(parent_idx));
244 }
245 }
246 Expression::Anonymous(a) => {
247 for arg in &a.expressions {
248 self.add_expr(arg, Some(parent_idx));
249 }
250 }
251 Expression::WindowFunction(wf) => {
252 self.add_expr(&wf.this, Some(parent_idx));
253 }
254 Expression::Cast(cast) => {
255 self.add_expr(&cast.this, Some(parent_idx));
256 }
257 Expression::Subquery(sq) => {
258 self.add_expr(&sq.this, Some(parent_idx));
259 }
260 Expression::Paren(p) => {
261 self.add_expr(&p.this, Some(parent_idx));
262 }
263 Expression::Union(u) => {
264 self.add_expr(&u.left, Some(parent_idx));
265 self.add_expr(&u.right, Some(parent_idx));
266 }
267 Expression::Intersect(i) => {
268 self.add_expr(&i.left, Some(parent_idx));
269 self.add_expr(&i.right, Some(parent_idx));
270 }
271 Expression::Except(e) => {
272 self.add_expr(&e.left, Some(parent_idx));
273 self.add_expr(&e.right, Some(parent_idx));
274 }
275 Expression::Cte(cte) => {
276 self.add_expr(&cte.this, Some(parent_idx));
277 }
278 Expression::Case(c) => {
279 if let Some(operand) = &c.operand {
280 self.add_expr(operand, Some(parent_idx));
281 }
282 for (when, then) in &c.whens {
283 self.add_expr(when, Some(parent_idx));
284 self.add_expr(then, Some(parent_idx));
285 }
286 if let Some(else_) = &c.else_ {
287 self.add_expr(else_, Some(parent_idx));
288 }
289 }
290 Expression::In(i) => {
291 self.add_expr(&i.this, Some(parent_idx));
292 for e in &i.expressions {
293 self.add_expr(e, Some(parent_idx));
294 }
295 if let Some(q) = &i.query {
296 self.add_expr(q, Some(parent_idx));
297 }
298 }
299 Expression::Between(b) => {
300 self.add_expr(&b.this, Some(parent_idx));
301 self.add_expr(&b.low, Some(parent_idx));
302 self.add_expr(&b.high, Some(parent_idx));
303 }
304 Expression::IsNull(i) => {
305 self.add_expr(&i.this, Some(parent_idx));
306 }
307 Expression::Exists(e) => {
308 self.add_expr(&e.this, Some(parent_idx));
309 }
310 Expression::Ordered(o) => {
311 self.add_expr(&o.this, Some(parent_idx));
312 }
313 Expression::Lambda(l) => {
314 self.add_expr(&l.body, Some(parent_idx));
315 }
316 Expression::Coalesce(c) => {
317 for e in &c.expressions {
318 self.add_expr(e, Some(parent_idx));
319 }
320 }
321 Expression::Tuple(t) => {
322 for e in &t.expressions {
323 self.add_expr(e, Some(parent_idx));
324 }
325 }
326 Expression::Array(a) => {
327 for e in &a.expressions {
328 self.add_expr(e, Some(parent_idx));
329 }
330 }
331 Expression::Literal(_)
333 | Expression::Boolean(_)
334 | Expression::Null(_)
335 | Expression::Column(_)
336 | Expression::Table(_)
337 | Expression::Star(_)
338 | Expression::DataType(_)
339 | Expression::CurrentDate(_)
340 | Expression::CurrentTime(_)
341 | Expression::CurrentTimestamp(_) => {}
342 other => {
344 use crate::traversal::ExpressionWalk;
345 for child in other.children() {
346 if !matches!(child, Expression::Identifier(_)) {
347 self.add_expr(child, Some(parent_idx));
348 }
349 }
350 }
351 }
352 }
353
354 fn is_leaf(&self, idx: usize) -> bool {
355 self.children_indices[idx].is_empty()
356 }
357
358 fn leaf_indices(&self) -> Vec<usize> {
359 (0..self.nodes.len()).filter(|&i| self.is_leaf(i)).collect()
360 }
361
362 fn leaf_descendants(&self, idx: usize) -> Vec<usize> {
364 let mut result = Vec::new();
365 let mut stack = vec![idx];
366 while let Some(i) = stack.pop() {
367 if self.is_leaf(i) {
368 result.push(i);
369 }
370 for &child in &self.children_indices[i] {
371 stack.push(child);
372 }
373 }
374 result
375 }
376}
377
378fn dice_coefficient(a: &str, b: &str) -> f64 {
384 if a.len() < 2 || b.len() < 2 {
386 return if a == b { 1.0 } else { 0.0 };
387 }
388 let a_bigrams = bigram_histo(a);
389 let b_bigrams = bigram_histo(b);
390 let common: usize = a_bigrams
391 .iter()
392 .map(|(k, v)| v.min(b_bigrams.get(k).unwrap_or(&0)))
393 .sum();
394 let total: usize = a_bigrams.values().sum::<usize>() + b_bigrams.values().sum::<usize>();
395 if total == 0 {
396 1.0
397 } else {
398 2.0 * common as f64 / total as f64
399 }
400}
401
402fn bigram_histo(s: &str) -> HashMap<(char, char), usize> {
404 let chars: Vec<char> = s.chars().collect();
405 let mut map = HashMap::new();
406 for w in chars.windows(2) {
407 *map.entry((w[0], w[1])).or_insert(0) += 1;
408 }
409 map
410}
411
412fn node_sql(expr: &Expression, dialect: Option<DialectType>) -> String {
414 match dialect {
415 Some(d) => {
416 let config = GeneratorConfig {
417 dialect: Some(d),
418 ..GeneratorConfig::default()
419 };
420 let mut gen = Generator::with_config(config);
421 gen.generate(expr).unwrap_or_default()
422 }
423 None => Generator::sql(expr).unwrap_or_default(),
424 }
425}
426
427fn is_same_type(a: &Expression, b: &Expression) -> bool {
432 if std::mem::discriminant(a) != std::mem::discriminant(b) {
433 return false;
434 }
435 match (a, b) {
436 (Expression::Join(ja), Expression::Join(jb)) => ja.kind == jb.kind,
437 (Expression::Anonymous(aa), Expression::Anonymous(ab)) => {
438 Generator::sql(&aa.this).unwrap_or_default()
439 == Generator::sql(&ab.this).unwrap_or_default()
440 }
441 _ => true,
442 }
443}
444
445fn parent_similarity_score(
447 src_idx: usize,
448 tgt_idx: usize,
449 src_tree: &IndexedTree,
450 tgt_tree: &IndexedTree,
451 matchings: &HashMap<usize, usize>,
452) -> usize {
453 let mut score = 0;
454 let mut s = src_tree.parents[src_idx];
455 let mut t = tgt_tree.parents[tgt_idx];
456 while let (Some(sp), Some(tp)) = (s, t) {
457 if matchings.get(&sp) == Some(&tp) {
458 score += 1;
459 s = src_tree.parents[sp];
460 t = tgt_tree.parents[tp];
461 } else {
462 break;
463 }
464 }
465 score
466}
467
468fn is_updatable(expr: &Expression) -> bool {
472 matches!(
473 expr,
474 Expression::Alias(_)
475 | Expression::Boolean(_)
476 | Expression::Column(_)
477 | Expression::DataType(_)
478 | Expression::Lambda(_)
479 | Expression::Literal(_)
480 | Expression::Table(_)
481 | Expression::WindowFunction(_)
482 )
483}
484
485fn has_non_expression_leaf_change(a: &Expression, b: &Expression) -> bool {
489 match (a, b) {
490 (Expression::Union(ua), Expression::Union(ub)) => {
491 ua.all != ub.all || ua.distinct != ub.distinct
492 }
493 (Expression::Intersect(ia), Expression::Intersect(ib)) => {
494 ia.all != ib.all || ia.distinct != ib.distinct
495 }
496 (Expression::Except(ea), Expression::Except(eb)) => {
497 ea.all != eb.all || ea.distinct != eb.distinct
498 }
499 (Expression::Ordered(oa), Expression::Ordered(ob)) => {
500 oa.desc != ob.desc || oa.nulls_first != ob.nulls_first
501 }
502 (Expression::Join(ja), Expression::Join(jb)) => ja.kind != jb.kind,
503 _ => false,
504 }
505}
506
507fn lcs<T, F>(a: &[T], b: &[T], eq_fn: F) -> Vec<(usize, usize)>
509where
510 F: Fn(&T, &T) -> bool,
511{
512 let m = a.len();
513 let n = b.len();
514 let mut dp = vec![vec![0usize; n + 1]; m + 1];
515 for i in 1..=m {
516 for j in 1..=n {
517 if eq_fn(&a[i - 1], &b[j - 1]) {
518 dp[i][j] = dp[i - 1][j - 1] + 1;
519 } else {
520 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
521 }
522 }
523 }
524 let mut result = Vec::new();
525 let mut i = m;
526 let mut j = n;
527 while i > 0 && j > 0 {
528 if eq_fn(&a[i - 1], &b[j - 1]) {
529 result.push((i - 1, j - 1));
530 i -= 1;
531 j -= 1;
532 } else if dp[i - 1][j] > dp[i][j - 1] {
533 i -= 1;
534 } else {
535 j -= 1;
536 }
537 }
538 result.reverse();
539 result
540}
541
542#[derive(PartialEq)]
547struct MatchCandidate {
548 score: f64,
549 parent_sim: usize,
550 counter: usize, src_idx: usize,
552 tgt_idx: usize,
553}
554
555impl Eq for MatchCandidate {}
556
557impl PartialOrd for MatchCandidate {
558 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
559 Some(self.cmp(other))
560 }
561}
562
563impl Ord for MatchCandidate {
564 fn cmp(&self, other: &Self) -> Ordering {
565 self.score
566 .partial_cmp(&other.score)
567 .unwrap_or(Ordering::Equal)
568 .then_with(|| self.parent_sim.cmp(&other.parent_sim))
569 .then_with(|| self.counter.cmp(&other.counter))
570 }
571}
572
573struct ChangeDistiller {
578 config: DiffConfig,
579 src_tree: IndexedTree,
580 tgt_tree: IndexedTree,
581 matchings: HashMap<usize, usize>, }
583
584impl ChangeDistiller {
585 fn new(config: DiffConfig) -> Self {
586 Self {
587 config,
588 src_tree: IndexedTree::empty(),
589 tgt_tree: IndexedTree::empty(),
590 matchings: HashMap::new(),
591 }
592 }
593
594 fn diff(&mut self, source: &Expression, target: &Expression, delta_only: bool) -> Vec<Edit> {
595 self.src_tree = IndexedTree::build(source);
596 self.tgt_tree = IndexedTree::build(target);
597
598 self.match_leaves();
600
601 self.match_internal_nodes();
603
604 self.generate_edits(delta_only)
606 }
607
608 fn match_leaves(&mut self) {
611 let src_leaves = self.src_tree.leaf_indices();
612 let tgt_leaves = self.tgt_tree.leaf_indices();
613
614 let src_sql: Vec<String> = src_leaves
616 .iter()
617 .map(|&i| node_sql(&self.src_tree.nodes[i], self.config.dialect))
618 .collect();
619 let tgt_sql: Vec<String> = tgt_leaves
620 .iter()
621 .map(|&i| node_sql(&self.tgt_tree.nodes[i], self.config.dialect))
622 .collect();
623
624 let mut heap = BinaryHeap::new();
625 let mut counter = 0usize;
626
627 for (si_pos, &si) in src_leaves.iter().enumerate() {
628 for (ti_pos, &ti) in tgt_leaves.iter().enumerate() {
629 if !is_same_type(&self.src_tree.nodes[si], &self.tgt_tree.nodes[ti]) {
630 continue;
631 }
632 let score = dice_coefficient(&src_sql[si_pos], &tgt_sql[ti_pos]);
633 if score >= self.config.t {
634 let parent_sim = parent_similarity_score(
635 si,
636 ti,
637 &self.src_tree,
638 &self.tgt_tree,
639 &self.matchings,
640 );
641 heap.push(MatchCandidate {
642 score,
643 parent_sim,
644 counter,
645 src_idx: si,
646 tgt_idx: ti,
647 });
648 counter += 1;
649 }
650 }
651 }
652
653 let mut matched_src: HashSet<usize> = HashSet::new();
654 let mut matched_tgt: HashSet<usize> = HashSet::new();
655
656 while let Some(m) = heap.pop() {
657 if matched_src.contains(&m.src_idx) || matched_tgt.contains(&m.tgt_idx) {
658 continue;
659 }
660 self.matchings.insert(m.src_idx, m.tgt_idx);
661 matched_src.insert(m.src_idx);
662 matched_tgt.insert(m.tgt_idx);
663 }
664 }
665
666 fn match_internal_nodes(&mut self) {
669 let src_internal: Vec<usize> = (0..self.src_tree.nodes.len())
672 .rev()
673 .filter(|&i| !self.src_tree.is_leaf(i) && !self.matchings.contains_key(&i))
674 .collect();
675
676 let tgt_internal: Vec<usize> = (0..self.tgt_tree.nodes.len())
677 .rev()
678 .filter(|&i| !self.tgt_tree.is_leaf(i))
679 .collect();
680
681 let mut matched_tgt: HashSet<usize> = self.matchings.values().cloned().collect();
682
683 let mut heap = BinaryHeap::new();
684 let mut counter = 0usize;
685
686 for &si in &src_internal {
687 let src_leaves: HashSet<usize> =
688 self.src_tree.leaf_descendants(si).into_iter().collect();
689 let src_sql = node_sql(&self.src_tree.nodes[si], self.config.dialect);
690
691 for &ti in &tgt_internal {
692 if matched_tgt.contains(&ti) {
693 continue;
694 }
695 if !is_same_type(&self.src_tree.nodes[si], &self.tgt_tree.nodes[ti]) {
696 continue;
697 }
698
699 let tgt_leaves: HashSet<usize> =
700 self.tgt_tree.leaf_descendants(ti).into_iter().collect();
701
702 let common = src_leaves
704 .iter()
705 .filter(|&&sl| {
706 self.matchings
707 .get(&sl)
708 .map_or(false, |&tl| tgt_leaves.contains(&tl))
709 })
710 .count();
711
712 let max_leaves = src_leaves.len().max(tgt_leaves.len());
713 if max_leaves == 0 {
714 continue;
715 }
716
717 let leaf_sim = common as f64 / max_leaves as f64;
718
719 let t = if src_leaves.len().min(tgt_leaves.len()) <= 4 {
721 0.4
722 } else {
723 self.config.t
724 };
725
726 let tgt_sql = node_sql(&self.tgt_tree.nodes[ti], self.config.dialect);
727 let dice = dice_coefficient(&src_sql, &tgt_sql);
728
729 if leaf_sim >= 0.8 || (leaf_sim >= t && dice >= self.config.f) {
730 heap.push(MatchCandidate {
731 score: leaf_sim,
732 parent_sim: parent_similarity_score(
733 si,
734 ti,
735 &self.src_tree,
736 &self.tgt_tree,
737 &self.matchings,
738 ),
739 counter,
740 src_idx: si,
741 tgt_idx: ti,
742 });
743 counter += 1;
744 }
745 }
746 }
747
748 while let Some(m) = heap.pop() {
749 if self.matchings.contains_key(&m.src_idx) || matched_tgt.contains(&m.tgt_idx) {
750 continue;
751 }
752 self.matchings.insert(m.src_idx, m.tgt_idx);
753 matched_tgt.insert(m.tgt_idx);
754 }
755 }
756
757 fn generate_edits(&self, delta_only: bool) -> Vec<Edit> {
760 let mut edits = Vec::new();
761 let matched_tgt: HashSet<usize> = self.matchings.values().cloned().collect();
762
763 let reverse_matchings: HashMap<usize, usize> =
765 self.matchings.iter().map(|(&s, &t)| (t, s)).collect();
766
767 let mut moved_src: HashSet<usize> = HashSet::new();
769
770 for (&src_parent, &tgt_parent) in &self.matchings {
771 if self.src_tree.is_leaf(src_parent) {
772 continue;
773 }
774
775 let src_children = &self.src_tree.children_indices[src_parent];
776 let tgt_children = &self.tgt_tree.children_indices[tgt_parent];
777
778 if src_children.is_empty() || tgt_children.is_empty() {
779 continue;
780 }
781
782 let src_seq: Vec<usize> = src_children
784 .iter()
785 .filter_map(|&sc| self.matchings.get(&sc).cloned())
786 .collect();
787
788 let tgt_seq: Vec<usize> = tgt_children
790 .iter()
791 .filter(|&&tc| reverse_matchings.contains_key(&tc))
792 .cloned()
793 .collect();
794
795 let lcs_pairs = lcs(&src_seq, &tgt_seq, |a, b| a == b);
796 let lcs_tgt_set: HashSet<usize> = lcs_pairs.iter().map(|&(i, _)| src_seq[i]).collect();
797
798 for &sc in src_children {
800 if let Some(&tc) = self.matchings.get(&sc) {
801 if !lcs_tgt_set.contains(&tc) {
802 moved_src.insert(sc);
803 }
804 }
805 }
806 }
807
808 for i in 0..self.src_tree.nodes.len() {
810 if !self.matchings.contains_key(&i) {
811 edits.push(Edit::Remove {
812 expression: self.src_tree.nodes[i].clone(),
813 });
814 }
815 }
816
817 for i in 0..self.tgt_tree.nodes.len() {
819 if !matched_tgt.contains(&i) {
820 edits.push(Edit::Insert {
821 expression: self.tgt_tree.nodes[i].clone(),
822 });
823 }
824 }
825
826 for (&src_idx, &tgt_idx) in &self.matchings {
828 let src_node = &self.src_tree.nodes[src_idx];
829 let tgt_node = &self.tgt_tree.nodes[tgt_idx];
830
831 let src_sql = node_sql(src_node, self.config.dialect);
832 let tgt_sql = node_sql(tgt_node, self.config.dialect);
833
834 if is_updatable(src_node) && src_sql != tgt_sql {
835 edits.push(Edit::Update {
836 source: src_node.clone(),
837 target: tgt_node.clone(),
838 });
839 } else if has_non_expression_leaf_change(src_node, tgt_node) {
840 edits.push(Edit::Update {
841 source: src_node.clone(),
842 target: tgt_node.clone(),
843 });
844 } else if moved_src.contains(&src_idx) {
845 edits.push(Edit::Move {
846 source: src_node.clone(),
847 target: tgt_node.clone(),
848 });
849 } else if !delta_only {
850 edits.push(Edit::Keep {
851 source: src_node.clone(),
852 target: tgt_node.clone(),
853 });
854 }
855 }
856
857 edits
858 }
859}
860
861#[cfg(test)]
866mod tests {
867 use super::*;
868 use crate::dialects::{Dialect, DialectType};
869
870 fn parse(sql: &str) -> Expression {
871 let dialect = Dialect::get(DialectType::Generic);
872 let ast = dialect.parse(sql).unwrap();
873 ast.into_iter().next().unwrap()
874 }
875
876 fn count_edits(edits: &[Edit]) -> (usize, usize, usize, usize, usize) {
877 let mut insert = 0;
878 let mut remove = 0;
879 let mut r#move = 0;
880 let mut update = 0;
881 let mut keep = 0;
882 for e in edits {
883 match e {
884 Edit::Insert { .. } => insert += 1,
885 Edit::Remove { .. } => remove += 1,
886 Edit::Move { .. } => r#move += 1,
887 Edit::Update { .. } => update += 1,
888 Edit::Keep { .. } => keep += 1,
889 }
890 }
891 (insert, remove, r#move, update, keep)
892 }
893
894 #[test]
895 fn test_diff_identical() {
896 let source = parse("SELECT a FROM t");
897 let target = parse("SELECT a FROM t");
898
899 let edits = diff(&source, &target, false);
900
901 assert!(
903 edits.iter().all(|e| matches!(e, Edit::Keep { .. })),
904 "Expected only Keep edits, got: {:?}",
905 count_edits(&edits)
906 );
907 }
908
909 #[test]
910 fn test_diff_simple_change() {
911 let source = parse("SELECT a FROM t");
912 let target = parse("SELECT b FROM t");
913
914 let edits = diff(&source, &target, true);
915
916 assert!(!edits.is_empty());
919 assert!(has_changes(&edits));
920 let (ins, rem, _, _, _) = count_edits(&edits);
921 assert!(
922 ins > 0 && rem > 0,
923 "Expected Insert+Remove, got ins={ins} rem={rem}"
924 );
925 }
926
927 #[test]
928 fn test_diff_similar_column_update() {
929 let source = parse("SELECT col_a FROM t");
930 let target = parse("SELECT col_b FROM t");
931
932 let edits = diff(&source, &target, true);
933
934 assert!(has_changes(&edits));
936 assert!(
937 edits.iter().any(|e| matches!(e, Edit::Update { .. })),
938 "Expected Update for similar column name change"
939 );
940 }
941
942 #[test]
943 fn test_operator_change() {
944 let source = parse("SELECT a + b FROM t");
945 let target = parse("SELECT a - b FROM t");
946
947 let edits = diff(&source, &target, true);
948
949 assert!(!edits.is_empty());
952 let (ins, rem, _, _, _) = count_edits(&edits);
953 assert!(
954 ins > 0 && rem > 0,
955 "Expected Insert and Remove for operator change, got ins={ins} rem={rem}"
956 );
957 }
958
959 #[test]
960 fn test_column_added() {
961 let source = parse("SELECT a, b FROM t");
962 let target = parse("SELECT a, b, c FROM t");
963
964 let edits = diff(&source, &target, true);
965
966 assert!(
968 edits.iter().any(|e| matches!(e, Edit::Insert { .. })),
969 "Expected at least one Insert edit for added column"
970 );
971 }
972
973 #[test]
974 fn test_column_removed() {
975 let source = parse("SELECT a, b, c FROM t");
976 let target = parse("SELECT a, c FROM t");
977
978 let edits = diff(&source, &target, true);
979
980 assert!(
982 edits.iter().any(|e| matches!(e, Edit::Remove { .. })),
983 "Expected at least one Remove edit for removed column"
984 );
985 }
986
987 #[test]
988 fn test_table_updated() {
989 let source = parse("SELECT a FROM table_one");
990 let target = parse("SELECT a FROM table_two");
991
992 let edits = diff(&source, &target, true);
993
994 assert!(!edits.is_empty());
996 assert!(has_changes(&edits));
997 assert!(
998 edits.iter().any(|e| matches!(e, Edit::Update { .. })),
999 "Expected Update for table name change"
1000 );
1001 }
1002
1003 #[test]
1004 fn test_lambda() {
1005 let source = parse("SELECT TRANSFORM(arr, a -> a + 1) FROM t");
1006 let target = parse("SELECT TRANSFORM(arr, b -> b + 1) FROM t");
1007
1008 let edits = diff(&source, &target, true);
1009
1010 assert!(has_changes(&edits));
1012 }
1013
1014 #[test]
1015 fn test_node_position_changed() {
1016 let source = parse("SELECT a, b, c FROM t");
1017 let target = parse("SELECT c, a, b FROM t");
1018
1019 let edits = diff(&source, &target, false);
1020
1021 let (_, _, moves, _, _) = count_edits(&edits);
1023 assert!(
1024 moves > 0,
1025 "Expected at least one Move for reordered columns"
1026 );
1027 }
1028
1029 #[test]
1030 fn test_cte_changes() {
1031 let source = parse("WITH cte AS (SELECT a FROM t WHERE a > 1000) SELECT * FROM cte");
1032 let target = parse("WITH cte AS (SELECT a FROM t WHERE a > 2000) SELECT * FROM cte");
1033
1034 let edits = diff(&source, &target, true);
1035
1036 assert!(has_changes(&edits));
1038 assert!(
1039 edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1040 "Expected Update for literal change in CTE"
1041 );
1042 }
1043
1044 #[test]
1045 fn test_join_changes() {
1046 let source = parse("SELECT a FROM t LEFT JOIN s ON t.id = s.id");
1047 let target = parse("SELECT a FROM t RIGHT JOIN s ON t.id = s.id");
1048
1049 let edits = diff(&source, &target, true);
1050
1051 assert!(has_changes(&edits));
1054 let (ins, rem, _, _, _) = count_edits(&edits);
1055 assert!(
1056 ins > 0 && rem > 0,
1057 "Expected Insert+Remove for join kind change, got ins={ins} rem={rem}"
1058 );
1059 }
1060
1061 #[test]
1062 fn test_window_functions() {
1063 let source = parse("SELECT ROW_NUMBER() OVER (ORDER BY a) FROM t");
1064 let target = parse("SELECT RANK() OVER (ORDER BY a) FROM t");
1065
1066 let edits = diff(&source, &target, true);
1067
1068 assert!(has_changes(&edits));
1070 }
1071
1072 #[test]
1073 fn test_non_expression_leaf_delta() {
1074 let source = parse("SELECT a FROM t UNION SELECT b FROM s");
1075 let target = parse("SELECT a FROM t UNION ALL SELECT b FROM s");
1076
1077 let edits = diff(&source, &target, true);
1078
1079 assert!(has_changes(&edits));
1081 assert!(
1082 edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1083 "Expected Update for UNION → UNION ALL"
1084 );
1085 }
1086
1087 #[test]
1088 fn test_is_leaf() {
1089 let tree = IndexedTree::build(&parse("SELECT a, 1 FROM t"));
1090 assert!(!tree.is_leaf(0));
1092 let leaves = tree.leaf_indices();
1094 assert!(!leaves.is_empty());
1095 for &l in &leaves {
1097 assert!(tree.children_indices[l].is_empty());
1098 }
1099 }
1100
1101 #[test]
1102 fn test_same_type_special_cases() {
1103 let a = Expression::Literal(crate::expressions::Literal::Number("1".to_string()));
1105 let b = Expression::Literal(crate::expressions::Literal::String("abc".to_string()));
1106 assert!(is_same_type(&a, &b));
1107
1108 let c = Expression::Null(crate::expressions::Null);
1110 assert!(!is_same_type(&a, &c));
1111
1112 let join_left = Expression::Join(Box::new(crate::expressions::Join {
1114 this: Expression::Table(crate::expressions::TableRef::new("t")),
1115 on: None,
1116 using: vec![],
1117 kind: crate::expressions::JoinKind::Left,
1118 use_inner_keyword: false,
1119 use_outer_keyword: false,
1120 deferred_condition: false,
1121 join_hint: None,
1122 match_condition: None,
1123 pivots: vec![],
1124 comments: vec![],
1125 nesting_group: 0,
1126 directed: false,
1127 }));
1128 let join_right = Expression::Join(Box::new(crate::expressions::Join {
1129 this: Expression::Table(crate::expressions::TableRef::new("t")),
1130 on: None,
1131 using: vec![],
1132 kind: crate::expressions::JoinKind::Right,
1133 use_inner_keyword: false,
1134 use_outer_keyword: false,
1135 deferred_condition: false,
1136 join_hint: None,
1137 match_condition: None,
1138 pivots: vec![],
1139 comments: vec![],
1140 nesting_group: 0,
1141 directed: false,
1142 }));
1143 assert!(!is_same_type(&join_left, &join_right));
1144 }
1145
1146 #[test]
1147 fn test_comments_excluded() {
1148 let source = parse("SELECT a FROM t");
1150 let target = parse("SELECT a FROM t");
1151
1152 let edits = diff(&source, &target, true);
1153
1154 assert!(edits.is_empty() || !has_changes(&edits));
1156 }
1157}