Skip to main content

sql_schema/
diff.rs

1use std::{cmp::Ordering, collections::HashSet, fmt};
2
3use bon::bon;
4use sqlparser::ast::{
5    helpers::attached_token::AttachedToken, AlterTable, AlterTableOperation, AlterType,
6    AlterTypeAddValue, AlterTypeAddValuePosition, AlterTypeOperation, CreateDomain,
7    CreateExtension, CreateIndex, CreateTable, DropDomain, DropExtension, Ident, ObjectName,
8    ObjectType, Statement, UserDefinedTypeRepresentation,
9};
10use thiserror::Error;
11
12#[derive(Error, Debug)]
13pub struct DiffError {
14    kind: DiffErrorKind,
15    statement_a: Option<Box<Statement>>,
16    statement_b: Option<Box<Statement>>,
17}
18
19impl fmt::Display for DiffError {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        write!(
22            f,
23            "Oops, we couldn't diff that: {reason}",
24            reason = self.kind
25        )?;
26        if let Some(statement_a) = &self.statement_a {
27            write!(f, "\n\nStatement A:\n{statement_a}")?;
28        }
29        if let Some(statement_b) = &self.statement_b {
30            write!(f, "\n\nStatement B:\n{statement_b}")?;
31        }
32        Ok(())
33    }
34}
35
36#[bon]
37impl DiffError {
38    #[builder]
39    fn new(
40        kind: DiffErrorKind,
41        statement_a: Option<Statement>,
42        statement_b: Option<Statement>,
43    ) -> Self {
44        Self {
45            kind,
46            statement_a: statement_a.map(Box::new),
47            statement_b: statement_b.map(Box::new),
48        }
49    }
50}
51
52#[derive(Error, Debug)]
53#[non_exhaustive]
54enum DiffErrorKind {
55    #[error("can't drop unnamed index")]
56    DropUnnamedIndex,
57    #[error("can't compare unnamed index")]
58    CompareUnnamedIndex,
59    #[error("removing enum labels is not supported")]
60    RemoveEnumLabel,
61    #[error("not yet supported")]
62    NotImplemented,
63}
64
65pub(crate) trait Diff: Sized {
66    type Diff;
67
68    fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError>;
69}
70
71impl Diff for Vec<Statement> {
72    type Diff = Option<Vec<Statement>>;
73
74    fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
75        let res = self
76            .iter()
77            .filter_map(|sa| {
78                match sa {
79                    // CreateTable: compare against another CreateTable with the same name
80                    // TODO: handle renames (e.g. use comments to tag a previous name for a table in a schema)
81                    Statement::CreateTable(a) => find_and_compare_create_table(sa, a, other),
82                    Statement::CreateIndex(a) => find_and_compare_create_index(sa, a, other),
83                    Statement::CreateType { name, .. } => {
84                        find_and_compare_create_type(sa, name, other)
85                    }
86                    Statement::CreateExtension(CreateExtension {
87                        name,
88                        if_not_exists,
89                        cascade,
90                        ..
91                    }) => {
92                        find_and_compare_create_extension(sa, name, *if_not_exists, *cascade, other)
93                    }
94                    Statement::CreateDomain(a) => find_and_compare_create_domain(sa, a, other),
95                    _ => Err(DiffError::builder()
96                        .kind(DiffErrorKind::NotImplemented)
97                        .statement_a(sa.clone())
98                        .build()),
99                }
100                .transpose()
101            })
102            // find resources that are in `other` but not in `self`
103            .chain(other.iter().filter_map(|sb| {
104                match sb {
105                    Statement::CreateTable(b) => Ok(self.iter().find(|sa| match sa {
106                        Statement::CreateTable(a) => a.name == b.name,
107                        _ => false,
108                    })),
109                    Statement::CreateIndex(b) => Ok(self.iter().find(|sa| match sa {
110                        Statement::CreateIndex(a) => a.name == b.name,
111                        _ => false,
112                    })),
113                    Statement::CreateType { name: b_name, .. } => {
114                        Ok(self.iter().find(|sa| match sa {
115                            Statement::CreateType { name: a_name, .. } => a_name == b_name,
116                            _ => false,
117                        }))
118                    }
119                    Statement::CreateExtension(CreateExtension { name: b_name, .. }) => {
120                        Ok(self.iter().find(|sa| match sa {
121                            Statement::CreateExtension(CreateExtension {
122                                name: a_name, ..
123                            }) => a_name == b_name,
124                            _ => false,
125                        }))
126                    }
127                    Statement::CreateDomain(b) => Ok(self.iter().find(|sa| match sa {
128                        Statement::CreateDomain(a) => a.name == b.name,
129                        _ => false,
130                    })),
131                    _ => Err(DiffError::builder()
132                        .kind(DiffErrorKind::NotImplemented)
133                        .statement_a(sb.clone())
134                        .build()),
135                }
136                .transpose()
137                // return the statement if it's not in `self`
138                .map_or_else(|| Some(Ok(vec![sb.clone()])), |_| None)
139            }))
140            .collect::<Result<Vec<_>, _>>()?
141            .into_iter()
142            .flatten()
143            .collect::<Vec<_>>();
144
145        if res.is_empty() {
146            Ok(None)
147        } else {
148            Ok(Some(res))
149        }
150    }
151}
152
153fn find_and_compare<MF, DF>(
154    sa: &Statement,
155    other: &[Statement],
156    match_fn: MF,
157    drop_fn: DF,
158) -> Result<Option<Vec<Statement>>, DiffError>
159where
160    MF: Fn(&&Statement) -> bool,
161    DF: Fn() -> Result<Option<Vec<Statement>>, DiffError>,
162{
163    other.iter().find(match_fn).map_or_else(
164        // drop the statement if it wasn't found in `other`
165        drop_fn,
166        // otherwise diff the two statements
167        |sb| sa.diff(sb),
168    )
169}
170
171fn find_and_compare_create_table(
172    sa: &Statement,
173    a: &CreateTable,
174    other: &[Statement],
175) -> Result<Option<Vec<Statement>>, DiffError> {
176    find_and_compare(
177        sa,
178        other,
179        |sb| match sb {
180            Statement::CreateTable(b) => a.name == b.name,
181            _ => false,
182        },
183        || {
184            Ok(Some(vec![Statement::Drop {
185                object_type: sqlparser::ast::ObjectType::Table,
186                if_exists: a.if_not_exists,
187                names: vec![a.name.clone()],
188                cascade: false,
189                restrict: false,
190                purge: false,
191                temporary: false,
192                table: None,
193            }]))
194        },
195    )
196}
197
198fn find_and_compare_create_index(
199    sa: &Statement,
200    a: &CreateIndex,
201    other: &[Statement],
202) -> Result<Option<Vec<Statement>>, DiffError> {
203    find_and_compare(
204        sa,
205        other,
206        |sb| match sb {
207            Statement::CreateIndex(b) => a.name == b.name,
208            _ => false,
209        },
210        || {
211            let name = a.name.clone().ok_or_else(|| {
212                DiffError::builder()
213                    .kind(DiffErrorKind::DropUnnamedIndex)
214                    .statement_a(sa.clone())
215                    .build()
216            })?;
217
218            Ok(Some(vec![Statement::Drop {
219                object_type: sqlparser::ast::ObjectType::Index,
220                if_exists: a.if_not_exists,
221                names: vec![name],
222                cascade: false,
223                restrict: false,
224                purge: false,
225                temporary: false,
226                table: None,
227            }]))
228        },
229    )
230}
231
232fn find_and_compare_create_type(
233    sa: &Statement,
234    a_name: &ObjectName,
235    other: &[Statement],
236) -> Result<Option<Vec<Statement>>, DiffError> {
237    find_and_compare(
238        sa,
239        other,
240        |sb| match sb {
241            Statement::CreateType { name: b_name, .. } => a_name == b_name,
242            _ => false,
243        },
244        || {
245            Ok(Some(vec![Statement::Drop {
246                object_type: sqlparser::ast::ObjectType::Type,
247                if_exists: false,
248                names: vec![a_name.clone()],
249                cascade: false,
250                restrict: false,
251                purge: false,
252                temporary: false,
253                table: None,
254            }]))
255        },
256    )
257}
258
259fn find_and_compare_create_extension(
260    sa: &Statement,
261    a_name: &Ident,
262    if_not_exists: bool,
263    cascade: bool,
264    other: &[Statement],
265) -> Result<Option<Vec<Statement>>, DiffError> {
266    find_and_compare(
267        sa,
268        other,
269        |sb| match sb {
270            Statement::CreateExtension(CreateExtension { name: b_name, .. }) => a_name == b_name,
271            _ => false,
272        },
273        || {
274            Ok(Some(vec![Statement::DropExtension(DropExtension {
275                names: vec![a_name.clone()],
276                if_exists: if_not_exists,
277                cascade_or_restrict: if cascade {
278                    Some(sqlparser::ast::ReferentialAction::Cascade)
279                } else {
280                    None
281                },
282            })]))
283        },
284    )
285}
286
287fn find_and_compare_create_domain(
288    orig: &Statement,
289    domain: &CreateDomain,
290    other: &[Statement],
291) -> Result<Option<Vec<Statement>>, DiffError> {
292    let res = other
293        .iter()
294        .find(|sb| match sb {
295            Statement::CreateDomain(b) => b.name == domain.name,
296            _ => false,
297        })
298        .map(|sb| orig.diff(sb))
299        .transpose()?
300        .flatten();
301    Ok(res)
302}
303
304impl Diff for Statement {
305    type Diff = Option<Vec<Statement>>;
306
307    fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
308        match self {
309            Self::CreateTable(a) => match other {
310                Self::CreateTable(b) => Ok(compare_create_table(a, b)),
311                _ => Ok(None),
312            },
313            Self::CreateIndex(a) => match other {
314                Self::CreateIndex(b) => compare_create_index(a, b),
315                _ => Ok(None),
316            },
317            Self::CreateType {
318                name: a_name,
319                representation: a_rep,
320            } => match other {
321                Self::CreateType {
322                    name: b_name,
323                    representation: b_rep,
324                } => compare_create_type(self, a_name, a_rep, other, b_name, b_rep),
325                _ => Ok(None),
326            },
327            Self::CreateDomain(a) => match other {
328                Self::CreateDomain(b) => Ok(compare_create_domain(a, b)),
329                _ => Ok(None),
330            },
331            _ => Err(DiffError::builder()
332                .kind(DiffErrorKind::NotImplemented)
333                .statement_a(self.clone())
334                .statement_b(other.clone())
335                .build()),
336        }
337    }
338}
339
340fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Option<Vec<Statement>> {
341    if a == b {
342        return None;
343    }
344
345    let a_column_names: HashSet<_> = a.columns.iter().map(|c| c.name.value.clone()).collect();
346    let b_column_names: HashSet<_> = b.columns.iter().map(|c| c.name.value.clone()).collect();
347
348    let operations: Vec<_> = a
349        .columns
350        .iter()
351        .filter_map(|ac| {
352            if b_column_names.contains(&ac.name.value) {
353                None
354            } else {
355                // drop column if it only exists in `a`
356                Some(AlterTableOperation::DropColumn {
357                    column_names: vec![ac.name.clone()],
358                    if_exists: a.if_not_exists,
359                    drop_behavior: None,
360                    has_column_keyword: true,
361                })
362            }
363        })
364        .chain(b.columns.iter().filter_map(|bc| {
365            if a_column_names.contains(&bc.name.value) {
366                None
367            } else {
368                // add the column if it only exists in `b`
369                Some(AlterTableOperation::AddColumn {
370                    column_keyword: true,
371                    if_not_exists: a.if_not_exists,
372                    column_def: bc.clone(),
373                    column_position: None,
374                })
375            }
376        }))
377        .collect();
378
379    if operations.is_empty() {
380        return None;
381    }
382
383    Some(vec![Statement::AlterTable(AlterTable {
384        table_type: None,
385        name: a.name.clone(),
386        if_exists: a.if_not_exists,
387        only: false,
388        operations,
389        location: None,
390        on_cluster: a.on_cluster.clone(),
391        end_token: AttachedToken::empty(),
392    })])
393}
394
395fn compare_create_index(
396    a: &CreateIndex,
397    b: &CreateIndex,
398) -> Result<Option<Vec<Statement>>, DiffError> {
399    if a == b {
400        return Ok(None);
401    }
402
403    if a.name.is_none() || b.name.is_none() {
404        return Err(DiffError::builder()
405            .kind(DiffErrorKind::CompareUnnamedIndex)
406            .statement_a(Statement::CreateIndex(a.clone()))
407            .statement_b(Statement::CreateIndex(b.clone()))
408            .build());
409    }
410    let name = a.name.clone().unwrap();
411
412    Ok(Some(vec![
413        Statement::Drop {
414            object_type: ObjectType::Index,
415            if_exists: a.if_not_exists,
416            names: vec![name],
417            cascade: false,
418            restrict: false,
419            purge: false,
420            temporary: false,
421            table: None,
422        },
423        Statement::CreateIndex(b.clone()),
424    ]))
425}
426
427fn compare_create_type(
428    a: &Statement,
429    a_name: &ObjectName,
430    a_rep: &Option<UserDefinedTypeRepresentation>,
431    b: &Statement,
432    b_name: &ObjectName,
433    b_rep: &Option<UserDefinedTypeRepresentation>,
434) -> Result<Option<Vec<Statement>>, DiffError> {
435    if a_name == b_name && a_rep == b_rep {
436        return Ok(None);
437    }
438
439    let operations = match a_rep {
440        Some(UserDefinedTypeRepresentation::Enum { labels: a_labels }) => match b_rep {
441            Some(UserDefinedTypeRepresentation::Enum { labels: b_labels }) => {
442                match a_labels.len().cmp(&b_labels.len()) {
443                    Ordering::Equal => {
444                        let rename_labels: Vec<_> = a_labels
445                            .iter()
446                            .zip(b_labels.iter())
447                            .filter_map(|(a, b)| {
448                                if a == b {
449                                    None
450                                } else {
451                                    Some(AlterTypeOperation::RenameValue(
452                                        sqlparser::ast::AlterTypeRenameValue {
453                                            from: a.clone(),
454                                            to: b.clone(),
455                                        },
456                                    ))
457                                }
458                            })
459                            .collect();
460                        rename_labels
461                    }
462                    Ordering::Less => {
463                        let mut a_labels_iter = a_labels.iter().peekable();
464                        let mut operations = Vec::new();
465                        let mut prev = None;
466                        for b in b_labels {
467                            match a_labels_iter.peek() {
468                                Some(a) => {
469                                    let a = *a;
470                                    if a == b {
471                                        prev = Some(a);
472                                        a_labels_iter.next();
473                                        continue;
474                                    }
475
476                                    let position = match prev {
477                                        Some(a) => AlterTypeAddValuePosition::After(a.clone()),
478                                        None => AlterTypeAddValuePosition::Before(a.clone()),
479                                    };
480
481                                    prev = Some(b);
482                                    operations.push(AlterTypeOperation::AddValue(
483                                        AlterTypeAddValue {
484                                            if_not_exists: false,
485                                            value: b.clone(),
486                                            position: Some(position),
487                                        },
488                                    ));
489                                }
490                                None => {
491                                    if a_labels.contains(b) {
492                                        continue;
493                                    }
494                                    // labels occuring after all existing ones get added to the end
495                                    operations.push(AlterTypeOperation::AddValue(
496                                        AlterTypeAddValue {
497                                            if_not_exists: false,
498                                            value: b.clone(),
499                                            position: None,
500                                        },
501                                    ));
502                                }
503                            }
504                        }
505                        operations
506                    }
507                    _ => {
508                        return Err(DiffError::builder()
509                            .kind(DiffErrorKind::RemoveEnumLabel)
510                            .statement_a(a.clone())
511                            .statement_b(b.clone())
512                            .build());
513                    }
514                }
515            }
516            _ => {
517                // TODO: DROP and CREATE type
518                return Err(DiffError::builder()
519                    .kind(DiffErrorKind::NotImplemented)
520                    .statement_a(a.clone())
521                    .statement_b(b.clone())
522                    .build());
523            }
524        },
525        _ => {
526            // TODO: handle diffing composite attributes for CREATE TYPE
527            return Err(DiffError::builder()
528                .kind(DiffErrorKind::NotImplemented)
529                .statement_a(a.clone())
530                .statement_b(b.clone())
531                .build());
532        }
533    };
534
535    if operations.is_empty() {
536        return Ok(None);
537    }
538
539    Ok(Some(
540        operations
541            .into_iter()
542            .map(|operation| {
543                Statement::AlterType(AlterType {
544                    name: a_name.clone(),
545                    operation,
546                })
547            })
548            .collect(),
549    ))
550}
551
552fn compare_create_domain(a: &CreateDomain, b: &CreateDomain) -> Option<Vec<Statement>> {
553    if a == b {
554        return None;
555    }
556
557    Some(vec![
558        Statement::DropDomain(DropDomain {
559            if_exists: true,
560            name: a.name.clone(),
561            drop_behavior: None,
562        }),
563        Statement::CreateDomain(b.clone()),
564    ])
565}