Skip to main content

sql_splitter/splitter/
mod.rs

1use crate::parser::{determine_buffer_size, ContentFilter, Parser, SqlDialect, StatementType};
2use crate::progress::ProgressReader;
3use crate::writer::WriterPool;
4use ahash::AHashSet;
5use anyhow::Context;
6use serde::Serialize;
7use std::fs::File;
8use std::io::Read;
9use std::path::{Path, PathBuf};
10
11#[derive(Serialize)]
12pub struct Stats {
13    pub statements_processed: u64,
14    pub tables_found: usize,
15    pub bytes_processed: u64,
16    pub table_names: Vec<String>,
17}
18
19#[derive(Default)]
20pub struct SplitterConfig {
21    pub dialect: SqlDialect,
22    pub dry_run: bool,
23    pub table_filter: Option<AHashSet<String>>,
24    pub progress_fn: Option<Box<dyn Fn(u64)>>,
25    pub content_filter: ContentFilter,
26}
27
28/// Compression format detected from file extension
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Compression {
31    None,
32    Gzip,
33    Bzip2,
34    Xz,
35    Zstd,
36}
37
38impl Compression {
39    /// Detect compression format from file extension
40    pub fn from_path(path: &Path) -> Self {
41        let ext = path
42            .extension()
43            .and_then(|e| e.to_str())
44            .map(|e| e.to_lowercase());
45
46        match ext.as_deref() {
47            Some("gz" | "gzip") => Compression::Gzip,
48            Some("bz2" | "bzip2") => Compression::Bzip2,
49            Some("xz" | "lzma") => Compression::Xz,
50            Some("zst" | "zstd") => Compression::Zstd,
51            _ => Compression::None,
52        }
53    }
54
55    /// Wrap a reader with the appropriate decompressor
56    pub fn wrap_reader<'a>(
57        &self,
58        reader: Box<dyn Read + 'a>,
59    ) -> std::io::Result<Box<dyn Read + 'a>> {
60        Ok(match self {
61            Compression::None => reader,
62            Compression::Gzip => Box::new(flate2::read::GzDecoder::new(reader)),
63            Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(reader)),
64            Compression::Xz => Box::new(xz2::read::XzDecoder::new(reader)),
65            Compression::Zstd => Box::new(zstd::stream::read::Decoder::new(reader)?),
66        })
67    }
68}
69
70impl std::fmt::Display for Compression {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        match self {
73            Compression::None => write!(f, "none"),
74            Compression::Gzip => write!(f, "gzip"),
75            Compression::Bzip2 => write!(f, "bzip2"),
76            Compression::Xz => write!(f, "xz"),
77            Compression::Zstd => write!(f, "zstd"),
78        }
79    }
80}
81
82pub struct Splitter {
83    input_file: PathBuf,
84    output_dir: PathBuf,
85    config: SplitterConfig,
86}
87
88impl Splitter {
89    pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
90        Self {
91            input_file,
92            output_dir,
93            config: SplitterConfig::default(),
94        }
95    }
96
97    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
98        self.config.dialect = dialect;
99        self
100    }
101
102    pub fn with_dry_run(mut self, dry_run: bool) -> Self {
103        self.config.dry_run = dry_run;
104        self
105    }
106
107    pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
108        if !tables.is_empty() {
109            self.config.table_filter = Some(tables.into_iter().collect());
110        }
111        self
112    }
113
114    pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
115        self.config.progress_fn = Some(Box::new(f));
116        self
117    }
118
119    pub fn with_content_filter(mut self, filter: ContentFilter) -> Self {
120        self.config.content_filter = filter;
121        self
122    }
123
124    pub fn split(mut self) -> anyhow::Result<Stats> {
125        let file = File::open(&self.input_file)
126            .with_context(|| format!("Failed to open input file: {:?}", self.input_file))?;
127        let file_size = file.metadata()?.len();
128        let buffer_size = determine_buffer_size(file_size);
129        let dialect = self.config.dialect;
130        let content_filter = self.config.content_filter;
131
132        // Detect and apply decompression
133        let compression = Compression::from_path(&self.input_file);
134
135        let reader: Box<dyn Read> = if let Some(cb) = self.config.progress_fn.take() {
136            let progress_reader = ProgressReader::new(file, cb);
137            compression
138                .wrap_reader(Box::new(progress_reader))
139                .with_context(|| {
140                    format!(
141                        "Failed to initialize {} decompression for {:?}",
142                        compression, self.input_file
143                    )
144                })?
145        } else {
146            compression.wrap_reader(Box::new(file)).with_context(|| {
147                format!(
148                    "Failed to initialize {} decompression for {:?}",
149                    compression, self.input_file
150                )
151            })?
152        };
153
154        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
155
156        let mut writer_pool = WriterPool::new(self.output_dir.clone());
157        if !self.config.dry_run {
158            writer_pool.ensure_output_dir().with_context(|| {
159                format!("Failed to create output directory: {:?}", self.output_dir)
160            })?;
161        }
162
163        let mut tables_seen: AHashSet<String> = AHashSet::new();
164        let mut stats = Stats {
165            statements_processed: 0,
166            tables_found: 0,
167            bytes_processed: 0,
168            table_names: Vec::new(),
169        };
170
171        // Track the last COPY table for PostgreSQL COPY data blocks
172        let mut last_copy_table: Option<String> = None;
173
174        while let Some(stmt) = parser.read_statement()? {
175            let (stmt_type, mut table_name) =
176                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
177
178            // Track COPY statements for data association
179            if stmt_type == StatementType::Copy {
180                last_copy_table = Some(table_name.clone());
181            }
182
183            // Handle PostgreSQL COPY data blocks - associate with last COPY table
184            let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
185                // Check if this looks like COPY data (ends with \.\n)
186                if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
187                    // Safe: we just checked is_some() above
188                    if let Some(copy_table) = last_copy_table.take() {
189                        table_name = copy_table;
190                        true
191                    } else {
192                        false
193                    }
194                } else {
195                    false
196                }
197            } else {
198                false
199            };
200
201            if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
202                continue;
203            }
204
205            // Apply content filter (schema-only or data-only)
206            match content_filter {
207                ContentFilter::SchemaOnly => {
208                    if !stmt_type.is_schema() {
209                        continue;
210                    }
211                }
212                ContentFilter::DataOnly => {
213                    // For data-only, include INSERT, COPY, and COPY data blocks
214                    if !stmt_type.is_data() && !is_copy_data {
215                        continue;
216                    }
217                }
218                ContentFilter::All => {}
219            }
220
221            if let Some(ref filter) = self.config.table_filter {
222                if !filter.contains(&table_name) {
223                    continue;
224                }
225            }
226
227            if !tables_seen.contains(&table_name) {
228                tables_seen.insert(table_name.clone());
229                stats.tables_found += 1;
230                stats.table_names.push(table_name.clone());
231            }
232
233            if !self.config.dry_run {
234                // For MSSQL, add semicolon if statement doesn't end with one
235                // (MSSQL uses GO as batch separator, but we need semicolons for re-parsing)
236                let write_result = if self.config.dialect == SqlDialect::Mssql {
237                    let trimmed = stmt
238                        .iter()
239                        .rev()
240                        .find(|&&b| b != b'\n' && b != b'\r' && b != b' ' && b != b'\t');
241                    if trimmed != Some(&b';') {
242                        // Write statement + semicolon without cloning
243                        writer_pool.write_statement_with_suffix(&table_name, &stmt, b";")
244                    } else {
245                        writer_pool.write_statement(&table_name, &stmt)
246                    }
247                } else {
248                    writer_pool.write_statement(&table_name, &stmt)
249                };
250                write_result.with_context(|| {
251                    format!("Failed to write statement to table file: {}", table_name)
252                })?;
253            }
254
255            stats.statements_processed += 1;
256            stats.bytes_processed += stmt.len() as u64;
257        }
258
259        if !self.config.dry_run {
260            writer_pool.close_all()?;
261        }
262
263        Ok(stats)
264    }
265}