Skip to main content

patchworks/diff/
export.rs

1//! SQL export for Patchworks diffs.
2
3use std::collections::{BTreeMap, BTreeSet};
4use std::io::Write;
5use std::path::Path;
6
7use crate::db::inspector::{for_each_row, quote_identifier};
8use crate::db::types::{
9    DatabaseSummary, SchemaDiff, SchemaObjectInfo, SqlValue, TableDataDiff, TableInfo,
10};
11use crate::error::{PatchworksError, Result};
12
13/// Generates a SQL migration script that transforms the left database into the right database.
14///
15/// This is a convenience wrapper around [`write_export`] that collects the output into a `String`.
16/// For large migrations, prefer [`write_export`] with a file or buffered writer to avoid
17/// holding the entire migration in memory.
18pub fn export_diff_as_sql(
19    right_path: &Path,
20    left: &DatabaseSummary,
21    right: &DatabaseSummary,
22    schema_diff: &SchemaDiff,
23    data_diffs: &[TableDataDiff],
24) -> Result<String> {
25    let mut buffer = Vec::new();
26    write_export(
27        &mut buffer,
28        right_path,
29        left,
30        right,
31        schema_diff,
32        data_diffs,
33    )?;
34    String::from_utf8(buffer).map_err(|error| {
35        PatchworksError::InvalidState(format!("generated SQL contained invalid UTF-8: {error}"))
36    })
37}
38
39/// Writes a SQL migration script to any [`Write`] sink, streaming one statement at a time.
40///
41/// This is the bounded-memory export path. Table seeding rows are streamed from disk rather than
42/// materialized into a full in-memory collection, and each SQL statement is flushed individually.
43pub fn write_export<W: Write>(
44    writer: &mut W,
45    right_path: &Path,
46    left: &DatabaseSummary,
47    right: &DatabaseSummary,
48    schema_diff: &SchemaDiff,
49    data_diffs: &[TableDataDiff],
50) -> Result<()> {
51    let left_tables = left
52        .tables
53        .iter()
54        .map(|table| (table.name.clone(), table))
55        .collect::<BTreeMap<_, _>>();
56    let right_tables = right
57        .tables
58        .iter()
59        .map(|table| (table.name.clone(), table))
60        .collect::<BTreeMap<_, _>>();
61    let rebuilt_tables = rebuilt_table_names(schema_diff);
62    let incrementally_changed_tables = incrementally_changed_table_names(schema_diff, data_diffs);
63    let object_changed_tables = schema_object_changed_table_names(schema_diff);
64    let trigger_reset_tables = rebuilt_tables
65        .union(&incrementally_changed_tables)
66        .cloned()
67        .chain(object_changed_tables.iter().cloned())
68        .collect::<BTreeSet<_>>();
69    let index_reset_tables = rebuilt_tables
70        .union(&object_changed_tables)
71        .cloned()
72        .collect::<BTreeSet<_>>();
73
74    writeln!(writer, "PRAGMA foreign_keys=OFF;")?;
75    writeln!(writer, "BEGIN TRANSACTION;")?;
76
77    for trigger in &left.triggers {
78        if trigger_reset_tables.contains(&trigger.table_name) {
79            writeln!(
80                writer,
81                "DROP TRIGGER IF EXISTS {};",
82                quote_identifier(&trigger.name)
83            )?;
84        }
85    }
86
87    for index in &left.indexes {
88        if index_reset_tables.contains(&index.table_name) {
89            writeln!(
90                writer,
91                "DROP INDEX IF EXISTS {};",
92                quote_identifier(&index.name)
93            )?;
94        }
95    }
96
97    for table in &schema_diff.removed_tables {
98        writeln!(
99            writer,
100            "DROP TABLE IF EXISTS {};",
101            quote_identifier(&table.name)
102        )?;
103    }
104
105    for table in &schema_diff.added_tables {
106        stream_create_and_seed(writer, right_path, table, &table.name)?;
107    }
108
109    for table_diff in &schema_diff.modified_tables {
110        let right_table = right_tables.get(&table_diff.table_name).ok_or_else(|| {
111            PatchworksError::InvalidState(format!(
112                "missing right-side table definition for `{}`",
113                table_diff.table_name
114            ))
115        })?;
116        let replacement_name = format!("__patchworks_new_{}", right_table.name);
117        stream_create_and_seed(writer, right_path, right_table, &replacement_name)?;
118        writeln!(
119            writer,
120            "DROP TABLE {};",
121            quote_identifier(&right_table.name)
122        )?;
123        writeln!(
124            writer,
125            "ALTER TABLE {} RENAME TO {};",
126            quote_identifier(&replacement_name),
127            quote_identifier(&right_table.name)
128        )?;
129    }
130
131    for table_name in &schema_diff.unchanged_tables {
132        let table = left_tables.get(table_name).ok_or_else(|| {
133            PatchworksError::InvalidState(format!("missing unchanged table `{table_name}`"))
134        })?;
135        if let Some(data_diff) = data_diffs
136            .iter()
137            .find(|diff| diff.table_name == *table_name)
138        {
139            write_incremental_changes(writer, table, data_diff)?;
140        }
141    }
142
143    for index in &right.indexes {
144        if index_reset_tables.contains(&index.table_name) {
145            writeln!(writer, "{}", schema_object_create_sql(index, "index")?)?;
146        }
147    }
148
149    for trigger in &right.triggers {
150        if trigger_reset_tables.contains(&trigger.table_name) {
151            writeln!(writer, "{}", schema_object_create_sql(trigger, "trigger")?)?;
152        }
153    }
154
155    writeln!(writer, "COMMIT;")?;
156    write!(writer, "PRAGMA foreign_keys=ON;")?;
157    Ok(())
158}
159
160fn rebuilt_table_names(schema_diff: &SchemaDiff) -> BTreeSet<String> {
161    schema_diff
162        .added_tables
163        .iter()
164        .map(|table| table.name.clone())
165        .chain(
166            schema_diff
167                .modified_tables
168                .iter()
169                .map(|table| table.table_name.clone()),
170        )
171        .collect()
172}
173
174fn incrementally_changed_table_names(
175    schema_diff: &SchemaDiff,
176    data_diffs: &[TableDataDiff],
177) -> BTreeSet<String> {
178    let unchanged_tables = schema_diff
179        .unchanged_tables
180        .iter()
181        .cloned()
182        .collect::<BTreeSet<_>>();
183
184    data_diffs
185        .iter()
186        .filter(|diff| diff.stats.added > 0 || diff.stats.removed > 0 || diff.stats.modified > 0)
187        .map(|diff| diff.table_name.clone())
188        .filter(|table_name| unchanged_tables.contains(table_name))
189        .collect()
190}
191
192fn schema_object_changed_table_names(schema_diff: &SchemaDiff) -> BTreeSet<String> {
193    schema_diff
194        .added_indexes
195        .iter()
196        .map(|object| object.table_name.clone())
197        .chain(
198            schema_diff
199                .removed_indexes
200                .iter()
201                .map(|object| object.table_name.clone()),
202        )
203        .chain(
204            schema_diff
205                .modified_indexes
206                .iter()
207                .flat_map(|(left, right)| [left.table_name.clone(), right.table_name.clone()]),
208        )
209        .chain(
210            schema_diff
211                .added_triggers
212                .iter()
213                .map(|object| object.table_name.clone()),
214        )
215        .chain(
216            schema_diff
217                .removed_triggers
218                .iter()
219                .map(|object| object.table_name.clone()),
220        )
221        .chain(
222            schema_diff
223                .modified_triggers
224                .iter()
225                .flat_map(|(left, right)| [left.table_name.clone(), right.table_name.clone()]),
226        )
227        .collect()
228}
229
230/// Streams CREATE TABLE plus INSERT statements for a table, one row at a time.
231fn stream_create_and_seed<W: Write>(
232    writer: &mut W,
233    path: &Path,
234    table: &TableInfo,
235    target_name: &str,
236) -> Result<()> {
237    writeln!(writer, "{}", create_table_sql_for_name(table, target_name)?)?;
238    let column_list = table
239        .columns
240        .iter()
241        .map(|column| quote_identifier(&column.name))
242        .collect::<Vec<_>>()
243        .join(", ");
244    let target_quoted = quote_identifier(target_name);
245
246    for_each_row(path, table, |row| {
247        writeln!(
248            writer,
249            "INSERT INTO {} ({}) VALUES ({});",
250            target_quoted,
251            column_list,
252            row.iter().map(sql_literal).collect::<Vec<_>>().join(", ")
253        )?;
254        Ok(())
255    })?;
256
257    Ok(())
258}
259
260fn write_incremental_changes<W: Write>(
261    writer: &mut W,
262    table: &TableInfo,
263    data_diff: &TableDataDiff,
264) -> Result<()> {
265    let primary_key = export_identity_columns(table)?;
266
267    for (index, row) in data_diff.removed_rows.iter().enumerate() {
268        let key = if table.primary_key.is_empty() {
269            data_diff.removed_row_keys.get(index).unwrap_or(row)
270        } else {
271            row
272        };
273        writeln!(
274            writer,
275            "DELETE FROM {} WHERE {};",
276            quote_identifier(&table.name),
277            where_clause(&table.name, &data_diff.columns, key, &primary_key)?
278        )?;
279    }
280
281    for row in &data_diff.added_rows {
282        writeln!(
283            writer,
284            "INSERT INTO {} ({}) VALUES ({});",
285            quote_identifier(&table.name),
286            data_diff
287                .columns
288                .iter()
289                .map(|column| quote_identifier(column))
290                .collect::<Vec<_>>()
291                .join(", "),
292            row.iter().map(sql_literal).collect::<Vec<_>>().join(", ")
293        )?;
294    }
295
296    for row in &data_diff.modified_rows {
297        let set_clause = row
298            .changes
299            .iter()
300            .map(|change| {
301                format!(
302                    "{} = {}",
303                    quote_identifier(&change.column),
304                    sql_literal(&change.new_value)
305                )
306            })
307            .collect::<Vec<_>>()
308            .join(", ");
309        let where_clause = if table.primary_key.is_empty() {
310            format!("rowid = {}", sql_literal(&row.primary_key[0]))
311        } else {
312            primary_key
313                .iter()
314                .zip(row.primary_key.iter())
315                .map(|(column, value)| {
316                    format!("{} = {}", quote_identifier(column), sql_literal(value))
317                })
318                .collect::<Vec<_>>()
319                .join(" AND ")
320        };
321        writeln!(
322            writer,
323            "UPDATE {} SET {} WHERE {};",
324            quote_identifier(&table.name),
325            set_clause,
326            where_clause
327        )?;
328    }
329
330    Ok(())
331}
332
333fn schema_object_create_sql(object: &SchemaObjectInfo, kind: &str) -> Result<String> {
334    object
335        .create_sql
336        .as_ref()
337        .map(|sql| sql.trim_end_matches(';').to_owned() + ";")
338        .ok_or_else(|| {
339            PatchworksError::InvalidState(format!(
340                "missing CREATE {} SQL for `{}`",
341                kind, object.name
342            ))
343        })
344}
345
346fn where_clause(
347    table_name: &str,
348    columns: &[String],
349    row: &[SqlValue],
350    primary_key: &[String],
351) -> Result<String> {
352    if primary_key.len() == 1 && primary_key[0] == "rowid" {
353        return Ok(format!("rowid = {}", sql_literal(&row[0])));
354    }
355
356    let clauses = primary_key
357        .iter()
358        .map(|key| {
359            let index = columns
360                .iter()
361                .position(|column| column == key)
362                .ok_or_else(|| {
363                    PatchworksError::InvalidState(format!(
364                        "missing primary key column `{key}` while exporting `{table_name}`"
365                    ))
366                })?;
367            let value = row.get(index).ok_or_else(|| {
368                PatchworksError::InvalidState(format!(
369                    "missing primary key value for column `{key}` while exporting `{table_name}`"
370                ))
371            })?;
372            Ok(format!(
373                "{} = {}",
374                quote_identifier(key),
375                sql_literal(value)
376            ))
377        })
378        .collect::<Result<Vec<_>>>()?;
379
380    Ok(clauses.join(" AND "))
381}
382
383fn sql_literal(value: &SqlValue) -> String {
384    match value {
385        SqlValue::Null => "NULL".to_owned(),
386        SqlValue::Integer(value) => value.to_string(),
387        SqlValue::Real(value) => {
388            if value.is_finite() {
389                value.to_string()
390            } else {
391                "NULL".to_owned()
392            }
393        }
394        SqlValue::Text(value) => format!("'{}'", value.replace('\'', "''")),
395        SqlValue::Blob(bytes) => {
396            let hex = bytes
397                .iter()
398                .map(|byte| format!("{byte:02X}"))
399                .collect::<String>();
400            format!("X'{hex}'")
401        }
402    }
403}
404
405fn export_identity_columns(table: &TableInfo) -> Result<Vec<String>> {
406    if table.primary_key.is_empty() {
407        if table_supports_rowid(table) {
408            Ok(vec!["rowid".to_owned()])
409        } else {
410            Err(PatchworksError::InvalidState(format!(
411                "table `{}` has no primary key and cannot use rowid during SQL export",
412                table.name
413            )))
414        }
415    } else {
416        Ok(table.primary_key.clone())
417    }
418}
419
420fn table_supports_rowid(table: &TableInfo) -> bool {
421    table
422        .create_sql
423        .as_ref()
424        .map(|sql| !sql.to_ascii_uppercase().contains("WITHOUT ROWID"))
425        .unwrap_or(true)
426}
427
428fn create_table_sql_for_name(table: &TableInfo, target_name: &str) -> Result<String> {
429    let create_sql = table.create_sql.clone().ok_or_else(|| {
430        PatchworksError::InvalidState(format!("missing CREATE TABLE SQL for `{}`", table.name))
431    })?;
432    let trimmed = create_sql.trim_end_matches(';');
433    let sql = if table.name == target_name {
434        trimmed.to_owned()
435    } else {
436        rewrite_create_table_name(trimmed, target_name)?
437    };
438    Ok(sql + ";")
439}
440
441fn rewrite_create_table_name(create_sql: &str, target_name: &str) -> Result<String> {
442    let name_start = create_table_name_start(create_sql)?;
443    let name_end = create_table_name_end(create_sql, name_start)?;
444
445    if name_end <= name_start {
446        return Err(PatchworksError::InvalidState(
447            "CREATE TABLE SQL has an invalid table-name position while rewriting export".to_owned(),
448        ));
449    }
450
451    Ok(format!(
452        "{}{}{}",
453        &create_sql[..name_start],
454        target_name,
455        &create_sql[name_end..]
456    ))
457}
458
459fn create_table_name_start(create_sql: &str) -> Result<usize> {
460    let mut index = skip_ascii_whitespace(create_sql, 0);
461    index = consume_keyword(create_sql, index, "CREATE").ok_or_else(|| {
462        PatchworksError::InvalidState(
463            "CREATE TABLE SQL did not start with CREATE while rewriting export".to_owned(),
464        )
465    })?;
466    index = skip_ascii_whitespace(create_sql, index);
467
468    if let Some(next) = consume_keyword(create_sql, index, "TEMPORARY") {
469        index = skip_ascii_whitespace(create_sql, next);
470    } else if let Some(next) = consume_keyword(create_sql, index, "TEMP") {
471        index = skip_ascii_whitespace(create_sql, next);
472    }
473
474    index = consume_keyword(create_sql, index, "TABLE").ok_or_else(|| {
475        PatchworksError::InvalidState(
476            "CREATE TABLE SQL did not contain TABLE while rewriting export".to_owned(),
477        )
478    })?;
479    index = skip_ascii_whitespace(create_sql, index);
480
481    if let Some(next) = consume_keyword(create_sql, index, "IF") {
482        index = skip_ascii_whitespace(create_sql, next);
483        index = consume_keyword(create_sql, index, "NOT").ok_or_else(|| {
484            PatchworksError::InvalidState(
485                "CREATE TABLE SQL had IF without NOT while rewriting export".to_owned(),
486            )
487        })?;
488        index = skip_ascii_whitespace(create_sql, index);
489        index = consume_keyword(create_sql, index, "EXISTS").ok_or_else(|| {
490            PatchworksError::InvalidState(
491                "CREATE TABLE SQL had IF NOT without EXISTS while rewriting export".to_owned(),
492            )
493        })?;
494        index = skip_ascii_whitespace(create_sql, index);
495    }
496
497    Ok(index)
498}
499
500fn create_table_name_end(create_sql: &str, start: usize) -> Result<usize> {
501    let bytes = create_sql.as_bytes();
502    let mut index = start;
503    let mut quoted_by: Option<u8> = None;
504
505    while let Some(&byte) = bytes.get(index) {
506        if let Some(quote) = quoted_by {
507            if byte == quote {
508                if matches!(quote, b'"' | b'`') && bytes.get(index + 1) == Some(&quote) {
509                    index += 2;
510                    continue;
511                }
512                quoted_by = None;
513            }
514            index += 1;
515            continue;
516        }
517
518        match byte {
519            b'"' => quoted_by = Some(b'"'),
520            b'`' => quoted_by = Some(b'`'),
521            b'[' => quoted_by = Some(b']'),
522            b'(' => break,
523            _ if byte.is_ascii_whitespace() => break,
524            _ => {}
525        }
526        index += 1;
527    }
528
529    if index == start {
530        Err(PatchworksError::InvalidState(
531            "CREATE TABLE SQL is missing a table name while rewriting export".to_owned(),
532        ))
533    } else {
534        Ok(index)
535    }
536}
537
538fn skip_ascii_whitespace(sql: &str, mut index: usize) -> usize {
539    while let Some(byte) = sql.as_bytes().get(index) {
540        if byte.is_ascii_whitespace() {
541            index += 1;
542        } else {
543            break;
544        }
545    }
546    index
547}
548
549fn consume_keyword(sql: &str, index: usize, keyword: &str) -> Option<usize> {
550    let end = index.checked_add(keyword.len())?;
551    let slice = sql.get(index..end)?;
552    if !slice.eq_ignore_ascii_case(keyword) {
553        return None;
554    }
555
556    match sql[end..].chars().next() {
557        Some(ch) if !ch.is_ascii_whitespace() => None,
558        _ => Some(end),
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::{create_table_sql_for_name, schema_object_create_sql, where_clause, TableInfo};
565    use crate::db::types::{ColumnInfo, SchemaObjectInfo, SqlValue};
566    use crate::error::PatchworksError;
567
568    #[test]
569    fn where_clause_rejects_missing_primary_key_columns() {
570        let error = where_clause(
571            "items",
572            &[String::from("name")],
573            &[SqlValue::Text(String::from("widget"))],
574            &[String::from("id")],
575        )
576        .expect_err("missing primary key column should error");
577
578        assert!(matches!(error, PatchworksError::InvalidState(_)));
579        assert!(error
580            .to_string()
581            .contains("missing primary key column `id` while exporting `items`"));
582    }
583
584    #[test]
585    fn schema_object_create_sql_requires_source_sql() {
586        let error = schema_object_create_sql(
587            &SchemaObjectInfo {
588                name: String::from("items_name_idx"),
589                table_name: String::from("items"),
590                create_sql: None,
591            },
592            "index",
593        )
594        .expect_err("missing sql should error");
595
596        assert!(matches!(error, PatchworksError::InvalidState(_)));
597        assert!(error
598            .to_string()
599            .contains("missing CREATE index SQL for `items_name_idx`"));
600    }
601
602    #[test]
603    fn create_table_sql_for_name_rewrites_table_name_for_rebuilds() {
604        let sql = create_table_sql_for_name(
605            &TableInfo {
606                name: "parents".to_owned(),
607                columns: vec![ColumnInfo {
608                    name: "id".to_owned(),
609                    col_type: "INTEGER".to_owned(),
610                    nullable: false,
611                    default_value: None,
612                    is_primary_key: true,
613                }],
614                row_count: 0,
615                primary_key: vec!["id".to_owned()],
616                create_sql: Some(
617                    "CREATE TABLE IF NOT EXISTS parents (id INTEGER PRIMARY KEY) WITHOUT ROWID"
618                        .to_owned(),
619                ),
620            },
621            "__patchworks_new_parents",
622        )
623        .expect("rewrite create sql");
624
625        assert_eq!(
626            sql,
627            "CREATE TABLE IF NOT EXISTS __patchworks_new_parents (id INTEGER PRIMARY KEY) WITHOUT ROWID;"
628        );
629    }
630}