sql_splitter/merger/
mod.rs1use crate::parser::SqlDialect;
4use std::collections::HashSet;
5use std::fs::{self, File};
6use std::io::{self, BufRead, BufReader, BufWriter, Write};
7use std::path::PathBuf;
8
9#[derive(Debug, Default)]
11pub struct MergeStats {
12 pub tables_merged: usize,
14 pub bytes_written: u64,
16 pub table_names: Vec<String>,
18}
19
20#[derive(Default)]
22pub struct MergerConfig {
23 pub dialect: SqlDialect,
24 pub tables: Option<HashSet<String>>,
25 pub exclude: HashSet<String>,
26 pub add_transaction: bool,
27 pub add_header: bool,
28}
29
30pub struct Merger {
32 input_dir: PathBuf,
33 output: Option<PathBuf>,
34 config: MergerConfig,
35}
36
37impl Merger {
38 pub fn new(input_dir: PathBuf, output: Option<PathBuf>) -> Self {
39 Self {
40 input_dir,
41 output,
42 config: MergerConfig::default(),
43 }
44 }
45
46 pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
47 self.config.dialect = dialect;
48 self
49 }
50
51 pub fn with_tables(mut self, tables: HashSet<String>) -> Self {
52 self.config.tables = Some(tables);
53 self
54 }
55
56 pub fn with_exclude(mut self, exclude: HashSet<String>) -> Self {
57 self.config.exclude = exclude;
58 self
59 }
60
61 pub fn with_transaction(mut self, add_transaction: bool) -> Self {
62 self.config.add_transaction = add_transaction;
63 self
64 }
65
66 pub fn with_header(mut self, add_header: bool) -> Self {
67 self.config.add_header = add_header;
68 self
69 }
70
71 pub fn merge(&self) -> anyhow::Result<MergeStats> {
73 let sql_files = self.discover_sql_files()?;
75 if sql_files.is_empty() {
76 anyhow::bail!(
77 "no .sql files found in directory: {}",
78 self.input_dir.display()
79 );
80 }
81
82 let filtered_files: Vec<(String, PathBuf)> = sql_files
84 .into_iter()
85 .filter(|(name, _)| {
86 let name_lower = name.to_lowercase();
87 if let Some(ref include) = self.config.tables {
88 if !include.contains(&name_lower) {
89 return false;
90 }
91 }
92 !self.config.exclude.contains(&name_lower)
93 })
94 .collect();
95
96 if filtered_files.is_empty() {
97 anyhow::bail!("no tables remaining after filtering");
98 }
99
100 let mut sorted_files = filtered_files;
102 sorted_files.sort_by(|a, b| a.0.cmp(&b.0));
103
104 if let Some(ref out_path) = self.output {
106 if let Some(parent) = out_path.parent() {
107 fs::create_dir_all(parent)?;
108 }
109 let file = File::create(out_path)?;
110 let writer = BufWriter::with_capacity(256 * 1024, file);
111 self.merge_files(sorted_files, writer)
112 } else {
113 let stdout = io::stdout();
114 let writer = BufWriter::new(stdout.lock());
115 self.merge_files(sorted_files, writer)
116 }
117 }
118
119 fn discover_sql_files(&self) -> anyhow::Result<Vec<(String, PathBuf)>> {
120 let mut files = Vec::new();
121
122 for entry in fs::read_dir(&self.input_dir)? {
123 let entry = entry?;
124 let path = entry.path();
125
126 if path.is_file() {
127 if let Some(ext) = path.extension() {
128 if ext.eq_ignore_ascii_case("sql") {
129 if let Some(stem) = path.file_stem() {
130 let table_name = stem.to_string_lossy().to_string();
131 files.push((table_name, path));
132 }
133 }
134 }
135 }
136 }
137
138 Ok(files)
139 }
140
141 fn merge_files<W: Write>(
142 &self,
143 files: Vec<(String, PathBuf)>,
144 mut writer: W,
145 ) -> anyhow::Result<MergeStats> {
146 let mut stats = MergeStats::default();
147
148 if self.config.add_header {
150 self.write_header(&mut writer, files.len())?;
151 }
152
153 if self.config.add_transaction {
155 let tx_start = self.transaction_start();
156 writer.write_all(tx_start.as_bytes())?;
157 stats.bytes_written += tx_start.len() as u64;
158 }
159
160 for (table_name, path) in &files {
162 let separator = format!(
164 "\n-- ============================================================\n-- Table: {}\n-- ============================================================\n\n",
165 table_name
166 );
167 writer.write_all(separator.as_bytes())?;
168 stats.bytes_written += separator.len() as u64;
169
170 let file = File::open(path)?;
172 let reader = BufReader::with_capacity(64 * 1024, file);
173
174 for line in reader.lines() {
175 let line = line?;
176 writer.write_all(line.as_bytes())?;
177 writer.write_all(b"\n")?;
178 stats.bytes_written += line.len() as u64 + 1;
179 }
180
181 stats.table_names.push(table_name.clone());
182 stats.tables_merged += 1;
183 }
184
185 if self.config.add_transaction {
187 let tx_end = "\nCOMMIT;\n";
188 writer.write_all(tx_end.as_bytes())?;
189 stats.bytes_written += tx_end.len() as u64;
190 }
191
192 if self.config.add_header {
194 self.write_footer(&mut writer)?;
195 }
196
197 writer.flush()?;
198
199 Ok(stats)
200 }
201
202 fn write_header<W: Write>(&self, w: &mut W, table_count: usize) -> io::Result<()> {
203 writeln!(w, "-- SQL Merge Output")?;
204 writeln!(w, "-- Generated by sql-splitter")?;
205 writeln!(w, "-- Tables: {}", table_count)?;
206 writeln!(w, "-- Dialect: {}", self.config.dialect)?;
207 writeln!(w)?;
208
209 match self.config.dialect {
210 SqlDialect::MySql => {
211 writeln!(w, "SET NAMES utf8mb4;")?;
212 writeln!(w, "SET FOREIGN_KEY_CHECKS = 0;")?;
213 }
214 SqlDialect::Postgres => {
215 writeln!(w, "SET client_encoding = 'UTF8';")?;
216 }
217 SqlDialect::Sqlite => {
218 writeln!(w, "PRAGMA foreign_keys = OFF;")?;
219 }
220 SqlDialect::Mssql => {
221 writeln!(w, "SET ANSI_NULLS ON;")?;
222 writeln!(w, "SET QUOTED_IDENTIFIER ON;")?;
223 writeln!(w, "SET NOCOUNT ON;")?;
224 }
225 }
226 writeln!(w)?;
227
228 Ok(())
229 }
230
231 fn write_footer<W: Write>(&self, w: &mut W) -> io::Result<()> {
232 writeln!(w)?;
233 match self.config.dialect {
234 SqlDialect::MySql => {
235 writeln!(w, "SET FOREIGN_KEY_CHECKS = 1;")?;
236 }
237 SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Mssql => {}
238 }
239 Ok(())
240 }
241
242 fn transaction_start(&self) -> &'static str {
243 match self.config.dialect {
244 SqlDialect::MySql => "START TRANSACTION;\n\n",
245 SqlDialect::Postgres => "BEGIN;\n\n",
246 SqlDialect::Sqlite | SqlDialect::Mssql => "BEGIN TRANSACTION;\n\n",
247 }
248 }
249}