1use std::collections::{BTreeMap, BTreeSet};
4use std::io::Write;
5use std::path::Path;
6
7use crate::db::inspector::{for_each_row, quote_identifier};
8use crate::db::types::{
9 DatabaseSummary, SchemaDiff, SchemaObjectInfo, SqlValue, TableDataDiff, TableInfo,
10};
11use crate::error::{PatchworksError, Result};
12
13pub fn export_diff_as_sql(
19 right_path: &Path,
20 left: &DatabaseSummary,
21 right: &DatabaseSummary,
22 schema_diff: &SchemaDiff,
23 data_diffs: &[TableDataDiff],
24) -> Result<String> {
25 let mut buffer = Vec::new();
26 write_export(
27 &mut buffer,
28 right_path,
29 left,
30 right,
31 schema_diff,
32 data_diffs,
33 )?;
34 String::from_utf8(buffer).map_err(|error| {
35 PatchworksError::InvalidState(format!("generated SQL contained invalid UTF-8: {error}"))
36 })
37}
38
39pub fn write_export<W: Write>(
44 writer: &mut W,
45 right_path: &Path,
46 left: &DatabaseSummary,
47 right: &DatabaseSummary,
48 schema_diff: &SchemaDiff,
49 data_diffs: &[TableDataDiff],
50) -> Result<()> {
51 let left_tables = left
52 .tables
53 .iter()
54 .map(|table| (table.name.clone(), table))
55 .collect::<BTreeMap<_, _>>();
56 let right_tables = right
57 .tables
58 .iter()
59 .map(|table| (table.name.clone(), table))
60 .collect::<BTreeMap<_, _>>();
61 let rebuilt_tables = rebuilt_table_names(schema_diff);
62 let incrementally_changed_tables = incrementally_changed_table_names(schema_diff, data_diffs);
63 let object_changed_tables = schema_object_changed_table_names(schema_diff);
64 let trigger_reset_tables = rebuilt_tables
65 .union(&incrementally_changed_tables)
66 .cloned()
67 .chain(object_changed_tables.iter().cloned())
68 .collect::<BTreeSet<_>>();
69 let index_reset_tables = rebuilt_tables
70 .union(&object_changed_tables)
71 .cloned()
72 .collect::<BTreeSet<_>>();
73
74 writeln!(writer, "PRAGMA foreign_keys=OFF;")?;
75 writeln!(writer, "BEGIN TRANSACTION;")?;
76
77 for trigger in &left.triggers {
78 if trigger_reset_tables.contains(&trigger.table_name) {
79 writeln!(
80 writer,
81 "DROP TRIGGER IF EXISTS {};",
82 quote_identifier(&trigger.name)
83 )?;
84 }
85 }
86
87 for index in &left.indexes {
88 if index_reset_tables.contains(&index.table_name) {
89 writeln!(
90 writer,
91 "DROP INDEX IF EXISTS {};",
92 quote_identifier(&index.name)
93 )?;
94 }
95 }
96
97 for table in &schema_diff.removed_tables {
98 writeln!(
99 writer,
100 "DROP TABLE IF EXISTS {};",
101 quote_identifier(&table.name)
102 )?;
103 }
104
105 for table in &schema_diff.added_tables {
106 stream_create_and_seed(writer, right_path, table, &table.name)?;
107 }
108
109 for table_diff in &schema_diff.modified_tables {
110 let right_table = right_tables.get(&table_diff.table_name).ok_or_else(|| {
111 PatchworksError::InvalidState(format!(
112 "missing right-side table definition for `{}`",
113 table_diff.table_name
114 ))
115 })?;
116 let replacement_name = format!("__patchworks_new_{}", right_table.name);
117 stream_create_and_seed(writer, right_path, right_table, &replacement_name)?;
118 writeln!(
119 writer,
120 "DROP TABLE {};",
121 quote_identifier(&right_table.name)
122 )?;
123 writeln!(
124 writer,
125 "ALTER TABLE {} RENAME TO {};",
126 quote_identifier(&replacement_name),
127 quote_identifier(&right_table.name)
128 )?;
129 }
130
131 for table_name in &schema_diff.unchanged_tables {
132 let table = left_tables.get(table_name).ok_or_else(|| {
133 PatchworksError::InvalidState(format!("missing unchanged table `{table_name}`"))
134 })?;
135 if let Some(data_diff) = data_diffs
136 .iter()
137 .find(|diff| diff.table_name == *table_name)
138 {
139 write_incremental_changes(writer, table, data_diff)?;
140 }
141 }
142
143 for index in &right.indexes {
144 if index_reset_tables.contains(&index.table_name) {
145 writeln!(writer, "{}", schema_object_create_sql(index, "index")?)?;
146 }
147 }
148
149 for trigger in &right.triggers {
150 if trigger_reset_tables.contains(&trigger.table_name) {
151 writeln!(writer, "{}", schema_object_create_sql(trigger, "trigger")?)?;
152 }
153 }
154
155 writeln!(writer, "COMMIT;")?;
156 write!(writer, "PRAGMA foreign_keys=ON;")?;
157 Ok(())
158}
159
160fn rebuilt_table_names(schema_diff: &SchemaDiff) -> BTreeSet<String> {
161 schema_diff
162 .added_tables
163 .iter()
164 .map(|table| table.name.clone())
165 .chain(
166 schema_diff
167 .modified_tables
168 .iter()
169 .map(|table| table.table_name.clone()),
170 )
171 .collect()
172}
173
174fn incrementally_changed_table_names(
175 schema_diff: &SchemaDiff,
176 data_diffs: &[TableDataDiff],
177) -> BTreeSet<String> {
178 let unchanged_tables = schema_diff
179 .unchanged_tables
180 .iter()
181 .cloned()
182 .collect::<BTreeSet<_>>();
183
184 data_diffs
185 .iter()
186 .filter(|diff| diff.stats.added > 0 || diff.stats.removed > 0 || diff.stats.modified > 0)
187 .map(|diff| diff.table_name.clone())
188 .filter(|table_name| unchanged_tables.contains(table_name))
189 .collect()
190}
191
192fn schema_object_changed_table_names(schema_diff: &SchemaDiff) -> BTreeSet<String> {
193 schema_diff
194 .added_indexes
195 .iter()
196 .map(|object| object.table_name.clone())
197 .chain(
198 schema_diff
199 .removed_indexes
200 .iter()
201 .map(|object| object.table_name.clone()),
202 )
203 .chain(
204 schema_diff
205 .modified_indexes
206 .iter()
207 .flat_map(|(left, right)| [left.table_name.clone(), right.table_name.clone()]),
208 )
209 .chain(
210 schema_diff
211 .added_triggers
212 .iter()
213 .map(|object| object.table_name.clone()),
214 )
215 .chain(
216 schema_diff
217 .removed_triggers
218 .iter()
219 .map(|object| object.table_name.clone()),
220 )
221 .chain(
222 schema_diff
223 .modified_triggers
224 .iter()
225 .flat_map(|(left, right)| [left.table_name.clone(), right.table_name.clone()]),
226 )
227 .collect()
228}
229
230fn stream_create_and_seed<W: Write>(
232 writer: &mut W,
233 path: &Path,
234 table: &TableInfo,
235 target_name: &str,
236) -> Result<()> {
237 writeln!(writer, "{}", create_table_sql_for_name(table, target_name)?)?;
238 let column_list = table
239 .columns
240 .iter()
241 .map(|column| quote_identifier(&column.name))
242 .collect::<Vec<_>>()
243 .join(", ");
244 let target_quoted = quote_identifier(target_name);
245
246 for_each_row(path, table, |row| {
247 writeln!(
248 writer,
249 "INSERT INTO {} ({}) VALUES ({});",
250 target_quoted,
251 column_list,
252 row.iter().map(sql_literal).collect::<Vec<_>>().join(", ")
253 )?;
254 Ok(())
255 })?;
256
257 Ok(())
258}
259
260fn write_incremental_changes<W: Write>(
261 writer: &mut W,
262 table: &TableInfo,
263 data_diff: &TableDataDiff,
264) -> Result<()> {
265 let primary_key = export_identity_columns(table)?;
266
267 for (index, row) in data_diff.removed_rows.iter().enumerate() {
268 let key = if table.primary_key.is_empty() {
269 data_diff.removed_row_keys.get(index).unwrap_or(row)
270 } else {
271 row
272 };
273 writeln!(
274 writer,
275 "DELETE FROM {} WHERE {};",
276 quote_identifier(&table.name),
277 where_clause(&table.name, &data_diff.columns, key, &primary_key)?
278 )?;
279 }
280
281 for row in &data_diff.added_rows {
282 writeln!(
283 writer,
284 "INSERT INTO {} ({}) VALUES ({});",
285 quote_identifier(&table.name),
286 data_diff
287 .columns
288 .iter()
289 .map(|column| quote_identifier(column))
290 .collect::<Vec<_>>()
291 .join(", "),
292 row.iter().map(sql_literal).collect::<Vec<_>>().join(", ")
293 )?;
294 }
295
296 for row in &data_diff.modified_rows {
297 let set_clause = row
298 .changes
299 .iter()
300 .map(|change| {
301 format!(
302 "{} = {}",
303 quote_identifier(&change.column),
304 sql_literal(&change.new_value)
305 )
306 })
307 .collect::<Vec<_>>()
308 .join(", ");
309 let where_clause = if table.primary_key.is_empty() {
310 format!("rowid = {}", sql_literal(&row.primary_key[0]))
311 } else {
312 primary_key
313 .iter()
314 .zip(row.primary_key.iter())
315 .map(|(column, value)| {
316 format!("{} = {}", quote_identifier(column), sql_literal(value))
317 })
318 .collect::<Vec<_>>()
319 .join(" AND ")
320 };
321 writeln!(
322 writer,
323 "UPDATE {} SET {} WHERE {};",
324 quote_identifier(&table.name),
325 set_clause,
326 where_clause
327 )?;
328 }
329
330 Ok(())
331}
332
333fn schema_object_create_sql(object: &SchemaObjectInfo, kind: &str) -> Result<String> {
334 object
335 .create_sql
336 .as_ref()
337 .map(|sql| sql.trim_end_matches(';').to_owned() + ";")
338 .ok_or_else(|| {
339 PatchworksError::InvalidState(format!(
340 "missing CREATE {} SQL for `{}`",
341 kind, object.name
342 ))
343 })
344}
345
346fn where_clause(
347 table_name: &str,
348 columns: &[String],
349 row: &[SqlValue],
350 primary_key: &[String],
351) -> Result<String> {
352 if primary_key.len() == 1 && primary_key[0] == "rowid" {
353 return Ok(format!("rowid = {}", sql_literal(&row[0])));
354 }
355
356 let clauses = primary_key
357 .iter()
358 .map(|key| {
359 let index = columns
360 .iter()
361 .position(|column| column == key)
362 .ok_or_else(|| {
363 PatchworksError::InvalidState(format!(
364 "missing primary key column `{key}` while exporting `{table_name}`"
365 ))
366 })?;
367 let value = row.get(index).ok_or_else(|| {
368 PatchworksError::InvalidState(format!(
369 "missing primary key value for column `{key}` while exporting `{table_name}`"
370 ))
371 })?;
372 Ok(format!(
373 "{} = {}",
374 quote_identifier(key),
375 sql_literal(value)
376 ))
377 })
378 .collect::<Result<Vec<_>>>()?;
379
380 Ok(clauses.join(" AND "))
381}
382
383fn sql_literal(value: &SqlValue) -> String {
384 match value {
385 SqlValue::Null => "NULL".to_owned(),
386 SqlValue::Integer(value) => value.to_string(),
387 SqlValue::Real(value) => {
388 if value.is_finite() {
389 value.to_string()
390 } else {
391 "NULL".to_owned()
392 }
393 }
394 SqlValue::Text(value) => format!("'{}'", value.replace('\'', "''")),
395 SqlValue::Blob(bytes) => {
396 let hex = bytes
397 .iter()
398 .map(|byte| format!("{byte:02X}"))
399 .collect::<String>();
400 format!("X'{hex}'")
401 }
402 }
403}
404
405fn export_identity_columns(table: &TableInfo) -> Result<Vec<String>> {
406 if table.primary_key.is_empty() {
407 if table_supports_rowid(table) {
408 Ok(vec!["rowid".to_owned()])
409 } else {
410 Err(PatchworksError::InvalidState(format!(
411 "table `{}` has no primary key and cannot use rowid during SQL export",
412 table.name
413 )))
414 }
415 } else {
416 Ok(table.primary_key.clone())
417 }
418}
419
420fn table_supports_rowid(table: &TableInfo) -> bool {
421 table
422 .create_sql
423 .as_ref()
424 .map(|sql| !sql.to_ascii_uppercase().contains("WITHOUT ROWID"))
425 .unwrap_or(true)
426}
427
428fn create_table_sql_for_name(table: &TableInfo, target_name: &str) -> Result<String> {
429 let create_sql = table.create_sql.clone().ok_or_else(|| {
430 PatchworksError::InvalidState(format!("missing CREATE TABLE SQL for `{}`", table.name))
431 })?;
432 let trimmed = create_sql.trim_end_matches(';');
433 let sql = if table.name == target_name {
434 trimmed.to_owned()
435 } else {
436 rewrite_create_table_name(trimmed, target_name)?
437 };
438 Ok(sql + ";")
439}
440
441fn rewrite_create_table_name(create_sql: &str, target_name: &str) -> Result<String> {
442 let name_start = create_table_name_start(create_sql)?;
443 let name_end = create_table_name_end(create_sql, name_start)?;
444
445 if name_end <= name_start {
446 return Err(PatchworksError::InvalidState(
447 "CREATE TABLE SQL has an invalid table-name position while rewriting export".to_owned(),
448 ));
449 }
450
451 Ok(format!(
452 "{}{}{}",
453 &create_sql[..name_start],
454 target_name,
455 &create_sql[name_end..]
456 ))
457}
458
459fn create_table_name_start(create_sql: &str) -> Result<usize> {
460 let mut index = skip_ascii_whitespace(create_sql, 0);
461 index = consume_keyword(create_sql, index, "CREATE").ok_or_else(|| {
462 PatchworksError::InvalidState(
463 "CREATE TABLE SQL did not start with CREATE while rewriting export".to_owned(),
464 )
465 })?;
466 index = skip_ascii_whitespace(create_sql, index);
467
468 if let Some(next) = consume_keyword(create_sql, index, "TEMPORARY") {
469 index = skip_ascii_whitespace(create_sql, next);
470 } else if let Some(next) = consume_keyword(create_sql, index, "TEMP") {
471 index = skip_ascii_whitespace(create_sql, next);
472 }
473
474 index = consume_keyword(create_sql, index, "TABLE").ok_or_else(|| {
475 PatchworksError::InvalidState(
476 "CREATE TABLE SQL did not contain TABLE while rewriting export".to_owned(),
477 )
478 })?;
479 index = skip_ascii_whitespace(create_sql, index);
480
481 if let Some(next) = consume_keyword(create_sql, index, "IF") {
482 index = skip_ascii_whitespace(create_sql, next);
483 index = consume_keyword(create_sql, index, "NOT").ok_or_else(|| {
484 PatchworksError::InvalidState(
485 "CREATE TABLE SQL had IF without NOT while rewriting export".to_owned(),
486 )
487 })?;
488 index = skip_ascii_whitespace(create_sql, index);
489 index = consume_keyword(create_sql, index, "EXISTS").ok_or_else(|| {
490 PatchworksError::InvalidState(
491 "CREATE TABLE SQL had IF NOT without EXISTS while rewriting export".to_owned(),
492 )
493 })?;
494 index = skip_ascii_whitespace(create_sql, index);
495 }
496
497 Ok(index)
498}
499
500fn create_table_name_end(create_sql: &str, start: usize) -> Result<usize> {
501 let bytes = create_sql.as_bytes();
502 let mut index = start;
503 let mut quoted_by: Option<u8> = None;
504
505 while let Some(&byte) = bytes.get(index) {
506 if let Some(quote) = quoted_by {
507 if byte == quote {
508 if matches!(quote, b'"' | b'`') && bytes.get(index + 1) == Some("e) {
509 index += 2;
510 continue;
511 }
512 quoted_by = None;
513 }
514 index += 1;
515 continue;
516 }
517
518 match byte {
519 b'"' => quoted_by = Some(b'"'),
520 b'`' => quoted_by = Some(b'`'),
521 b'[' => quoted_by = Some(b']'),
522 b'(' => break,
523 _ if byte.is_ascii_whitespace() => break,
524 _ => {}
525 }
526 index += 1;
527 }
528
529 if index == start {
530 Err(PatchworksError::InvalidState(
531 "CREATE TABLE SQL is missing a table name while rewriting export".to_owned(),
532 ))
533 } else {
534 Ok(index)
535 }
536}
537
538fn skip_ascii_whitespace(sql: &str, mut index: usize) -> usize {
539 while let Some(byte) = sql.as_bytes().get(index) {
540 if byte.is_ascii_whitespace() {
541 index += 1;
542 } else {
543 break;
544 }
545 }
546 index
547}
548
549fn consume_keyword(sql: &str, index: usize, keyword: &str) -> Option<usize> {
550 let end = index.checked_add(keyword.len())?;
551 let slice = sql.get(index..end)?;
552 if !slice.eq_ignore_ascii_case(keyword) {
553 return None;
554 }
555
556 match sql[end..].chars().next() {
557 Some(ch) if !ch.is_ascii_whitespace() => None,
558 _ => Some(end),
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use super::{create_table_sql_for_name, schema_object_create_sql, where_clause, TableInfo};
565 use crate::db::types::{ColumnInfo, SchemaObjectInfo, SqlValue};
566 use crate::error::PatchworksError;
567
568 #[test]
569 fn where_clause_rejects_missing_primary_key_columns() {
570 let error = where_clause(
571 "items",
572 &[String::from("name")],
573 &[SqlValue::Text(String::from("widget"))],
574 &[String::from("id")],
575 )
576 .expect_err("missing primary key column should error");
577
578 assert!(matches!(error, PatchworksError::InvalidState(_)));
579 assert!(error
580 .to_string()
581 .contains("missing primary key column `id` while exporting `items`"));
582 }
583
584 #[test]
585 fn schema_object_create_sql_requires_source_sql() {
586 let error = schema_object_create_sql(
587 &SchemaObjectInfo {
588 name: String::from("items_name_idx"),
589 table_name: String::from("items"),
590 create_sql: None,
591 },
592 "index",
593 )
594 .expect_err("missing sql should error");
595
596 assert!(matches!(error, PatchworksError::InvalidState(_)));
597 assert!(error
598 .to_string()
599 .contains("missing CREATE index SQL for `items_name_idx`"));
600 }
601
602 #[test]
603 fn create_table_sql_for_name_rewrites_table_name_for_rebuilds() {
604 let sql = create_table_sql_for_name(
605 &TableInfo {
606 name: "parents".to_owned(),
607 columns: vec![ColumnInfo {
608 name: "id".to_owned(),
609 col_type: "INTEGER".to_owned(),
610 nullable: false,
611 default_value: None,
612 is_primary_key: true,
613 }],
614 row_count: 0,
615 primary_key: vec!["id".to_owned()],
616 create_sql: Some(
617 "CREATE TABLE IF NOT EXISTS parents (id INTEGER PRIMARY KEY) WITHOUT ROWID"
618 .to_owned(),
619 ),
620 },
621 "__patchworks_new_parents",
622 )
623 .expect("rewrite create sql");
624
625 assert_eq!(
626 sql,
627 "CREATE TABLE IF NOT EXISTS __patchworks_new_parents (id INTEGER PRIMARY KEY) WITHOUT ROWID;"
628 );
629 }
630}