sql_splitter/merger/
mod.rs1use crate::parser::SqlDialect;
4use std::collections::HashSet;
5use std::fs::{self, File};
6use std::io::{self, BufRead, BufReader, BufWriter, Write};
7use std::path::PathBuf;
8
9#[derive(Debug, Default)]
11pub struct MergeStats {
12 pub tables_merged: usize,
13 pub bytes_written: u64,
14 pub table_names: Vec<String>,
15}
16
17#[derive(Default)]
19pub struct MergerConfig {
20 pub dialect: SqlDialect,
21 pub tables: Option<HashSet<String>>,
22 pub exclude: HashSet<String>,
23 pub add_transaction: bool,
24 pub add_header: bool,
25}
26
27pub struct Merger {
29 input_dir: PathBuf,
30 output: Option<PathBuf>,
31 config: MergerConfig,
32}
33
34impl Merger {
35 pub fn new(input_dir: PathBuf, output: Option<PathBuf>) -> Self {
36 Self {
37 input_dir,
38 output,
39 config: MergerConfig::default(),
40 }
41 }
42
43 pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
44 self.config.dialect = dialect;
45 self
46 }
47
48 pub fn with_tables(mut self, tables: HashSet<String>) -> Self {
49 self.config.tables = Some(tables);
50 self
51 }
52
53 pub fn with_exclude(mut self, exclude: HashSet<String>) -> Self {
54 self.config.exclude = exclude;
55 self
56 }
57
58 pub fn with_transaction(mut self, add_transaction: bool) -> Self {
59 self.config.add_transaction = add_transaction;
60 self
61 }
62
63 pub fn with_header(mut self, add_header: bool) -> Self {
64 self.config.add_header = add_header;
65 self
66 }
67
68 pub fn merge(&self) -> anyhow::Result<MergeStats> {
70 let sql_files = self.discover_sql_files()?;
72 if sql_files.is_empty() {
73 anyhow::bail!(
74 "no .sql files found in directory: {}",
75 self.input_dir.display()
76 );
77 }
78
79 let filtered_files: Vec<(String, PathBuf)> = sql_files
81 .into_iter()
82 .filter(|(name, _)| {
83 let name_lower = name.to_lowercase();
84 if let Some(ref include) = self.config.tables {
85 if !include.contains(&name_lower) {
86 return false;
87 }
88 }
89 !self.config.exclude.contains(&name_lower)
90 })
91 .collect();
92
93 if filtered_files.is_empty() {
94 anyhow::bail!("no tables remaining after filtering");
95 }
96
97 let mut sorted_files = filtered_files;
99 sorted_files.sort_by(|a, b| a.0.cmp(&b.0));
100
101 if let Some(ref out_path) = self.output {
103 if let Some(parent) = out_path.parent() {
104 fs::create_dir_all(parent)?;
105 }
106 let file = File::create(out_path)?;
107 let writer = BufWriter::with_capacity(256 * 1024, file);
108 self.merge_files(sorted_files, writer)
109 } else {
110 let stdout = io::stdout();
111 let writer = BufWriter::new(stdout.lock());
112 self.merge_files(sorted_files, writer)
113 }
114 }
115
116 fn discover_sql_files(&self) -> anyhow::Result<Vec<(String, PathBuf)>> {
117 let mut files = Vec::new();
118
119 for entry in fs::read_dir(&self.input_dir)? {
120 let entry = entry?;
121 let path = entry.path();
122
123 if path.is_file() {
124 if let Some(ext) = path.extension() {
125 if ext.eq_ignore_ascii_case("sql") {
126 if let Some(stem) = path.file_stem() {
127 let table_name = stem.to_string_lossy().to_string();
128 files.push((table_name, path));
129 }
130 }
131 }
132 }
133 }
134
135 Ok(files)
136 }
137
138 fn merge_files<W: Write>(
139 &self,
140 files: Vec<(String, PathBuf)>,
141 mut writer: W,
142 ) -> anyhow::Result<MergeStats> {
143 let mut stats = MergeStats::default();
144
145 if self.config.add_header {
147 self.write_header(&mut writer, files.len())?;
148 }
149
150 if self.config.add_transaction {
152 let tx_start = self.transaction_start();
153 writer.write_all(tx_start.as_bytes())?;
154 stats.bytes_written += tx_start.len() as u64;
155 }
156
157 for (table_name, path) in &files {
159 let separator = format!(
161 "\n-- ============================================================\n-- Table: {}\n-- ============================================================\n\n",
162 table_name
163 );
164 writer.write_all(separator.as_bytes())?;
165 stats.bytes_written += separator.len() as u64;
166
167 let file = File::open(path)?;
169 let reader = BufReader::with_capacity(64 * 1024, file);
170
171 for line in reader.lines() {
172 let line = line?;
173 writer.write_all(line.as_bytes())?;
174 writer.write_all(b"\n")?;
175 stats.bytes_written += line.len() as u64 + 1;
176 }
177
178 stats.table_names.push(table_name.clone());
179 stats.tables_merged += 1;
180 }
181
182 if self.config.add_transaction {
184 let tx_end = "\nCOMMIT;\n";
185 writer.write_all(tx_end.as_bytes())?;
186 stats.bytes_written += tx_end.len() as u64;
187 }
188
189 if self.config.add_header {
191 self.write_footer(&mut writer)?;
192 }
193
194 writer.flush()?;
195
196 Ok(stats)
197 }
198
199 fn write_header<W: Write>(&self, w: &mut W, table_count: usize) -> io::Result<()> {
200 writeln!(w, "-- SQL Merge Output")?;
201 writeln!(w, "-- Generated by sql-splitter")?;
202 writeln!(w, "-- Tables: {}", table_count)?;
203 writeln!(w, "-- Dialect: {}", self.config.dialect)?;
204 writeln!(w)?;
205
206 match self.config.dialect {
207 SqlDialect::MySql => {
208 writeln!(w, "SET NAMES utf8mb4;")?;
209 writeln!(w, "SET FOREIGN_KEY_CHECKS = 0;")?;
210 }
211 SqlDialect::Postgres => {
212 writeln!(w, "SET client_encoding = 'UTF8';")?;
213 }
214 SqlDialect::Sqlite => {
215 writeln!(w, "PRAGMA foreign_keys = OFF;")?;
216 }
217 SqlDialect::Mssql => {
218 writeln!(w, "SET ANSI_NULLS ON;")?;
219 writeln!(w, "SET QUOTED_IDENTIFIER ON;")?;
220 writeln!(w, "SET NOCOUNT ON;")?;
221 }
222 }
223 writeln!(w)?;
224
225 Ok(())
226 }
227
228 fn write_footer<W: Write>(&self, w: &mut W) -> io::Result<()> {
229 writeln!(w)?;
230 match self.config.dialect {
231 SqlDialect::MySql => {
232 writeln!(w, "SET FOREIGN_KEY_CHECKS = 1;")?;
233 }
234 SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {}
235 }
236 Ok(())
237 }
238
239 fn transaction_start(&self) -> &'static str {
240 match self.config.dialect {
241 SqlDialect::MySql => "START TRANSACTION;\n\n",
242 SqlDialect::Postgres => "BEGIN;\n\n",
243 SqlDialect::Sqlite | SqlDialect::Mssql => "BEGIN TRANSACTION;\n\n",
244 }
245 }
246}