sql_splitter/splitter/
mod.rs1use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
2use crate::writer::WriterPool;
3use ahash::AHashSet;
4use std::fs::File;
5use std::io::Read;
6use std::path::PathBuf;
7
8pub struct Stats {
9 pub statements_processed: u64,
10 pub tables_found: usize,
11 pub bytes_processed: u64,
12 pub table_names: Vec<String>,
13}
14
15#[derive(Default)]
16pub struct SplitterConfig {
17 pub dialect: SqlDialect,
18 pub dry_run: bool,
19 pub table_filter: Option<AHashSet<String>>,
20 pub progress_fn: Option<Box<dyn Fn(u64)>>,
21}
22
23pub struct Splitter {
24 input_file: PathBuf,
25 output_dir: PathBuf,
26 config: SplitterConfig,
27}
28
29impl Splitter {
30 pub fn new(input_file: PathBuf, output_dir: PathBuf) -> Self {
31 Self {
32 input_file,
33 output_dir,
34 config: SplitterConfig::default(),
35 }
36 }
37
38 pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
39 self.config.dialect = dialect;
40 self
41 }
42
43 pub fn with_dry_run(mut self, dry_run: bool) -> Self {
44 self.config.dry_run = dry_run;
45 self
46 }
47
48 pub fn with_table_filter(mut self, tables: Vec<String>) -> Self {
49 if !tables.is_empty() {
50 self.config.table_filter = Some(tables.into_iter().collect());
51 }
52 self
53 }
54
55 pub fn with_progress<F: Fn(u64) + 'static>(mut self, f: F) -> Self {
56 self.config.progress_fn = Some(Box::new(f));
57 self
58 }
59
60 pub fn split(self) -> anyhow::Result<Stats> {
61 let file = File::open(&self.input_file)?;
62 let file_size = file.metadata()?.len();
63 let buffer_size = determine_buffer_size(file_size);
64 let dialect = self.config.dialect;
65
66 let reader: Box<dyn Read> = if self.config.progress_fn.is_some() {
67 Box::new(ProgressReader::new(file, self.config.progress_fn.unwrap()))
68 } else {
69 Box::new(file)
70 };
71
72 let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
73
74 let mut writer_pool = WriterPool::new(self.output_dir.clone());
75 if !self.config.dry_run {
76 writer_pool.ensure_output_dir()?;
77 }
78
79 let mut tables_seen: AHashSet<String> = AHashSet::new();
80 let mut stats = Stats {
81 statements_processed: 0,
82 tables_found: 0,
83 bytes_processed: 0,
84 table_names: Vec::new(),
85 };
86
87 let mut last_copy_table: Option<String> = None;
89
90 while let Some(stmt) = parser.read_statement()? {
91 let (stmt_type, mut table_name) = Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
92
93 if stmt_type == StatementType::Copy {
95 last_copy_table = Some(table_name.clone());
96 }
97
98 let is_copy_data = if stmt_type == StatementType::Unknown && last_copy_table.is_some() {
100 if stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n") {
102 table_name = last_copy_table.take().unwrap();
103 true
104 } else {
105 false
106 }
107 } else {
108 false
109 };
110
111 if !is_copy_data && (stmt_type == StatementType::Unknown || table_name.is_empty()) {
112 continue;
113 }
114
115 if let Some(ref filter) = self.config.table_filter {
116 if !filter.contains(&table_name) {
117 continue;
118 }
119 }
120
121 if !tables_seen.contains(&table_name) {
122 tables_seen.insert(table_name.clone());
123 stats.tables_found += 1;
124 stats.table_names.push(table_name.clone());
125 }
126
127 if !self.config.dry_run {
128 writer_pool.write_statement(&table_name, &stmt)?;
129 }
130
131 stats.statements_processed += 1;
132 stats.bytes_processed += stmt.len() as u64;
133 }
134
135 if !self.config.dry_run {
136 writer_pool.close_all()?;
137 }
138
139 Ok(stats)
140 }
141}
142
143struct ProgressReader<R: Read> {
144 reader: R,
145 callback: Box<dyn Fn(u64)>,
146 bytes_read: u64,
147}
148
149impl<R: Read> ProgressReader<R> {
150 fn new(reader: R, callback: Box<dyn Fn(u64)>) -> Self {
151 Self {
152 reader,
153 callback,
154 bytes_read: 0,
155 }
156 }
157}
158
159impl<R: Read> Read for ProgressReader<R> {
160 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
161 let n = self.reader.read(buf)?;
162 self.bytes_read += n as u64;
163 (self.callback)(self.bytes_read);
164 Ok(n)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use tempfile::TempDir;
172
173 #[test]
174 fn test_splitter_basic() {
175 let temp_dir = TempDir::new().unwrap();
176 let input_file = temp_dir.path().join("input.sql");
177 let output_dir = temp_dir.path().join("output");
178
179 std::fs::write(
180 &input_file,
181 b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nCREATE TABLE posts (id INT);\n",
182 )
183 .unwrap();
184
185 let splitter = Splitter::new(input_file, output_dir.clone());
186 let stats = splitter.split().unwrap();
187
188 assert_eq!(stats.tables_found, 2);
189 assert_eq!(stats.statements_processed, 3);
190
191 assert!(output_dir.join("users.sql").exists());
192 assert!(output_dir.join("posts.sql").exists());
193 }
194
195 #[test]
196 fn test_splitter_dry_run() {
197 let temp_dir = TempDir::new().unwrap();
198 let input_file = temp_dir.path().join("input.sql");
199 let output_dir = temp_dir.path().join("output");
200
201 std::fs::write(&input_file, b"CREATE TABLE users (id INT);").unwrap();
202
203 let splitter = Splitter::new(input_file, output_dir.clone()).with_dry_run(true);
204 let stats = splitter.split().unwrap();
205
206 assert_eq!(stats.tables_found, 1);
207 assert!(!output_dir.exists());
208 }
209
210 #[test]
211 fn test_splitter_table_filter() {
212 let temp_dir = TempDir::new().unwrap();
213 let input_file = temp_dir.path().join("input.sql");
214 let output_dir = temp_dir.path().join("output");
215
216 std::fs::write(
217 &input_file,
218 b"CREATE TABLE users (id INT);\nCREATE TABLE posts (id INT);\nCREATE TABLE orders (id INT);",
219 )
220 .unwrap();
221
222 let splitter = Splitter::new(input_file, output_dir.clone())
223 .with_table_filter(vec!["users".to_string(), "orders".to_string()]);
224 let stats = splitter.split().unwrap();
225
226 assert_eq!(stats.tables_found, 2);
227 assert!(output_dir.join("users.sql").exists());
228 assert!(!output_dir.join("posts.sql").exists());
229 assert!(output_dir.join("orders.sql").exists());
230 }
231}