Skip to main content

sql_splitter/splitter/
mod.rs

1use crate::parser::{determine_buffer_size, ContentFilter, Parser, SqlDialect, StatementType};
2use crate::progress::ProgressReader;
3use crate::writer::WriterPool;
4use ahash::AHashSet;
5use anyhow::Context;
6use serde::Serialize;
7use std::fs::File;
8use std::io::Read;
9use std::path::{Path, PathBuf};
10
11/// Statistics from a split operation.
12#[derive(Serialize)]
13pub struct Stats {
14    /// Total statements processed.
15    pub statements_processed: u64,
16    /// Number of unique tables found.
17    pub tables_found: usize,
18    /// Total bytes processed from input.
19    pub bytes_processed: u64,
20    /// Names of all tables found.
21    pub table_names: Vec<String>,
22}
23
24/// Configuration for the splitter.
25#[derive(Default)]
26pub struct SplitterConfig {
27    /// SQL dialect for parsing.
28    pub dialect: SqlDialect,
29    /// If true, parse without writing output files.
30    pub dry_run: bool,
31    /// If set, only process tables in this set.
32    pub table_filter: Option<AHashSet<String>>,
33    /// Optional callback for progress reporting.
34    pub progress_fn: Option<Box<dyn Fn(u64)>>,
35    /// Filter for which statement types to include.
36    pub content_filter: ContentFilter,
37}
38
39/// Compression format detected from file extension
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum Compression {
42    None,
43    Gzip,
44    Bzip2,
45    Xz,
46    Zstd,
47}
48
49impl Compression {
50    /// Detect compression format from file extension
51    pub fn from_path(path: &Path) -> Self {
52        let ext = path
53            .extension()
54            .and_then(|e| e.to_str())
55            .map(|e| e.to_lowercase());
56
57        match ext.as_deref() {
58            Some("gz" | "gzip") => Compression::Gzip,
59            Some("bz2" | "bzip2") => Compression::Bzip2,
60            Some("xz" | "lzma") => Compression::Xz,
61            Some("zst" | "zstd") => Compression::Zstd,
62            _ => Compression::None,
63        }
64    }
65
66    /// Wrap a reader with the appropriate decompressor
67    pub fn wrap_reader<'a>(
68        &self,
69        reader: Box<dyn Read + 'a>,
70    ) -> std::io::Result<Box<dyn Read + 'a>> {
71        Ok(match self {
72            Compression::None => reader,
73            Compression::Gzip => Box::new(flate2::read::GzDecoder::new(reader)),
74            Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(reader)),
75            Compression::Xz => Box::new(xz2::read::XzDecoder::new(reader)),
76            Compression::Zstd => Box::new(zstd::stream::read::Decoder::new(reader)?),
77        })
78    }
79}
80
81impl std::fmt::Display for Compression {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        match self {
84            Compression::None => write!(f, "none"),
85            Compression::Gzip => write!(f, "gzip"),
86            Compression::Bzip2 => write!(f, "bzip2"),
87            Compression::Xz => write!(f, "xz"),
88            Compression::Zstd => write!(f, "zstd"),
89        }
90    }
91}
92
93pub struct Splitter {
94    input_file: PathBuf,
95    output_dir: PathBuf,
96    config: SplitterConfig,
97}
98
99impl Splitter {
100    pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
101        Self {
102            input_file,
103            output_dir,
104            config: SplitterConfig::default(),
105        }
106    }
107
108    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
109        self.config.dialect = dialect;
110        self
111    }
112
113    pub fn with_dry_run(mut self, dry_run: bool) -> Self {
114        self.config.dry_run = dry_run;
115        self
116    }
117
118    pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
119        if !tables.is_empty() {
120            self.config.table_filter = Some(tables.into_iter().collect());
121        }
122        self
123    }
124
125    pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
126        self.config.progress_fn = Some(Box::new(f));
127        self
128    }
129
130    pub fn with_content_filter(mut self, filter: ContentFilter) -> Self {
131        self.config.content_filter = filter;
132        self
133    }
134
135    pub fn split(mut self) -> anyhow::Result<Stats> {
136        let file = File::open(&self.input_file)
137            .with_context(|| format!("Failed to open input file: {:?}", self.input_file))?;
138        let file_size = file.metadata()?.len();
139        let buffer_size = determine_buffer_size(file_size);
140        let dialect = self.config.dialect;
141        let content_filter = self.config.content_filter;
142
143        // Detect and apply decompression
144        let compression = Compression::from_path(&self.input_file);
145
146        let reader: Box<dyn Read> = if let Some(cb) = self.config.progress_fn.take() {
147            let progress_reader = ProgressReader::new(file, cb);
148            compression
149                .wrap_reader(Box::new(progress_reader))
150                .with_context(|| {
151                    format!(
152                        "Failed to initialize {} decompression for {:?}",
153                        compression, self.input_file
154                    )
155                })?
156        } else {
157            compression.wrap_reader(Box::new(file)).with_context(|| {
158                format!(
159                    "Failed to initialize {} decompression for {:?}",
160                    compression, self.input_file
161                )
162            })?
163        };
164
165        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
166
167        let mut writer_pool = WriterPool::new(self.output_dir.clone());
168        if !self.config.dry_run {
169            writer_pool.ensure_output_dir().with_context(|| {
170                format!("Failed to create output directory: {:?}", self.output_dir)
171            })?;
172        }
173
174        let mut tables_seen: AHashSet<String> = AHashSet::new();
175        let mut stats = Stats {
176            statements_processed: 0,
177            tables_found: 0,
178            bytes_processed: 0,
179            table_names: Vec::new(),
180        };
181
182        // Track the last COPY table for PostgreSQL COPY data blocks
183        let mut last_copy_table: Option<String> = None;
184
185        while let Some(stmt) = parser.read_statement()? {
186            let (stmt_type, mut table_name) =
187                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
188
189            // Track COPY statements for data association
190            if stmt_type == StatementType::Copy {
191                last_copy_table = Some(table_name.clone());
192            }
193
194            // Handle PostgreSQL COPY data blocks - associate with last COPY table
195            let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
196                // Check if this looks like COPY data (ends with \.\n)
197                if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
198                    // Safe: we just checked is_some() above
199                    if let Some(copy_table) = last_copy_table.take() {
200                        table_name = copy_table;
201                        true
202                    } else {
203                        false
204                    }
205                } else {
206                    false
207                }
208            } else {
209                false
210            };
211
212            if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
213                continue;
214            }
215
216            // Apply content filter (schema-only or data-only)
217            match content_filter {
218                ContentFilter::SchemaOnly => {
219                    if !stmt_type.is_schema() {
220                        continue;
221                    }
222                }
223                ContentFilter::DataOnly => {
224                    // For data-only, include INSERT, COPY, and COPY data blocks
225                    if !stmt_type.is_data() && !is_copy_data {
226                        continue;
227                    }
228                }
229                ContentFilter::All => {}
230            }
231
232            if let Some(ref filter) = self.config.table_filter {
233                if !filter.contains(&table_name) {
234                    continue;
235                }
236            }
237
238            if !tables_seen.contains(&table_name) {
239                tables_seen.insert(table_name.clone());
240                stats.tables_found += 1;
241                stats.table_names.push(table_name.clone());
242            }
243
244            if !self.config.dry_run {
245                // For MSSQL, add semicolon if statement doesn't end with one
246                // (MSSQL uses GO as batch separator, but we need semicolons for re-parsing)
247                let write_result = if self.config.dialect == SqlDialect::Mssql {
248                    let trimmed = stmt
249                        .iter()
250                        .rev()
251                        .find(|&&b| b != b'\n' && b != b'\r' && b != b' ' && b != b'\t');
252                    if trimmed != Some(&b';') {
253                        // Write statement + semicolon without cloning
254                        writer_pool.write_statement_with_suffix(&table_name, &stmt, b";")
255                    } else {
256                        writer_pool.write_statement(&table_name, &stmt)
257                    }
258                } else {
259                    writer_pool.write_statement(&table_name, &stmt)
260                };
261                write_result.with_context(|| {
262                    format!("Failed to write statement to table file: {}", table_name)
263                })?;
264            }
265
266            stats.statements_processed += 1;
267            stats.bytes_processed += stmt.len() as u64;
268        }
269
270        if !self.config.dry_run {
271            writer_pool.close_all()?;
272        }
273
274        Ok(stats)
275    }
276}