Skip to main content

polyglot_sql/
diff.rs

1//! AST Diff - Compare SQL expressions
2//!
3//! This module provides functionality to compare two SQL ASTs and generate
4//! a list of differences (edit script) between them, using the ChangeDistiller
5//! algorithm with Dice coefficient matching.
6//!
7
8use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::generator::{Generator, GeneratorConfig};
11use serde::{Deserialize, Serialize};
12use std::cmp::Ordering;
13use std::collections::{BinaryHeap, HashMap, HashSet};
14
15/// Types of edits that can occur between two ASTs
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17#[serde(tag = "type", rename_all = "snake_case")]
18pub enum Edit {
19    /// A new node has been inserted
20    Insert { expression: Expression },
21    /// An existing node has been removed
22    Remove { expression: Expression },
23    /// An existing node's position has changed
24    Move {
25        source: Expression,
26        target: Expression,
27    },
28    /// An existing node has been updated (same position, different value)
29    Update {
30        source: Expression,
31        target: Expression,
32    },
33    /// An existing node hasn't been changed
34    Keep {
35        source: Expression,
36        target: Expression,
37    },
38}
39
40impl Edit {
41    /// Check if this edit represents a change (not a Keep)
42    pub fn is_change(&self) -> bool {
43        !matches!(self, Edit::Keep { .. })
44    }
45}
46
47/// Configuration for the diff algorithm
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct DiffConfig {
50    /// Dice coefficient threshold for internal node matching (default 0.6)
51    pub f: f64,
52    /// Leaf similarity threshold (default 0.6)
53    pub t: f64,
54    /// Optional dialect for SQL generation during comparison
55    pub dialect: Option<DialectType>,
56}
57
58impl Default for DiffConfig {
59    fn default() -> Self {
60        Self {
61            f: 0.6,
62            t: 0.6,
63            dialect: None,
64        }
65    }
66}
67
68/// Compare two expressions and return a list of edits
69///
70/// # Arguments
71/// * `source` - The source expression
72/// * `target` - The target expression to compare against
73/// * `delta_only` - If true, exclude Keep edits from the result
74///
75/// # Returns
76/// A vector of Edit operations that transform source into target
77///
78/// # Example
79/// ```ignore
80/// use polyglot_sql::diff::diff;
81/// use polyglot_sql::parse_one;
82/// use polyglot_sql::DialectType;
83///
84/// let source = parse_one("SELECT a + b FROM t", DialectType::Generic).unwrap();
85/// let target = parse_one("SELECT a + c FROM t", DialectType::Generic).unwrap();
86/// let edits = diff(&source, &target, false);
87/// ```
88pub fn diff(source: &Expression, target: &Expression, delta_only: bool) -> Vec<Edit> {
89    let config = DiffConfig::default();
90    diff_with_config(source, target, delta_only, &config)
91}
92
93/// Compare two expressions with custom configuration
94pub fn diff_with_config(
95    source: &Expression,
96    target: &Expression,
97    delta_only: bool,
98    config: &DiffConfig,
99) -> Vec<Edit> {
100    let mut distiller = ChangeDistiller::new(config.clone());
101    distiller.diff(source, target, delta_only)
102}
103
104/// Check if the diff contains any changes
105pub fn has_changes(edits: &[Edit]) -> bool {
106    edits.iter().any(|e| e.is_change())
107}
108
109/// Get only the changes from an edit list
110pub fn changes_only(edits: Vec<Edit>) -> Vec<Edit> {
111    edits.into_iter().filter(|e| e.is_change()).collect()
112}
113
114// ---------------------------------------------------------------------------
115// IndexedTree: flat BFS representation with parent-child tracking
116// ---------------------------------------------------------------------------
117
118/// Flattened tree representation with explicit parent/child index maps.
119struct IndexedTree {
120    nodes: Vec<Expression>,
121    parents: Vec<Option<usize>>,
122    children_indices: Vec<Vec<usize>>,
123}
124
125impl IndexedTree {
126    fn empty() -> Self {
127        Self {
128            nodes: Vec::new(),
129            parents: Vec::new(),
130            children_indices: Vec::new(),
131        }
132    }
133
134    fn build(root: &Expression) -> Self {
135        let mut tree = Self::empty();
136        tree.add_expr(root, None);
137        tree
138    }
139
140    fn add_expr(&mut self, expr: &Expression, parent_idx: Option<usize>) {
141        // Skip bare Identifier nodes — they're names, not diff targets
142        if matches!(expr, Expression::Identifier(_)) {
143            return;
144        }
145
146        let idx = self.nodes.len();
147        self.nodes.push(expr.clone());
148        self.parents.push(parent_idx);
149        self.children_indices.push(Vec::new());
150
151        if let Some(p) = parent_idx {
152            self.children_indices[p].push(idx);
153        }
154
155        self.add_children(expr, idx);
156    }
157
158    fn add_children(&mut self, expr: &Expression, parent_idx: usize) {
159        match expr {
160            Expression::Select(select) => {
161                if let Some(with) = &select.with {
162                    for cte in &with.ctes {
163                        self.add_expr(&Expression::Cte(Box::new(cte.clone())), Some(parent_idx));
164                    }
165                }
166                for e in &select.expressions {
167                    self.add_expr(e, Some(parent_idx));
168                }
169                if let Some(from) = &select.from {
170                    for e in &from.expressions {
171                        self.add_expr(e, Some(parent_idx));
172                    }
173                }
174                for join in &select.joins {
175                    self.add_expr(
176                        &Expression::Join(Box::new(join.clone())),
177                        Some(parent_idx),
178                    );
179                }
180                if let Some(w) = &select.where_clause {
181                    self.add_expr(&w.this, Some(parent_idx));
182                }
183                if let Some(gb) = &select.group_by {
184                    for e in &gb.expressions {
185                        self.add_expr(e, Some(parent_idx));
186                    }
187                }
188                if let Some(h) = &select.having {
189                    self.add_expr(&h.this, Some(parent_idx));
190                }
191                if let Some(ob) = &select.order_by {
192                    for o in &ob.expressions {
193                        self.add_expr(
194                            &Expression::Ordered(Box::new(o.clone())),
195                            Some(parent_idx),
196                        );
197                    }
198                }
199                if let Some(limit) = &select.limit {
200                    self.add_expr(&limit.this, Some(parent_idx));
201                }
202                if let Some(offset) = &select.offset {
203                    self.add_expr(&offset.this, Some(parent_idx));
204                }
205            }
206            Expression::Alias(alias) => {
207                self.add_expr(&alias.this, Some(parent_idx));
208            }
209            Expression::And(op)
210            | Expression::Or(op)
211            | Expression::Eq(op)
212            | Expression::Neq(op)
213            | Expression::Lt(op)
214            | Expression::Lte(op)
215            | Expression::Gt(op)
216            | Expression::Gte(op)
217            | Expression::Add(op)
218            | Expression::Sub(op)
219            | Expression::Mul(op)
220            | Expression::Div(op)
221            | Expression::Mod(op)
222            | Expression::BitwiseAnd(op)
223            | Expression::BitwiseOr(op)
224            | Expression::BitwiseXor(op)
225            | Expression::Concat(op) => {
226                self.add_expr(&op.left, Some(parent_idx));
227                self.add_expr(&op.right, Some(parent_idx));
228            }
229            Expression::Like(op) | Expression::ILike(op) => {
230                self.add_expr(&op.left, Some(parent_idx));
231                self.add_expr(&op.right, Some(parent_idx));
232            }
233            Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
234                self.add_expr(&u.this, Some(parent_idx));
235            }
236            Expression::Function(func) => {
237                for arg in &func.args {
238                    self.add_expr(arg, Some(parent_idx));
239                }
240            }
241            Expression::AggregateFunction(func) => {
242                for arg in &func.args {
243                    self.add_expr(arg, Some(parent_idx));
244                }
245            }
246            Expression::Join(j) => {
247                self.add_expr(&j.this, Some(parent_idx));
248                if let Some(on) = &j.on {
249                    self.add_expr(on, Some(parent_idx));
250                }
251            }
252            Expression::Anonymous(a) => {
253                for arg in &a.expressions {
254                    self.add_expr(arg, Some(parent_idx));
255                }
256            }
257            Expression::WindowFunction(wf) => {
258                self.add_expr(&wf.this, Some(parent_idx));
259            }
260            Expression::Cast(cast) => {
261                self.add_expr(&cast.this, Some(parent_idx));
262            }
263            Expression::Subquery(sq) => {
264                self.add_expr(&sq.this, Some(parent_idx));
265            }
266            Expression::Paren(p) => {
267                self.add_expr(&p.this, Some(parent_idx));
268            }
269            Expression::Union(u) => {
270                self.add_expr(&u.left, Some(parent_idx));
271                self.add_expr(&u.right, Some(parent_idx));
272            }
273            Expression::Intersect(i) => {
274                self.add_expr(&i.left, Some(parent_idx));
275                self.add_expr(&i.right, Some(parent_idx));
276            }
277            Expression::Except(e) => {
278                self.add_expr(&e.left, Some(parent_idx));
279                self.add_expr(&e.right, Some(parent_idx));
280            }
281            Expression::Cte(cte) => {
282                self.add_expr(&cte.this, Some(parent_idx));
283            }
284            Expression::Case(c) => {
285                if let Some(operand) = &c.operand {
286                    self.add_expr(operand, Some(parent_idx));
287                }
288                for (when, then) in &c.whens {
289                    self.add_expr(when, Some(parent_idx));
290                    self.add_expr(then, Some(parent_idx));
291                }
292                if let Some(else_) = &c.else_ {
293                    self.add_expr(else_, Some(parent_idx));
294                }
295            }
296            Expression::In(i) => {
297                self.add_expr(&i.this, Some(parent_idx));
298                for e in &i.expressions {
299                    self.add_expr(e, Some(parent_idx));
300                }
301                if let Some(q) = &i.query {
302                    self.add_expr(q, Some(parent_idx));
303                }
304            }
305            Expression::Between(b) => {
306                self.add_expr(&b.this, Some(parent_idx));
307                self.add_expr(&b.low, Some(parent_idx));
308                self.add_expr(&b.high, Some(parent_idx));
309            }
310            Expression::IsNull(i) => {
311                self.add_expr(&i.this, Some(parent_idx));
312            }
313            Expression::Exists(e) => {
314                self.add_expr(&e.this, Some(parent_idx));
315            }
316            Expression::Ordered(o) => {
317                self.add_expr(&o.this, Some(parent_idx));
318            }
319            Expression::Lambda(l) => {
320                self.add_expr(&l.body, Some(parent_idx));
321            }
322            Expression::Coalesce(c) => {
323                for e in &c.expressions {
324                    self.add_expr(e, Some(parent_idx));
325                }
326            }
327            Expression::Tuple(t) => {
328                for e in &t.expressions {
329                    self.add_expr(e, Some(parent_idx));
330                }
331            }
332            Expression::Array(a) => {
333                for e in &a.expressions {
334                    self.add_expr(e, Some(parent_idx));
335                }
336            }
337            // Leaf nodes — no children to add
338            Expression::Literal(_)
339            | Expression::Boolean(_)
340            | Expression::Null(_)
341            | Expression::Column(_)
342            | Expression::Table(_)
343            | Expression::Star(_)
344            | Expression::DataType(_)
345            | Expression::CurrentDate(_)
346            | Expression::CurrentTime(_)
347            | Expression::CurrentTimestamp(_) => {}
348            // Fallback: use ExpressionWalk::children()
349            other => {
350                use crate::traversal::ExpressionWalk;
351                for child in other.children() {
352                    if !matches!(child, Expression::Identifier(_)) {
353                        self.add_expr(child, Some(parent_idx));
354                    }
355                }
356            }
357        }
358    }
359
360    fn is_leaf(&self, idx: usize) -> bool {
361        self.children_indices[idx].is_empty()
362    }
363
364    fn leaf_indices(&self) -> Vec<usize> {
365        (0..self.nodes.len())
366            .filter(|&i| self.is_leaf(i))
367            .collect()
368    }
369
370    /// Get all leaf descendants of a node (including itself if it is a leaf).
371    fn leaf_descendants(&self, idx: usize) -> Vec<usize> {
372        let mut result = Vec::new();
373        let mut stack = vec![idx];
374        while let Some(i) = stack.pop() {
375            if self.is_leaf(i) {
376                result.push(i);
377            }
378            for &child in &self.children_indices[i] {
379                stack.push(child);
380            }
381        }
382        result
383    }
384}
385
386// ---------------------------------------------------------------------------
387// Helper functions
388// ---------------------------------------------------------------------------
389
390/// Compute Dice coefficient on character bigrams of two strings.
391fn dice_coefficient(a: &str, b: &str) -> f64 {
392    // For very short strings, use exact equality
393    if a.len() < 2 || b.len() < 2 {
394        return if a == b { 1.0 } else { 0.0 };
395    }
396    let a_bigrams = bigram_histo(a);
397    let b_bigrams = bigram_histo(b);
398    let common: usize = a_bigrams
399        .iter()
400        .map(|(k, v)| v.min(b_bigrams.get(k).unwrap_or(&0)))
401        .sum();
402    let total: usize = a_bigrams.values().sum::<usize>() + b_bigrams.values().sum::<usize>();
403    if total == 0 {
404        1.0
405    } else {
406        2.0 * common as f64 / total as f64
407    }
408}
409
410/// Build a frequency histogram of character bigrams.
411fn bigram_histo(s: &str) -> HashMap<(char, char), usize> {
412    let chars: Vec<char> = s.chars().collect();
413    let mut map = HashMap::new();
414    for w in chars.windows(2) {
415        *map.entry((w[0], w[1])).or_insert(0) += 1;
416    }
417    map
418}
419
420/// Generate SQL string for an expression, optionally with a dialect.
421fn node_sql(expr: &Expression, dialect: Option<DialectType>) -> String {
422    match dialect {
423        Some(d) => {
424            let config = GeneratorConfig {
425                dialect: Some(d),
426                ..GeneratorConfig::default()
427            };
428            let mut gen = Generator::with_config(config);
429            gen.generate(expr).unwrap_or_default()
430        }
431        None => Generator::sql(expr).unwrap_or_default(),
432    }
433}
434
435/// Check if two expressions are the same type for matching purposes.
436///
437/// Uses discriminant comparison with special cases for Join (must share kind)
438/// and Anonymous (must share function name).
439fn is_same_type(a: &Expression, b: &Expression) -> bool {
440    if std::mem::discriminant(a) != std::mem::discriminant(b) {
441        return false;
442    }
443    match (a, b) {
444        (Expression::Join(ja), Expression::Join(jb)) => ja.kind == jb.kind,
445        (Expression::Anonymous(aa), Expression::Anonymous(ab)) => {
446            Generator::sql(&aa.this).unwrap_or_default()
447                == Generator::sql(&ab.this).unwrap_or_default()
448        }
449        _ => true,
450    }
451}
452
453/// Count matching ancestor chain depth for parent similarity tiebreaker.
454fn parent_similarity_score(
455    src_idx: usize,
456    tgt_idx: usize,
457    src_tree: &IndexedTree,
458    tgt_tree: &IndexedTree,
459    matchings: &HashMap<usize, usize>,
460) -> usize {
461    let mut score = 0;
462    let mut s = src_tree.parents[src_idx];
463    let mut t = tgt_tree.parents[tgt_idx];
464    while let (Some(sp), Some(tp)) = (s, t) {
465        if matchings.get(&sp) == Some(&tp) {
466            score += 1;
467            s = src_tree.parents[sp];
468            t = tgt_tree.parents[tp];
469        } else {
470            break;
471        }
472    }
473    score
474}
475
476/// Check if an expression is an updatable leaf type.
477///
478/// Updatable types emit Update edits when matched but different.
479fn is_updatable(expr: &Expression) -> bool {
480    matches!(
481        expr,
482        Expression::Alias(_)
483            | Expression::Boolean(_)
484            | Expression::Column(_)
485            | Expression::DataType(_)
486            | Expression::Lambda(_)
487            | Expression::Literal(_)
488            | Expression::Table(_)
489            | Expression::WindowFunction(_)
490    )
491}
492
493/// Check if non-expression leaf fields differ between two matched same-type nodes.
494///
495/// These are scalar fields (booleans, enums) that aren't child expressions.
496fn has_non_expression_leaf_change(a: &Expression, b: &Expression) -> bool {
497    match (a, b) {
498        (Expression::Union(ua), Expression::Union(ub)) => {
499            ua.all != ub.all || ua.distinct != ub.distinct
500        }
501        (Expression::Intersect(ia), Expression::Intersect(ib)) => {
502            ia.all != ib.all || ia.distinct != ib.distinct
503        }
504        (Expression::Except(ea), Expression::Except(eb)) => {
505            ea.all != eb.all || ea.distinct != eb.distinct
506        }
507        (Expression::Ordered(oa), Expression::Ordered(ob)) => {
508            oa.desc != ob.desc || oa.nulls_first != ob.nulls_first
509        }
510        (Expression::Join(ja), Expression::Join(jb)) => ja.kind != jb.kind,
511        _ => false,
512    }
513}
514
515/// Standard LCS returning matched index pairs.
516fn lcs<T, F>(a: &[T], b: &[T], eq_fn: F) -> Vec<(usize, usize)>
517where
518    F: Fn(&T, &T) -> bool,
519{
520    let m = a.len();
521    let n = b.len();
522    let mut dp = vec![vec![0usize; n + 1]; m + 1];
523    for i in 1..=m {
524        for j in 1..=n {
525            if eq_fn(&a[i - 1], &b[j - 1]) {
526                dp[i][j] = dp[i - 1][j - 1] + 1;
527            } else {
528                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
529            }
530        }
531    }
532    let mut result = Vec::new();
533    let mut i = m;
534    let mut j = n;
535    while i > 0 && j > 0 {
536        if eq_fn(&a[i - 1], &b[j - 1]) {
537            result.push((i - 1, j - 1));
538            i -= 1;
539            j -= 1;
540        } else if dp[i - 1][j] > dp[i][j - 1] {
541            i -= 1;
542        } else {
543            j -= 1;
544        }
545    }
546    result.reverse();
547    result
548}
549
550// ---------------------------------------------------------------------------
551// BinaryHeap entry for greedy matching
552// ---------------------------------------------------------------------------
553
554#[derive(PartialEq)]
555struct MatchCandidate {
556    score: f64,
557    parent_sim: usize,
558    counter: usize, // tiebreaker for deterministic ordering
559    src_idx: usize,
560    tgt_idx: usize,
561}
562
563impl Eq for MatchCandidate {}
564
565impl PartialOrd for MatchCandidate {
566    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
567        Some(self.cmp(other))
568    }
569}
570
571impl Ord for MatchCandidate {
572    fn cmp(&self, other: &Self) -> Ordering {
573        self.score
574            .partial_cmp(&other.score)
575            .unwrap_or(Ordering::Equal)
576            .then_with(|| self.parent_sim.cmp(&other.parent_sim))
577            .then_with(|| self.counter.cmp(&other.counter))
578    }
579}
580
581// ---------------------------------------------------------------------------
582// ChangeDistiller: three-phase algorithm
583// ---------------------------------------------------------------------------
584
585struct ChangeDistiller {
586    config: DiffConfig,
587    src_tree: IndexedTree,
588    tgt_tree: IndexedTree,
589    matchings: HashMap<usize, usize>, // src_idx -> tgt_idx
590}
591
592impl ChangeDistiller {
593    fn new(config: DiffConfig) -> Self {
594        Self {
595            config,
596            src_tree: IndexedTree::empty(),
597            tgt_tree: IndexedTree::empty(),
598            matchings: HashMap::new(),
599        }
600    }
601
602    fn diff(
603        &mut self,
604        source: &Expression,
605        target: &Expression,
606        delta_only: bool,
607    ) -> Vec<Edit> {
608        self.src_tree = IndexedTree::build(source);
609        self.tgt_tree = IndexedTree::build(target);
610
611        // Phase 1: leaf matching via Dice coefficient
612        self.match_leaves();
613
614        // Phase 2: internal node matching via leaf descendants
615        self.match_internal_nodes();
616
617        // Phase 3: generate edit script with Move detection
618        self.generate_edits(delta_only)
619    }
620
621    // -- Phase 1: Leaf matching -----------------------------------------------
622
623    fn match_leaves(&mut self) {
624        let src_leaves = self.src_tree.leaf_indices();
625        let tgt_leaves = self.tgt_tree.leaf_indices();
626
627        // Pre-compute SQL strings for all leaves
628        let src_sql: Vec<String> = src_leaves
629            .iter()
630            .map(|&i| node_sql(&self.src_tree.nodes[i], self.config.dialect))
631            .collect();
632        let tgt_sql: Vec<String> = tgt_leaves
633            .iter()
634            .map(|&i| node_sql(&self.tgt_tree.nodes[i], self.config.dialect))
635            .collect();
636
637        let mut heap = BinaryHeap::new();
638        let mut counter = 0usize;
639
640        for (si_pos, &si) in src_leaves.iter().enumerate() {
641            for (ti_pos, &ti) in tgt_leaves.iter().enumerate() {
642                if !is_same_type(&self.src_tree.nodes[si], &self.tgt_tree.nodes[ti]) {
643                    continue;
644                }
645                let score = dice_coefficient(&src_sql[si_pos], &tgt_sql[ti_pos]);
646                if score >= self.config.t {
647                    let parent_sim = parent_similarity_score(
648                        si,
649                        ti,
650                        &self.src_tree,
651                        &self.tgt_tree,
652                        &self.matchings,
653                    );
654                    heap.push(MatchCandidate {
655                        score,
656                        parent_sim,
657                        counter,
658                        src_idx: si,
659                        tgt_idx: ti,
660                    });
661                    counter += 1;
662                }
663            }
664        }
665
666        let mut matched_src: HashSet<usize> = HashSet::new();
667        let mut matched_tgt: HashSet<usize> = HashSet::new();
668
669        while let Some(m) = heap.pop() {
670            if matched_src.contains(&m.src_idx) || matched_tgt.contains(&m.tgt_idx) {
671                continue;
672            }
673            self.matchings.insert(m.src_idx, m.tgt_idx);
674            matched_src.insert(m.src_idx);
675            matched_tgt.insert(m.tgt_idx);
676        }
677    }
678
679    // -- Phase 2: Internal node matching -------------------------------------
680
681    fn match_internal_nodes(&mut self) {
682        // Process from deepest to shallowest. In BFS-built tree, higher indices
683        // are generally deeper, so we iterate in reverse.
684        let src_internal: Vec<usize> = (0..self.src_tree.nodes.len())
685            .rev()
686            .filter(|&i| !self.src_tree.is_leaf(i) && !self.matchings.contains_key(&i))
687            .collect();
688
689        let tgt_internal: Vec<usize> = (0..self.tgt_tree.nodes.len())
690            .rev()
691            .filter(|&i| !self.tgt_tree.is_leaf(i))
692            .collect();
693
694        let mut matched_tgt: HashSet<usize> = self.matchings.values().cloned().collect();
695
696        let mut heap = BinaryHeap::new();
697        let mut counter = 0usize;
698
699        for &si in &src_internal {
700            let src_leaves: HashSet<usize> =
701                self.src_tree.leaf_descendants(si).into_iter().collect();
702            let src_sql = node_sql(&self.src_tree.nodes[si], self.config.dialect);
703
704            for &ti in &tgt_internal {
705                if matched_tgt.contains(&ti) {
706                    continue;
707                }
708                if !is_same_type(&self.src_tree.nodes[si], &self.tgt_tree.nodes[ti]) {
709                    continue;
710                }
711
712                let tgt_leaves: HashSet<usize> =
713                    self.tgt_tree.leaf_descendants(ti).into_iter().collect();
714
715                // Count leaf descendants matched to each other
716                let common = src_leaves
717                    .iter()
718                    .filter(|&&sl| {
719                        self.matchings
720                            .get(&sl)
721                            .map_or(false, |&tl| tgt_leaves.contains(&tl))
722                    })
723                    .count();
724
725                let max_leaves = src_leaves.len().max(tgt_leaves.len());
726                if max_leaves == 0 {
727                    continue;
728                }
729
730                let leaf_sim = common as f64 / max_leaves as f64;
731
732                // Adaptive threshold for small subtrees
733                let t = if src_leaves.len().min(tgt_leaves.len()) <= 4 {
734                    0.4
735                } else {
736                    self.config.t
737                };
738
739                let tgt_sql = node_sql(&self.tgt_tree.nodes[ti], self.config.dialect);
740                let dice = dice_coefficient(&src_sql, &tgt_sql);
741
742                if leaf_sim >= 0.8 || (leaf_sim >= t && dice >= self.config.f) {
743                    heap.push(MatchCandidate {
744                        score: leaf_sim,
745                        parent_sim: parent_similarity_score(
746                            si,
747                            ti,
748                            &self.src_tree,
749                            &self.tgt_tree,
750                            &self.matchings,
751                        ),
752                        counter,
753                        src_idx: si,
754                        tgt_idx: ti,
755                    });
756                    counter += 1;
757                }
758            }
759        }
760
761        while let Some(m) = heap.pop() {
762            if self.matchings.contains_key(&m.src_idx) || matched_tgt.contains(&m.tgt_idx) {
763                continue;
764            }
765            self.matchings.insert(m.src_idx, m.tgt_idx);
766            matched_tgt.insert(m.tgt_idx);
767        }
768    }
769
770    // -- Phase 3: Edit script generation with Move detection ------------------
771
772    fn generate_edits(&self, delta_only: bool) -> Vec<Edit> {
773        let mut edits = Vec::new();
774        let matched_tgt: HashSet<usize> = self.matchings.values().cloned().collect();
775
776        // Build reverse mapping: tgt_idx -> src_idx
777        let reverse_matchings: HashMap<usize, usize> = self
778            .matchings
779            .iter()
780            .map(|(&s, &t)| (t, s))
781            .collect();
782
783        // Detect moved nodes via LCS on each matched parent's children
784        let mut moved_src: HashSet<usize> = HashSet::new();
785
786        for (&src_parent, &tgt_parent) in &self.matchings {
787            if self.src_tree.is_leaf(src_parent) {
788                continue;
789            }
790
791            let src_children = &self.src_tree.children_indices[src_parent];
792            let tgt_children = &self.tgt_tree.children_indices[tgt_parent];
793
794            if src_children.is_empty() || tgt_children.is_empty() {
795                continue;
796            }
797
798            // Build sequence of tgt indices for matched src children (in src order)
799            let src_seq: Vec<usize> = src_children
800                .iter()
801                .filter_map(|&sc| self.matchings.get(&sc).cloned())
802                .collect();
803
804            // Build sequence of tgt children that have a reverse match (in tgt order)
805            let tgt_seq: Vec<usize> = tgt_children
806                .iter()
807                .filter(|&&tc| reverse_matchings.contains_key(&tc))
808                .cloned()
809                .collect();
810
811            let lcs_pairs = lcs(&src_seq, &tgt_seq, |a, b| a == b);
812            let lcs_tgt_set: HashSet<usize> = lcs_pairs.iter().map(|&(i, _)| src_seq[i]).collect();
813
814            // Matched children not in the LCS had their position changed
815            for &sc in src_children {
816                if let Some(&tc) = self.matchings.get(&sc) {
817                    if !lcs_tgt_set.contains(&tc) {
818                        moved_src.insert(sc);
819                    }
820                }
821            }
822        }
823
824        // Unmatched source nodes → Remove
825        for i in 0..self.src_tree.nodes.len() {
826            if !self.matchings.contains_key(&i) {
827                edits.push(Edit::Remove {
828                    expression: self.src_tree.nodes[i].clone(),
829                });
830            }
831        }
832
833        // Unmatched target nodes → Insert
834        for i in 0..self.tgt_tree.nodes.len() {
835            if !matched_tgt.contains(&i) {
836                edits.push(Edit::Insert {
837                    expression: self.tgt_tree.nodes[i].clone(),
838                });
839            }
840        }
841
842        // Matched pairs → Update / Move / Keep
843        for (&src_idx, &tgt_idx) in &self.matchings {
844            let src_node = &self.src_tree.nodes[src_idx];
845            let tgt_node = &self.tgt_tree.nodes[tgt_idx];
846
847            let src_sql = node_sql(src_node, self.config.dialect);
848            let tgt_sql = node_sql(tgt_node, self.config.dialect);
849
850            if is_updatable(src_node) && src_sql != tgt_sql {
851                edits.push(Edit::Update {
852                    source: src_node.clone(),
853                    target: tgt_node.clone(),
854                });
855            } else if has_non_expression_leaf_change(src_node, tgt_node) {
856                edits.push(Edit::Update {
857                    source: src_node.clone(),
858                    target: tgt_node.clone(),
859                });
860            } else if moved_src.contains(&src_idx) {
861                edits.push(Edit::Move {
862                    source: src_node.clone(),
863                    target: tgt_node.clone(),
864                });
865            } else if !delta_only {
866                edits.push(Edit::Keep {
867                    source: src_node.clone(),
868                    target: tgt_node.clone(),
869                });
870            }
871        }
872
873        edits
874    }
875}
876
877// ---------------------------------------------------------------------------
878// Tests
879// ---------------------------------------------------------------------------
880
881#[cfg(test)]
882mod tests {
883    use super::*;
884    use crate::dialects::{Dialect, DialectType};
885
886    fn parse(sql: &str) -> Expression {
887        let dialect = Dialect::get(DialectType::Generic);
888        let ast = dialect.parse(sql).unwrap();
889        ast.into_iter().next().unwrap()
890    }
891
892    fn count_edits(edits: &[Edit]) -> (usize, usize, usize, usize, usize) {
893        let mut insert = 0;
894        let mut remove = 0;
895        let mut r#move = 0;
896        let mut update = 0;
897        let mut keep = 0;
898        for e in edits {
899            match e {
900                Edit::Insert { .. } => insert += 1,
901                Edit::Remove { .. } => remove += 1,
902                Edit::Move { .. } => r#move += 1,
903                Edit::Update { .. } => update += 1,
904                Edit::Keep { .. } => keep += 1,
905            }
906        }
907        (insert, remove, r#move, update, keep)
908    }
909
910    #[test]
911    fn test_diff_identical() {
912        let source = parse("SELECT a FROM t");
913        let target = parse("SELECT a FROM t");
914
915        let edits = diff(&source, &target, false);
916
917        // Should only have Keep edits
918        assert!(
919            edits.iter().all(|e| matches!(e, Edit::Keep { .. })),
920            "Expected only Keep edits, got: {:?}",
921            count_edits(&edits)
922        );
923    }
924
925    #[test]
926    fn test_diff_simple_change() {
927        let source = parse("SELECT a FROM t");
928        let target = parse("SELECT b FROM t");
929
930        let edits = diff(&source, &target, true);
931
932        // Column a → column b: single-char names with dice=0 don't match
933        // → Remove(a) + Insert(b)
934        assert!(!edits.is_empty());
935        assert!(has_changes(&edits));
936        let (ins, rem, _, _, _) = count_edits(&edits);
937        assert!(ins > 0 && rem > 0, "Expected Insert+Remove, got ins={ins} rem={rem}");
938    }
939
940    #[test]
941    fn test_diff_similar_column_update() {
942        let source = parse("SELECT col_a FROM t");
943        let target = parse("SELECT col_b FROM t");
944
945        let edits = diff(&source, &target, true);
946
947        // Longer names share bigrams → matched → Update
948        assert!(has_changes(&edits));
949        assert!(
950            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
951            "Expected Update for similar column name change"
952        );
953    }
954
955    #[test]
956    fn test_operator_change() {
957        let source = parse("SELECT a + b FROM t");
958        let target = parse("SELECT a - b FROM t");
959
960        let edits = diff(&source, &target, true);
961
962        // The operator changed from Add to Sub — different discriminants
963        // so they can't be matched → Remove(Add) + Insert(Sub)
964        assert!(!edits.is_empty());
965        let (ins, rem, _, _, _) = count_edits(&edits);
966        assert!(
967            ins > 0 && rem > 0,
968            "Expected Insert and Remove for operator change, got ins={ins} rem={rem}"
969        );
970    }
971
972    #[test]
973    fn test_column_added() {
974        let source = parse("SELECT a, b FROM t");
975        let target = parse("SELECT a, b, c FROM t");
976
977        let edits = diff(&source, &target, true);
978
979        // Column c was added
980        assert!(
981            edits.iter().any(|e| matches!(e, Edit::Insert { .. })),
982            "Expected at least one Insert edit for added column"
983        );
984    }
985
986    #[test]
987    fn test_column_removed() {
988        let source = parse("SELECT a, b, c FROM t");
989        let target = parse("SELECT a, c FROM t");
990
991        let edits = diff(&source, &target, true);
992
993        // Column b was removed
994        assert!(
995            edits.iter().any(|e| matches!(e, Edit::Remove { .. })),
996            "Expected at least one Remove edit for removed column"
997        );
998    }
999
1000    #[test]
1001    fn test_table_updated() {
1002        let source = parse("SELECT a FROM table_one");
1003        let target = parse("SELECT a FROM table_two");
1004
1005        let edits = diff(&source, &target, true);
1006
1007        // Table names share enough bigrams to match → Update
1008        assert!(!edits.is_empty());
1009        assert!(has_changes(&edits));
1010        assert!(
1011            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1012            "Expected Update for table name change"
1013        );
1014    }
1015
1016    #[test]
1017    fn test_lambda() {
1018        let source = parse("SELECT TRANSFORM(arr, a -> a + 1) FROM t");
1019        let target = parse("SELECT TRANSFORM(arr, b -> b + 1) FROM t");
1020
1021        let edits = diff(&source, &target, true);
1022
1023        // The lambda parameter changed
1024        assert!(has_changes(&edits));
1025    }
1026
1027    #[test]
1028    fn test_node_position_changed() {
1029        let source = parse("SELECT a, b, c FROM t");
1030        let target = parse("SELECT c, a, b FROM t");
1031
1032        let edits = diff(&source, &target, false);
1033
1034        // Some columns should be detected as moved
1035        let (_, _, moves, _, _) = count_edits(&edits);
1036        assert!(moves > 0, "Expected at least one Move for reordered columns");
1037    }
1038
1039    #[test]
1040    fn test_cte_changes() {
1041        let source = parse("WITH cte AS (SELECT a FROM t WHERE a > 1000) SELECT * FROM cte");
1042        let target = parse("WITH cte AS (SELECT a FROM t WHERE a > 2000) SELECT * FROM cte");
1043
1044        let edits = diff(&source, &target, true);
1045
1046        // The literal in the WHERE clause changed (1000 → 2000 share bigrams → Update)
1047        assert!(has_changes(&edits));
1048        assert!(
1049            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1050            "Expected Update for literal change in CTE"
1051        );
1052    }
1053
1054    #[test]
1055    fn test_join_changes() {
1056        let source = parse("SELECT a FROM t LEFT JOIN s ON t.id = s.id");
1057        let target = parse("SELECT a FROM t RIGHT JOIN s ON t.id = s.id");
1058
1059        let edits = diff(&source, &target, true);
1060
1061        // LEFT vs RIGHT have different JoinKind → not same_type
1062        // The Join nodes produce Remove(LEFT JOIN) + Insert(RIGHT JOIN)
1063        assert!(has_changes(&edits));
1064        let (ins, rem, _, _, _) = count_edits(&edits);
1065        assert!(
1066            ins > 0 && rem > 0,
1067            "Expected Insert+Remove for join kind change, got ins={ins} rem={rem}"
1068        );
1069    }
1070
1071    #[test]
1072    fn test_window_functions() {
1073        let source = parse("SELECT ROW_NUMBER() OVER (ORDER BY a) FROM t");
1074        let target = parse("SELECT RANK() OVER (ORDER BY a) FROM t");
1075
1076        let edits = diff(&source, &target, true);
1077
1078        // Different window functions
1079        assert!(has_changes(&edits));
1080    }
1081
1082    #[test]
1083    fn test_non_expression_leaf_delta() {
1084        let source = parse("SELECT a FROM t UNION SELECT b FROM s");
1085        let target = parse("SELECT a FROM t UNION ALL SELECT b FROM s");
1086
1087        let edits = diff(&source, &target, true);
1088
1089        // UNION vs UNION ALL — non-expression leaf change (all flag)
1090        assert!(has_changes(&edits));
1091        assert!(
1092            edits.iter().any(|e| matches!(e, Edit::Update { .. })),
1093            "Expected Update for UNION → UNION ALL"
1094        );
1095    }
1096
1097    #[test]
1098    fn test_is_leaf() {
1099        let tree = IndexedTree::build(&parse("SELECT a, 1 FROM t"));
1100        // Root (Select) should not be a leaf
1101        assert!(!tree.is_leaf(0));
1102        // Leaf nodes should exist in the tree
1103        let leaves = tree.leaf_indices();
1104        assert!(!leaves.is_empty());
1105        // All leaves should have no children
1106        for &l in &leaves {
1107            assert!(tree.children_indices[l].is_empty());
1108        }
1109    }
1110
1111    #[test]
1112    fn test_same_type_special_cases() {
1113        // Same type — both Literal
1114        let a = Expression::Literal(crate::expressions::Literal::Number("1".to_string()));
1115        let b = Expression::Literal(crate::expressions::Literal::String("abc".to_string()));
1116        assert!(is_same_type(&a, &b));
1117
1118        // Different type — Literal vs Null
1119        let c = Expression::Null(crate::expressions::Null);
1120        assert!(!is_same_type(&a, &c));
1121
1122        // Join kind matters
1123        let join_left = Expression::Join(Box::new(crate::expressions::Join {
1124            this: Expression::Table(crate::expressions::TableRef::new("t")),
1125            on: None,
1126            using: vec![],
1127            kind: crate::expressions::JoinKind::Left,
1128            use_inner_keyword: false,
1129            use_outer_keyword: false,
1130            deferred_condition: false,
1131            join_hint: None,
1132            match_condition: None,
1133            pivots: vec![],
1134        }));
1135        let join_right = Expression::Join(Box::new(crate::expressions::Join {
1136            this: Expression::Table(crate::expressions::TableRef::new("t")),
1137            on: None,
1138            using: vec![],
1139            kind: crate::expressions::JoinKind::Right,
1140            use_inner_keyword: false,
1141            use_outer_keyword: false,
1142            deferred_condition: false,
1143            join_hint: None,
1144            match_condition: None,
1145            pivots: vec![],
1146        }));
1147        assert!(!is_same_type(&join_left, &join_right));
1148    }
1149
1150    #[test]
1151    fn test_comments_excluded() {
1152        // Comments on nodes should not affect the diff
1153        let source = parse("SELECT a FROM t");
1154        let target = parse("SELECT a FROM t");
1155
1156        let edits = diff(&source, &target, true);
1157
1158        // No changes — comments don't matter
1159        assert!(edits.is_empty() || !has_changes(&edits));
1160    }
1161}