Skip to main content

polyglot_sql/
diff.rs

1//! AST Diff - Compare SQL expressions
2//!
3//! This module provides functionality to compare two SQL ASTs and generate
4//! a list of differences (edit script) between them, using the ChangeDistiller
5//! algorithm with Dice coefficient matching.
6//!
7
8use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::generator::{Generator, GeneratorConfig};
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap, HashSet};
13
14/// Types of edits that can occur between two ASTs
15#[derive(Debug, Clone, PartialEq)]
16pub enum Edit {
17    /// A new node has been inserted
18    Insert { expression: Expression },
19    /// An existing node has been removed
20    Remove { expression: Expression },
21    /// An existing node's position has changed
22    Move {
23        source: Expression,
24        target: Expression,
25    },
26    /// An existing node has been updated (same position, different value)
27    Update {
28        source: Expression,
29        target: Expression,
30    },
31    /// An existing node hasn't been changed
32    Keep {
33        source: Expression,
34        target: Expression,
35    },
36}
37
38impl Edit {
39    /// Check if this edit represents a change (not a Keep)
40    pub fn is_change(&self) -> bool {
41        !matches!(self, Edit::Keep { .. })
42    }
43}
44
45/// Configuration for the diff algorithm
46#[derive(Debug, Clone)]
47pub struct DiffConfig {
48    /// Dice coefficient threshold for internal node matching (default 0.6)
49    pub f: f64,
50    /// Leaf similarity threshold (default 0.6)
51    pub t: f64,
52    /// Optional dialect for SQL generation during comparison
53    pub dialect: Option<DialectType>,
54}
55
56impl Default for DiffConfig {
57    fn default() -> Self {
58        Self {
59            f: 0.6,
60            t: 0.6,
61            dialect: None,
62        }
63    }
64}
65
66/// Compare two expressions and return a list of edits
67///
68/// # Arguments
69/// * `source` - The source expression
70/// * `target` - The target expression to compare against
71/// * `delta_only` - If true, exclude Keep edits from the result
72///
73/// # Returns
74/// A vector of Edit operations that transform source into target
75///
76/// # Example
77/// ```ignore
78/// use polyglot_sql::diff::diff;
79/// use polyglot_sql::parse_one;
80/// use polyglot_sql::DialectType;
81///
82/// let source = parse_one("SELECT a + b FROM t", DialectType::Generic).unwrap();
83/// let target = parse_one("SELECT a + c FROM t", DialectType::Generic).unwrap();
84/// let edits = diff(&source, &target, false);
85/// ```
86pub fn diff(source: &Expression, target: &Expression, delta_only: bool) -> Vec<Edit> {
87    let config = DiffConfig::default();
88    diff_with_config(source, target, delta_only, &config)
89}
90
91/// Compare two expressions with custom configuration
92pub fn diff_with_config(
93    source: &Expression,
94    target: &Expression,
95    delta_only: bool,
96    config: &DiffConfig,
97) -> Vec<Edit> {
98    let mut distiller = ChangeDistiller::new(config.clone());
99    distiller.diff(source, target, delta_only)
100}
101
102/// Check if the diff contains any changes
103pub fn has_changes(edits: &[Edit]) -> bool {
104    edits.iter().any(|e| e.is_change())
105}
106
107/// Get only the changes from an edit list
108pub fn changes_only(edits: Vec<Edit>) -> Vec<Edit> {
109    edits.into_iter().filter(|e| e.is_change()).collect()
110}
111
112// ---------------------------------------------------------------------------
113// IndexedTree: flat BFS representation with parent-child tracking
114// ---------------------------------------------------------------------------
115
116/// Flattened tree representation with explicit parent/child index maps.
117struct IndexedTree {
118    nodes: Vec<Expression>,
119    parents: Vec<Option<usize>>,
120    children_indices: Vec<Vec<usize>>,
121}
122
123impl IndexedTree {
124    fn empty() -> Self {
125        Self {
126            nodes: Vec::new(),
127            parents: Vec::new(),
128            children_indices: Vec::new(),
129        }
130    }
131
132    fn build(root: &Expression) -> Self {
133        let mut tree = Self::empty();
134        tree.add_expr(root, None);
135        tree
136    }
137
138    fn add_expr(&mut self, expr: &Expression, parent_idx: Option<usize>) {
139        // Skip bare Identifier nodes — they're names, not diff targets
140        if matches!(expr, Expression::Identifier(_)) {
141            return;
142        }
143
144        let idx = self.nodes.len();
145        self.nodes.push(expr.clone());
146        self.parents.push(parent_idx);
147        self.children_indices.push(Vec::new());
148
149        if let Some(p) = parent_idx {
150            self.children_indices[p].push(idx);
151        }
152
153        self.add_children(expr, idx);
154    }
155
156    fn add_children(&mut self, expr: &Expression, parent_idx: usize) {
157        match expr {
158            Expression::Select(select) => {
159                if let Some(with) = &select.with {
160                    for cte in &with.ctes {
161                        self.add_expr(&Expression::Cte(Box::new(cte.clone())), Some(parent_idx));
162                    }
163                }
164                for e in &select.expressions {
165                    self.add_expr(e, Some(parent_idx));
166                }
167                if let Some(from) = &select.from {
168                    for e in &from.expressions {
169                        self.add_expr(e, Some(parent_idx));
170                    }
171                }
172                for join in &select.joins {
173                    self.add_expr(
174                        &Expression::Join(Box::new(join.clone())),
175                        Some(parent_idx),
176                    );
177                }
178                if let Some(w) = &select.where_clause {
179                    self.add_expr(&w.this, Some(parent_idx));
180                }
181                if let Some(gb) = &select.group_by {
182                    for e in &gb.expressions {
183                        self.add_expr(e, Some(parent_idx));
184                    }
185                }
186                if let Some(h) = &select.having {
187                    self.add_expr(&h.this, Some(parent_idx));
188                }
189                if let Some(ob) = &select.order_by {
190                    for o in &ob.expressions {
191                        self.add_expr(
192                            &Expression::Ordered(Box::new(o.clone())),
193                            Some(parent_idx),
194                        );
195                    }
196                }
197                if let Some(limit) = &select.limit {
198                    self.add_expr(&limit.this, Some(parent_idx));
199                }
200                if let Some(offset) = &select.offset {
201                    self.add_expr(&offset.this, Some(parent_idx));
202                }
203            }
204            Expression::Alias(alias) => {
205                self.add_expr(&alias.this, Some(parent_idx));
206            }
207            Expression::And(op)
208            | Expression::Or(op)
209            | Expression::Eq(op)
210            | Expression::Neq(op)
211            | Expression::Lt(op)
212            | Expression::Lte(op)
213            | Expression::Gt(op)
214            | Expression::Gte(op)
215            | Expression::Add(op)
216            | Expression::Sub(op)
217            | Expression::Mul(op)
218            | Expression::Div(op)
219            | Expression::Mod(op)
220            | Expression::BitwiseAnd(op)
221            | Expression::BitwiseOr(op)
222            | Expression::BitwiseXor(op)
223            | Expression::Concat(op) => {
224                self.add_expr(&op.left, Some(parent_idx));
225                self.add_expr(&op.right, Some(parent_idx));
226            }
227            Expression::Like(op) | Expression::ILike(op) => {
228                self.add_expr(&op.left, Some(parent_idx));
229                self.add_expr(&op.right, Some(parent_idx));
230            }
231            Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
232                self.add_expr(&u.this, Some(parent_idx));
233            }
234            Expression::Function(func) => {
235                for arg in &func.args {
236                    self.add_expr(arg, Some(parent_idx));
237                }
238            }
239            Expression::AggregateFunction(func) => {
240                for arg in &func.args {
241                    self.add_expr(arg, Some(parent_idx));
242                }
243            }
244            Expression::Join(j) => {
245                self.add_expr(&j.this, Some(parent_idx));
246                if let Some(on) = &j.on {
247                    self.add_expr(on, Some(parent_idx));
248                }
249            }
250            Expression::Anonymous(a) => {
251                for arg in &a.expressions {
252                    self.add_expr(arg, Some(parent_idx));
253                }
254            }
255            Expression::WindowFunction(wf) => {
256                self.add_expr(&wf.this, Some(parent_idx));
257            }
258            Expression::Cast(cast) => {
259                self.add_expr(&cast.this, Some(parent_idx));
260            }
261            Expression::Subquery(sq) => {
262                self.add_expr(&sq.this, Some(parent_idx));
263            }
264            Expression::Paren(p) => {
265                self.add_expr(&p.this, Some(parent_idx));
266            }
267            Expression::Union(u) => {
268                self.add_expr(&u.left, Some(parent_idx));
269                self.add_expr(&u.right, Some(parent_idx));
270            }
271            Expression::Intersect(i) => {
272                self.add_expr(&i.left, Some(parent_idx));
273                self.add_expr(&i.right, Some(parent_idx));
274            }
275            Expression::Except(e) => {
276                self.add_expr(&e.left, Some(parent_idx));
277                self.add_expr(&e.right, Some(parent_idx));
278            }
279            Expression::Cte(cte) => {
280                self.add_expr(&cte.this, Some(parent_idx));
281            }
282            Expression::Case(c) => {
283                if let Some(operand) = &c.operand {
284                    self.add_expr(operand, Some(parent_idx));
285                }
286                for (when, then) in &c.whens {
287                    self.add_expr(when, Some(parent_idx));
288                    self.add_expr(then, Some(parent_idx));
289                }
290                if let Some(else_) = &c.else_ {
291                    self.add_expr(else_, Some(parent_idx));
292                }
293            }
294            Expression::In(i) => {
295                self.add_expr(&i.this, Some(parent_idx));
296                for e in &i.expressions {
297                    self.add_expr(e, Some(parent_idx));
298                }
299                if let Some(q) = &i.query {
300                    self.add_expr(q, Some(parent_idx));
301                }
302            }
303            Expression::Between(b) => {
304                self.add_expr(&b.this, Some(parent_idx));
305                self.add_expr(&b.low, Some(parent_idx));
306                self.add_expr(&b.high, Some(parent_idx));
307            }
308            Expression::IsNull(i) => {
309                self.add_expr(&i.this, Some(parent_idx));
310            }
311            Expression::Exists(e) => {
312                self.add_expr(&e.this, Some(parent_idx));
313            }
314            Expression::Ordered(o) => {
315                self.add_expr(&o.this, Some(parent_idx));
316            }
317            Expression::Lambda(l) => {
318                self.add_expr(&l.body, Some(parent_idx));
319            }
320            Expression::Coalesce(c) => {
321                for e in &c.expressions {
322                    self.add_expr(e, Some(parent_idx));
323                }
324            }
325            Expression::Tuple(t) => {
326                for e in &t.expressions {
327                    self.add_expr(e, Some(parent_idx));
328                }
329            }
330            Expression::Array(a) => {
331                for e in &a.expressions {
332                    self.add_expr(e, Some(parent_idx));
333                }
334            }
335            // Leaf nodes — no children to add
336            Expression::Literal(_)
337            | Expression::Boolean(_)
338            | Expression::Null(_)
339            | Expression::Column(_)
340            | Expression::Table(_)
341            | Expression::Star(_)
342            | Expression::DataType(_)
343            | Expression::CurrentDate(_)
344            | Expression::CurrentTime(_)
345            | Expression::CurrentTimestamp(_) => {}
346            // Fallback: use ExpressionWalk::children()
347            other => {
348                use crate::traversal::ExpressionWalk;
349                for child in other.children() {
350                    if !matches!(child, Expression::Identifier(_)) {
351                        self.add_expr(child, Some(parent_idx));
352                    }
353                }
354            }
355        }
356    }
357
358    fn is_leaf(&self, idx: usize) -> bool {
359        self.children_indices[idx].is_empty()
360    }
361
362    fn leaf_indices(&self) -> Vec<usize> {
363        (0..self.nodes.len())
364            .filter(|&i| self.is_leaf(i))
365            .collect()
366    }
367
368    /// Get all leaf descendants of a node (including itself if it is a leaf).
369    fn leaf_descendants(&self, idx: usize) -> Vec<usize> {
370        let mut result = Vec::new();
371        let mut stack = vec![idx];
372        while let Some(i) = stack.pop() {
373            if self.is_leaf(i) {
374                result.push(i);
375            }
376            for &child in &self.children_indices[i] {
377                stack.push(child);
378            }
379        }
380        result
381    }
382}
383
384// ---------------------------------------------------------------------------
385// Helper functions
386// ---------------------------------------------------------------------------
387
388/// Compute Dice coefficient on character bigrams of two strings.
389fn dice_coefficient(a: &str, b: &str) -> f64 {
390    // For very short strings, use exact equality
391    if a.len() < 2 || b.len() < 2 {
392        return if a == b { 1.0 } else { 0.0 };
393    }
394    let a_bigrams = bigram_histo(a);
395    let b_bigrams = bigram_histo(b);
396    let common: usize = a_bigrams
397        .iter()
398        .map(|(k, v)| v.min(b_bigrams.get(k).unwrap_or(&0)))
399        .sum();
400    let total: usize = a_bigrams.values().sum::<usize>() + b_bigrams.values().sum::<usize>();
401    if total == 0 {
402        1.0
403    } else {
404        2.0 * common as f64 / total as f64
405    }
406}
407
408/// Build a frequency histogram of character bigrams.
409fn bigram_histo(s: &str) -> HashMap<(char, char), usize> {
410    let chars: Vec<char> = s.chars().collect();
411    let mut map = HashMap::new();
412    for w in chars.windows(2) {
413        *map.entry((w[0], w[1])).or_insert(0) += 1;
414    }
415    map
416}
417
418/// Generate SQL string for an expression, optionally with a dialect.
419fn node_sql(expr: &Expression, dialect: Option<DialectType>) -> String {
420    match dialect {
421        Some(d) => {
422            let config = GeneratorConfig {
423                dialect: Some(d),
424                ..GeneratorConfig::default()
425            };
426            let mut gen = Generator::with_config(config);
427            gen.generate(expr).unwrap_or_default()
428        }
429        None => Generator::sql(expr).unwrap_or_default(),
430    }
431}
432
433/// Check if two expressions are the same type for matching purposes.
434///
435/// Uses discriminant comparison with special cases for Join (must share kind)
436/// and Anonymous (must share function name).
437fn is_same_type(a: &Expression, b: &Expression) -> bool {
438    if std::mem::discriminant(a) != std::mem::discriminant(b) {
439        return false;
440    }
441    match (a, b) {
442        (Expression::Join(ja), Expression::Join(jb)) => ja.kind == jb.kind,
443        (Expression::Anonymous(aa), Expression::Anonymous(ab)) => {
444            Generator::sql(&aa.this).unwrap_or_default()
445                == Generator::sql(&ab.this).unwrap_or_default()
446        }
447        _ => true,
448    }
449}
450
451/// Count matching ancestor chain depth for parent similarity tiebreaker.
452fn parent_similarity_score(
453    src_idx: usize,
454    tgt_idx: usize,
455    src_tree: &IndexedTree,
456    tgt_tree: &IndexedTree,
457    matchings: &HashMap<usize, usize>,
458) -> usize {
459    let mut score = 0;
460    let mut s = src_tree.parents[src_idx];
461    let mut t = tgt_tree.parents[tgt_idx];
462    while let (Some(sp), Some(tp)) = (s, t) {
463        if matchings.get(&sp) == Some(&tp) {
464            score += 1;
465            s = src_tree.parents[sp];
466            t = tgt_tree.parents[tp];
467        } else {
468            break;
469        }
470    }
471    score
472}
473
474/// Check if an expression is an updatable leaf type.
475///
476/// Updatable types emit Update edits when matched but different.
477fn is_updatable(expr: &Expression) -> bool {
478    matches!(
479        expr,
480        Expression::Alias(_)
481            | Expression::Boolean(_)
482            | Expression::Column(_)
483            | Expression::DataType(_)
484            | Expression::Lambda(_)
485            | Expression::Literal(_)
486            | Expression::Table(_)
487            | Expression::WindowFunction(_)
488    )
489}
490
491/// Check if non-expression leaf fields differ between two matched same-type nodes.
492///
493/// These are scalar fields (booleans, enums) that aren't child expressions.
494fn has_non_expression_leaf_change(a: &Expression, b: &Expression) -> bool {
495    match (a, b) {
496        (Expression::Union(ua), Expression::Union(ub)) => {
497            ua.all != ub.all || ua.distinct != ub.distinct
498        }
499        (Expression::Intersect(ia), Expression::Intersect(ib)) => {
500            ia.all != ib.all || ia.distinct != ib.distinct
501        }
502        (Expression::Except(ea), Expression::Except(eb)) => {
503            ea.all != eb.all || ea.distinct != eb.distinct
504        }
505        (Expression::Ordered(oa), Expression::Ordered(ob)) => {
506            oa.desc != ob.desc || oa.nulls_first != ob.nulls_first
507        }
508        (Expression::Join(ja), Expression::Join(jb)) => ja.kind != jb.kind,
509        _ => false,
510    }
511}
512
513/// Standard LCS returning matched index pairs.
514fn lcs<T, F>(a: &[T], b: &[T], eq_fn: F) -> Vec<(usize, usize)>
515where
516    F: Fn(&T, &T) -> bool,
517{
518    let m = a.len();
519    let n = b.len();
520    let mut dp = vec![vec![0usize; n + 1]; m + 1];
521    for i in 1..=m {
522        for j in 1..=n {
523            if eq_fn(&a[i - 1], &b[j - 1]) {
524                dp[i][j] = dp[i - 1][j - 1] + 1;
525            } else {
526                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
527            }
528        }
529    }
530    let mut result = Vec::new();
531    let mut i = m;
532    let mut j = n;
533    while i > 0 && j > 0 {
534        if eq_fn(&a[i - 1], &b[j - 1]) {
535            result.push((i - 1, j - 1));
536            i -= 1;
537            j -= 1;
538        } else if dp[i - 1][j] > dp[i][j - 1] {
539            i -= 1;
540        } else {
541            j -= 1;
542        }
543    }
544    result.reverse();
545    result
546}
547
548// ---------------------------------------------------------------------------
549// BinaryHeap entry for greedy matching
550// ---------------------------------------------------------------------------
551
552#[derive(PartialEq)]
553struct MatchCandidate {
554    score: f64,
555    parent_sim: usize,
556    counter: usize, // tiebreaker for deterministic ordering
557    src_idx: usize,
558    tgt_idx: usize,
559}
560
561impl Eq for MatchCandidate {}
562
563impl PartialOrd for MatchCandidate {
564    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
565        Some(self.cmp(other))
566    }
567}
568
569impl Ord for MatchCandidate {
570    fn cmp(&self, other: &Self) -> Ordering {
571        self.score
572            .partial_cmp(&other.score)
573            .unwrap_or(Ordering::Equal)
574            .then_with(|| self.parent_sim.cmp(&other.parent_sim))
575            .then_with(|| self.counter.cmp(&other.counter))
576    }
577}
578
579// ---------------------------------------------------------------------------
580// ChangeDistiller: three-phase algorithm
581// ---------------------------------------------------------------------------
582
583struct ChangeDistiller {
584    config: DiffConfig,
585    src_tree: IndexedTree,
586    tgt_tree: IndexedTree,
587    matchings: HashMap<usize, usize>, // src_idx -> tgt_idx
588}
589
590impl ChangeDistiller {
591    fn new(config: DiffConfig) -> Self {
592        Self {
593            config,
594            src_tree: IndexedTree::empty(),
595            tgt_tree: IndexedTree::empty(),
596            matchings: HashMap::new(),
597        }
598    }
599
600    fn diff(
601        &mut self,
602        source: &Expression,
603        target: &Expression,
604        delta_only: bool,
605    ) -> Vec<Edit> {
606        self.src_tree = IndexedTree::build(source);
607        self.tgt_tree = IndexedTree::build(target);
608
609        // Phase 1: leaf matching via Dice coefficient
610        self.match_leaves();
611
612        // Phase 2: internal node matching via leaf descendants
613        self.match_internal_nodes();
614
615        // Phase 3: generate edit script with Move detection
616        self.generate_edits(delta_only)
617    }
618
619    // -- Phase 1: Leaf matching -----------------------------------------------
620
621    fn match_leaves(&mut self) {
622        let src_leaves = self.src_tree.leaf_indices();
623        let tgt_leaves = self.tgt_tree.leaf_indices();
624
625        // Pre-compute SQL strings for all leaves
626        let src_sql: Vec<String> = src_leaves
627            .iter()
628            .map(|&i| node_sql(&self.src_tree.nodes[i], self.config.dialect))
629            .collect();
630        let tgt_sql: Vec<String> = tgt_leaves
631            .iter()
632            .map(|&i| node_sql(&self.tgt_tree.nodes[i], self.config.dialect))
633            .collect();
634
635        let mut heap = BinaryHeap::new();
636        let mut counter = 0usize;
637
638        for (si_pos, &si) in src_leaves.iter().enumerate() {
639            for (ti_pos, &ti) in tgt_leaves.iter().enumerate() {
640                if !is_same_type(&self.src_tree.nodes[si], &self.tgt_tree.nodes[ti]) {
641                    continue;
642                }
643                let score = dice_coefficient(&src_sql[si_pos], &tgt_sql[ti_pos]);
644                if score >= self.config.t {
645                    let parent_sim = parent_similarity_score(
646                        si,
647                        ti,
648                        &self.src_tree,
649                        &self.tgt_tree,
650                        &self.matchings,
651                    );
652                    heap.push(MatchCandidate {
653                        score,
654                        parent_sim,
655                        counter,
656                        src_idx: si,
657                        tgt_idx: ti,
658                    });
659                    counter += 1;
660                }
661            }
662        }
663
664        let mut matched_src: HashSet<usize> = HashSet::new();
665        let mut matched_tgt: HashSet<usize> = HashSet::new();
666
667        while let Some(m) = heap.pop() {
668            if matched_src.contains(&m.src_idx) || matched_tgt.contains(&m.tgt_idx) {
669                continue;
670            }
671            self.matchings.insert(m.src_idx, m.tgt_idx);
672            matched_src.insert(m.src_idx);
673            matched_tgt.insert(m.tgt_idx);
674        }
675    }
676
677    // -- Phase 2: Internal node matching -------------------------------------
678
679    fn match_internal_nodes(&mut self) {
680        // Process from deepest to shallowest. In BFS-built tree, higher indices
681        // are generally deeper, so we iterate in reverse.
682        let src_internal: Vec<usize> = (0..self.src_tree.nodes.len())
683            .rev()
684            .filter(|&i| !self.src_tree.is_leaf(i) && !self.matchings.contains_key(&i))
685            .collect();
686
687        let tgt_internal: Vec<usize> = (0..self.tgt_tree.nodes.len())
688            .rev()
689            .filter(|&i| !self.tgt_tree.is_leaf(i))
690            .collect();
691
692        let mut matched_tgt: HashSet<usize> = self.matchings.values().cloned().collect();
693
694        let mut heap = BinaryHeap::new();
695        let mut counter = 0usize;
696
697        for &si in &src_internal {
698            let src_leaves: HashSet<usize> =
699                self.src_tree.leaf_descendants(si).into_iter().collect();
700            let src_sql = node_sql(&self.src_tree.nodes[si], self.config.dialect);
701
702            for &ti in &tgt_internal {
703                if matched_tgt.contains(&ti) {
704                    continue;
705                }
706                if !is_same_type(&self.src_tree.nodes[si], &self.tgt_tree.nodes[ti]) {
707                    continue;
708                }
709
710                let tgt_leaves: HashSet<usize> =
711                    self.tgt_tree.leaf_descendants(ti).into_iter().collect();
712
713                // Count leaf descendants matched to each other
714                let common = src_leaves
715                    .iter()
716                    .filter(|&&sl| {
717                        self.matchings
718                            .get(&sl)
719                            .map_or(false, |&tl| tgt_leaves.contains(&tl))
720                    })
721                    .count();
722
723                let max_leaves = src_leaves.len().max(tgt_leaves.len());
724                if max_leaves == 0 {
725                    continue;
726                }
727
728                let leaf_sim = common as f64 / max_leaves as f64;
729
730                // Adaptive threshold for small subtrees
731                let t = if src_leaves.len().min(tgt_leaves.len()) <= 4 {
732                    0.4
733                } else {
734                    self.config.t
735                };
736
737                let tgt_sql = node_sql(&self.tgt_tree.nodes[ti], self.config.dialect);
738                let dice = dice_coefficient(&src_sql, &tgt_sql);
739
740                if leaf_sim >= 0.8 || (leaf_sim >= t && dice >= self.config.f) {
741                    heap.push(MatchCandidate {
742                        score: leaf_sim,
743                        parent_sim: parent_similarity_score(
744                            si,
745                            ti,
746                            &self.src_tree,
747                            &self.tgt_tree,
748                            &self.matchings,
749                        ),
750                        counter,
751                        src_idx: si,
752                        tgt_idx: ti,
753                    });
754                    counter += 1;
755                }
756            }
757        }
758
759        while let Some(m) = heap.pop() {
760            if self.matchings.contains_key(&m.src_idx) || matched_tgt.contains(&m.tgt_idx) {
761                continue;
762            }
763            self.matchings.insert(m.src_idx, m.tgt_idx);
764            matched_tgt.insert(m.tgt_idx);
765        }
766    }
767
768    // -- Phase 3: Edit script generation with Move detection ------------------
769
770    fn generate_edits(&self, delta_only: bool) -> Vec<Edit> {
771        let mut edits = Vec::new();
772        let matched_tgt: HashSet<usize> = self.matchings.values().cloned().collect();
773
774        // Build reverse mapping: tgt_idx -> src_idx
775        let reverse_matchings: HashMap<usize, usize> = self
776            .matchings
777            .iter()
778            .map(|(&s, &t)| (t, s))
779            .collect();
780
781        // Detect moved nodes via LCS on each matched parent's children
782        let mut moved_src: HashSet<usize> = HashSet::new();
783
784        for (&src_parent, &tgt_parent) in &self.matchings {
785            if self.src_tree.is_leaf(src_parent) {
786                continue;
787            }
788
789            let src_children = &self.src_tree.children_indices[src_parent];
790            let tgt_children = &self.tgt_tree.children_indices[tgt_parent];
791
792            if src_children.is_empty() || tgt_children.is_empty() {
793                continue;
794            }
795
796            // Build sequence of tgt indices for matched src children (in src order)
797            let src_seq: Vec<usize> = src_children
798                .iter()
799                .filter_map(|&sc| self.matchings.get(&sc).cloned())
800                .collect();
801
802            // Build sequence of tgt children that have a reverse match (in tgt order)
803            let tgt_seq: Vec<usize> = tgt_children
804                .iter()
805                .filter(|&&tc| reverse_matchings.contains_key(&tc))
806                .cloned()
807                .collect();
808
809            let lcs_pairs = lcs(&src_seq, &tgt_seq, |a, b| a == b);
810            let lcs_tgt_set: HashSet<usize> = lcs_pairs.iter().map(|&(i, _)| src_seq[i]).collect();
811
812            // Matched children not in the LCS had their position changed
813            for &sc in src_children {
814                if let Some(&tc) = self.matchings.get(&sc) {
815                    if !lcs_tgt_set.contains(&tc) {
816                        moved_src.insert(sc);
817                    }
818                }
819            }
820        }
821
822        // Unmatched source nodes → Remove
823        for i in 0..self.src_tree.nodes.len() {
824            if !self.matchings.contains_key(&i) {
825                edits.push(Edit::Remove {
826                    expression: self.src_tree.nodes[i].clone(),
827                });
828            }
829        }
830
831        // Unmatched target nodes → Insert
832        for i in 0..self.tgt_tree.nodes.len() {
833            if !matched_tgt.contains(&i) {
834                edits.push(Edit::Insert {
835                    expression: self.tgt_tree.nodes[i].clone(),
836                });
837            }
838        }
839
840        // Matched pairs → Update / Move / Keep
841        for (&src_idx, &tgt_idx) in &self.matchings {
842            let src_node = &self.src_tree.nodes[src_idx];
843            let tgt_node = &self.tgt_tree.nodes[tgt_idx];
844
845            let src_sql = node_sql(src_node, self.config.dialect);
846            let tgt_sql = node_sql(tgt_node, self.config.dialect);
847
848            if is_updatable(src_node) && src_sql != tgt_sql {
849                edits.push(Edit::Update {
850                    source: src_node.clone(),
851                    target: tgt_node.clone(),
852                });
853            } else if has_non_expression_leaf_change(src_node, tgt_node) {
854                edits.push(Edit::Update {
855                    source: src_node.clone(),
856                    target: tgt_node.clone(),
857                });
858            } else if moved_src.contains(&src_idx) {
859                edits.push(Edit::Move {
860                    source: src_node.clone(),
861                    target: tgt_node.clone(),
862                });
863            } else if !delta_only {
864                edits.push(Edit::Keep {
865                    source: src_node.clone(),
866                    target: tgt_node.clone(),
867                });
868            }
869        }
870
871        edits
872    }
873}
874
875// ---------------------------------------------------------------------------
876// Tests
877// ---------------------------------------------------------------------------
878
879#[cfg(test)]
880mod tests {
881    use super::*;
882    use crate::dialects::{Dialect, DialectType};
883
884    fn parse(sql: &str) -> Expression {
885        let dialect = Dialect::get(DialectType::Generic);
886        let ast = dialect.parse(sql).unwrap();
887        ast.into_iter().next().unwrap()
888    }
889
890    fn count_edits(edits: &[Edit]) -> (usize, usize, usize, usize, usize) {
891        let mut insert = 0;
892        let mut remove = 0;
893        let mut r#move = 0;
894        let mut update = 0;
895        let mut keep = 0;
896        for e in edits {
897            match e {
898                Edit::Insert { .. } => insert += 1,
899                Edit::Remove { .. } => remove += 1,
900                Edit::Move { .. } => r#move += 1,
901                Edit::Update { .. } => update += 1,
902                Edit::Keep { .. } => keep += 1,
903            }
904        }
905        (insert, remove, r#move, update, keep)
906    }
907
908    #[test]
909    fn test_diff_identical() {
910        let source = parse("SELECT a FROM t");
911        let target = parse("SELECT a FROM t");
912
913        let edits = diff(&source, &target, false);
914
915        // Should only have Keep edits
916        assert!(
917            edits.iter().all(|e| matches!(e, Edit::Keep { .. })),
918            "Expected only Keep edits, got: {:?}",
919            count_edits(&edits)
920        );
921    }
922
923    #[test]
924    fn test_diff_simple_change() {
925        let source = parse("SELECT a FROM t");
926        let target = parse("SELECT b FROM t");
927
928        let edits = diff(&source, &target, true);
929
930        // Column a → column b: single-char names with dice=0 don't match
931        // → Remove(a) + Insert(b)
932        assert!(!edits.is_empty());
933        assert!(has_changes(&edits));
934        let (ins, rem, _, _, _) = count_edits(&edits);
935        assert!(ins > 0 && rem > 0, "Expected Insert+Remove, got ins={ins} rem={rem}");
936    }
937
938    #[test]
939    fn test_diff_similar_column_update() {
940        let source = parse("SELECT col_a FROM t");
941        let target = parse("SELECT col_b FROM t");
942
943        let edits = diff(&source, &target, true);
944
945        // Longer names share bigrams → matched → Update
946        assert!(has_changes(&edits));
947        assert!(
948            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
949            "Expected Update for similar column name change"
950        );
951    }
952
953    #[test]
954    fn test_operator_change() {
955        let source = parse("SELECT a + b FROM t");
956        let target = parse("SELECT a - b FROM t");
957
958        let edits = diff(&source, &target, true);
959
960        // The operator changed from Add to Sub — different discriminants
961        // so they can't be matched → Remove(Add) + Insert(Sub)
962        assert!(!edits.is_empty());
963        let (ins, rem, _, _, _) = count_edits(&edits);
964        assert!(
965            ins > 0 && rem > 0,
966            "Expected Insert and Remove for operator change, got ins={ins} rem={rem}"
967        );
968    }
969
970    #[test]
971    fn test_column_added() {
972        let source = parse("SELECT a, b FROM t");
973        let target = parse("SELECT a, b, c FROM t");
974
975        let edits = diff(&source, &target, true);
976
977        // Column c was added
978        assert!(
979            edits.iter().any(|e| matches!(e, Edit::Insert { .. })),
980            "Expected at least one Insert edit for added column"
981        );
982    }
983
984    #[test]
985    fn test_column_removed() {
986        let source = parse("SELECT a, b, c FROM t");
987        let target = parse("SELECT a, c FROM t");
988
989        let edits = diff(&source, &target, true);
990
991        // Column b was removed
992        assert!(
993            edits.iter().any(|e| matches!(e, Edit::Remove { .. })),
994            "Expected at least one Remove edit for removed column"
995        );
996    }
997
998    #[test]
999    fn test_table_updated() {
1000        let source = parse("SELECT a FROM table_one");
1001        let target = parse("SELECT a FROM table_two");
1002
1003        let edits = diff(&source, &target, true);
1004
1005        // Table names share enough bigrams to match → Update
1006        assert!(!edits.is_empty());
1007        assert!(has_changes(&edits));
1008        assert!(
1009            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1010            "Expected Update for table name change"
1011        );
1012    }
1013
1014    #[test]
1015    fn test_lambda() {
1016        let source = parse("SELECT TRANSFORM(arr, a -> a + 1) FROM t");
1017        let target = parse("SELECT TRANSFORM(arr, b -> b + 1) FROM t");
1018
1019        let edits = diff(&source, &target, true);
1020
1021        // The lambda parameter changed
1022        assert!(has_changes(&edits));
1023    }
1024
1025    #[test]
1026    fn test_node_position_changed() {
1027        let source = parse("SELECT a, b, c FROM t");
1028        let target = parse("SELECT c, a, b FROM t");
1029
1030        let edits = diff(&source, &target, false);
1031
1032        // Some columns should be detected as moved
1033        let (_, _, moves, _, _) = count_edits(&edits);
1034        assert!(moves > 0, "Expected at least one Move for reordered columns");
1035    }
1036
1037    #[test]
1038    fn test_cte_changes() {
1039        let source = parse("WITH cte AS (SELECT a FROM t WHERE a > 1000) SELECT * FROM cte");
1040        let target = parse("WITH cte AS (SELECT a FROM t WHERE a > 2000) SELECT * FROM cte");
1041
1042        let edits = diff(&source, &target, true);
1043
1044        // The literal in the WHERE clause changed (1000 → 2000 share bigrams → Update)
1045        assert!(has_changes(&edits));
1046        assert!(
1047            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1048            "Expected Update for literal change in CTE"
1049        );
1050    }
1051
1052    #[test]
1053    fn test_join_changes() {
1054        let source = parse("SELECT a FROM t LEFT JOIN s ON t.id = s.id");
1055        let target = parse("SELECT a FROM t RIGHT JOIN s ON t.id = s.id");
1056
1057        let edits = diff(&source, &target, true);
1058
1059        // LEFT vs RIGHT have different JoinKind → not same_type
1060        // The Join nodes produce Remove(LEFT JOIN) + Insert(RIGHT JOIN)
1061        assert!(has_changes(&edits));
1062        let (ins, rem, _, _, _) = count_edits(&edits);
1063        assert!(
1064            ins > 0 && rem > 0,
1065            "Expected Insert+Remove for join kind change, got ins={ins} rem={rem}"
1066        );
1067    }
1068
1069    #[test]
1070    fn test_window_functions() {
1071        let source = parse("SELECT ROW_NUMBER() OVER (ORDER BY a) FROM t");
1072        let target = parse("SELECT RANK() OVER (ORDER BY a) FROM t");
1073
1074        let edits = diff(&source, &target, true);
1075
1076        // Different window functions
1077        assert!(has_changes(&edits));
1078    }
1079
1080    #[test]
1081    fn test_non_expression_leaf_delta() {
1082        let source = parse("SELECT a FROM t UNION SELECT b FROM s");
1083        let target = parse("SELECT a FROM t UNION ALL SELECT b FROM s");
1084
1085        let edits = diff(&source, &target, true);
1086
1087        // UNION vs UNION ALL — non-expression leaf change (all flag)
1088        assert!(has_changes(&edits));
1089        assert!(
1090            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1091            "Expected Update for UNION → UNION ALL"
1092        );
1093    }
1094
1095    #[test]
1096    fn test_is_leaf() {
1097        let tree = IndexedTree::build(&parse("SELECT a, 1 FROM t"));
1098        // Root (Select) should not be a leaf
1099        assert!(!tree.is_leaf(0));
1100        // Leaf nodes should exist in the tree
1101        let leaves = tree.leaf_indices();
1102        assert!(!leaves.is_empty());
1103        // All leaves should have no children
1104        for &l in &leaves {
1105            assert!(tree.children_indices[l].is_empty());
1106        }
1107    }
1108
1109    #[test]
1110    fn test_same_type_special_cases() {
1111        // Same type — both Literal
1112        let a = Expression::Literal(crate::expressions::Literal::Number("1".to_string()));
1113        let b = Expression::Literal(crate::expressions::Literal::String("abc".to_string()));
1114        assert!(is_same_type(&a, &b));
1115
1116        // Different type — Literal vs Null
1117        let c = Expression::Null(crate::expressions::Null);
1118        assert!(!is_same_type(&a, &c));
1119
1120        // Join kind matters
1121        let join_left = Expression::Join(Box::new(crate::expressions::Join {
1122            this: Expression::Table(crate::expressions::TableRef::new("t")),
1123            on: None,
1124            using: vec![],
1125            kind: crate::expressions::JoinKind::Left,
1126            use_inner_keyword: false,
1127            use_outer_keyword: false,
1128            deferred_condition: false,
1129            join_hint: None,
1130            match_condition: None,
1131            pivots: vec![],
1132        }));
1133        let join_right = Expression::Join(Box::new(crate::expressions::Join {
1134            this: Expression::Table(crate::expressions::TableRef::new("t")),
1135            on: None,
1136            using: vec![],
1137            kind: crate::expressions::JoinKind::Right,
1138            use_inner_keyword: false,
1139            use_outer_keyword: false,
1140            deferred_condition: false,
1141            join_hint: None,
1142            match_condition: None,
1143            pivots: vec![],
1144        }));
1145        assert!(!is_same_type(&join_left, &join_right));
1146    }
1147
1148    #[test]
1149    fn test_comments_excluded() {
1150        // Comments on nodes should not affect the diff
1151        let source = parse("SELECT a FROM t");
1152        let target = parse("SELECT a FROM t");
1153
1154        let edits = diff(&source, &target, true);
1155
1156        // No changes — comments don't matter
1157        assert!(edits.is_empty() || !has_changes(&edits));
1158    }
1159}