Skip to main content

sqlglot_rust/diff/
mod.rs

1//! AST Diff — semantic comparison of SQL expression trees.
2//!
3//! Implements a tree edit distance algorithm inspired by the Change Distiller
4//! approach used in Python sqlglot's `diff.py`. Computes a sequence of
5//! [`ChangeAction`]s that transform one AST into another.
6//!
7//! # Example
8//!
9//! ```rust
10//! use sqlglot_rust::{parse, Dialect};
11//! use sqlglot_rust::diff::{diff, ChangeAction};
12//!
13//! let source = parse("SELECT a, b FROM t WHERE a > 1", Dialect::Ansi).unwrap();
14//! let target = parse("SELECT a, c FROM t WHERE a > 2", Dialect::Ansi).unwrap();
15//! let changes = diff(&source, &target);
16//!
17//! for change in &changes {
18//!     println!("{change:?}");
19//! }
20//! ```
21
22use std::collections::HashMap;
23
24use crate::ast::*;
25
26/// A change action describing a single difference between two ASTs.
27#[derive(Debug, Clone, PartialEq)]
28pub enum ChangeAction {
29    /// A node present in `source` that was removed.
30    Remove(AstNode),
31    /// A node inserted into `target` that was not in `source`.
32    Insert(AstNode),
33    /// A node that is structurally identical in both trees.
34    Keep(AstNode, AstNode),
35    /// A node that was moved to a different position in the tree.
36    Move(AstNode, AstNode),
37    /// A node in `source` that was replaced by a different node in `target`.
38    Update(AstNode, AstNode),
39}
40
41/// A wrapper around an AST node that can represent either statements or
42/// expressions, enabling uniform diff output.
43#[derive(Debug, Clone, PartialEq)]
44pub enum AstNode {
45    Statement(Box<Statement>),
46    Expr(Expr),
47    SelectItem(SelectItem),
48    JoinClause(JoinClause),
49    OrderByItem(OrderByItem),
50    Cte(Box<Cte>),
51    ColumnDef(ColumnDef),
52    TableConstraint(TableConstraint),
53}
54
55impl std::fmt::Display for AstNode {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            AstNode::Statement(s) => write!(f, "{s:?}"),
59            AstNode::Expr(e) => write!(f, "{e:?}"),
60            AstNode::SelectItem(si) => write!(f, "{si:?}"),
61            AstNode::JoinClause(j) => write!(f, "{j:?}"),
62            AstNode::OrderByItem(o) => write!(f, "{o:?}"),
63            AstNode::Cte(c) => write!(f, "{c:?}"),
64            AstNode::ColumnDef(cd) => write!(f, "{cd:?}"),
65            AstNode::TableConstraint(tc) => write!(f, "{tc:?}"),
66        }
67    }
68}
69
70/// Compute the semantic diff between two SQL statements.
71///
72/// Returns a list of [`ChangeAction`]s describing the minimal set of changes
73/// needed to transform `source` into `target`.
74#[must_use]
75pub fn diff(source: &Statement, target: &Statement) -> Vec<ChangeAction> {
76    let mut differ = AstDiffer::new();
77    differ.diff_statements(source, target);
78    differ.changes
79}
80
81/// Internal differ state that accumulates change actions.
82struct AstDiffer {
83    changes: Vec<ChangeAction>,
84}
85
86impl AstDiffer {
87    fn new() -> Self {
88        Self {
89            changes: Vec::new(),
90        }
91    }
92
93    fn diff_statements(&mut self, source: &Statement, target: &Statement) {
94        use Statement::*;
95
96        match (source, target) {
97            (Select(s), Select(t)) => self.diff_select(s, t),
98            (Insert(s), Insert(t)) => self.diff_insert(s, t),
99            (Update(s), Update(t)) => self.diff_update(s, t),
100            (Delete(s), Delete(t)) => self.diff_delete(s, t),
101            (CreateTable(s), CreateTable(t)) => self.diff_create_table(s, t),
102            (DropTable(s), DropTable(t)) => self.diff_drop_table(s, t),
103            (SetOperation(s), SetOperation(t)) => self.diff_set_operation(s, t),
104            (AlterTable(s), AlterTable(t)) => self.diff_alter_table(s, t),
105            (CreateView(s), CreateView(t)) => self.diff_create_view(s, t),
106            (Expression(s), Expression(t)) => self.diff_exprs(s, t),
107            _ => {
108                // Different statement types → remove old, insert new
109                self.changes
110                    .push(ChangeAction::Remove(AstNode::Statement(Box::new(
111                        source.clone(),
112                    ))));
113                self.changes
114                    .push(ChangeAction::Insert(AstNode::Statement(Box::new(
115                        target.clone(),
116                    ))));
117            }
118        }
119    }
120
121    // ── SELECT ─────────────────────────────────────────────────────────
122
123    fn diff_select(&mut self, source: &SelectStatement, target: &SelectStatement) {
124        // CTEs
125        self.diff_ctes(&source.ctes, &target.ctes);
126
127        // DISTINCT
128        if source.distinct != target.distinct {
129            if target.distinct {
130                self.changes
131                    .push(ChangeAction::Insert(AstNode::Expr(Expr::Column {
132                        table: None,
133                        name: "DISTINCT".to_string(),
134                        quote_style: QuoteStyle::None,
135                        table_quote_style: QuoteStyle::None,
136                    })));
137            } else {
138                self.changes
139                    .push(ChangeAction::Remove(AstNode::Expr(Expr::Column {
140                        table: None,
141                        name: "DISTINCT".to_string(),
142                        quote_style: QuoteStyle::None,
143                        table_quote_style: QuoteStyle::None,
144                    })));
145            }
146        }
147
148        // SELECT columns (ordered)
149        self.diff_select_items(&source.columns, &target.columns);
150
151        // FROM
152        match (&source.from, &target.from) {
153            (Some(sf), Some(tf)) => self.diff_table_sources(&sf.source, &tf.source),
154            (None, Some(tf)) => self.insert_table_source(&tf.source),
155            (Some(sf), None) => self.remove_table_source(&sf.source),
156            (None, None) => {}
157        }
158
159        // JOINs
160        self.diff_joins(&source.joins, &target.joins);
161
162        // WHERE
163        self.diff_optional_exprs(&source.where_clause, &target.where_clause);
164
165        // GROUP BY
166        self.diff_expr_lists(&source.group_by, &target.group_by);
167
168        // HAVING
169        self.diff_optional_exprs(&source.having, &target.having);
170
171        // ORDER BY
172        self.diff_order_by(&source.order_by, &target.order_by);
173
174        // LIMIT
175        self.diff_optional_exprs(&source.limit, &target.limit);
176
177        // OFFSET
178        self.diff_optional_exprs(&source.offset, &target.offset);
179
180        // QUALIFY
181        self.diff_optional_exprs(&source.qualify, &target.qualify);
182    }
183
184    // ── INSERT ─────────────────────────────────────────────────────────
185
186    fn diff_insert(&mut self, source: &InsertStatement, target: &InsertStatement) {
187        if source.table != target.table {
188            self.changes.push(ChangeAction::Update(
189                AstNode::Expr(table_ref_to_expr(&source.table)),
190                AstNode::Expr(table_ref_to_expr(&target.table)),
191            ));
192        }
193
194        // Column list
195        self.diff_string_lists(&source.columns, &target.columns);
196
197        // Source
198        match (&source.source, &target.source) {
199            (InsertSource::Values(sv), InsertSource::Values(tv)) => {
200                for (i, (sr, tr)) in sv.iter().zip(tv.iter()).enumerate() {
201                    self.diff_expr_lists(sr, tr);
202                    let _ = i;
203                }
204                for extra in sv.iter().skip(tv.len()) {
205                    for e in extra {
206                        self.changes
207                            .push(ChangeAction::Remove(AstNode::Expr(e.clone())));
208                    }
209                }
210                for extra in tv.iter().skip(sv.len()) {
211                    for e in extra {
212                        self.changes
213                            .push(ChangeAction::Insert(AstNode::Expr(e.clone())));
214                    }
215                }
216            }
217            (InsertSource::Query(sq), InsertSource::Query(tq)) => {
218                self.diff_statements(sq, tq);
219            }
220            _ => {
221                self.changes
222                    .push(ChangeAction::Remove(AstNode::Statement(Box::new(
223                        Statement::Insert(source.clone()),
224                    ))));
225                self.changes
226                    .push(ChangeAction::Insert(AstNode::Statement(Box::new(
227                        Statement::Insert(target.clone()),
228                    ))));
229            }
230        }
231    }
232
233    // ── UPDATE ─────────────────────────────────────────────────────────
234
235    fn diff_update(&mut self, source: &UpdateStatement, target: &UpdateStatement) {
236        if source.table != target.table {
237            self.changes.push(ChangeAction::Update(
238                AstNode::Expr(table_ref_to_expr(&source.table)),
239                AstNode::Expr(table_ref_to_expr(&target.table)),
240            ));
241        }
242
243        // Assignments (ordered by column name matching)
244        let source_map: HashMap<&str, &Expr> = source
245            .assignments
246            .iter()
247            .map(|(k, v)| (k.as_str(), v))
248            .collect();
249        let target_map: HashMap<&str, &Expr> = target
250            .assignments
251            .iter()
252            .map(|(k, v)| (k.as_str(), v))
253            .collect();
254
255        for (col, src_val) in &source_map {
256            if let Some(tgt_val) = target_map.get(col) {
257                self.diff_exprs(src_val, tgt_val);
258            } else {
259                self.changes
260                    .push(ChangeAction::Remove(AstNode::Expr((*src_val).clone())));
261            }
262        }
263        for (col, tgt_val) in &target_map {
264            if !source_map.contains_key(col) {
265                self.changes
266                    .push(ChangeAction::Insert(AstNode::Expr((*tgt_val).clone())));
267            }
268        }
269
270        self.diff_optional_exprs(&source.where_clause, &target.where_clause);
271    }
272
273    // ── DELETE ─────────────────────────────────────────────────────────
274
275    fn diff_delete(&mut self, source: &DeleteStatement, target: &DeleteStatement) {
276        if source.table != target.table {
277            self.changes.push(ChangeAction::Update(
278                AstNode::Expr(table_ref_to_expr(&source.table)),
279                AstNode::Expr(table_ref_to_expr(&target.table)),
280            ));
281        }
282        self.diff_optional_exprs(&source.where_clause, &target.where_clause);
283    }
284
285    // ── CREATE TABLE ───────────────────────────────────────────────────
286
287    fn diff_create_table(&mut self, source: &CreateTableStatement, target: &CreateTableStatement) {
288        if source.table != target.table {
289            self.changes.push(ChangeAction::Update(
290                AstNode::Expr(table_ref_to_expr(&source.table)),
291                AstNode::Expr(table_ref_to_expr(&target.table)),
292            ));
293        }
294
295        // Column definitions (match by name)
296        let source_cols: HashMap<&str, &ColumnDef> = source
297            .columns
298            .iter()
299            .map(|c| (c.name.as_str(), c))
300            .collect();
301        let target_cols: HashMap<&str, &ColumnDef> = target
302            .columns
303            .iter()
304            .map(|c| (c.name.as_str(), c))
305            .collect();
306
307        for (name, src_col) in &source_cols {
308            if let Some(tgt_col) = target_cols.get(name) {
309                if src_col != tgt_col {
310                    self.changes.push(ChangeAction::Update(
311                        AstNode::ColumnDef((*src_col).clone()),
312                        AstNode::ColumnDef((*tgt_col).clone()),
313                    ));
314                } else {
315                    self.changes.push(ChangeAction::Keep(
316                        AstNode::ColumnDef((*src_col).clone()),
317                        AstNode::ColumnDef((*tgt_col).clone()),
318                    ));
319                }
320            } else {
321                self.changes
322                    .push(ChangeAction::Remove(AstNode::ColumnDef((*src_col).clone())));
323            }
324        }
325        for (name, tgt_col) in &target_cols {
326            if !source_cols.contains_key(name) {
327                self.changes
328                    .push(ChangeAction::Insert(AstNode::ColumnDef((*tgt_col).clone())));
329            }
330        }
331
332        // Constraints
333        self.diff_constraints(&source.constraints, &target.constraints);
334    }
335
336    // ── DROP TABLE ─────────────────────────────────────────────────────
337
338    fn diff_drop_table(&mut self, source: &DropTableStatement, target: &DropTableStatement) {
339        if source != target {
340            self.changes.push(ChangeAction::Update(
341                AstNode::Statement(Box::new(Statement::DropTable(source.clone()))),
342                AstNode::Statement(Box::new(Statement::DropTable(target.clone()))),
343            ));
344        } else {
345            self.changes.push(ChangeAction::Keep(
346                AstNode::Statement(Box::new(Statement::DropTable(source.clone()))),
347                AstNode::Statement(Box::new(Statement::DropTable(target.clone()))),
348            ));
349        }
350    }
351
352    // ── SET OPERATION ──────────────────────────────────────────────────
353
354    fn diff_set_operation(
355        &mut self,
356        source: &SetOperationStatement,
357        target: &SetOperationStatement,
358    ) {
359        if source.op != target.op || source.all != target.all {
360            self.changes.push(ChangeAction::Update(
361                AstNode::Statement(Box::new(Statement::SetOperation(source.clone()))),
362                AstNode::Statement(Box::new(Statement::SetOperation(target.clone()))),
363            ));
364            return;
365        }
366        self.diff_statements(&source.left, &target.left);
367        self.diff_statements(&source.right, &target.right);
368        self.diff_order_by(&source.order_by, &target.order_by);
369        self.diff_optional_exprs(&source.limit, &target.limit);
370        self.diff_optional_exprs(&source.offset, &target.offset);
371    }
372
373    // ── ALTER TABLE ────────────────────────────────────────────────────
374
375    fn diff_alter_table(&mut self, source: &AlterTableStatement, target: &AlterTableStatement) {
376        if source.table != target.table {
377            self.changes.push(ChangeAction::Update(
378                AstNode::Expr(table_ref_to_expr(&source.table)),
379                AstNode::Expr(table_ref_to_expr(&target.table)),
380            ));
381        }
382        // Actions compared for equality
383        if source.actions != target.actions {
384            self.changes.push(ChangeAction::Update(
385                AstNode::Statement(Box::new(Statement::AlterTable(source.clone()))),
386                AstNode::Statement(Box::new(Statement::AlterTable(target.clone()))),
387            ));
388        }
389    }
390
391    // ── CREATE VIEW ────────────────────────────────────────────────────
392
393    fn diff_create_view(&mut self, source: &CreateViewStatement, target: &CreateViewStatement) {
394        if source.name != target.name {
395            self.changes.push(ChangeAction::Update(
396                AstNode::Expr(table_ref_to_expr(&source.name)),
397                AstNode::Expr(table_ref_to_expr(&target.name)),
398            ));
399        }
400        self.diff_statements(&source.query, &target.query);
401    }
402
403    // ── Shared helpers ─────────────────────────────────────────────────
404
405    fn diff_exprs(&mut self, source: &Expr, target: &Expr) {
406        if source == target {
407            self.changes.push(ChangeAction::Keep(
408                AstNode::Expr(source.clone()),
409                AstNode::Expr(target.clone()),
410            ));
411            return;
412        }
413
414        // Same top-level variant → recurse into children
415        match (source, target) {
416            (
417                Expr::BinaryOp {
418                    left: sl,
419                    op: sop,
420                    right: sr,
421                },
422                Expr::BinaryOp {
423                    left: tl,
424                    op: top,
425                    right: tr,
426                },
427            ) => {
428                if sop != top {
429                    self.changes.push(ChangeAction::Update(
430                        AstNode::Expr(source.clone()),
431                        AstNode::Expr(target.clone()),
432                    ));
433                } else {
434                    self.diff_exprs(sl, tl);
435                    self.diff_exprs(sr, tr);
436                }
437            }
438            (Expr::UnaryOp { op: sop, expr: se }, Expr::UnaryOp { op: top, expr: te }) => {
439                if sop != top {
440                    self.changes.push(ChangeAction::Update(
441                        AstNode::Expr(source.clone()),
442                        AstNode::Expr(target.clone()),
443                    ));
444                } else {
445                    self.diff_exprs(se, te);
446                }
447            }
448            (
449                Expr::Function {
450                    name: sn,
451                    args: sa,
452                    distinct: sd,
453                    ..
454                },
455                Expr::Function {
456                    name: tn,
457                    args: ta,
458                    distinct: td,
459                    ..
460                },
461            ) => {
462                if sn != tn || sd != td {
463                    self.changes.push(ChangeAction::Update(
464                        AstNode::Expr(source.clone()),
465                        AstNode::Expr(target.clone()),
466                    ));
467                } else {
468                    self.diff_expr_lists(sa, ta);
469                }
470            }
471            (
472                Expr::Cast {
473                    expr: se,
474                    data_type: sd,
475                },
476                Expr::Cast {
477                    expr: te,
478                    data_type: td,
479                },
480            ) => {
481                if sd != td {
482                    self.changes.push(ChangeAction::Update(
483                        AstNode::Expr(source.clone()),
484                        AstNode::Expr(target.clone()),
485                    ));
486                } else {
487                    self.diff_exprs(se, te);
488                }
489            }
490            (
491                Expr::Case {
492                    operand: so,
493                    when_clauses: sw,
494                    else_clause: se,
495                },
496                Expr::Case {
497                    operand: to,
498                    when_clauses: tw,
499                    else_clause: te,
500                },
501            ) => {
502                self.diff_optional_boxed_exprs(so, to);
503                // when clauses — ordered
504                for (i, ((sc, sr), (tc, tr))) in sw.iter().zip(tw.iter()).enumerate() {
505                    self.diff_exprs(sc, tc);
506                    self.diff_exprs(sr, tr);
507                    let _ = i;
508                }
509                for (sc, sr) in sw.iter().skip(tw.len()) {
510                    self.changes
511                        .push(ChangeAction::Remove(AstNode::Expr(sc.clone())));
512                    self.changes
513                        .push(ChangeAction::Remove(AstNode::Expr(sr.clone())));
514                }
515                for (tc, tr) in tw.iter().skip(sw.len()) {
516                    self.changes
517                        .push(ChangeAction::Insert(AstNode::Expr(tc.clone())));
518                    self.changes
519                        .push(ChangeAction::Insert(AstNode::Expr(tr.clone())));
520                }
521                self.diff_optional_boxed_exprs(se, te);
522            }
523            (Expr::Nested(se), Expr::Nested(te)) => self.diff_exprs(se, te),
524            (
525                Expr::Between {
526                    expr: se,
527                    low: sl,
528                    high: sh,
529                    negated: sn,
530                },
531                Expr::Between {
532                    expr: te,
533                    low: tl,
534                    high: th,
535                    negated: tn,
536                },
537            ) => {
538                if sn != tn {
539                    self.changes.push(ChangeAction::Update(
540                        AstNode::Expr(source.clone()),
541                        AstNode::Expr(target.clone()),
542                    ));
543                } else {
544                    self.diff_exprs(se, te);
545                    self.diff_exprs(sl, tl);
546                    self.diff_exprs(sh, th);
547                }
548            }
549            (
550                Expr::InList {
551                    expr: se,
552                    list: sl,
553                    negated: sn,
554                },
555                Expr::InList {
556                    expr: te,
557                    list: tl,
558                    negated: tn,
559                },
560            ) => {
561                if sn != tn {
562                    self.changes.push(ChangeAction::Update(
563                        AstNode::Expr(source.clone()),
564                        AstNode::Expr(target.clone()),
565                    ));
566                } else {
567                    self.diff_exprs(se, te);
568                    self.diff_expr_lists(sl, tl);
569                }
570            }
571            (
572                Expr::InSubquery {
573                    expr: se,
574                    subquery: ss,
575                    negated: sn,
576                },
577                Expr::InSubquery {
578                    expr: te,
579                    subquery: ts,
580                    negated: tn,
581                },
582            ) => {
583                if sn != tn {
584                    self.changes.push(ChangeAction::Update(
585                        AstNode::Expr(source.clone()),
586                        AstNode::Expr(target.clone()),
587                    ));
588                } else {
589                    self.diff_exprs(se, te);
590                    self.diff_statements(ss, ts);
591                }
592            }
593            (
594                Expr::IsNull {
595                    expr: se,
596                    negated: sn,
597                },
598                Expr::IsNull {
599                    expr: te,
600                    negated: tn,
601                },
602            ) => {
603                if sn != tn {
604                    self.changes.push(ChangeAction::Update(
605                        AstNode::Expr(source.clone()),
606                        AstNode::Expr(target.clone()),
607                    ));
608                } else {
609                    self.diff_exprs(se, te);
610                }
611            }
612            (
613                Expr::Like {
614                    expr: se,
615                    pattern: sp,
616                    negated: sn,
617                    ..
618                },
619                Expr::Like {
620                    expr: te,
621                    pattern: tp,
622                    negated: tn,
623                    ..
624                },
625            )
626            | (
627                Expr::ILike {
628                    expr: se,
629                    pattern: sp,
630                    negated: sn,
631                    ..
632                },
633                Expr::ILike {
634                    expr: te,
635                    pattern: tp,
636                    negated: tn,
637                    ..
638                },
639            ) => {
640                if sn != tn {
641                    self.changes.push(ChangeAction::Update(
642                        AstNode::Expr(source.clone()),
643                        AstNode::Expr(target.clone()),
644                    ));
645                } else {
646                    self.diff_exprs(se, te);
647                    self.diff_exprs(sp, tp);
648                }
649            }
650            (Expr::Subquery(ss), Expr::Subquery(ts)) => self.diff_statements(ss, ts),
651            (
652                Expr::Exists {
653                    subquery: ss,
654                    negated: sn,
655                },
656                Expr::Exists {
657                    subquery: ts,
658                    negated: tn,
659                },
660            ) => {
661                if sn != tn {
662                    self.changes.push(ChangeAction::Update(
663                        AstNode::Expr(source.clone()),
664                        AstNode::Expr(target.clone()),
665                    ));
666                } else {
667                    self.diff_statements(ss, ts);
668                }
669            }
670            (Expr::Alias { expr: se, name: sn }, Expr::Alias { expr: te, name: tn }) => {
671                if sn != tn {
672                    self.changes.push(ChangeAction::Update(
673                        AstNode::Expr(source.clone()),
674                        AstNode::Expr(target.clone()),
675                    ));
676                } else {
677                    self.diff_exprs(se, te);
678                }
679            }
680            (Expr::Coalesce(sa), Expr::Coalesce(ta)) => self.diff_expr_lists(sa, ta),
681            (Expr::ArrayLiteral(sa), Expr::ArrayLiteral(ta)) => self.diff_expr_lists(sa, ta),
682            (Expr::Tuple(sa), Expr::Tuple(ta)) => self.diff_expr_lists(sa, ta),
683            (Expr::TypedFunction { func: sf, .. }, Expr::TypedFunction { func: tf, .. }) => {
684                if std::mem::discriminant(sf) == std::mem::discriminant(tf) && source == target {
685                    self.changes.push(ChangeAction::Keep(
686                        AstNode::Expr(source.clone()),
687                        AstNode::Expr(target.clone()),
688                    ));
689                } else {
690                    self.changes.push(ChangeAction::Update(
691                        AstNode::Expr(source.clone()),
692                        AstNode::Expr(target.clone()),
693                    ));
694                }
695            }
696            // Different variant types → leaf-level update
697            _ => {
698                self.changes.push(ChangeAction::Update(
699                    AstNode::Expr(source.clone()),
700                    AstNode::Expr(target.clone()),
701                ));
702            }
703        }
704    }
705
706    /// Diff two ordered expression lists (e.g., SELECT columns, function args).
707    fn diff_expr_lists(&mut self, source: &[Expr], target: &[Expr]) {
708        // Use longest common subsequence for ordered diff
709        let lcs = compute_lcs(source, target);
710        let mut si = 0;
711        let mut ti = 0;
712        let mut li = 0;
713
714        while si < source.len() || ti < target.len() {
715            if li < lcs.len() {
716                let (lcs_si, lcs_ti) = lcs[li];
717
718                // Remove items before the next LCS match in source
719                while si < lcs_si {
720                    self.changes
721                        .push(ChangeAction::Remove(AstNode::Expr(source[si].clone())));
722                    si += 1;
723                }
724                // Insert items before the next LCS match in target
725                while ti < lcs_ti {
726                    self.changes
727                        .push(ChangeAction::Insert(AstNode::Expr(target[ti].clone())));
728                    ti += 1;
729                }
730                // Matched pair — recurse to find deeper changes
731                self.diff_exprs(&source[si], &target[ti]);
732                si += 1;
733                ti += 1;
734                li += 1;
735            } else {
736                // Remaining source items are removed
737                while si < source.len() {
738                    self.changes
739                        .push(ChangeAction::Remove(AstNode::Expr(source[si].clone())));
740                    si += 1;
741                }
742                // Remaining target items are inserted
743                while ti < target.len() {
744                    self.changes
745                        .push(ChangeAction::Insert(AstNode::Expr(target[ti].clone())));
746                    ti += 1;
747                }
748            }
749        }
750    }
751
752    fn diff_select_items(&mut self, source: &[SelectItem], target: &[SelectItem]) {
753        let min_len = source.len().min(target.len());
754        for i in 0..min_len {
755            if source[i] == target[i] {
756                self.changes.push(ChangeAction::Keep(
757                    AstNode::SelectItem(source[i].clone()),
758                    AstNode::SelectItem(target[i].clone()),
759                ));
760            } else {
761                match (&source[i], &target[i]) {
762                    (
763                        SelectItem::Expr {
764                            expr: se,
765                            alias: sa,
766                        },
767                        SelectItem::Expr {
768                            expr: te,
769                            alias: ta,
770                        },
771                    ) => {
772                        if sa != ta {
773                            self.changes.push(ChangeAction::Update(
774                                AstNode::SelectItem(source[i].clone()),
775                                AstNode::SelectItem(target[i].clone()),
776                            ));
777                        } else {
778                            self.diff_exprs(se, te);
779                        }
780                    }
781                    _ => {
782                        self.changes.push(ChangeAction::Update(
783                            AstNode::SelectItem(source[i].clone()),
784                            AstNode::SelectItem(target[i].clone()),
785                        ));
786                    }
787                }
788            }
789        }
790        for item in source.iter().skip(min_len) {
791            self.changes
792                .push(ChangeAction::Remove(AstNode::SelectItem(item.clone())));
793        }
794        for item in target.iter().skip(min_len) {
795            self.changes
796                .push(ChangeAction::Insert(AstNode::SelectItem(item.clone())));
797        }
798    }
799
800    fn diff_optional_exprs(&mut self, source: &Option<Expr>, target: &Option<Expr>) {
801        match (source, target) {
802            (Some(s), Some(t)) => self.diff_exprs(s, t),
803            (None, Some(t)) => self
804                .changes
805                .push(ChangeAction::Insert(AstNode::Expr(t.clone()))),
806            (Some(s), None) => self
807                .changes
808                .push(ChangeAction::Remove(AstNode::Expr(s.clone()))),
809            (None, None) => {}
810        }
811    }
812
813    fn diff_optional_boxed_exprs(
814        &mut self,
815        source: &Option<Box<Expr>>,
816        target: &Option<Box<Expr>>,
817    ) {
818        match (source, target) {
819            (Some(s), Some(t)) => self.diff_exprs(s, t),
820            (None, Some(t)) => self
821                .changes
822                .push(ChangeAction::Insert(AstNode::Expr((**t).clone()))),
823            (Some(s), None) => self
824                .changes
825                .push(ChangeAction::Remove(AstNode::Expr((**s).clone()))),
826            (None, None) => {}
827        }
828    }
829
830    fn diff_ctes(&mut self, source: &[Cte], target: &[Cte]) {
831        // Match CTEs by name
832        let source_map: HashMap<&str, &Cte> = source.iter().map(|c| (c.name.as_str(), c)).collect();
833        let target_map: HashMap<&str, &Cte> = target.iter().map(|c| (c.name.as_str(), c)).collect();
834
835        for (name, sc) in &source_map {
836            if let Some(tc) = target_map.get(name) {
837                if sc == tc {
838                    self.changes.push(ChangeAction::Keep(
839                        AstNode::Cte(Box::new((*sc).clone())),
840                        AstNode::Cte(Box::new((*tc).clone())),
841                    ));
842                } else {
843                    self.diff_statements(&sc.query, &tc.query);
844                }
845            } else {
846                self.changes
847                    .push(ChangeAction::Remove(AstNode::Cte(Box::new((*sc).clone()))));
848            }
849        }
850        for (name, tc) in &target_map {
851            if !source_map.contains_key(name) {
852                self.changes
853                    .push(ChangeAction::Insert(AstNode::Cte(Box::new((*tc).clone()))));
854            }
855        }
856    }
857
858    fn diff_joins(&mut self, source: &[JoinClause], target: &[JoinClause]) {
859        let min_len = source.len().min(target.len());
860        for i in 0..min_len {
861            if source[i] == target[i] {
862                self.changes.push(ChangeAction::Keep(
863                    AstNode::JoinClause(source[i].clone()),
864                    AstNode::JoinClause(target[i].clone()),
865                ));
866            } else if source[i].join_type == target[i].join_type {
867                // Same join type, diff the contents
868                self.diff_table_sources(&source[i].table, &target[i].table);
869                self.diff_optional_exprs(&source[i].on, &target[i].on);
870            } else {
871                self.changes.push(ChangeAction::Update(
872                    AstNode::JoinClause(source[i].clone()),
873                    AstNode::JoinClause(target[i].clone()),
874                ));
875            }
876        }
877        for item in source.iter().skip(min_len) {
878            self.changes
879                .push(ChangeAction::Remove(AstNode::JoinClause(item.clone())));
880        }
881        for item in target.iter().skip(min_len) {
882            self.changes
883                .push(ChangeAction::Insert(AstNode::JoinClause(item.clone())));
884        }
885    }
886
887    fn diff_order_by(&mut self, source: &[OrderByItem], target: &[OrderByItem]) {
888        let min_len = source.len().min(target.len());
889        for i in 0..min_len {
890            if source[i] == target[i] {
891                self.changes.push(ChangeAction::Keep(
892                    AstNode::OrderByItem(source[i].clone()),
893                    AstNode::OrderByItem(target[i].clone()),
894                ));
895            } else if source[i].ascending == target[i].ascending
896                && source[i].nulls_first == target[i].nulls_first
897            {
898                self.diff_exprs(&source[i].expr, &target[i].expr);
899            } else {
900                self.changes.push(ChangeAction::Update(
901                    AstNode::OrderByItem(source[i].clone()),
902                    AstNode::OrderByItem(target[i].clone()),
903                ));
904            }
905        }
906        for item in source.iter().skip(min_len) {
907            self.changes
908                .push(ChangeAction::Remove(AstNode::OrderByItem(item.clone())));
909        }
910        for item in target.iter().skip(min_len) {
911            self.changes
912                .push(ChangeAction::Insert(AstNode::OrderByItem(item.clone())));
913        }
914    }
915
916    fn diff_table_sources(&mut self, source: &TableSource, target: &TableSource) {
917        if source == target {
918            return;
919        }
920        match (source, target) {
921            (TableSource::Table(st), TableSource::Table(tt)) => {
922                if st != tt {
923                    self.changes.push(ChangeAction::Update(
924                        AstNode::Expr(table_ref_to_expr(st)),
925                        AstNode::Expr(table_ref_to_expr(tt)),
926                    ));
927                }
928            }
929            (TableSource::Subquery { query: sq, .. }, TableSource::Subquery { query: tq, .. }) => {
930                self.diff_statements(sq, tq);
931            }
932            _ => {
933                // Different source types
934                self.remove_table_source(source);
935                self.insert_table_source(target);
936            }
937        }
938    }
939
940    fn insert_table_source(&mut self, source: &TableSource) {
941        match source {
942            TableSource::Table(t) => {
943                self.changes
944                    .push(ChangeAction::Insert(AstNode::Expr(table_ref_to_expr(t))));
945            }
946            TableSource::Subquery { query, .. } => {
947                self.changes
948                    .push(ChangeAction::Insert(AstNode::Statement(Box::new(
949                        (**query).clone(),
950                    ))));
951            }
952            other => {
953                self.changes
954                    .push(ChangeAction::Insert(AstNode::Expr(Expr::StringLiteral(
955                        format!("{other:?}"),
956                    ))));
957            }
958        }
959    }
960
961    fn remove_table_source(&mut self, source: &TableSource) {
962        match source {
963            TableSource::Table(t) => {
964                self.changes
965                    .push(ChangeAction::Remove(AstNode::Expr(table_ref_to_expr(t))));
966            }
967            TableSource::Subquery { query, .. } => {
968                self.changes
969                    .push(ChangeAction::Remove(AstNode::Statement(Box::new(
970                        (**query).clone(),
971                    ))));
972            }
973            other => {
974                self.changes
975                    .push(ChangeAction::Remove(AstNode::Expr(Expr::StringLiteral(
976                        format!("{other:?}"),
977                    ))));
978            }
979        }
980    }
981
982    fn diff_constraints(&mut self, source: &[TableConstraint], target: &[TableConstraint]) {
983        // Simple positional diff for constraints
984        let min_len = source.len().min(target.len());
985        for i in 0..min_len {
986            if source[i] == target[i] {
987                self.changes.push(ChangeAction::Keep(
988                    AstNode::TableConstraint(source[i].clone()),
989                    AstNode::TableConstraint(target[i].clone()),
990                ));
991            } else {
992                self.changes.push(ChangeAction::Update(
993                    AstNode::TableConstraint(source[i].clone()),
994                    AstNode::TableConstraint(target[i].clone()),
995                ));
996            }
997        }
998        for item in source.iter().skip(min_len) {
999            self.changes
1000                .push(ChangeAction::Remove(AstNode::TableConstraint(item.clone())));
1001        }
1002        for item in target.iter().skip(min_len) {
1003            self.changes
1004                .push(ChangeAction::Insert(AstNode::TableConstraint(item.clone())));
1005        }
1006    }
1007
1008    fn diff_string_lists(&mut self, source: &[String], target: &[String]) {
1009        for s in source {
1010            if !target.contains(s) {
1011                self.changes
1012                    .push(ChangeAction::Remove(AstNode::Expr(Expr::Column {
1013                        table: None,
1014                        name: s.clone(),
1015                        quote_style: QuoteStyle::None,
1016                        table_quote_style: QuoteStyle::None,
1017                    })));
1018            }
1019        }
1020        for t in target {
1021            if !source.contains(t) {
1022                self.changes
1023                    .push(ChangeAction::Insert(AstNode::Expr(Expr::Column {
1024                        table: None,
1025                        name: t.clone(),
1026                        quote_style: QuoteStyle::None,
1027                        table_quote_style: QuoteStyle::None,
1028                    })));
1029            }
1030        }
1031    }
1032}
1033
1034// ═══════════════════════════════════════════════════════════════════════
1035// LCS — Longest Common Subsequence for ordered diff
1036// ═══════════════════════════════════════════════════════════════════════
1037
1038/// Compute the longest common subsequence of two expression slices,
1039/// returning pairs of (source_index, target_index).
1040fn compute_lcs(source: &[Expr], target: &[Expr]) -> Vec<(usize, usize)> {
1041    let m = source.len();
1042    let n = target.len();
1043    if m == 0 || n == 0 {
1044        return Vec::new();
1045    }
1046
1047    // Build DP table
1048    let mut dp = vec![vec![0u32; n + 1]; m + 1];
1049    for i in 1..=m {
1050        for j in 1..=n {
1051            if source[i - 1] == target[j - 1] {
1052                dp[i][j] = dp[i - 1][j - 1] + 1;
1053            } else {
1054                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
1055            }
1056        }
1057    }
1058
1059    // Backtrack to find the actual subsequence indices
1060    let mut result = Vec::new();
1061    let mut i = m;
1062    let mut j = n;
1063    while i > 0 && j > 0 {
1064        if source[i - 1] == target[j - 1] {
1065            result.push((i - 1, j - 1));
1066            i -= 1;
1067            j -= 1;
1068        } else if dp[i - 1][j] >= dp[i][j - 1] {
1069            i -= 1;
1070        } else {
1071            j -= 1;
1072        }
1073    }
1074    result.reverse();
1075    result
1076}
1077
1078/// Convert a `TableRef` to an `Expr::Column` for uniform representation.
1079fn table_ref_to_expr(table: &TableRef) -> Expr {
1080    let full_name = match (&table.catalog, &table.schema) {
1081        (Some(c), Some(s)) => format!("{c}.{s}.{}", table.name),
1082        (None, Some(s)) => format!("{s}.{}", table.name),
1083        _ => table.name.clone(),
1084    };
1085    Expr::Column {
1086        table: table.schema.clone(),
1087        name: full_name,
1088        quote_style: table.name_quote_style,
1089        table_quote_style: QuoteStyle::None,
1090    }
1091}
1092
1093// ═══════════════════════════════════════════════════════════════════════
1094// Convenience: diff from SQL strings
1095// ═══════════════════════════════════════════════════════════════════════
1096
1097/// Parse two SQL strings and compute their diff.
1098///
1099/// # Errors
1100///
1101/// Returns a [`SqlglotError`](crate::errors::SqlglotError) if either
1102/// string fails to parse.
1103pub fn diff_sql(
1104    source_sql: &str,
1105    target_sql: &str,
1106    dialect: crate::dialects::Dialect,
1107) -> crate::errors::Result<Vec<ChangeAction>> {
1108    let source = crate::parser::parse(source_sql, dialect)?;
1109    let target = crate::parser::parse(target_sql, dialect)?;
1110    Ok(diff(&source, &target))
1111}
1112
1113#[cfg(test)]
1114mod tests {
1115    use super::*;
1116    use crate::dialects::Dialect;
1117    use crate::parser::parse;
1118
1119    fn count_by_action(changes: &[ChangeAction]) -> (usize, usize, usize, usize, usize) {
1120        let mut keeps = 0;
1121        let mut inserts = 0;
1122        let mut removes = 0;
1123        let mut updates = 0;
1124        let mut moves = 0;
1125        for c in changes {
1126            match c {
1127                ChangeAction::Keep(..) => keeps += 1,
1128                ChangeAction::Insert(..) => inserts += 1,
1129                ChangeAction::Remove(..) => removes += 1,
1130                ChangeAction::Update(..) => updates += 1,
1131                ChangeAction::Move(..) => moves += 1,
1132            }
1133        }
1134        (keeps, inserts, removes, updates, moves)
1135    }
1136
1137    #[test]
1138    fn test_identical_queries_are_all_keep() {
1139        let sql = "SELECT a, b FROM t WHERE a > 1";
1140        let source = parse(sql, Dialect::Ansi).unwrap();
1141        let target = parse(sql, Dialect::Ansi).unwrap();
1142        let changes = diff(&source, &target);
1143        let (keeps, inserts, removes, updates, _moves) = count_by_action(&changes);
1144        assert!(keeps > 0, "should have keep actions");
1145        assert_eq!(inserts, 0, "no inserts for identical queries");
1146        assert_eq!(removes, 0, "no removes for identical queries");
1147        assert_eq!(updates, 0, "no updates for identical queries");
1148    }
1149
1150    #[test]
1151    fn test_column_added() {
1152        let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1153        let target = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1154        let changes = diff(&source, &target);
1155        let (keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1156        assert!(keeps > 0);
1157        assert!(inserts > 0, "should have insert for new column b");
1158        assert_eq!(removes, 0);
1159    }
1160
1161    #[test]
1162    fn test_column_removed() {
1163        let source = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1164        let target = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1165        let changes = diff(&source, &target);
1166        let (keeps, _inserts, removes, _updates, _moves) = count_by_action(&changes);
1167        assert!(keeps > 0);
1168        assert!(removes > 0, "should have remove for column b");
1169    }
1170
1171    #[test]
1172    fn test_column_changed() {
1173        let source = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
1174        let target = parse("SELECT a, c FROM t", Dialect::Ansi).unwrap();
1175        let changes = diff(&source, &target);
1176        let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1177        assert!(updates > 0, "should have update for b -> c");
1178    }
1179
1180    #[test]
1181    fn test_where_clause_added() {
1182        let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1183        let target = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1184        let changes = diff(&source, &target);
1185        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1186        assert!(inserts > 0, "should have insert for WHERE clause");
1187    }
1188
1189    #[test]
1190    fn test_where_clause_removed() {
1191        let source = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1192        let target = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1193        let changes = diff(&source, &target);
1194        let (_keeps, _inserts, removes, _updates, _moves) = count_by_action(&changes);
1195        assert!(removes > 0, "should have remove for WHERE clause");
1196    }
1197
1198    #[test]
1199    fn test_where_clause_updated() {
1200        let source = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
1201        let target = parse("SELECT a FROM t WHERE a > 2", Dialect::Ansi).unwrap();
1202        let changes = diff(&source, &target);
1203        let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1204        assert!(updates > 0, "should have update for WHERE value change");
1205    }
1206
1207    #[test]
1208    fn test_table_changed() {
1209        let source = parse("SELECT a FROM t1", Dialect::Ansi).unwrap();
1210        let target = parse("SELECT a FROM t2", Dialect::Ansi).unwrap();
1211        let changes = diff(&source, &target);
1212        let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1213        assert!(updates > 0, "should have update for table change");
1214    }
1215
1216    #[test]
1217    fn test_join_added() {
1218        let source = parse("SELECT a FROM t1", Dialect::Ansi).unwrap();
1219        let target = parse("SELECT a FROM t1 JOIN t2 ON t1.id = t2.id", Dialect::Ansi).unwrap();
1220        let changes = diff(&source, &target);
1221        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1222        assert!(inserts > 0, "should have insert for JOIN");
1223    }
1224
1225    #[test]
1226    fn test_order_by_changed() {
1227        let source = parse("SELECT a FROM t ORDER BY a ASC", Dialect::Ansi).unwrap();
1228        let target = parse("SELECT a FROM t ORDER BY a DESC", Dialect::Ansi).unwrap();
1229        let changes = diff(&source, &target);
1230        let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1231        assert!(updates > 0, "should have update for ORDER BY direction");
1232    }
1233
1234    #[test]
1235    fn test_complex_nested_query() {
1236        let source = parse(
1237            "SELECT a, b FROM t1 WHERE a IN (SELECT x FROM t2 WHERE x > 0)",
1238            Dialect::Ansi,
1239        )
1240        .unwrap();
1241        let target = parse(
1242            "SELECT a, c FROM t1 WHERE a IN (SELECT x FROM t2 WHERE x > 5)",
1243            Dialect::Ansi,
1244        )
1245        .unwrap();
1246        let changes = diff(&source, &target);
1247        let (keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1248        assert!(keeps > 0, "unchanged parts should be kept");
1249        assert!(updates > 0, "changed parts should be updated (b->c, 0->5)");
1250    }
1251
1252    #[test]
1253    fn test_different_statement_types() {
1254        let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1255        let target = parse("CREATE TABLE t (a INT)", Dialect::Ansi).unwrap();
1256        let changes = diff(&source, &target);
1257        let (_keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1258        assert!(removes > 0, "source should be removed");
1259        assert!(inserts > 0, "target should be inserted");
1260    }
1261
1262    #[test]
1263    fn test_cte_added() {
1264        let source = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
1265        let target = parse("WITH cte AS (SELECT 1 AS x) SELECT a FROM t", Dialect::Ansi).unwrap();
1266        let changes = diff(&source, &target);
1267        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1268        assert!(inserts > 0, "should have insert for CTE");
1269    }
1270
1271    #[test]
1272    fn test_limit_changed() {
1273        let source = parse("SELECT a FROM t LIMIT 10", Dialect::Ansi).unwrap();
1274        let target = parse("SELECT a FROM t LIMIT 20", Dialect::Ansi).unwrap();
1275        let changes = diff(&source, &target);
1276        let (_keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1277        assert!(updates > 0, "should have update for LIMIT change");
1278    }
1279
1280    #[test]
1281    fn test_group_by_added() {
1282        let source = parse("SELECT a, COUNT(*) FROM t", Dialect::Ansi).unwrap();
1283        let target = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
1284        let changes = diff(&source, &target);
1285        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1286        assert!(inserts > 0, "should have insert for GROUP BY");
1287    }
1288
1289    #[test]
1290    fn test_diff_sql_convenience() {
1291        let changes = diff_sql("SELECT a FROM t", "SELECT a, b FROM t", Dialect::Ansi).unwrap();
1292        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1293        assert!(inserts > 0);
1294    }
1295
1296    #[test]
1297    fn test_having_added() {
1298        let source = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
1299        let target = parse(
1300            "SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1",
1301            Dialect::Ansi,
1302        )
1303        .unwrap();
1304        let changes = diff(&source, &target);
1305        let (_keeps, inserts, _removes, _updates, _moves) = count_by_action(&changes);
1306        assert!(inserts > 0, "should have insert for HAVING");
1307    }
1308
1309    #[test]
1310    fn test_create_table_column_diff() {
1311        let source = parse("CREATE TABLE t (a INT, b TEXT)", Dialect::Ansi).unwrap();
1312        let target = parse("CREATE TABLE t (a INT, c TEXT)", Dialect::Ansi).unwrap();
1313        let changes = diff(&source, &target);
1314        let (_keeps, inserts, removes, _updates, _moves) = count_by_action(&changes);
1315        assert!(removes > 0, "should remove column b");
1316        assert!(inserts > 0, "should insert column c");
1317    }
1318
1319    #[test]
1320    fn test_union_diff() {
1321        let source = parse("SELECT a FROM t1 UNION SELECT b FROM t2", Dialect::Ansi).unwrap();
1322        let target = parse("SELECT a FROM t1 UNION SELECT c FROM t2", Dialect::Ansi).unwrap();
1323        let changes = diff(&source, &target);
1324        let (keeps, _inserts, _removes, updates, _moves) = count_by_action(&changes);
1325        assert!(keeps > 0);
1326        assert!(updates > 0, "should have update for b -> c");
1327    }
1328}