Skip to main content

polyglot_sql/
diff.rs

1//! AST Diff - Compare SQL expressions
2//!
3//! This module provides functionality to compare two SQL ASTs and generate
4//! a list of differences (edit script) between them, using the ChangeDistiller
5//! algorithm with Dice coefficient matching.
6//!
7
8use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::generator::{Generator, GeneratorConfig};
11use serde::{Deserialize, Serialize};
12use std::cmp::Ordering;
13use std::collections::{BinaryHeap, HashMap, HashSet};
14
15/// Types of edits that can occur between two ASTs
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17#[serde(tag = "type", rename_all = "snake_case")]
18pub enum Edit {
19    /// A new node has been inserted
20    Insert { expression: Expression },
21    /// An existing node has been removed
22    Remove { expression: Expression },
23    /// An existing node's position has changed
24    Move {
25        source: Expression,
26        target: Expression,
27    },
28    /// An existing node has been updated (same position, different value)
29    Update {
30        source: Expression,
31        target: Expression,
32    },
33    /// An existing node hasn't been changed
34    Keep {
35        source: Expression,
36        target: Expression,
37    },
38}
39
40impl Edit {
41    /// Check if this edit represents a change (not a Keep)
42    pub fn is_change(&self) -> bool {
43        !matches!(self, Edit::Keep { .. })
44    }
45}
46
47/// Configuration for the diff algorithm
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct DiffConfig {
50    /// Dice coefficient threshold for internal node matching (default 0.6)
51    pub f: f64,
52    /// Leaf similarity threshold (default 0.6)
53    pub t: f64,
54    /// Optional dialect for SQL generation during comparison
55    pub dialect: Option<DialectType>,
56}
57
58impl Default for DiffConfig {
59    fn default() -> Self {
60        Self {
61            f: 0.6,
62            t: 0.6,
63            dialect: None,
64        }
65    }
66}
67
68/// Compare two expressions and return a list of edits
69///
70/// # Arguments
71/// * `source` - The source expression
72/// * `target` - The target expression to compare against
73/// * `delta_only` - If true, exclude Keep edits from the result
74///
75/// # Returns
76/// A vector of Edit operations that transform source into target
77///
78/// # Example
79/// ```ignore
80/// use polyglot_sql::diff::diff;
81/// use polyglot_sql::parse_one;
82/// use polyglot_sql::DialectType;
83///
84/// let source = parse_one("SELECT a + b FROM t", DialectType::Generic).unwrap();
85/// let target = parse_one("SELECT a + c FROM t", DialectType::Generic).unwrap();
86/// let edits = diff(&source, &target, false);
87/// ```
88pub fn diff(source: &Expression, target: &Expression, delta_only: bool) -> Vec<Edit> {
89    let config = DiffConfig::default();
90    diff_with_config(source, target, delta_only, &config)
91}
92
93/// Compare two expressions with custom configuration
94pub fn diff_with_config(
95    source: &Expression,
96    target: &Expression,
97    delta_only: bool,
98    config: &DiffConfig,
99) -> Vec<Edit> {
100    let mut distiller = ChangeDistiller::new(config.clone());
101    distiller.diff(source, target, delta_only)
102}
103
104/// Check if the diff contains any changes
105pub fn has_changes(edits: &[Edit]) -> bool {
106    edits.iter().any(|e| e.is_change())
107}
108
109/// Get only the changes from an edit list
110pub fn changes_only(edits: Vec<Edit>) -> Vec<Edit> {
111    edits.into_iter().filter(|e| e.is_change()).collect()
112}
113
114// ---------------------------------------------------------------------------
115// IndexedTree: flat BFS representation with parent-child tracking
116// ---------------------------------------------------------------------------
117
118/// Flattened tree representation with explicit parent/child index maps.
119struct IndexedTree {
120    nodes: Vec<Expression>,
121    parents: Vec<Option<usize>>,
122    children_indices: Vec<Vec<usize>>,
123}
124
125impl IndexedTree {
126    fn empty() -> Self {
127        Self {
128            nodes: Vec::new(),
129            parents: Vec::new(),
130            children_indices: Vec::new(),
131        }
132    }
133
134    fn build(root: &Expression) -> Self {
135        let mut tree = Self::empty();
136        tree.add_expr(root, None);
137        tree
138    }
139
140    fn add_expr(&mut self, expr: &Expression, parent_idx: Option<usize>) {
141        // Skip bare Identifier nodes — they're names, not diff targets
142        if matches!(expr, Expression::Identifier(_)) {
143            return;
144        }
145
146        let idx = self.nodes.len();
147        self.nodes.push(expr.clone());
148        self.parents.push(parent_idx);
149        self.children_indices.push(Vec::new());
150
151        if let Some(p) = parent_idx {
152            self.children_indices[p].push(idx);
153        }
154
155        self.add_children(expr, idx);
156    }
157
158    fn add_children(&mut self, expr: &Expression, parent_idx: usize) {
159        match expr {
160            Expression::Select(select) => {
161                if let Some(with) = &select.with {
162                    for cte in &with.ctes {
163                        self.add_expr(&Expression::Cte(Box::new(cte.clone())), Some(parent_idx));
164                    }
165                }
166                for e in &select.expressions {
167                    self.add_expr(e, Some(parent_idx));
168                }
169                if let Some(from) = &select.from {
170                    for e in &from.expressions {
171                        self.add_expr(e, Some(parent_idx));
172                    }
173                }
174                for join in &select.joins {
175                    self.add_expr(&Expression::Join(Box::new(join.clone())), Some(parent_idx));
176                }
177                if let Some(w) = &select.where_clause {
178                    self.add_expr(&w.this, Some(parent_idx));
179                }
180                if let Some(gb) = &select.group_by {
181                    for e in &gb.expressions {
182                        self.add_expr(e, Some(parent_idx));
183                    }
184                }
185                if let Some(h) = &select.having {
186                    self.add_expr(&h.this, Some(parent_idx));
187                }
188                if let Some(ob) = &select.order_by {
189                    for o in &ob.expressions {
190                        self.add_expr(&Expression::Ordered(Box::new(o.clone())), Some(parent_idx));
191                    }
192                }
193                if let Some(limit) = &select.limit {
194                    self.add_expr(&limit.this, Some(parent_idx));
195                }
196                if let Some(offset) = &select.offset {
197                    self.add_expr(&offset.this, Some(parent_idx));
198                }
199            }
200            Expression::Alias(alias) => {
201                self.add_expr(&alias.this, Some(parent_idx));
202            }
203            Expression::And(op)
204            | Expression::Or(op)
205            | Expression::Eq(op)
206            | Expression::Neq(op)
207            | Expression::Lt(op)
208            | Expression::Lte(op)
209            | Expression::Gt(op)
210            | Expression::Gte(op)
211            | Expression::Add(op)
212            | Expression::Sub(op)
213            | Expression::Mul(op)
214            | Expression::Div(op)
215            | Expression::Mod(op)
216            | Expression::BitwiseAnd(op)
217            | Expression::BitwiseOr(op)
218            | Expression::BitwiseXor(op)
219            | Expression::Concat(op) => {
220                self.add_expr(&op.left, Some(parent_idx));
221                self.add_expr(&op.right, Some(parent_idx));
222            }
223            Expression::Like(op) | Expression::ILike(op) => {
224                self.add_expr(&op.left, Some(parent_idx));
225                self.add_expr(&op.right, Some(parent_idx));
226            }
227            Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
228                self.add_expr(&u.this, Some(parent_idx));
229            }
230            Expression::Function(func) => {
231                for arg in &func.args {
232                    self.add_expr(arg, Some(parent_idx));
233                }
234            }
235            Expression::AggregateFunction(func) => {
236                for arg in &func.args {
237                    self.add_expr(arg, Some(parent_idx));
238                }
239            }
240            Expression::Join(j) => {
241                self.add_expr(&j.this, Some(parent_idx));
242                if let Some(on) = &j.on {
243                    self.add_expr(on, Some(parent_idx));
244                }
245            }
246            Expression::Anonymous(a) => {
247                for arg in &a.expressions {
248                    self.add_expr(arg, Some(parent_idx));
249                }
250            }
251            Expression::WindowFunction(wf) => {
252                self.add_expr(&wf.this, Some(parent_idx));
253            }
254            Expression::Cast(cast) => {
255                self.add_expr(&cast.this, Some(parent_idx));
256            }
257            Expression::Subquery(sq) => {
258                self.add_expr(&sq.this, Some(parent_idx));
259            }
260            Expression::Paren(p) => {
261                self.add_expr(&p.this, Some(parent_idx));
262            }
263            Expression::Union(u) => {
264                self.add_expr(&u.left, Some(parent_idx));
265                self.add_expr(&u.right, Some(parent_idx));
266            }
267            Expression::Intersect(i) => {
268                self.add_expr(&i.left, Some(parent_idx));
269                self.add_expr(&i.right, Some(parent_idx));
270            }
271            Expression::Except(e) => {
272                self.add_expr(&e.left, Some(parent_idx));
273                self.add_expr(&e.right, Some(parent_idx));
274            }
275            Expression::Cte(cte) => {
276                self.add_expr(&cte.this, Some(parent_idx));
277            }
278            Expression::Case(c) => {
279                if let Some(operand) = &c.operand {
280                    self.add_expr(operand, Some(parent_idx));
281                }
282                for (when, then) in &c.whens {
283                    self.add_expr(when, Some(parent_idx));
284                    self.add_expr(then, Some(parent_idx));
285                }
286                if let Some(else_) = &c.else_ {
287                    self.add_expr(else_, Some(parent_idx));
288                }
289            }
290            Expression::In(i) => {
291                self.add_expr(&i.this, Some(parent_idx));
292                for e in &i.expressions {
293                    self.add_expr(e, Some(parent_idx));
294                }
295                if let Some(q) = &i.query {
296                    self.add_expr(q, Some(parent_idx));
297                }
298            }
299            Expression::Between(b) => {
300                self.add_expr(&b.this, Some(parent_idx));
301                self.add_expr(&b.low, Some(parent_idx));
302                self.add_expr(&b.high, Some(parent_idx));
303            }
304            Expression::IsNull(i) => {
305                self.add_expr(&i.this, Some(parent_idx));
306            }
307            Expression::Exists(e) => {
308                self.add_expr(&e.this, Some(parent_idx));
309            }
310            Expression::Ordered(o) => {
311                self.add_expr(&o.this, Some(parent_idx));
312            }
313            Expression::Lambda(l) => {
314                self.add_expr(&l.body, Some(parent_idx));
315            }
316            Expression::Coalesce(c) => {
317                for e in &c.expressions {
318                    self.add_expr(e, Some(parent_idx));
319                }
320            }
321            Expression::Tuple(t) => {
322                for e in &t.expressions {
323                    self.add_expr(e, Some(parent_idx));
324                }
325            }
326            Expression::Array(a) => {
327                for e in &a.expressions {
328                    self.add_expr(e, Some(parent_idx));
329                }
330            }
331            // Leaf nodes — no children to add
332            Expression::Literal(_)
333            | Expression::Boolean(_)
334            | Expression::Null(_)
335            | Expression::Column(_)
336            | Expression::Table(_)
337            | Expression::Star(_)
338            | Expression::DataType(_)
339            | Expression::CurrentDate(_)
340            | Expression::CurrentTime(_)
341            | Expression::CurrentTimestamp(_) => {}
342            // Fallback: use ExpressionWalk::children()
343            other => {
344                use crate::traversal::ExpressionWalk;
345                for child in other.children() {
346                    if !matches!(child, Expression::Identifier(_)) {
347                        self.add_expr(child, Some(parent_idx));
348                    }
349                }
350            }
351        }
352    }
353
354    fn is_leaf(&self, idx: usize) -> bool {
355        self.children_indices[idx].is_empty()
356    }
357
358    fn leaf_indices(&self) -> Vec<usize> {
359        (0..self.nodes.len()).filter(|&i| self.is_leaf(i)).collect()
360    }
361
362    /// Get all leaf descendants of a node (including itself if it is a leaf).
363    fn leaf_descendants(&self, idx: usize) -> Vec<usize> {
364        let mut result = Vec::new();
365        let mut stack = vec![idx];
366        while let Some(i) = stack.pop() {
367            if self.is_leaf(i) {
368                result.push(i);
369            }
370            for &child in &self.children_indices[i] {
371                stack.push(child);
372            }
373        }
374        result
375    }
376}
377
378// ---------------------------------------------------------------------------
379// Helper functions
380// ---------------------------------------------------------------------------
381
382/// Compute Dice coefficient on character bigrams of two strings.
383fn dice_coefficient(a: &str, b: &str) -> f64 {
384    // For very short strings, use exact equality
385    if a.len() < 2 || b.len() < 2 {
386        return if a == b { 1.0 } else { 0.0 };
387    }
388    let a_bigrams = bigram_histo(a);
389    let b_bigrams = bigram_histo(b);
390    let common: usize = a_bigrams
391        .iter()
392        .map(|(k, v)| v.min(b_bigrams.get(k).unwrap_or(&0)))
393        .sum();
394    let total: usize = a_bigrams.values().sum::<usize>() + b_bigrams.values().sum::<usize>();
395    if total == 0 {
396        1.0
397    } else {
398        2.0 * common as f64 / total as f64
399    }
400}
401
402/// Build a frequency histogram of character bigrams.
403fn bigram_histo(s: &str) -> HashMap<(char, char), usize> {
404    let chars: Vec<char> = s.chars().collect();
405    let mut map = HashMap::new();
406    for w in chars.windows(2) {
407        *map.entry((w[0], w[1])).or_insert(0) += 1;
408    }
409    map
410}
411
412/// Generate SQL string for an expression, optionally with a dialect.
413fn node_sql(expr: &Expression, dialect: Option<DialectType>) -> String {
414    match dialect {
415        Some(d) => {
416            let config = GeneratorConfig {
417                dialect: Some(d),
418                ..GeneratorConfig::default()
419            };
420            let mut gen = Generator::with_config(config);
421            gen.generate(expr).unwrap_or_default()
422        }
423        None => Generator::sql(expr).unwrap_or_default(),
424    }
425}
426
427/// Check if two expressions are the same type for matching purposes.
428///
429/// Uses discriminant comparison with special cases for Join (must share kind)
430/// and Anonymous (must share function name).
431fn is_same_type(a: &Expression, b: &Expression) -> bool {
432    if std::mem::discriminant(a) != std::mem::discriminant(b) {
433        return false;
434    }
435    match (a, b) {
436        (Expression::Join(ja), Expression::Join(jb)) => ja.kind == jb.kind,
437        (Expression::Anonymous(aa), Expression::Anonymous(ab)) => {
438            Generator::sql(&aa.this).unwrap_or_default()
439                == Generator::sql(&ab.this).unwrap_or_default()
440        }
441        _ => true,
442    }
443}
444
445/// Count matching ancestor chain depth for parent similarity tiebreaker.
446fn parent_similarity_score(
447    src_idx: usize,
448    tgt_idx: usize,
449    src_tree: &IndexedTree,
450    tgt_tree: &IndexedTree,
451    matchings: &HashMap<usize, usize>,
452) -> usize {
453    let mut score = 0;
454    let mut s = src_tree.parents[src_idx];
455    let mut t = tgt_tree.parents[tgt_idx];
456    while let (Some(sp), Some(tp)) = (s, t) {
457        if matchings.get(&sp) == Some(&tp) {
458            score += 1;
459            s = src_tree.parents[sp];
460            t = tgt_tree.parents[tp];
461        } else {
462            break;
463        }
464    }
465    score
466}
467
468/// Check if an expression is an updatable leaf type.
469///
470/// Updatable types emit Update edits when matched but different.
471fn is_updatable(expr: &Expression) -> bool {
472    matches!(
473        expr,
474        Expression::Alias(_)
475            | Expression::Boolean(_)
476            | Expression::Column(_)
477            | Expression::DataType(_)
478            | Expression::Lambda(_)
479            | Expression::Literal(_)
480            | Expression::Table(_)
481            | Expression::WindowFunction(_)
482    )
483}
484
485/// Check if non-expression leaf fields differ between two matched same-type nodes.
486///
487/// These are scalar fields (booleans, enums) that aren't child expressions.
488fn has_non_expression_leaf_change(a: &Expression, b: &Expression) -> bool {
489    match (a, b) {
490        (Expression::Union(ua), Expression::Union(ub)) => {
491            ua.all != ub.all || ua.distinct != ub.distinct
492        }
493        (Expression::Intersect(ia), Expression::Intersect(ib)) => {
494            ia.all != ib.all || ia.distinct != ib.distinct
495        }
496        (Expression::Except(ea), Expression::Except(eb)) => {
497            ea.all != eb.all || ea.distinct != eb.distinct
498        }
499        (Expression::Ordered(oa), Expression::Ordered(ob)) => {
500            oa.desc != ob.desc || oa.nulls_first != ob.nulls_first
501        }
502        (Expression::Join(ja), Expression::Join(jb)) => ja.kind != jb.kind,
503        _ => false,
504    }
505}
506
507/// Standard LCS returning matched index pairs.
508fn lcs<T, F>(a: &[T], b: &[T], eq_fn: F) -> Vec<(usize, usize)>
509where
510    F: Fn(&T, &T) -> bool,
511{
512    let m = a.len();
513    let n = b.len();
514    let mut dp = vec![vec![0usize; n + 1]; m + 1];
515    for i in 1..=m {
516        for j in 1..=n {
517            if eq_fn(&a[i - 1], &b[j - 1]) {
518                dp[i][j] = dp[i - 1][j - 1] + 1;
519            } else {
520                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
521            }
522        }
523    }
524    let mut result = Vec::new();
525    let mut i = m;
526    let mut j = n;
527    while i > 0 && j > 0 {
528        if eq_fn(&a[i - 1], &b[j - 1]) {
529            result.push((i - 1, j - 1));
530            i -= 1;
531            j -= 1;
532        } else if dp[i - 1][j] > dp[i][j - 1] {
533            i -= 1;
534        } else {
535            j -= 1;
536        }
537    }
538    result.reverse();
539    result
540}
541
542// ---------------------------------------------------------------------------
543// BinaryHeap entry for greedy matching
544// ---------------------------------------------------------------------------
545
546#[derive(PartialEq)]
547struct MatchCandidate {
548    score: f64,
549    parent_sim: usize,
550    counter: usize, // tiebreaker for deterministic ordering
551    src_idx: usize,
552    tgt_idx: usize,
553}
554
555impl Eq for MatchCandidate {}
556
557impl PartialOrd for MatchCandidate {
558    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
559        Some(self.cmp(other))
560    }
561}
562
563impl Ord for MatchCandidate {
564    fn cmp(&self, other: &Self) -> Ordering {
565        self.score
566            .partial_cmp(&other.score)
567            .unwrap_or(Ordering::Equal)
568            .then_with(|| self.parent_sim.cmp(&other.parent_sim))
569            .then_with(|| self.counter.cmp(&other.counter))
570    }
571}
572
573// ---------------------------------------------------------------------------
574// ChangeDistiller: three-phase algorithm
575// ---------------------------------------------------------------------------
576
577struct ChangeDistiller {
578    config: DiffConfig,
579    src_tree: IndexedTree,
580    tgt_tree: IndexedTree,
581    matchings: HashMap<usize, usize>, // src_idx -> tgt_idx
582}
583
584impl ChangeDistiller {
585    fn new(config: DiffConfig) -> Self {
586        Self {
587            config,
588            src_tree: IndexedTree::empty(),
589            tgt_tree: IndexedTree::empty(),
590            matchings: HashMap::new(),
591        }
592    }
593
594    fn diff(&mut self, source: &Expression, target: &Expression, delta_only: bool) -> Vec<Edit> {
595        self.src_tree = IndexedTree::build(source);
596        self.tgt_tree = IndexedTree::build(target);
597
598        // Phase 1: leaf matching via Dice coefficient
599        self.match_leaves();
600
601        // Phase 2: internal node matching via leaf descendants
602        self.match_internal_nodes();
603
604        // Phase 3: generate edit script with Move detection
605        self.generate_edits(delta_only)
606    }
607
608    // -- Phase 1: Leaf matching -----------------------------------------------
609
610    fn match_leaves(&mut self) {
611        let src_leaves = self.src_tree.leaf_indices();
612        let tgt_leaves = self.tgt_tree.leaf_indices();
613
614        // Pre-compute SQL strings for all leaves
615        let src_sql: Vec<String> = src_leaves
616            .iter()
617            .map(|&i| node_sql(&self.src_tree.nodes[i], self.config.dialect))
618            .collect();
619        let tgt_sql: Vec<String> = tgt_leaves
620            .iter()
621            .map(|&i| node_sql(&self.tgt_tree.nodes[i], self.config.dialect))
622            .collect();
623
624        let mut heap = BinaryHeap::new();
625        let mut counter = 0usize;
626
627        for (si_pos, &si) in src_leaves.iter().enumerate() {
628            for (ti_pos, &ti) in tgt_leaves.iter().enumerate() {
629                if !is_same_type(&self.src_tree.nodes[si], &self.tgt_tree.nodes[ti]) {
630                    continue;
631                }
632                let score = dice_coefficient(&src_sql[si_pos], &tgt_sql[ti_pos]);
633                if score >= self.config.t {
634                    let parent_sim = parent_similarity_score(
635                        si,
636                        ti,
637                        &self.src_tree,
638                        &self.tgt_tree,
639                        &self.matchings,
640                    );
641                    heap.push(MatchCandidate {
642                        score,
643                        parent_sim,
644                        counter,
645                        src_idx: si,
646                        tgt_idx: ti,
647                    });
648                    counter += 1;
649                }
650            }
651        }
652
653        let mut matched_src: HashSet<usize> = HashSet::new();
654        let mut matched_tgt: HashSet<usize> = HashSet::new();
655
656        while let Some(m) = heap.pop() {
657            if matched_src.contains(&m.src_idx) || matched_tgt.contains(&m.tgt_idx) {
658                continue;
659            }
660            self.matchings.insert(m.src_idx, m.tgt_idx);
661            matched_src.insert(m.src_idx);
662            matched_tgt.insert(m.tgt_idx);
663        }
664    }
665
666    // -- Phase 2: Internal node matching -------------------------------------
667
668    fn match_internal_nodes(&mut self) {
669        // Process from deepest to shallowest. In BFS-built tree, higher indices
670        // are generally deeper, so we iterate in reverse.
671        let src_internal: Vec<usize> = (0..self.src_tree.nodes.len())
672            .rev()
673            .filter(|&i| !self.src_tree.is_leaf(i) && !self.matchings.contains_key(&i))
674            .collect();
675
676        let tgt_internal: Vec<usize> = (0..self.tgt_tree.nodes.len())
677            .rev()
678            .filter(|&i| !self.tgt_tree.is_leaf(i))
679            .collect();
680
681        let mut matched_tgt: HashSet<usize> = self.matchings.values().cloned().collect();
682
683        let mut heap = BinaryHeap::new();
684        let mut counter = 0usize;
685
686        for &si in &src_internal {
687            let src_leaves: HashSet<usize> =
688                self.src_tree.leaf_descendants(si).into_iter().collect();
689            let src_sql = node_sql(&self.src_tree.nodes[si], self.config.dialect);
690
691            for &ti in &tgt_internal {
692                if matched_tgt.contains(&ti) {
693                    continue;
694                }
695                if !is_same_type(&self.src_tree.nodes[si], &self.tgt_tree.nodes[ti]) {
696                    continue;
697                }
698
699                let tgt_leaves: HashSet<usize> =
700                    self.tgt_tree.leaf_descendants(ti).into_iter().collect();
701
702                // Count leaf descendants matched to each other
703                let common = src_leaves
704                    .iter()
705                    .filter(|&&sl| {
706                        self.matchings
707                            .get(&sl)
708                            .map_or(false, |&tl| tgt_leaves.contains(&tl))
709                    })
710                    .count();
711
712                let max_leaves = src_leaves.len().max(tgt_leaves.len());
713                if max_leaves == 0 {
714                    continue;
715                }
716
717                let leaf_sim = common as f64 / max_leaves as f64;
718
719                // Adaptive threshold for small subtrees
720                let t = if src_leaves.len().min(tgt_leaves.len()) <= 4 {
721                    0.4
722                } else {
723                    self.config.t
724                };
725
726                let tgt_sql = node_sql(&self.tgt_tree.nodes[ti], self.config.dialect);
727                let dice = dice_coefficient(&src_sql, &tgt_sql);
728
729                if leaf_sim >= 0.8 || (leaf_sim >= t && dice >= self.config.f) {
730                    heap.push(MatchCandidate {
731                        score: leaf_sim,
732                        parent_sim: parent_similarity_score(
733                            si,
734                            ti,
735                            &self.src_tree,
736                            &self.tgt_tree,
737                            &self.matchings,
738                        ),
739                        counter,
740                        src_idx: si,
741                        tgt_idx: ti,
742                    });
743                    counter += 1;
744                }
745            }
746        }
747
748        while let Some(m) = heap.pop() {
749            if self.matchings.contains_key(&m.src_idx) || matched_tgt.contains(&m.tgt_idx) {
750                continue;
751            }
752            self.matchings.insert(m.src_idx, m.tgt_idx);
753            matched_tgt.insert(m.tgt_idx);
754        }
755    }
756
757    // -- Phase 3: Edit script generation with Move detection ------------------
758
759    fn generate_edits(&self, delta_only: bool) -> Vec<Edit> {
760        let mut edits = Vec::new();
761        let matched_tgt: HashSet<usize> = self.matchings.values().cloned().collect();
762
763        // Build reverse mapping: tgt_idx -> src_idx
764        let reverse_matchings: HashMap<usize, usize> =
765            self.matchings.iter().map(|(&s, &t)| (t, s)).collect();
766
767        // Detect moved nodes via LCS on each matched parent's children
768        let mut moved_src: HashSet<usize> = HashSet::new();
769
770        for (&src_parent, &tgt_parent) in &self.matchings {
771            if self.src_tree.is_leaf(src_parent) {
772                continue;
773            }
774
775            let src_children = &self.src_tree.children_indices[src_parent];
776            let tgt_children = &self.tgt_tree.children_indices[tgt_parent];
777
778            if src_children.is_empty() || tgt_children.is_empty() {
779                continue;
780            }
781
782            // Build sequence of tgt indices for matched src children (in src order)
783            let src_seq: Vec<usize> = src_children
784                .iter()
785                .filter_map(|&sc| self.matchings.get(&sc).cloned())
786                .collect();
787
788            // Build sequence of tgt children that have a reverse match (in tgt order)
789            let tgt_seq: Vec<usize> = tgt_children
790                .iter()
791                .filter(|&&tc| reverse_matchings.contains_key(&tc))
792                .cloned()
793                .collect();
794
795            let lcs_pairs = lcs(&src_seq, &tgt_seq, |a, b| a == b);
796            let lcs_tgt_set: HashSet<usize> = lcs_pairs.iter().map(|&(i, _)| src_seq[i]).collect();
797
798            // Matched children not in the LCS had their position changed
799            for &sc in src_children {
800                if let Some(&tc) = self.matchings.get(&sc) {
801                    if !lcs_tgt_set.contains(&tc) {
802                        moved_src.insert(sc);
803                    }
804                }
805            }
806        }
807
808        // Unmatched source nodes → Remove
809        for i in 0..self.src_tree.nodes.len() {
810            if !self.matchings.contains_key(&i) {
811                edits.push(Edit::Remove {
812                    expression: self.src_tree.nodes[i].clone(),
813                });
814            }
815        }
816
817        // Unmatched target nodes → Insert
818        for i in 0..self.tgt_tree.nodes.len() {
819            if !matched_tgt.contains(&i) {
820                edits.push(Edit::Insert {
821                    expression: self.tgt_tree.nodes[i].clone(),
822                });
823            }
824        }
825
826        // Matched pairs → Update / Move / Keep
827        for (&src_idx, &tgt_idx) in &self.matchings {
828            let src_node = &self.src_tree.nodes[src_idx];
829            let tgt_node = &self.tgt_tree.nodes[tgt_idx];
830
831            let src_sql = node_sql(src_node, self.config.dialect);
832            let tgt_sql = node_sql(tgt_node, self.config.dialect);
833
834            if is_updatable(src_node) && src_sql != tgt_sql {
835                edits.push(Edit::Update {
836                    source: src_node.clone(),
837                    target: tgt_node.clone(),
838                });
839            } else if has_non_expression_leaf_change(src_node, tgt_node) {
840                edits.push(Edit::Update {
841                    source: src_node.clone(),
842                    target: tgt_node.clone(),
843                });
844            } else if moved_src.contains(&src_idx) {
845                edits.push(Edit::Move {
846                    source: src_node.clone(),
847                    target: tgt_node.clone(),
848                });
849            } else if !delta_only {
850                edits.push(Edit::Keep {
851                    source: src_node.clone(),
852                    target: tgt_node.clone(),
853                });
854            }
855        }
856
857        edits
858    }
859}
860
861// ---------------------------------------------------------------------------
862// Tests
863// ---------------------------------------------------------------------------
864
865#[cfg(test)]
866mod tests {
867    use super::*;
868    use crate::dialects::{Dialect, DialectType};
869
870    fn parse(sql: &str) -> Expression {
871        let dialect = Dialect::get(DialectType::Generic);
872        let ast = dialect.parse(sql).unwrap();
873        ast.into_iter().next().unwrap()
874    }
875
876    fn count_edits(edits: &[Edit]) -> (usize, usize, usize, usize, usize) {
877        let mut insert = 0;
878        let mut remove = 0;
879        let mut r#move = 0;
880        let mut update = 0;
881        let mut keep = 0;
882        for e in edits {
883            match e {
884                Edit::Insert { .. } => insert += 1,
885                Edit::Remove { .. } => remove += 1,
886                Edit::Move { .. } => r#move += 1,
887                Edit::Update { .. } => update += 1,
888                Edit::Keep { .. } => keep += 1,
889            }
890        }
891        (insert, remove, r#move, update, keep)
892    }
893
894    #[test]
895    fn test_diff_identical() {
896        let source = parse("SELECT a FROM t");
897        let target = parse("SELECT a FROM t");
898
899        let edits = diff(&source, &target, false);
900
901        // Should only have Keep edits
902        assert!(
903            edits.iter().all(|e| matches!(e, Edit::Keep { .. })),
904            "Expected only Keep edits, got: {:?}",
905            count_edits(&edits)
906        );
907    }
908
909    #[test]
910    fn test_diff_simple_change() {
911        let source = parse("SELECT a FROM t");
912        let target = parse("SELECT b FROM t");
913
914        let edits = diff(&source, &target, true);
915
916        // Column a → column b: single-char names with dice=0 don't match
917        // → Remove(a) + Insert(b)
918        assert!(!edits.is_empty());
919        assert!(has_changes(&edits));
920        let (ins, rem, _, _, _) = count_edits(&edits);
921        assert!(
922            ins > 0 && rem > 0,
923            "Expected Insert+Remove, got ins={ins} rem={rem}"
924        );
925    }
926
927    #[test]
928    fn test_diff_similar_column_update() {
929        let source = parse("SELECT col_a FROM t");
930        let target = parse("SELECT col_b FROM t");
931
932        let edits = diff(&source, &target, true);
933
934        // Longer names share bigrams → matched → Update
935        assert!(has_changes(&edits));
936        assert!(
937            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
938            "Expected Update for similar column name change"
939        );
940    }
941
942    #[test]
943    fn test_operator_change() {
944        let source = parse("SELECT a + b FROM t");
945        let target = parse("SELECT a - b FROM t");
946
947        let edits = diff(&source, &target, true);
948
949        // The operator changed from Add to Sub — different discriminants
950        // so they can't be matched → Remove(Add) + Insert(Sub)
951        assert!(!edits.is_empty());
952        let (ins, rem, _, _, _) = count_edits(&edits);
953        assert!(
954            ins > 0 && rem > 0,
955            "Expected Insert and Remove for operator change, got ins={ins} rem={rem}"
956        );
957    }
958
959    #[test]
960    fn test_column_added() {
961        let source = parse("SELECT a, b FROM t");
962        let target = parse("SELECT a, b, c FROM t");
963
964        let edits = diff(&source, &target, true);
965
966        // Column c was added
967        assert!(
968            edits.iter().any(|e| matches!(e, Edit::Insert { .. })),
969            "Expected at least one Insert edit for added column"
970        );
971    }
972
973    #[test]
974    fn test_column_removed() {
975        let source = parse("SELECT a, b, c FROM t");
976        let target = parse("SELECT a, c FROM t");
977
978        let edits = diff(&source, &target, true);
979
980        // Column b was removed
981        assert!(
982            edits.iter().any(|e| matches!(e, Edit::Remove { .. })),
983            "Expected at least one Remove edit for removed column"
984        );
985    }
986
987    #[test]
988    fn test_table_updated() {
989        let source = parse("SELECT a FROM table_one");
990        let target = parse("SELECT a FROM table_two");
991
992        let edits = diff(&source, &target, true);
993
994        // Table names share enough bigrams to match → Update
995        assert!(!edits.is_empty());
996        assert!(has_changes(&edits));
997        assert!(
998            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
999            "Expected Update for table name change"
1000        );
1001    }
1002
1003    #[test]
1004    fn test_lambda() {
1005        let source = parse("SELECT TRANSFORM(arr, a -> a + 1) FROM t");
1006        let target = parse("SELECT TRANSFORM(arr, b -> b + 1) FROM t");
1007
1008        let edits = diff(&source, &target, true);
1009
1010        // The lambda parameter changed
1011        assert!(has_changes(&edits));
1012    }
1013
1014    #[test]
1015    fn test_node_position_changed() {
1016        let source = parse("SELECT a, b, c FROM t");
1017        let target = parse("SELECT c, a, b FROM t");
1018
1019        let edits = diff(&source, &target, false);
1020
1021        // Some columns should be detected as moved
1022        let (_, _, moves, _, _) = count_edits(&edits);
1023        assert!(
1024            moves > 0,
1025            "Expected at least one Move for reordered columns"
1026        );
1027    }
1028
1029    #[test]
1030    fn test_cte_changes() {
1031        let source = parse("WITH cte AS (SELECT a FROM t WHERE a > 1000) SELECT * FROM cte");
1032        let target = parse("WITH cte AS (SELECT a FROM t WHERE a > 2000) SELECT * FROM cte");
1033
1034        let edits = diff(&source, &target, true);
1035
1036        // The literal in the WHERE clause changed (1000 → 2000 share bigrams → Update)
1037        assert!(has_changes(&edits));
1038        assert!(
1039            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1040            "Expected Update for literal change in CTE"
1041        );
1042    }
1043
1044    #[test]
1045    fn test_join_changes() {
1046        let source = parse("SELECT a FROM t LEFT JOIN s ON t.id = s.id");
1047        let target = parse("SELECT a FROM t RIGHT JOIN s ON t.id = s.id");
1048
1049        let edits = diff(&source, &target, true);
1050
1051        // LEFT vs RIGHT have different JoinKind → not same_type
1052        // The Join nodes produce Remove(LEFT JOIN) + Insert(RIGHT JOIN)
1053        assert!(has_changes(&edits));
1054        let (ins, rem, _, _, _) = count_edits(&edits);
1055        assert!(
1056            ins > 0 && rem > 0,
1057            "Expected Insert+Remove for join kind change, got ins={ins} rem={rem}"
1058        );
1059    }
1060
1061    #[test]
1062    fn test_window_functions() {
1063        let source = parse("SELECT ROW_NUMBER() OVER (ORDER BY a) FROM t");
1064        let target = parse("SELECT RANK() OVER (ORDER BY a) FROM t");
1065
1066        let edits = diff(&source, &target, true);
1067
1068        // Different window functions
1069        assert!(has_changes(&edits));
1070    }
1071
1072    #[test]
1073    fn test_non_expression_leaf_delta() {
1074        let source = parse("SELECT a FROM t UNION SELECT b FROM s");
1075        let target = parse("SELECT a FROM t UNION ALL SELECT b FROM s");
1076
1077        let edits = diff(&source, &target, true);
1078
1079        // UNION vs UNION ALL — non-expression leaf change (all flag)
1080        assert!(has_changes(&edits));
1081        assert!(
1082            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1083            "Expected Update for UNION → UNION ALL"
1084        );
1085    }
1086
1087    #[test]
1088    fn test_is_leaf() {
1089        let tree = IndexedTree::build(&parse("SELECT a, 1 FROM t"));
1090        // Root (Select) should not be a leaf
1091        assert!(!tree.is_leaf(0));
1092        // Leaf nodes should exist in the tree
1093        let leaves = tree.leaf_indices();
1094        assert!(!leaves.is_empty());
1095        // All leaves should have no children
1096        for &l in &leaves {
1097            assert!(tree.children_indices[l].is_empty());
1098        }
1099    }
1100
1101    #[test]
1102    fn test_same_type_special_cases() {
1103        // Same type — both Literal
1104        let a = Expression::Literal(crate::expressions::Literal::Number("1".to_string()));
1105        let b = Expression::Literal(crate::expressions::Literal::String("abc".to_string()));
1106        assert!(is_same_type(&a, &b));
1107
1108        // Different type — Literal vs Null
1109        let c = Expression::Null(crate::expressions::Null);
1110        assert!(!is_same_type(&a, &c));
1111
1112        // Join kind matters
1113        let join_left = Expression::Join(Box::new(crate::expressions::Join {
1114            this: Expression::Table(crate::expressions::TableRef::new("t")),
1115            on: None,
1116            using: vec![],
1117            kind: crate::expressions::JoinKind::Left,
1118            use_inner_keyword: false,
1119            use_outer_keyword: false,
1120            deferred_condition: false,
1121            join_hint: None,
1122            match_condition: None,
1123            pivots: vec![],
1124            comments: vec![],
1125            nesting_group: 0,
1126            directed: false,
1127        }));
1128        let join_right = Expression::Join(Box::new(crate::expressions::Join {
1129            this: Expression::Table(crate::expressions::TableRef::new("t")),
1130            on: None,
1131            using: vec![],
1132            kind: crate::expressions::JoinKind::Right,
1133            use_inner_keyword: false,
1134            use_outer_keyword: false,
1135            deferred_condition: false,
1136            join_hint: None,
1137            match_condition: None,
1138            pivots: vec![],
1139            comments: vec![],
1140            nesting_group: 0,
1141            directed: false,
1142        }));
1143        assert!(!is_same_type(&join_left, &join_right));
1144    }
1145
1146    #[test]
1147    fn test_comments_excluded() {
1148        // Comments on nodes should not affect the diff
1149        let source = parse("SELECT a FROM t");
1150        let target = parse("SELECT a FROM t");
1151
1152        let edits = diff(&source, &target, true);
1153
1154        // No changes — comments don't matter
1155        assert!(edits.is_empty() || !has_changes(&edits));
1156    }
1157}