1use super::batch::{flush_batch, BatchManager, MAX_ROWS_PER_BATCH};
4use super::types::TypeConverter;
5use super::{ImportStats, QueryConfig};
6use crate::convert::copy_to_insert::{copy_to_inserts, parse_copy_header, CopyHeader};
7use crate::parser::{
8 detect_dialect_from_file, parse_insert_for_bulk, Parser, SqlDialect, StatementType,
9};
10use crate::progress::ProgressReader;
11use crate::splitter::Compression;
12use anyhow::{Context, Result};
13use duckdb::Connection;
14use indicatif::{ProgressBar, ProgressStyle};
15use once_cell::sync::Lazy;
16use regex::Regex;
17use std::fs::File;
18use std::io::{BufRead, BufReader, Read};
19use std::path::Path;
20
21const MAX_COPY_ROWS_PER_BATCH: usize = 10_000;
24
25pub struct DumpLoader<'a> {
27 conn: &'a Connection,
28 config: &'a QueryConfig,
29}
30
31impl<'a> DumpLoader<'a> {
32 pub fn new(conn: &'a Connection, config: &'a QueryConfig) -> Self {
34 Self { conn, config }
35 }
36
37 pub fn load(&self, dump_path: &Path) -> Result<ImportStats> {
39 let start = std::time::Instant::now();
40 let mut stats = ImportStats::default();
41
42 let dialect = if let Some(d) = self.config.dialect {
44 d
45 } else {
46 let result = detect_dialect_from_file(dump_path)?;
47 result.dialect
48 };
49
50 let file_size = std::fs::metadata(dump_path)?.len();
52
53 let progress_bar = if self.config.progress {
55 let pb = ProgressBar::new(file_size);
56 pb.set_style(
57 ProgressStyle::default_bar()
58 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)")
59 .unwrap()
60 .progress_chars("=>-"),
61 );
62 Some(pb)
63 } else {
64 None
65 };
66
67 let file = File::open(dump_path).context("Failed to open dump file")?;
69 let compression = Compression::from_path(dump_path);
70 let reader: Box<dyn Read> = match compression {
71 Compression::Gzip => Box::new(flate2::read::GzDecoder::new(file)),
72 Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(file)),
73 Compression::Xz => Box::new(xz2::read::XzDecoder::new(file)),
74 Compression::Zstd => Box::new(zstd::stream::Decoder::new(file)?),
75 Compression::None => Box::new(file),
76 };
77
78 let reader: Box<dyn Read> = if let Some(ref pb) = progress_bar {
79 let pb_clone = pb.clone();
80 Box::new(ProgressReader::new(reader, move |bytes| {
81 pb_clone.set_position(bytes);
82 }))
83 } else {
84 reader
85 };
86
87 let buf_reader = BufReader::with_capacity(256 * 1024, reader);
88
89 self.load_statements(buf_reader, dialect, &mut stats)?;
91
92 if let Some(pb) = progress_bar {
93 pb.finish_with_message("Import complete");
94 }
95
96 stats.duration_secs = start.elapsed().as_secs_f64();
97 Ok(stats)
98 }
99
100 fn load_statements<R: Read>(
102 &self,
103 reader: BufReader<R>,
104 dialect: SqlDialect,
105 stats: &mut ImportStats,
106 ) -> Result<()> {
107 let mut parser = StatementReader::new(reader, dialect);
108 let mut pending_copy: Option<CopyHeader> = None;
109
110 let mut copy_batch_data: Vec<u8> = Vec::new();
112 let mut copy_batch_rows: usize = 0;
113
114 let mut failed_tables: std::collections::HashSet<String> = std::collections::HashSet::new();
116
117 let mut batch_mgr = BatchManager::new(MAX_ROWS_PER_BATCH);
119
120 while let Some(stmt_result) = parser.next_statement() {
121 let stmt = stmt_result?;
122
123 if let Some(ref header) = pending_copy {
125 let trimmed = stmt.trim();
126
127 if trimmed == "\\." {
129 if !copy_batch_data.is_empty() {
131 self.process_copy_batch(
132 header,
133 ©_batch_data,
134 stats,
135 &mut failed_tables,
136 );
137 copy_batch_data.clear();
138 copy_batch_rows = 0;
139 }
140 pending_copy = None;
141 parser.set_copy_mode(false);
142 continue;
143 }
144
145 if trimmed.is_empty() {
147 continue;
148 }
149
150 if Self::looks_like_copy_data(&stmt) {
152 if failed_tables.contains(&header.table) {
154 continue;
155 }
156
157 copy_batch_data.extend_from_slice(stmt.as_bytes());
158 copy_batch_data.push(b'\n');
159 copy_batch_rows += 1;
160
161 if copy_batch_rows >= MAX_COPY_ROWS_PER_BATCH {
163 self.process_copy_batch(
164 header,
165 ©_batch_data,
166 stats,
167 &mut failed_tables,
168 );
169 copy_batch_data.clear();
170 copy_batch_rows = 0;
171 }
172 continue;
173 }
174
175 if !copy_batch_data.is_empty() {
177 self.process_copy_batch(header, ©_batch_data, stats, &mut failed_tables);
178 copy_batch_data.clear();
179 copy_batch_rows = 0;
180 }
181 pending_copy = None;
182 parser.set_copy_mode(false);
183 }
185
186 let (mut stmt_type, table_name) =
187 Parser::<&[u8]>::parse_statement_with_dialect(stmt.as_bytes(), dialect);
188
189 if dialect == SqlDialect::Postgres && stmt_type == StatementType::Unknown {
191 let upper = stmt.to_uppercase();
192 if let Some(copy_pos) = upper.find("COPY ") {
193 let after_copy = &upper[copy_pos..];
194 if after_copy.contains("FROM STDIN") {
195 stmt_type = StatementType::Copy;
196 }
197 }
198 }
199
200 if let Some(ref tables) = self.config.tables {
202 if !table_name.is_empty()
203 && !tables.iter().any(|t| t.eq_ignore_ascii_case(&table_name))
204 {
205 continue;
206 }
207 }
208
209 match stmt_type {
210 StatementType::CreateTable => {
211 let duckdb_sql = self.convert_create_table(&stmt, dialect)?;
212 match self.conn.execute(&duckdb_sql, []) {
213 Ok(_) => stats.tables_created += 1,
214 Err(e) => {
215 stats
216 .warnings
217 .push(format!("Failed to create table {}: {}", table_name, e));
218 stats.statements_skipped += 1;
219 }
220 }
221 }
222 StatementType::Insert => {
223 if !self.try_queue_for_bulk(
225 &stmt,
226 dialect,
227 &mut batch_mgr,
228 stats,
229 &mut failed_tables,
230 ) {
231 let duckdb_sql = self.convert_insert(&stmt, dialect)?;
233 match self.conn.execute(&duckdb_sql, []) {
234 Ok(_) => {
235 stats.insert_statements += 1;
236 stats.rows_inserted += Self::count_insert_rows(&duckdb_sql);
237 }
238 Err(e) => {
239 stats
240 .warnings
241 .push(format!("Failed INSERT for {}: {}", table_name, e));
242 stats.statements_skipped += 1;
243 }
244 }
245 }
246
247 for mut batch in batch_mgr.get_ready_batches() {
249 flush_batch(self.conn, &mut batch, stats, &mut failed_tables)?;
250 }
251 }
252 StatementType::Copy => {
253 if let Some(header) = parse_copy_header(&stmt) {
255 if failed_tables.contains(&header.table) {
257 parser.set_copy_mode(true);
259 Self::skip_copy_block(&mut parser);
260 parser.set_copy_mode(false);
261 continue;
262 }
263
264 if !self.table_exists(&header.table) {
266 failed_tables.insert(header.table.clone());
267 if stats.warnings.len() < 100 {
268 stats.warnings.push(format!(
269 "Skipping COPY for non-existent table {}",
270 header.table
271 ));
272 }
273 parser.set_copy_mode(true);
275 Self::skip_copy_block(&mut parser);
276 parser.set_copy_mode(false);
277 continue;
278 }
279
280 copy_batch_data.clear();
281 copy_batch_rows = 0;
282 pending_copy = Some(header);
283 parser.set_copy_mode(true);
284 }
285 }
286 StatementType::CreateIndex => {
287 stats.statements_skipped += 1;
289 }
290 _ => {
291 stats.statements_skipped += 1;
293 }
294 }
295 }
296
297 if let Some(ref header) = pending_copy {
299 if !copy_batch_data.is_empty() {
300 self.process_copy_batch(header, ©_batch_data, stats, &mut failed_tables);
301 }
302 }
303
304 for mut batch in batch_mgr.drain_all() {
306 flush_batch(self.conn, &mut batch, stats, &mut failed_tables)?;
307 }
308
309 Ok(())
310 }
311
312 fn process_copy_batch(
314 &self,
315 header: &CopyHeader,
316 batch_data: &[u8],
317 stats: &mut ImportStats,
318 failed_tables: &mut std::collections::HashSet<String>,
319 ) {
320 if batch_data.is_empty() {
321 return;
322 }
323
324 if failed_tables.contains(&header.table) {
326 return;
327 }
328
329 let inserts = copy_to_inserts(header, batch_data, SqlDialect::Postgres);
330 for insert in inserts {
331 let insert_sql = String::from_utf8_lossy(&insert);
332 match self.conn.execute(&insert_sql, []) {
333 Ok(_) => {
334 stats.rows_inserted += Self::count_insert_rows(&insert_sql);
335 }
336 Err(e) => {
337 let err_str = e.to_string();
338 if err_str.contains("does not exist") {
340 failed_tables.insert(header.table.clone());
341 if stats.warnings.len() < 100 {
342 stats.warnings.push(format!(
343 "Table {} does not exist, skipping COPY data",
344 header.table
345 ));
346 }
347 return; }
349 if stats.warnings.len() < 100 {
351 stats.warnings.push(format!(
352 "Failed to insert COPY data for {}: {}",
353 header.table, e
354 ));
355 }
356 stats.statements_skipped += 1;
357 }
358 }
359 }
360 }
361
362 fn table_exists(&self, table: &str) -> bool {
364 let query = "SELECT 1 FROM information_schema.tables WHERE table_name = ? LIMIT 1";
365 match self.conn.prepare(query) {
366 Ok(mut stmt) => stmt.exists([table]).unwrap_or(false),
367 Err(_) => false,
368 }
369 }
370
371 fn try_queue_for_bulk(
374 &self,
375 stmt: &str,
376 dialect: SqlDialect,
377 batch_mgr: &mut BatchManager,
378 stats: &mut ImportStats,
379 failed_tables: &mut std::collections::HashSet<String>,
380 ) -> bool {
381 let upper = stmt.to_uppercase();
383 if upper.contains("ON DUPLICATE KEY")
384 || upper.contains("ON CONFLICT")
385 || upper.contains("REPLACE")
386 || upper.contains("IGNORE")
387 || upper.contains("RETURNING")
388 || upper.contains("SELECT")
389 {
390 return false;
391 }
392
393 let parsed = match parse_insert_for_bulk(stmt.as_bytes()) {
395 Ok(p) => p,
396 Err(_) => return false, };
398
399 if failed_tables.contains(&parsed.table) {
401 return true; }
403
404 if parsed.rows.is_empty() {
406 return false;
407 }
408
409 let duckdb_sql = match self.convert_insert(stmt, dialect) {
411 Ok(sql) => sql,
412 Err(_) => return false,
413 };
414
415 if let Some(mut batch) =
417 batch_mgr.queue_insert(&parsed.table, parsed.columns, parsed.rows, duckdb_sql)
418 {
419 if let Err(e) = flush_batch(self.conn, &mut batch, stats, failed_tables) {
421 if stats.warnings.len() < 100 {
422 stats.warnings.push(format!("Batch flush error: {}", e));
423 }
424 }
425 }
426
427 true
428 }
429
430 fn skip_copy_block<R: Read>(parser: &mut StatementReader<R>) {
432 while let Some(Ok(line)) = parser.next_statement() {
433 if line.trim() == "\\." {
434 break;
435 }
436 }
437 }
438
439 fn looks_like_copy_data(line: &str) -> bool {
441 let trimmed = line.trim();
442
443 if trimmed.is_empty() {
445 return false;
446 }
447
448 if trimmed == "\\." {
450 return false;
451 }
452
453 let first_char = trimmed.chars().next().unwrap_or(' ');
456
457 if matches!(
459 first_char,
460 'S' | 's'
461 | 'I'
462 | 'i'
463 | 'C'
464 | 'c'
465 | 'D'
466 | 'd'
467 | 'A'
468 | 'a'
469 | 'U'
470 | 'u'
471 | 'G'
472 | 'g'
473 | '-'
474 | '/'
475 ) {
476 let upper_prefix: String = trimmed.chars().take(7).collect::<String>().to_uppercase();
477 if upper_prefix.starts_with("SELECT")
478 || upper_prefix.starts_with("INSERT")
479 || upper_prefix.starts_with("CREATE")
480 || upper_prefix.starts_with("DROP")
481 || upper_prefix.starts_with("ALTER")
482 || upper_prefix.starts_with("UPDATE")
483 || upper_prefix.starts_with("GRANT")
484 || upper_prefix.starts_with("--")
485 || upper_prefix.starts_with("/*")
486 {
487 return false;
488 }
489 }
490
491 true
494 }
495
496 fn convert_create_table(&self, stmt: &str, dialect: SqlDialect) -> Result<String> {
498 let mut result = stmt.to_string();
499
500 result = Self::convert_identifiers(&result, dialect);
502
503 result = Self::strip_mysql_clauses(&result);
506
507 result = Self::convert_types_in_statement(&result);
509
510 if dialect == SqlDialect::Postgres {
512 result = Self::strip_postgres_syntax(&result);
513 }
514
515 if dialect == SqlDialect::Mssql {
517 result = Self::strip_mssql_syntax(&result);
518 }
519
520 Ok(result)
521 }
522
523 fn convert_insert(&self, stmt: &str, dialect: SqlDialect) -> Result<String> {
525 let mut result = stmt.to_string();
526
527 result = Self::convert_identifiers(&result, dialect);
529
530 if dialect == SqlDialect::MySql {
532 result = Self::convert_mysql_escapes(&result);
533 }
534
535 if dialect == SqlDialect::Postgres {
537 result = Self::strip_schema_prefix(&result);
538 }
539
540 if dialect == SqlDialect::Mssql {
542 result = Self::strip_mssql_schema_prefix(&result);
543 }
544
545 Ok(result)
546 }
547
548 fn convert_mysql_escapes(stmt: &str) -> String {
550 let mut result = String::with_capacity(stmt.len() + 100);
551 let mut chars = stmt.chars().peekable();
552 let mut in_string = false;
553
554 while let Some(c) = chars.next() {
555 if c == '\'' {
556 in_string = !in_string;
557 result.push(c);
558 } else if c == '\\' && in_string {
559 match chars.peek() {
561 Some('\'') => {
562 chars.next();
564 result.push_str("''");
565 }
566 Some('\\') => {
567 chars.next();
569 result.push('\\');
570 }
571 Some('n') => {
572 chars.next();
574 result.push('\n');
575 }
576 Some('r') => {
577 chars.next();
579 result.push('\r');
580 }
581 Some('t') => {
582 chars.next();
584 result.push('\t');
585 }
586 Some('0') => {
587 chars.next();
589 }
590 Some('"') => {
591 chars.next();
593 result.push('"');
594 }
595 _ => {
596 result.push(c);
598 }
599 }
600 } else {
601 result.push(c);
602 }
603 }
604 result
605 }
606
607 fn convert_identifiers(stmt: &str, dialect: SqlDialect) -> String {
609 match dialect {
610 SqlDialect::MySql => {
611 let mut result = String::with_capacity(stmt.len());
613 let mut in_string = false;
614 let mut in_backtick = false;
615
616 for c in stmt.chars() {
617 if c == '\'' && !in_backtick {
618 in_string = !in_string;
619 result.push(c);
620 } else if c == '`' && !in_string {
621 in_backtick = !in_backtick;
622 result.push('"');
623 } else {
624 result.push(c);
625 }
626 }
627 result
628 }
629 SqlDialect::Mssql => {
630 let mut result = String::with_capacity(stmt.len());
632 let mut in_string = false;
633 let mut in_bracket = false;
634 let mut chars = stmt.chars().peekable();
635
636 while let Some(c) = chars.next() {
637 if c == '\'' && !in_bracket {
638 in_string = !in_string;
639 result.push(c);
640 } else if c == '[' && !in_string {
641 in_bracket = true;
642 result.push('"');
643 } else if c == ']' && !in_string {
644 if chars.peek() == Some(&']') {
646 chars.next();
647 result.push(']');
648 } else {
649 in_bracket = false;
650 result.push('"');
651 }
652 } else if c == 'N' && !in_string && !in_bracket && chars.peek() == Some(&'\'') {
653 } else {
656 result.push(c);
657 }
658 }
659 result
660 }
661 _ => stmt.to_string(),
662 }
663 }
664
665 fn convert_types_in_statement(stmt: &str) -> String {
667 static RE_COLUMN_TYPE: Lazy<Regex> = Lazy::new(|| {
674 Regex::new(r#"(?i)(["'`\]\s])\s*(BIGSERIAL|SMALLSERIAL|SERIAL|BIGINT|SMALLINT|MEDIUMINT|TINYINT|INTEGER|INT|DOUBLE\s+PRECISION|DOUBLE|FLOAT|DECIMAL|NUMERIC|CHARACTER\s+VARYING|NVARCHAR|NCHAR|VARCHAR|CHAR|VARBINARY|BINARY|LONGTEXT|MEDIUMTEXT|TINYTEXT|NTEXT|TEXT|LONGBLOB|MEDIUMBLOB|TINYBLOB|IMAGE|BLOB|DATETIME2|DATETIMEOFFSET|SMALLDATETIME|DATETIME|TIMESTAMPTZ|TIMESTAMP|TIMETZ|TIME|DATE|YEAR|ENUM|SET|JSONB|JSON|UUID|UNIQUEIDENTIFIER|BYTEA|BOOLEAN|BOOL|BIT|REAL|MONEY|SMALLMONEY|INTERVAL|ROWVERSION|XML|SQL_VARIANT)\b(\s*\([^)]+\))?(\s+(?:UNSIGNED|WITH(?:OUT)?\s+TIME\s+ZONE))?"#).unwrap()
675 });
676
677 RE_COLUMN_TYPE
678 .replace_all(stmt, |caps: ®ex::Captures| {
679 let full_match = caps.get(0).unwrap().as_str();
680 let leading_char = caps.get(1).unwrap().as_str();
681 let type_part = caps.get(2).unwrap().as_str();
682 let size_part = caps.get(3).map(|m| m.as_str()).unwrap_or("");
683 let suffix = caps.get(4).map(|m| m.as_str()).unwrap_or("");
684
685 let end_pos = caps.get(0).unwrap().end();
688 let stmt_bytes = stmt.as_bytes();
689 if end_pos < stmt_bytes.len() {
690 let next_char = stmt_bytes[end_pos] as char;
691 if next_char == '"' || next_char == '\'' || next_char == '`' {
693 return full_match.to_string();
694 }
695 }
696
697 let ws_len = full_match.len() - leading_char.len() - type_part.len() - size_part.len() - suffix.len();
699 let ws = &full_match[leading_char.len()..leading_char.len() + ws_len];
700
701 let converted = TypeConverter::convert(&format!("{}{}{}", type_part, size_part, suffix));
702 format!("{}{}{}", leading_char, ws, converted)
703 })
704 .to_string()
705 }
706
707 fn strip_mysql_clauses(stmt: &str) -> String {
709 let mut result = stmt.to_string();
710
711 static RE_ENGINE: Lazy<Regex> =
713 Lazy::new(|| Regex::new(r"(?i)\s*ENGINE\s*=\s*\w+").unwrap());
714 result = RE_ENGINE.replace_all(&result, "").to_string();
715
716 static RE_AUTO_INC: Lazy<Regex> =
718 Lazy::new(|| Regex::new(r"(?i)\s*AUTO_INCREMENT\s*=\s*\d+").unwrap());
719 result = RE_AUTO_INC.replace_all(&result, "").to_string();
720
721 result = result.replace(" AUTO_INCREMENT", "");
723 result = result.replace(" auto_increment", "");
724
725 static RE_CHAR_SET: Lazy<Regex> =
727 Lazy::new(|| Regex::new(r"(?i)\s*CHARACTER\s+SET\s+\w+").unwrap());
728 result = RE_CHAR_SET.replace_all(&result, "").to_string();
729
730 static RE_CHARSET: Lazy<Regex> =
732 Lazy::new(|| Regex::new(r"(?i)\s*(DEFAULT\s+)?CHARSET\s*=\s*\w+").unwrap());
733 result = RE_CHARSET.replace_all(&result, "").to_string();
734
735 static RE_COLLATE: Lazy<Regex> =
737 Lazy::new(|| Regex::new(r"(?i)\s*COLLATE\s*=?\s*\w+").unwrap());
738 result = RE_COLLATE.replace_all(&result, "").to_string();
739
740 static RE_ROW_FORMAT: Lazy<Regex> =
742 Lazy::new(|| Regex::new(r"(?i)\s*ROW_FORMAT\s*=\s*\w+").unwrap());
743 result = RE_ROW_FORMAT.replace_all(&result, "").to_string();
744
745 static RE_KEY_BLOCK: Lazy<Regex> =
747 Lazy::new(|| Regex::new(r"(?i)\s*KEY_BLOCK_SIZE\s*=\s*\d+").unwrap());
748 result = RE_KEY_BLOCK.replace_all(&result, "").to_string();
749
750 static RE_COMMENT: Lazy<Regex> =
752 Lazy::new(|| Regex::new(r"(?i)\s*COMMENT\s*=?\s*'[^']*'").unwrap());
753 result = RE_COMMENT.replace_all(&result, "").to_string();
754
755 static RE_COND_COMMENT: Lazy<Regex> = Lazy::new(|| Regex::new(r"/\*!\d+\s*|\*/").unwrap());
757 result = RE_COND_COMMENT.replace_all(&result, "").to_string();
758
759 static RE_ON_UPDATE: Lazy<Regex> =
761 Lazy::new(|| Regex::new(r"(?i)\s*ON\s+UPDATE\s+CURRENT_TIMESTAMP").unwrap());
762 result = RE_ON_UPDATE.replace_all(&result, "").to_string();
763
764 static RE_UNIQUE_KEY: Lazy<Regex> = Lazy::new(|| {
767 Regex::new(r#"(?i),?\s*UNIQUE\s+KEY\s+[`"']?\w+[`"']?\s*\([^)]+\)"#).unwrap()
768 });
769 result = RE_UNIQUE_KEY.replace_all(&result, "").to_string();
770
771 static RE_KEY_INDEX: Lazy<Regex> = Lazy::new(|| {
775 Regex::new(r#"(?i)(?:,\s*|\n\s*)(?:FULLTEXT\s+|SPATIAL\s+)?KEY\s+[`"']?\w+[`"']?\s*\([^)]+\)"#).unwrap()
776 });
777 result = RE_KEY_INDEX.replace_all(&result, "").to_string();
778
779 static RE_GENERATED_COL: Lazy<Regex> = Lazy::new(|| {
783 Regex::new(r#"(?i),?\s*[`"']?\w+[`"']?\s+\w+\s+GENERATED\s+ALWAYS\s+AS\s*\((?:[^()]+|\([^()]*\))+\)\s*(?:STORED|VIRTUAL)?"#).unwrap()
784 });
785 result = RE_GENERATED_COL.replace_all(&result, "").to_string();
786
787 static RE_FK_CONSTRAINT: Lazy<Regex> = Lazy::new(|| {
791 Regex::new(r#"(?i),?\s*(?:CONSTRAINT\s+[`"']?\w+[`"']?\s+)?FOREIGN\s+KEY\s*\([^)]+\)\s*REFERENCES\s+[`"']?\w+[`"']?\s*\([^)]+\)(?:\s+ON\s+(?:DELETE|UPDATE)\s+(?:CASCADE|SET\s+NULL|SET\s+DEFAULT|NO\s+ACTION|RESTRICT))*"#).unwrap()
792 });
793 result = RE_FK_CONSTRAINT.replace_all(&result, "").to_string();
794
795 result
796 }
797
798 fn strip_postgres_syntax(stmt: &str) -> String {
800 let mut result = stmt.to_string();
801
802 result = Self::strip_schema_prefix(&result);
804
805 static RE_CAST: Lazy<Regex> = Lazy::new(|| {
807 Regex::new(r"::[a-zA-Z_][a-zA-Z0-9_]*(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)*").unwrap()
808 });
809 result = RE_CAST.replace_all(&result, "").to_string();
810
811 static RE_NEXTVAL: Lazy<Regex> =
813 Lazy::new(|| Regex::new(r"(?i)\s*DEFAULT\s+nextval\s*\([^)]+\)").unwrap());
814 result = RE_NEXTVAL.replace_all(&result, "").to_string();
815
816 static RE_NOW: Lazy<Regex> =
818 Lazy::new(|| Regex::new(r"(?i)\bDEFAULT\s+now\s*\(\s*\)").unwrap());
819 result = RE_NOW
820 .replace_all(&result, "DEFAULT CURRENT_TIMESTAMP")
821 .to_string();
822
823 static RE_INHERITS: Lazy<Regex> =
825 Lazy::new(|| Regex::new(r"(?i)\s*INHERITS\s*\([^)]+\)").unwrap());
826 result = RE_INHERITS.replace_all(&result, "").to_string();
827
828 static RE_WITH: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)\s*WITH\s*\([^)]+\)").unwrap());
830 result = RE_WITH.replace_all(&result, "").to_string();
831
832 result
833 }
834
835 fn strip_schema_prefix(stmt: &str) -> String {
837 static RE_SCHEMA: Lazy<Regex> =
838 Lazy::new(|| Regex::new(r#"(?i)\b(public|pg_catalog|pg_temp)\s*\.\s*"#).unwrap());
839 RE_SCHEMA.replace_all(stmt, "").to_string()
840 }
841
842 fn strip_mssql_schema_prefix(stmt: &str) -> String {
844 static RE_SCHEMA: Lazy<Regex> =
846 Lazy::new(|| Regex::new(r#"(?i)"?(dbo|master|tempdb|model|msdb)"?\s*\.\s*"#).unwrap());
847 RE_SCHEMA.replace_all(stmt, "").to_string()
848 }
849
850 fn strip_mssql_syntax(stmt: &str) -> String {
852 let mut result = Self::strip_mssql_schema_prefix(stmt);
853
854 static RE_IDENTITY_NOT_NULL: Lazy<Regex> = Lazy::new(|| {
857 Regex::new(r"(?i)\s*IDENTITY\s*\(\s*\d+\s*,\s*\d+\s*\)\s*NOT\s+NULL").unwrap()
858 });
859 result = RE_IDENTITY_NOT_NULL.replace_all(&result, "").to_string();
860
861 static RE_IDENTITY: Lazy<Regex> =
863 Lazy::new(|| Regex::new(r"(?i)\s*IDENTITY\s*\(\s*\d+\s*,\s*\d+\s*\)").unwrap());
864 result = RE_IDENTITY.replace_all(&result, "").to_string();
865
866 static RE_CLUSTERED: Lazy<Regex> =
868 Lazy::new(|| Regex::new(r"(?i)\s*(?:NON)?CLUSTERED\s*").unwrap());
869 result = RE_CLUSTERED.replace_all(&result, " ").to_string();
870
871 static RE_FILEGROUP: Lazy<Regex> =
873 Lazy::new(|| Regex::new(r#"(?i)\s*ON\s*"?PRIMARY"?"#).unwrap());
874 result = RE_FILEGROUP.replace_all(&result, "").to_string();
875
876 static RE_PK_CONSTRAINT: Lazy<Regex> = Lazy::new(|| {
878 Regex::new(r#"(?i),?\s*CONSTRAINT\s+"?\w+"?\s+PRIMARY\s+KEY\s+\([^)]+\)"#).unwrap()
879 });
880 result = RE_PK_CONSTRAINT.replace_all(&result, "").to_string();
881
882 static RE_FK_CONSTRAINT: Lazy<Regex> = Lazy::new(|| {
884 Regex::new(r#"(?i),?\s*CONSTRAINT\s+"?\w+"?\s+FOREIGN\s+KEY\s*\([^)]+\)\s*REFERENCES\s+[^\s(]+\s*\([^)]+\)"#).unwrap()
885 });
886 result = RE_FK_CONSTRAINT.replace_all(&result, "").to_string();
887
888 static RE_WITH: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)\s*WITH\s*\([^)]+\)").unwrap());
890 result = RE_WITH.replace_all(&result, "").to_string();
891
892 static RE_TEXTIMAGE: Lazy<Regex> =
894 Lazy::new(|| Regex::new(r#"(?i)\s*TEXTIMAGE_ON\s*"?\w+"?"#).unwrap());
895 result = RE_TEXTIMAGE.replace_all(&result, "").to_string();
896
897 static RE_GETDATE: Lazy<Regex> =
899 Lazy::new(|| Regex::new(r"(?i)\bGETDATE\s*\(\s*\)").unwrap());
900 result = RE_GETDATE
901 .replace_all(&result, "CURRENT_TIMESTAMP")
902 .to_string();
903
904 static RE_NEWID: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)\bNEWID\s*\(\s*\)").unwrap());
906 result = RE_NEWID
907 .replace_all(&result, "gen_random_uuid()")
908 .to_string();
909
910 result
911 }
912
913 fn count_insert_rows(sql: &str) -> u64 {
915 if let Some(values_pos) = sql.to_uppercase().find("VALUES") {
917 let after_values = &sql[values_pos + 6..];
918 let mut count = 0u64;
920 let mut depth = 0;
921 let mut in_string = false;
922 let mut prev_char = ' ';
923
924 for c in after_values.chars() {
925 if c == '\'' && prev_char != '\\' {
926 in_string = !in_string;
927 }
928 if !in_string {
929 if c == '(' {
930 if depth == 0 {
931 count += 1;
932 }
933 depth += 1;
934 } else if c == ')' {
935 depth -= 1;
936 }
937 }
938 prev_char = c;
939 }
940 count
941 } else {
942 1
943 }
944 }
945}
946
947struct StatementReader<R> {
949 reader: BufReader<R>,
950 dialect: SqlDialect,
951 buffer: String,
952 buffer_pos: usize,
954 eof: bool,
955 in_copy_mode: bool,
957}
958
959impl<R: Read> StatementReader<R> {
960 fn new(reader: BufReader<R>, dialect: SqlDialect) -> Self {
961 Self {
962 reader,
963 dialect,
964 buffer: String::new(),
965 buffer_pos: 0,
966 eof: false,
967 in_copy_mode: false,
968 }
969 }
970
971 fn compact_buffer(&mut self) {
974 if self.buffer_pos > 0 {
975 self.buffer.drain(..self.buffer_pos);
976 self.buffer_pos = 0;
977 }
978 }
979
980 fn remaining_buffer(&self) -> &str {
982 &self.buffer[self.buffer_pos..]
983 }
984
985 fn set_copy_mode(&mut self, enabled: bool) {
987 self.in_copy_mode = enabled;
988 }
989
990 fn strip_leading_sql_comments(s: &str) -> &str {
992 let mut result = s.trim();
993 loop {
994 if result.starts_with("--") {
996 if let Some(pos) = result.find('\n') {
997 result = result[pos + 1..].trim();
998 continue;
999 } else {
1000 return ""; }
1002 }
1003 if result.starts_with("/*") {
1005 if let Some(pos) = result.find("*/") {
1006 result = result[pos + 2..].trim();
1007 continue;
1008 } else {
1009 return ""; }
1011 }
1012 break;
1013 }
1014 result
1015 }
1016
1017 fn next_statement(&mut self) -> Option<Result<String>> {
1018 if self.eof && self.remaining_buffer().is_empty() {
1019 return None;
1020 }
1021
1022 loop {
1023 if self.in_copy_mode {
1025 if let Some(line) = self.extract_copy_line() {
1026 return Some(Ok(line));
1027 }
1028 } else {
1029 if let Some(stmt) = self.extract_statement() {
1031 return Some(Ok(stmt));
1033 }
1034 }
1035
1036 if self.buffer_pos > 64 * 1024 {
1039 self.compact_buffer();
1040 }
1041
1042 let mut line = String::new();
1044 match self.reader.read_line(&mut line) {
1045 Ok(0) => {
1046 self.eof = true;
1047 self.in_copy_mode = false; let remaining = self.remaining_buffer().trim();
1049 if !remaining.is_empty() {
1050 let stmt = remaining.to_string();
1051 self.buffer.clear();
1052 self.buffer_pos = 0;
1053 return Some(Ok(stmt));
1054 }
1055 return None;
1056 }
1057 Ok(_) => {
1058 self.buffer.push_str(&line);
1059 }
1060 Err(e) => return Some(Err(e.into())),
1061 }
1062 }
1063 }
1064
1065 fn extract_copy_line(&mut self) -> Option<String> {
1067 let remaining = self.remaining_buffer();
1068 if let Some(newline_pos) = remaining.find('\n') {
1069 let line = remaining[..newline_pos].to_string();
1070 self.buffer_pos += newline_pos + 1;
1071 Some(line)
1073 } else {
1074 None
1075 }
1076 }
1077
1078 fn extract_statement(&mut self) -> Option<String> {
1079 let remaining = self.remaining_buffer();
1080 let mut in_string = false;
1081 let mut in_dollar_quote = false;
1082 let mut in_bracket = false;
1083 let mut in_line_comment = false;
1084 let mut in_block_comment = false;
1085 let mut escape_next = false;
1086 let mut chars = remaining.char_indices().peekable();
1087 let mut end_pos = None;
1088
1089 if self.dialect == SqlDialect::Mssql {
1091 if let Some(go_pos) = self.find_go_separator() {
1092 let stmt = remaining[..go_pos].to_string();
1093 let after_go = &remaining[go_pos..];
1095 if let Some(line_end) = after_go.find('\n') {
1096 self.buffer_pos += go_pos + line_end + 1;
1097 } else {
1098 self.buffer_pos = self.buffer.len();
1099 }
1100
1101 let trimmed = stmt.trim();
1102 if trimmed.is_empty()
1103 || trimmed.starts_with("--")
1104 || (trimmed.starts_with("/*") && !trimmed.contains("/*!"))
1105 {
1106 return self.extract_statement();
1107 }
1108 return Some(stmt);
1109 }
1110 }
1111
1112 while let Some((i, c)) = chars.next() {
1113 if escape_next {
1114 escape_next = false;
1115 continue;
1116 }
1117
1118 if in_line_comment {
1120 if c == '\n' {
1121 in_line_comment = false;
1122 }
1123 continue;
1124 }
1125
1126 if in_block_comment {
1128 if c == '*' && chars.peek().map(|(_, c)| *c == '/').unwrap_or(false) {
1129 chars.next();
1130 in_block_comment = false;
1131 }
1132 continue;
1133 }
1134
1135 match c {
1136 '\\' if self.dialect == SqlDialect::MySql && in_string => {
1137 escape_next = true;
1138 }
1139 '\'' if !in_dollar_quote && !in_bracket => {
1140 in_string = !in_string;
1141 }
1142 '[' if self.dialect == SqlDialect::Mssql && !in_string => {
1143 in_bracket = true;
1144 }
1145 ']' if self.dialect == SqlDialect::Mssql && !in_string => {
1146 if chars.peek().map(|(_, c)| *c == ']').unwrap_or(false) {
1148 chars.next();
1149 } else {
1150 in_bracket = false;
1151 }
1152 }
1153 '$' if self.dialect == SqlDialect::Postgres && !in_string => {
1154 if chars.peek().map(|(_, c)| *c == '$').unwrap_or(false) {
1156 in_dollar_quote = !in_dollar_quote;
1157 chars.next();
1158 }
1159 }
1160 '-' if !in_string && !in_dollar_quote && !in_bracket => {
1161 if chars.peek().map(|(_, c)| *c == '-').unwrap_or(false) {
1163 chars.next();
1164 in_line_comment = true;
1165 }
1166 }
1167 '/' if !in_string && !in_dollar_quote && !in_bracket => {
1168 if chars.peek().map(|(_, c)| *c == '*').unwrap_or(false) {
1170 chars.next();
1171 in_block_comment = true;
1172 }
1173 }
1174 ';' if !in_string && !in_dollar_quote && !in_bracket => {
1175 end_pos = Some(i + 1);
1176 break;
1177 }
1178 _ => {}
1179 }
1180 }
1181
1182 if let Some(pos) = end_pos {
1183 let stmt = remaining[..pos].to_string();
1184 let after_stmt = &remaining[pos..];
1186 let trimmed_len = after_stmt.len() - after_stmt.trim_start().len();
1187 self.buffer_pos += pos + trimmed_len;
1188
1189 let trimmed = stmt.trim();
1191
1192 let stripped = Self::strip_leading_sql_comments(trimmed);
1194 if stripped.is_empty() {
1195 return self.extract_statement();
1196 }
1197
1198 let trimmed = stripped;
1200
1201 if self.dialect == SqlDialect::Postgres {
1204 let upper = trimmed.to_uppercase();
1205 if upper.ends_with("FROM STDIN;") && upper.contains("COPY ") {
1206 self.in_copy_mode = true;
1207 }
1208 }
1209
1210 Some(stmt)
1211 } else {
1212 None
1213 }
1214 }
1215
1216 fn find_go_separator(&self) -> Option<usize> {
1218 let remaining = self.remaining_buffer();
1219 let mut in_string = false;
1220 let mut in_bracket = false;
1221 let mut line_start = 0;
1222
1223 for (i, c) in remaining.char_indices() {
1224 if c == '\'' && !in_bracket {
1225 in_string = !in_string;
1226 } else if c == '[' && !in_string {
1227 in_bracket = true;
1228 } else if c == ']' && !in_string {
1229 in_bracket = false;
1230 } else if c == '\n' {
1231 line_start = i + 1;
1232 } else if !in_string && !in_bracket && i == line_start {
1233 let rest = &remaining[i..];
1235 if rest.len() >= 2 {
1236 let word = &rest[..2.min(rest.len())];
1237 if word.eq_ignore_ascii_case("GO") {
1238 let after_go = if rest.len() > 2 {
1240 rest.chars().nth(2)
1241 } else {
1242 None
1243 };
1244 if after_go.is_none()
1245 || after_go == Some('\n')
1246 || after_go == Some('\r')
1247 || after_go == Some(' ')
1248 || after_go.unwrap().is_ascii_digit()
1249 {
1250 return Some(i);
1251 }
1252 }
1253 }
1254 }
1255 }
1256 None
1257 }
1258}
1259
1260#[cfg(test)]
1261mod tests {
1262 use super::*;
1263
1264 #[test]
1265 fn test_count_insert_rows() {
1266 assert_eq!(
1267 DumpLoader::count_insert_rows("INSERT INTO t VALUES (1, 'a')"),
1268 1
1269 );
1270 assert_eq!(
1271 DumpLoader::count_insert_rows("INSERT INTO t VALUES (1, 'a'), (2, 'b'), (3, 'c')"),
1272 3
1273 );
1274 assert_eq!(
1275 DumpLoader::count_insert_rows("INSERT INTO t VALUES (1, '(test)')"),
1276 1
1277 );
1278 }
1279
1280 #[test]
1281 fn test_strip_mysql_clauses() {
1282 let sql = "CREATE TABLE t (id INT) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4";
1283 let result = DumpLoader::strip_mysql_clauses(sql);
1284 assert!(!result.contains("ENGINE"));
1285 assert!(!result.contains("CHARSET"));
1286 }
1287
1288 #[test]
1289 fn test_convert_identifiers() {
1290 let sql = "INSERT INTO `users` (`id`, `name`) VALUES (1, 'test')";
1291 let result = DumpLoader::convert_identifiers(sql, SqlDialect::MySql);
1292 assert_eq!(
1293 result,
1294 "INSERT INTO \"users\" (\"id\", \"name\") VALUES (1, 'test')"
1295 );
1296 }
1297
1298 #[test]
1299 fn test_looks_like_copy_data() {
1300 assert!(DumpLoader::looks_like_copy_data("1\tAlice\t2024-01-01"));
1301 assert!(!DumpLoader::looks_like_copy_data("SELECT * FROM users"));
1302 assert!(!DumpLoader::looks_like_copy_data("INSERT INTO t VALUES"));
1303 }
1304
1305 #[test]
1306 fn test_strip_mssql_syntax() {
1307 let sql = r#"CREATE TABLE "users" (
1308 "id" INTEGER NOT NULL,
1309 "email" VARCHAR(255) NOT NULL
1310)"#;
1311 let result = DumpLoader::strip_mssql_syntax(sql);
1313 assert!(!result.contains("IDENTITY"), "IDENTITY should be stripped");
1314
1315 let sql_with_identity = r#"CREATE TABLE "users" (
1317 "id" INTEGER IDENTITY(1,1) NOT NULL,
1318 "email" VARCHAR(255) NOT NULL
1319)"#;
1320 let result2 = DumpLoader::strip_mssql_syntax(sql_with_identity);
1321 assert!(
1322 !result2.contains("IDENTITY"),
1323 "IDENTITY should be stripped: {}",
1324 result2
1325 );
1326 assert!(
1328 !result2.contains("IDENTITY(1,1) NOT NULL"),
1329 "Should strip full IDENTITY NOT NULL"
1330 );
1331 }
1332
1333 #[test]
1334 fn test_convert_mssql_identifiers() {
1335 let sql = "INSERT INTO [dbo].[users] ([id], [name]) VALUES (1, N'test')";
1336 let result = DumpLoader::convert_identifiers(sql, SqlDialect::Mssql);
1337 assert_eq!(
1338 result,
1339 "INSERT INTO \"dbo\".\"users\" (\"id\", \"name\") VALUES (1, 'test')"
1340 );
1341 }
1342}