Skip to main content

sql_splitter/differ/
mod.rs

1//! Diff module for comparing two SQL dumps.
2//!
3//! This module provides:
4//! - Schema comparison (tables added/removed/modified, columns, PKs, FKs)
5//! - Data comparison (row counts: added/removed/modified)
6//! - Memory-bounded operation using PK hashing
7//! - Multiple output formats (text, json, sql)
8
9mod data;
10mod output;
11mod schema;
12
13pub use data::*;
14pub use output::*;
15pub use schema::*;
16
17use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
18use crate::progress::ProgressReader;
19use crate::schema::{Schema, SchemaBuilder};
20use crate::splitter::Compression;
21use glob::Pattern;
22use serde::Serialize;
23use std::fs::File;
24use std::io::Read;
25use std::path::{Path, PathBuf};
26use std::sync::Arc;
27
28/// Configuration for the diff operation
29#[derive(Debug, Clone)]
30pub struct DiffConfig {
31    /// Path to the "old" SQL file
32    pub old_path: PathBuf,
33    /// Path to the "new" SQL file
34    pub new_path: PathBuf,
35    /// SQL dialect (auto-detected if None)
36    pub dialect: Option<SqlDialect>,
37    /// Only compare schema, skip data
38    pub schema_only: bool,
39    /// Only compare data, skip schema
40    pub data_only: bool,
41    /// Tables to include (if empty, include all)
42    pub tables: Vec<String>,
43    /// Tables to exclude
44    pub exclude: Vec<String>,
45    /// Output format
46    pub format: DiffOutputFormat,
47    /// Show verbose row-level details
48    pub verbose: bool,
49    /// Show progress bar
50    pub progress: bool,
51    /// Maximum PK entries to track globally
52    pub max_pk_entries: usize,
53    /// Don't skip tables without PK, use all columns as key
54    pub allow_no_pk: bool,
55    /// Ignore column order when comparing schemas
56    pub ignore_column_order: bool,
57    /// Primary key overrides: table name -> column names
58    pub pk_overrides: std::collections::HashMap<String, Vec<String>>,
59    /// Column patterns to ignore (glob format: table.column)
60    pub ignore_columns: Vec<String>,
61}
62
63impl Default for DiffConfig {
64    fn default() -> Self {
65        Self {
66            old_path: PathBuf::new(),
67            new_path: PathBuf::new(),
68            dialect: None,
69            schema_only: false,
70            data_only: false,
71            tables: Vec::new(),
72            exclude: Vec::new(),
73            format: DiffOutputFormat::Text,
74            verbose: false,
75            progress: false,
76            max_pk_entries: 10_000_000, // 10M entries ~= 160MB
77            allow_no_pk: false,
78            ignore_column_order: false,
79            pk_overrides: std::collections::HashMap::new(),
80            ignore_columns: Vec::new(),
81        }
82    }
83}
84
85/// Output format for diff results
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
87pub enum DiffOutputFormat {
88    #[default]
89    Text,
90    Json,
91    Sql,
92}
93
94impl std::str::FromStr for DiffOutputFormat {
95    type Err = String;
96
97    fn from_str(s: &str) -> Result<Self, Self::Err> {
98        match s.to_lowercase().as_str() {
99            "text" => Ok(Self::Text),
100            "json" => Ok(Self::Json),
101            "sql" => Ok(Self::Sql),
102            _ => Err(format!("Unknown format: {}. Use: text, json, sql", s)),
103        }
104    }
105}
106
107/// A warning generated during diff operation
108#[derive(Debug, Serialize, Clone)]
109pub struct DiffWarning {
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub table: Option<String>,
112    pub message: String,
113}
114
115/// Complete diff result
116#[derive(Debug, Serialize)]
117pub struct DiffResult {
118    /// Schema differences
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub schema: Option<SchemaDiff>,
121    /// Data differences
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub data: Option<DataDiff>,
124    /// Warnings generated during diff
125    #[serde(skip_serializing_if = "Vec::is_empty")]
126    pub warnings: Vec<DiffWarning>,
127    /// Summary statistics
128    pub summary: DiffSummary,
129}
130
131/// Summary of differences
132#[derive(Debug, Serialize)]
133pub struct DiffSummary {
134    /// Number of tables added
135    pub tables_added: usize,
136    /// Number of tables removed
137    pub tables_removed: usize,
138    /// Number of tables modified (schema or data)
139    pub tables_modified: usize,
140    /// Total rows added across all tables
141    pub rows_added: u64,
142    /// Total rows removed across all tables
143    pub rows_removed: u64,
144    /// Total rows modified across all tables
145    pub rows_modified: u64,
146    /// Whether any data was truncated due to memory limits
147    pub truncated: bool,
148}
149
150/// Main differ engine
151pub struct Differ {
152    config: DiffConfig,
153    dialect: SqlDialect,
154    progress_fn: Option<Arc<dyn Fn(u64, u64) + Send + Sync>>,
155}
156
157impl Differ {
158    /// Create a new differ with the given configuration
159    pub fn new(config: DiffConfig) -> Self {
160        Self {
161            dialect: config.dialect.unwrap_or(SqlDialect::MySql),
162            config,
163            progress_fn: None,
164        }
165    }
166
167    /// Set a progress callback (receives current bytes, total bytes)
168    pub fn with_progress<F>(mut self, f: F) -> Self
169    where
170        F: Fn(u64, u64) + Send + Sync + 'static,
171    {
172        self.progress_fn = Some(Arc::new(f));
173        self
174    }
175
176    /// Run the diff operation
177    pub fn diff(self) -> anyhow::Result<DiffResult> {
178        // Calculate total bytes for progress (4 passes max: 2 schema + 2 data)
179        let old_size = std::fs::metadata(&self.config.old_path)?.len();
180        let new_size = std::fs::metadata(&self.config.new_path)?.len();
181        let total_bytes = if self.config.schema_only || self.config.data_only {
182            old_size + new_size
183        } else {
184            (old_size + new_size) * 2 // Schema pass + data pass for each file
185        };
186
187        // Pass 0: Extract schemas from both files (always needed, even data-only needs PK info)
188        let old_schema = self.extract_schema(&self.config.old_path.clone(), 0, total_bytes)?;
189        let new_schema =
190            self.extract_schema(&self.config.new_path.clone(), old_size, total_bytes)?;
191
192        // Schema comparison
193        let schema_diff = if !self.config.data_only {
194            Some(compare_schemas(&old_schema, &new_schema, &self.config))
195        } else {
196            None
197        };
198
199        // Data comparison
200        let (data_diff, warnings) = if !self.config.schema_only {
201            let base_offset = if self.config.data_only {
202                0
203            } else {
204                old_size + new_size
205            };
206
207            let (data, warns) =
208                self.compare_data(&old_schema, &new_schema, base_offset, total_bytes)?;
209            (Some(data), warns)
210        } else {
211            (None, Vec::new())
212        };
213
214        // Build summary
215        let summary = self.build_summary(&schema_diff, &data_diff);
216
217        Ok(DiffResult {
218            schema: schema_diff,
219            data: data_diff,
220            warnings,
221            summary,
222        })
223    }
224
225    /// Extract schema from a SQL file
226    fn extract_schema(
227        &self,
228        path: &Path,
229        byte_offset: u64,
230        total_bytes: u64,
231    ) -> anyhow::Result<Schema> {
232        let file = File::open(path)?;
233        let file_size = file.metadata()?.len();
234        let buffer_size = determine_buffer_size(file_size);
235        let compression = Compression::from_path(path);
236
237        let reader: Box<dyn Read> = if let Some(ref cb) = self.progress_fn {
238            let cb = Arc::clone(cb);
239            let progress_reader = ProgressReader::new(file, move |bytes| {
240                cb(byte_offset + bytes, total_bytes);
241            });
242            compression.wrap_reader(Box::new(progress_reader))?
243        } else {
244            compression.wrap_reader(Box::new(file))?
245        };
246
247        let mut parser = Parser::with_dialect(reader, buffer_size, self.dialect);
248        let mut builder = SchemaBuilder::new();
249
250        while let Some(stmt) = parser.read_statement()? {
251            let (stmt_type, _table_name) =
252                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, self.dialect);
253
254            match stmt_type {
255                StatementType::CreateTable => {
256                    if let Ok(stmt_str) = std::str::from_utf8(&stmt) {
257                        builder.parse_create_table(stmt_str);
258                    }
259                }
260                StatementType::AlterTable => {
261                    if let Ok(stmt_str) = std::str::from_utf8(&stmt) {
262                        builder.parse_alter_table(stmt_str);
263                    }
264                }
265                StatementType::CreateIndex => {
266                    if let Ok(stmt_str) = std::str::from_utf8(&stmt) {
267                        builder.parse_create_index(stmt_str);
268                    }
269                }
270                _ => {}
271            }
272        }
273
274        Ok(builder.build())
275    }
276
277    /// Compare data between two SQL files
278    fn compare_data(
279        &self,
280        old_schema: &Schema,
281        new_schema: &Schema,
282        byte_offset: u64,
283        total_bytes: u64,
284    ) -> anyhow::Result<(DataDiff, Vec<DiffWarning>)> {
285        let mut data_differ = DataDiffer::new(DataDiffOptions {
286            max_pk_entries_global: self.config.max_pk_entries,
287            max_pk_entries_per_table: self.config.max_pk_entries / 2,
288            sample_size: if self.config.verbose { 100 } else { 0 },
289            tables: self.config.tables.clone(),
290            exclude: self.config.exclude.clone(),
291            allow_no_pk: self.config.allow_no_pk,
292            pk_overrides: self.config.pk_overrides.clone(),
293            ignore_columns: self.config.ignore_columns.clone(),
294        });
295
296        let old_size = std::fs::metadata(&self.config.old_path)?.len();
297
298        // Pass 1: Scan old file
299        data_differ.scan_file(
300            &self.config.old_path,
301            old_schema,
302            self.dialect,
303            true, // is_old
304            &self.progress_fn,
305            byte_offset,
306            total_bytes,
307        )?;
308
309        // Pass 2: Scan new file
310        data_differ.scan_file(
311            &self.config.new_path,
312            new_schema,
313            self.dialect,
314            false, // is_old
315            &self.progress_fn,
316            byte_offset + old_size,
317            total_bytes,
318        )?;
319
320        Ok(data_differ.compute_diff())
321    }
322
323    /// Build summary from diff results
324    fn build_summary(
325        &self,
326        schema_diff: &Option<SchemaDiff>,
327        data_diff: &Option<DataDiff>,
328    ) -> DiffSummary {
329        let (tables_added, tables_removed, schema_modified) = schema_diff
330            .as_ref()
331            .map(|s| {
332                (
333                    s.tables_added.len(),
334                    s.tables_removed.len(),
335                    s.tables_modified.len(),
336                )
337            })
338            .unwrap_or((0, 0, 0));
339
340        let (rows_added, rows_removed, rows_modified, data_modified, truncated) = data_diff
341            .as_ref()
342            .map(|d| {
343                let mut added = 0u64;
344                let mut removed = 0u64;
345                let mut modified = 0u64;
346                let mut tables_with_changes = 0usize;
347                let mut any_truncated = false;
348
349                for table_diff in d.tables.values() {
350                    added += table_diff.added_count;
351                    removed += table_diff.removed_count;
352                    modified += table_diff.modified_count;
353                    if table_diff.added_count > 0
354                        || table_diff.removed_count > 0
355                        || table_diff.modified_count > 0
356                    {
357                        tables_with_changes += 1;
358                    }
359                    if table_diff.truncated {
360                        any_truncated = true;
361                    }
362                }
363
364                (added, removed, modified, tables_with_changes, any_truncated)
365            })
366            .unwrap_or((0, 0, 0, 0, false));
367
368        DiffSummary {
369            tables_added,
370            tables_removed,
371            tables_modified: schema_modified.max(data_modified),
372            rows_added,
373            rows_removed,
374            rows_modified,
375            truncated,
376        }
377    }
378}
379
380/// Parse ignore column patterns into compiled Pattern objects
381pub fn parse_ignore_patterns(patterns: &[String]) -> Vec<Pattern> {
382    patterns
383        .iter()
384        .filter_map(|p| Pattern::new(&p.to_lowercase()).ok())
385        .collect()
386}
387
388/// Check if a column should be ignored based on patterns
389pub fn should_ignore_column(table: &str, column: &str, patterns: &[Pattern]) -> bool {
390    let full_name = format!("{}.{}", table.to_lowercase(), column.to_lowercase());
391    patterns.iter().any(|p| p.matches(&full_name))
392}
393
394/// Check if a table should be included based on filter config
395pub fn should_include_table(table_name: &str, tables: &[String], exclude: &[String]) -> bool {
396    // If include list is specified, table must be in it
397    if !tables.is_empty() {
398        let name_lower = table_name.to_lowercase();
399        if !tables.iter().any(|t| t.to_lowercase() == name_lower) {
400            return false;
401        }
402    }
403
404    // If table is in exclude list, skip it
405    if !exclude.is_empty() {
406        let name_lower = table_name.to_lowercase();
407        if exclude.iter().any(|t| t.to_lowercase() == name_lower) {
408            return false;
409        }
410    }
411
412    true
413}