1use super::batch::{flush_batch, BatchManager, MAX_ROWS_PER_BATCH};
4use super::types::TypeConverter;
5use super::{ImportStats, QueryConfig};
6use crate::convert::copy_to_insert::{copy_to_inserts, parse_copy_header, CopyHeader};
7use crate::parser::{
8 detect_dialect_from_file, parse_insert_for_bulk, Parser, SqlDialect, StatementType,
9};
10use crate::progress::ProgressReader;
11use crate::splitter::Compression;
12use anyhow::{Context, Result};
13use duckdb::Connection;
14use indicatif::{ProgressBar, ProgressStyle};
15use once_cell::sync::Lazy;
16use regex::Regex;
17use std::fs::File;
18use std::io::{BufRead, BufReader, Read};
19use std::path::Path;
20
21const MAX_COPY_ROWS_PER_BATCH: usize = 10_000;
24
25pub struct DumpLoader<'a> {
27 conn: &'a Connection,
28 config: &'a QueryConfig,
29}
30
31impl<'a> DumpLoader<'a> {
32 pub fn new(conn: &'a Connection, config: &'a QueryConfig) -> Self {
34 Self { conn, config }
35 }
36
37 pub fn load(&self, dump_path: &Path) -> Result<ImportStats> {
39 let start = std::time::Instant::now();
40 let mut stats = ImportStats::default();
41
42 let dialect = if let Some(d) = self.config.dialect {
44 d
45 } else {
46 let result = detect_dialect_from_file(dump_path)?;
47 result.dialect
48 };
49
50 let file_size = std::fs::metadata(dump_path)?.len();
52
53 let progress_bar = if self.config.progress {
55 let pb = ProgressBar::new(file_size);
56 pb.set_style(
57 ProgressStyle::default_bar()
58 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)")
59 .unwrap()
60 .progress_chars("=>-"),
61 );
62 Some(pb)
63 } else {
64 None
65 };
66
67 let file = File::open(dump_path).context("Failed to open dump file")?;
69 let compression = Compression::from_path(dump_path);
70 let reader: Box<dyn Read> = match compression {
71 Compression::Gzip => Box::new(flate2::read::GzDecoder::new(file)),
72 Compression::Bzip2 => Box::new(bzip2::read::BzDecoder::new(file)),
73 Compression::Xz => Box::new(xz2::read::XzDecoder::new(file)),
74 Compression::Zstd => Box::new(zstd::stream::Decoder::new(file)?),
75 Compression::None => Box::new(file),
76 };
77
78 let reader: Box<dyn Read> = if let Some(ref pb) = progress_bar {
79 let pb_clone = pb.clone();
80 Box::new(ProgressReader::new(reader, move |bytes| {
81 pb_clone.set_position(bytes);
82 }))
83 } else {
84 reader
85 };
86
87 let buf_reader = BufReader::with_capacity(256 * 1024, reader);
88
89 self.load_statements(buf_reader, dialect, &mut stats)?;
91
92 if let Some(pb) = progress_bar {
93 pb.finish_with_message("Import complete");
94 }
95
96 stats.duration_secs = start.elapsed().as_secs_f64();
97 Ok(stats)
98 }
99
100 fn load_statements<R: Read>(
102 &self,
103 reader: BufReader<R>,
104 dialect: SqlDialect,
105 stats: &mut ImportStats,
106 ) -> Result<()> {
107 let mut parser = StatementReader::new(reader, dialect);
108 let mut pending_copy: Option<CopyHeader> = None;
109
110 let mut copy_batch_data: Vec<u8> = Vec::new();
112 let mut copy_batch_rows: usize = 0;
113
114 let mut failed_tables: std::collections::HashSet<String> = std::collections::HashSet::new();
116
117 let mut batch_mgr = BatchManager::new(MAX_ROWS_PER_BATCH);
119
120 while let Some(stmt_result) = parser.next_statement() {
121 let stmt = stmt_result?;
122
123 if let Some(ref header) = pending_copy {
125 let trimmed = stmt.trim();
126
127 if trimmed == "\\." {
129 if !copy_batch_data.is_empty() {
131 self.process_copy_batch(
132 header,
133 ©_batch_data,
134 stats,
135 &mut failed_tables,
136 );
137 copy_batch_data.clear();
138 copy_batch_rows = 0;
139 }
140 pending_copy = None;
141 parser.set_copy_mode(false);
142 continue;
143 }
144
145 if trimmed.is_empty() {
147 continue;
148 }
149
150 if Self::looks_like_copy_data(&stmt) {
152 if failed_tables.contains(&header.table) {
154 continue;
155 }
156
157 copy_batch_data.extend_from_slice(stmt.as_bytes());
158 copy_batch_data.push(b'\n');
159 copy_batch_rows += 1;
160
161 if copy_batch_rows >= MAX_COPY_ROWS_PER_BATCH {
163 self.process_copy_batch(
164 header,
165 ©_batch_data,
166 stats,
167 &mut failed_tables,
168 );
169 copy_batch_data.clear();
170 copy_batch_rows = 0;
171 }
172 continue;
173 }
174
175 if !copy_batch_data.is_empty() {
177 self.process_copy_batch(header, ©_batch_data, stats, &mut failed_tables);
178 copy_batch_data.clear();
179 copy_batch_rows = 0;
180 }
181 pending_copy = None;
182 parser.set_copy_mode(false);
183 }
185
186 let (mut stmt_type, table_name) =
187 Parser::<&[u8]>::parse_statement_with_dialect(stmt.as_bytes(), dialect);
188
189 if dialect == SqlDialect::Postgres && stmt_type == StatementType::Unknown {
191 let upper = stmt.to_uppercase();
192 if let Some(copy_pos) = upper.find("COPY ") {
193 let after_copy = &upper[copy_pos..];
194 if after_copy.contains("FROM STDIN") {
195 stmt_type = StatementType::Copy;
196 }
197 }
198 }
199
200 if let Some(ref tables) = self.config.tables {
202 if !table_name.is_empty()
203 && !tables.iter().any(|t| t.eq_ignore_ascii_case(&table_name))
204 {
205 continue;
206 }
207 }
208
209 match stmt_type {
210 StatementType::CreateTable => {
211 let duckdb_sql = self.convert_create_table(&stmt, dialect)?;
212 match self.conn.execute(&duckdb_sql, []) {
213 Ok(_) => stats.tables_created += 1,
214 Err(e) => {
215 stats
216 .warnings
217 .push(format!("Failed to create table {}: {}", table_name, e));
218 stats.statements_skipped += 1;
219 }
220 }
221 }
222 StatementType::Insert => {
223 if !self.try_queue_for_bulk(
225 &stmt,
226 dialect,
227 &mut batch_mgr,
228 stats,
229 &mut failed_tables,
230 ) {
231 let duckdb_sql = self.convert_insert(&stmt, dialect)?;
233 match self.conn.execute(&duckdb_sql, []) {
234 Ok(_) => {
235 stats.insert_statements += 1;
236 stats.rows_inserted += Self::count_insert_rows(&duckdb_sql);
237 }
238 Err(e) => {
239 stats
240 .warnings
241 .push(format!("Failed INSERT for {}: {}", table_name, e));
242 stats.statements_skipped += 1;
243 }
244 }
245 }
246
247 for mut batch in batch_mgr.get_ready_batches() {
249 flush_batch(self.conn, &mut batch, stats, &mut failed_tables)?;
250 }
251 }
252 StatementType::Copy => {
253 if let Some(header) = parse_copy_header(&stmt) {
255 if failed_tables.contains(&header.table) {
257 parser.set_copy_mode(true);
259 Self::skip_copy_block(&mut parser);
260 parser.set_copy_mode(false);
261 continue;
262 }
263
264 if !self.table_exists(&header.table) {
266 failed_tables.insert(header.table.clone());
267 if stats.warnings.len() < 100 {
268 stats.warnings.push(format!(
269 "Skipping COPY for non-existent table {}",
270 header.table
271 ));
272 }
273 parser.set_copy_mode(true);
275 Self::skip_copy_block(&mut parser);
276 parser.set_copy_mode(false);
277 continue;
278 }
279
280 copy_batch_data.clear();
281 copy_batch_rows = 0;
282 pending_copy = Some(header);
283 parser.set_copy_mode(true);
284 }
285 }
286 StatementType::CreateIndex => {
287 stats.statements_skipped += 1;
289 }
290 _ => {
291 stats.statements_skipped += 1;
293 }
294 }
295 }
296
297 if let Some(ref header) = pending_copy {
299 if !copy_batch_data.is_empty() {
300 self.process_copy_batch(header, ©_batch_data, stats, &mut failed_tables);
301 }
302 }
303
304 for mut batch in batch_mgr.drain_all() {
306 flush_batch(self.conn, &mut batch, stats, &mut failed_tables)?;
307 }
308
309 Ok(())
310 }
311
312 fn process_copy_batch(
314 &self,
315 header: &CopyHeader,
316 batch_data: &[u8],
317 stats: &mut ImportStats,
318 failed_tables: &mut std::collections::HashSet<String>,
319 ) {
320 if batch_data.is_empty() {
321 return;
322 }
323
324 if failed_tables.contains(&header.table) {
326 return;
327 }
328
329 let inserts = copy_to_inserts(header, batch_data, SqlDialect::Postgres);
330 for insert in inserts {
331 let insert_sql = String::from_utf8_lossy(&insert);
332 match self.conn.execute(&insert_sql, []) {
333 Ok(_) => {
334 stats.rows_inserted += Self::count_insert_rows(&insert_sql);
335 }
336 Err(e) => {
337 let err_str = e.to_string();
338 if err_str.contains("does not exist") {
340 failed_tables.insert(header.table.clone());
341 if stats.warnings.len() < 100 {
342 stats.warnings.push(format!(
343 "Table {} does not exist, skipping COPY data",
344 header.table
345 ));
346 }
347 return; }
349 if stats.warnings.len() < 100 {
351 stats.warnings.push(format!(
352 "Failed to insert COPY data for {}: {}",
353 header.table, e
354 ));
355 }
356 stats.statements_skipped += 1;
357 }
358 }
359 }
360 }
361
362 fn table_exists(&self, table: &str) -> bool {
364 let query = "SELECT 1 FROM information_schema.tables WHERE table_name = ? LIMIT 1";
365 match self.conn.prepare(query) {
366 Ok(mut stmt) => stmt.exists([table]).unwrap_or(false),
367 Err(_) => false,
368 }
369 }
370
371 fn try_queue_for_bulk(
374 &self,
375 stmt: &str,
376 dialect: SqlDialect,
377 batch_mgr: &mut BatchManager,
378 stats: &mut ImportStats,
379 failed_tables: &mut std::collections::HashSet<String>,
380 ) -> bool {
381 let upper = stmt.to_uppercase();
383 if upper.contains("ON DUPLICATE KEY")
384 || upper.contains("ON CONFLICT")
385 || upper.contains("REPLACE")
386 || upper.contains("IGNORE")
387 || upper.contains("RETURNING")
388 || upper.contains("SELECT")
389 {
390 return false;
391 }
392
393 let parsed = match parse_insert_for_bulk(stmt.as_bytes()) {
395 Ok(p) => p,
396 Err(_) => return false, };
398
399 if failed_tables.contains(&parsed.table) {
401 return true; }
403
404 if parsed.rows.is_empty() {
406 return false;
407 }
408
409 let duckdb_sql = match self.convert_insert(stmt, dialect) {
411 Ok(sql) => sql,
412 Err(_) => return false,
413 };
414
415 if let Some(mut batch) =
417 batch_mgr.queue_insert(&parsed.table, parsed.columns, parsed.rows, duckdb_sql)
418 {
419 if let Err(e) = flush_batch(self.conn, &mut batch, stats, failed_tables) {
421 if stats.warnings.len() < 100 {
422 stats.warnings.push(format!("Batch flush error: {}", e));
423 }
424 }
425 }
426
427 true
428 }
429
430 fn skip_copy_block<R: Read>(parser: &mut StatementReader<R>) {
432 while let Some(Ok(line)) = parser.next_statement() {
433 if line.trim() == "\\." {
434 break;
435 }
436 }
437 }
438
439 fn looks_like_copy_data(line: &str) -> bool {
441 let trimmed = line.trim();
442
443 if trimmed.is_empty() {
445 return false;
446 }
447
448 if trimmed == "\\." {
450 return false;
451 }
452
453 let first_char = trimmed.chars().next().unwrap_or(' ');
456
457 if matches!(
459 first_char,
460 'S' | 's'
461 | 'I'
462 | 'i'
463 | 'C'
464 | 'c'
465 | 'D'
466 | 'd'
467 | 'A'
468 | 'a'
469 | 'U'
470 | 'u'
471 | 'G'
472 | 'g'
473 | '-'
474 | '/'
475 ) {
476 let upper_prefix: String = trimmed.chars().take(7).collect::<String>().to_uppercase();
477 if upper_prefix.starts_with("SELECT")
478 || upper_prefix.starts_with("INSERT")
479 || upper_prefix.starts_with("CREATE")
480 || upper_prefix.starts_with("DROP")
481 || upper_prefix.starts_with("ALTER")
482 || upper_prefix.starts_with("UPDATE")
483 || upper_prefix.starts_with("GRANT")
484 || upper_prefix.starts_with("--")
485 || upper_prefix.starts_with("/*")
486 {
487 return false;
488 }
489 }
490
491 true
494 }
495
496 fn convert_create_table(&self, stmt: &str, dialect: SqlDialect) -> Result<String> {
498 let mut result = stmt.to_string();
499
500 result = Self::convert_identifiers(&result, dialect);
502
503 result = Self::strip_mysql_clauses(&result);
506
507 result = Self::convert_types_in_statement(&result);
509
510 if dialect == SqlDialect::Postgres {
512 result = Self::strip_postgres_syntax(&result);
513 }
514
515 if dialect == SqlDialect::Mssql {
517 result = Self::strip_mssql_syntax(&result);
518 }
519
520 if dialect == SqlDialect::Sqlite {
522 result = Self::strip_sqlite_syntax(&result);
523 }
524
525 Ok(result)
526 }
527
528 fn convert_insert(&self, stmt: &str, dialect: SqlDialect) -> Result<String> {
530 let mut result = stmt.to_string();
531
532 result = Self::convert_identifiers(&result, dialect);
534
535 if dialect == SqlDialect::MySql {
537 result = Self::convert_mysql_escapes(&result);
538 }
539
540 if dialect == SqlDialect::Postgres {
542 result = Self::strip_schema_prefix(&result);
543 }
544
545 if dialect == SqlDialect::Mssql {
547 result = Self::strip_mssql_schema_prefix(&result);
548 }
549
550 Ok(result)
551 }
552
553 fn convert_mysql_escapes(stmt: &str) -> String {
555 let mut result = String::with_capacity(stmt.len() + 100);
556 let mut chars = stmt.chars().peekable();
557 let mut in_string = false;
558
559 while let Some(c) = chars.next() {
560 if c == '\'' {
561 in_string = !in_string;
562 result.push(c);
563 } else if c == '\\' && in_string {
564 match chars.peek() {
566 Some('\'') => {
567 chars.next();
569 result.push_str("''");
570 }
571 Some('\\') => {
572 chars.next();
574 result.push('\\');
575 }
576 Some('n') => {
577 chars.next();
579 result.push('\n');
580 }
581 Some('r') => {
582 chars.next();
584 result.push('\r');
585 }
586 Some('t') => {
587 chars.next();
589 result.push('\t');
590 }
591 Some('0') => {
592 chars.next();
594 }
595 Some('"') => {
596 chars.next();
598 result.push('"');
599 }
600 _ => {
601 result.push(c);
603 }
604 }
605 } else {
606 result.push(c);
607 }
608 }
609 result
610 }
611
612 fn convert_identifiers(stmt: &str, dialect: SqlDialect) -> String {
614 match dialect {
615 SqlDialect::MySql => {
616 let mut result = String::with_capacity(stmt.len());
618 let mut in_string = false;
619 let mut in_backtick = false;
620
621 for c in stmt.chars() {
622 if c == '\'' && !in_backtick {
623 in_string = !in_string;
624 result.push(c);
625 } else if c == '`' && !in_string {
626 in_backtick = !in_backtick;
627 result.push('"');
628 } else {
629 result.push(c);
630 }
631 }
632 result
633 }
634 SqlDialect::Mssql => {
635 let mut result = String::with_capacity(stmt.len());
637 let mut in_string = false;
638 let mut in_bracket = false;
639 let mut chars = stmt.chars().peekable();
640
641 while let Some(c) = chars.next() {
642 if c == '\'' && !in_bracket {
643 in_string = !in_string;
644 result.push(c);
645 } else if c == '[' && !in_string {
646 in_bracket = true;
647 result.push('"');
648 } else if c == ']' && !in_string {
649 if chars.peek() == Some(&']') {
651 chars.next();
652 result.push(']');
653 } else {
654 in_bracket = false;
655 result.push('"');
656 }
657 } else if c == 'N' && !in_string && !in_bracket && chars.peek() == Some(&'\'') {
658 } else {
661 result.push(c);
662 }
663 }
664 result
665 }
666 _ => stmt.to_string(),
667 }
668 }
669
670 fn convert_types_in_statement(stmt: &str) -> String {
672 static RE_COLUMN_TYPE: Lazy<Regex> = Lazy::new(|| {
679 Regex::new(r#"(?i)(["'`\]\s])\s*(BIGSERIAL|SMALLSERIAL|SERIAL|BIGINT|SMALLINT|MEDIUMINT|TINYINT|INTEGER|INT|DOUBLE\s+PRECISION|DOUBLE|FLOAT|DECIMAL|NUMERIC|CHARACTER\s+VARYING|NVARCHAR|NCHAR|VARCHAR|CHAR|VARBINARY|BINARY|LONGTEXT|MEDIUMTEXT|TINYTEXT|NTEXT|TEXT|LONGBLOB|MEDIUMBLOB|TINYBLOB|IMAGE|BLOB|DATETIME2|DATETIMEOFFSET|SMALLDATETIME|DATETIME|TIMESTAMPTZ|TIMESTAMP|TIMETZ|TIME|DATE|YEAR|ENUM|SET|JSONB|JSON|UUID|UNIQUEIDENTIFIER|BYTEA|BOOLEAN|BOOL|BIT|REAL|MONEY|SMALLMONEY|INTERVAL|ROWVERSION|XML|SQL_VARIANT)\b(\s*\([^)]+\))?(\s+(?:UNSIGNED|WITH(?:OUT)?\s+TIME\s+ZONE))?"#).unwrap()
680 });
681
682 RE_COLUMN_TYPE
683 .replace_all(stmt, |caps: ®ex::Captures| {
684 let full_match = caps.get(0).unwrap().as_str();
685 let leading_char = caps.get(1).unwrap().as_str();
686 let type_part = caps.get(2).unwrap().as_str();
687 let size_part = caps.get(3).map(|m| m.as_str()).unwrap_or("");
688 let suffix = caps.get(4).map(|m| m.as_str()).unwrap_or("");
689
690 let end_pos = caps.get(0).unwrap().end();
693 let stmt_bytes = stmt.as_bytes();
694 if end_pos < stmt_bytes.len() {
695 let next_char = stmt_bytes[end_pos] as char;
696 if next_char == '"' || next_char == '\'' || next_char == '`' {
698 return full_match.to_string();
699 }
700 }
701
702 let ws_len = full_match.len()
704 - leading_char.len()
705 - type_part.len()
706 - size_part.len()
707 - suffix.len();
708 let ws = &full_match[leading_char.len()..leading_char.len() + ws_len];
709
710 let converted =
711 TypeConverter::convert(&format!("{}{}{}", type_part, size_part, suffix));
712 format!("{}{}{}", leading_char, ws, converted)
713 })
714 .to_string()
715 }
716
717 fn strip_mysql_clauses(stmt: &str) -> String {
719 let mut result = stmt.to_string();
720
721 static RE_ENGINE: Lazy<Regex> =
723 Lazy::new(|| Regex::new(r"(?i)\s*ENGINE\s*=\s*\w+").unwrap());
724 result = RE_ENGINE.replace_all(&result, "").to_string();
725
726 static RE_AUTO_INC: Lazy<Regex> =
728 Lazy::new(|| Regex::new(r"(?i)\s*AUTO_INCREMENT\s*=\s*\d+").unwrap());
729 result = RE_AUTO_INC.replace_all(&result, "").to_string();
730
731 result = result.replace(" AUTO_INCREMENT", "");
733 result = result.replace(" auto_increment", "");
734
735 static RE_CHAR_SET: Lazy<Regex> =
737 Lazy::new(|| Regex::new(r"(?i)\s*CHARACTER\s+SET\s+\w+").unwrap());
738 result = RE_CHAR_SET.replace_all(&result, "").to_string();
739
740 static RE_CHARSET: Lazy<Regex> =
742 Lazy::new(|| Regex::new(r"(?i)\s*(DEFAULT\s+)?CHARSET\s*=\s*\w+").unwrap());
743 result = RE_CHARSET.replace_all(&result, "").to_string();
744
745 static RE_COLLATE: Lazy<Regex> =
747 Lazy::new(|| Regex::new(r"(?i)\s*COLLATE\s*=?\s*\w+").unwrap());
748 result = RE_COLLATE.replace_all(&result, "").to_string();
749
750 static RE_ROW_FORMAT: Lazy<Regex> =
752 Lazy::new(|| Regex::new(r"(?i)\s*ROW_FORMAT\s*=\s*\w+").unwrap());
753 result = RE_ROW_FORMAT.replace_all(&result, "").to_string();
754
755 static RE_KEY_BLOCK: Lazy<Regex> =
757 Lazy::new(|| Regex::new(r"(?i)\s*KEY_BLOCK_SIZE\s*=\s*\d+").unwrap());
758 result = RE_KEY_BLOCK.replace_all(&result, "").to_string();
759
760 static RE_COMMENT: Lazy<Regex> =
762 Lazy::new(|| Regex::new(r"(?i)\s*COMMENT\s*=?\s*'[^']*'").unwrap());
763 result = RE_COMMENT.replace_all(&result, "").to_string();
764
765 static RE_COND_COMMENT: Lazy<Regex> = Lazy::new(|| Regex::new(r"/\*!\d+\s*|\*/").unwrap());
767 result = RE_COND_COMMENT.replace_all(&result, "").to_string();
768
769 static RE_ON_UPDATE: Lazy<Regex> =
771 Lazy::new(|| Regex::new(r"(?i)\s*ON\s+UPDATE\s+CURRENT_TIMESTAMP").unwrap());
772 result = RE_ON_UPDATE.replace_all(&result, "").to_string();
773
774 static RE_UNIQUE_KEY: Lazy<Regex> = Lazy::new(|| {
777 Regex::new(r#"(?i),?\s*UNIQUE\s+KEY\s+[`"']?\w+[`"']?\s*\([^)]+\)"#).unwrap()
778 });
779 result = RE_UNIQUE_KEY.replace_all(&result, "").to_string();
780
781 static RE_KEY_INDEX: Lazy<Regex> = Lazy::new(|| {
785 Regex::new(
786 r#"(?i)(?:,\s*|\n\s*)(?:FULLTEXT\s+|SPATIAL\s+)?KEY\s+[`"']?\w+[`"']?\s*\([^)]+\)"#,
787 )
788 .unwrap()
789 });
790 result = RE_KEY_INDEX.replace_all(&result, "").to_string();
791
792 static RE_GENERATED_COL: Lazy<Regex> = Lazy::new(|| {
796 Regex::new(r#"(?i),?\s*[`"']?\w+[`"']?\s+\w+\s+GENERATED\s+ALWAYS\s+AS\s*\((?:[^()]+|\([^()]*\))+\)\s*(?:STORED|VIRTUAL)?"#).unwrap()
797 });
798 result = RE_GENERATED_COL.replace_all(&result, "").to_string();
799
800 static RE_FK_CONSTRAINT: Lazy<Regex> = Lazy::new(|| {
804 Regex::new(r#"(?i),?\s*(?:CONSTRAINT\s+[`"']?\w+[`"']?\s+)?FOREIGN\s+KEY\s*\([^)]+\)\s*REFERENCES\s+[`"']?\w+[`"']?\s*\([^)]+\)(?:\s+ON\s+(?:DELETE|UPDATE)\s+(?:CASCADE|SET\s+NULL|SET\s+DEFAULT|NO\s+ACTION|RESTRICT))*"#).unwrap()
805 });
806 result = RE_FK_CONSTRAINT.replace_all(&result, "").to_string();
807
808 result
809 }
810
811 fn strip_postgres_syntax(stmt: &str) -> String {
813 let mut result = stmt.to_string();
814
815 result = Self::strip_schema_prefix(&result);
817
818 static RE_CAST: Lazy<Regex> = Lazy::new(|| {
820 Regex::new(r"::[a-zA-Z_][a-zA-Z0-9_]*(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)*").unwrap()
821 });
822 result = RE_CAST.replace_all(&result, "").to_string();
823
824 static RE_NEXTVAL: Lazy<Regex> =
826 Lazy::new(|| Regex::new(r"(?i)\s*DEFAULT\s+nextval\s*\([^)]+\)").unwrap());
827 result = RE_NEXTVAL.replace_all(&result, "").to_string();
828
829 static RE_NOW: Lazy<Regex> =
831 Lazy::new(|| Regex::new(r"(?i)\bDEFAULT\s+now\s*\(\s*\)").unwrap());
832 result = RE_NOW
833 .replace_all(&result, "DEFAULT CURRENT_TIMESTAMP")
834 .to_string();
835
836 static RE_INHERITS: Lazy<Regex> =
838 Lazy::new(|| Regex::new(r"(?i)\s*INHERITS\s*\([^)]+\)").unwrap());
839 result = RE_INHERITS.replace_all(&result, "").to_string();
840
841 static RE_WITH: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)\s*WITH\s*\([^)]+\)").unwrap());
843 result = RE_WITH.replace_all(&result, "").to_string();
844
845 result
846 }
847
848 fn strip_schema_prefix(stmt: &str) -> String {
850 static RE_SCHEMA: Lazy<Regex> =
851 Lazy::new(|| Regex::new(r#"(?i)\b(public|pg_catalog|pg_temp)\s*\.\s*"#).unwrap());
852 RE_SCHEMA.replace_all(stmt, "").to_string()
853 }
854
855 fn strip_mssql_schema_prefix(stmt: &str) -> String {
857 static RE_SCHEMA: Lazy<Regex> =
859 Lazy::new(|| Regex::new(r#"(?i)"?(dbo|master|tempdb|model|msdb)"?\s*\.\s*"#).unwrap());
860 RE_SCHEMA.replace_all(stmt, "").to_string()
861 }
862
863 fn strip_mssql_syntax(stmt: &str) -> String {
865 let mut result = Self::strip_mssql_schema_prefix(stmt);
866
867 static RE_IDENTITY_NOT_NULL: Lazy<Regex> = Lazy::new(|| {
870 Regex::new(r"(?i)\s*IDENTITY\s*\(\s*\d+\s*,\s*\d+\s*\)\s*NOT\s+NULL").unwrap()
871 });
872 result = RE_IDENTITY_NOT_NULL.replace_all(&result, "").to_string();
873
874 static RE_IDENTITY: Lazy<Regex> =
876 Lazy::new(|| Regex::new(r"(?i)\s*IDENTITY\s*\(\s*\d+\s*,\s*\d+\s*\)").unwrap());
877 result = RE_IDENTITY.replace_all(&result, "").to_string();
878
879 static RE_CLUSTERED: Lazy<Regex> =
881 Lazy::new(|| Regex::new(r"(?i)\s*(?:NON)?CLUSTERED\s*").unwrap());
882 result = RE_CLUSTERED.replace_all(&result, " ").to_string();
883
884 static RE_FILEGROUP: Lazy<Regex> =
886 Lazy::new(|| Regex::new(r#"(?i)\s*ON\s*"?PRIMARY"?"#).unwrap());
887 result = RE_FILEGROUP.replace_all(&result, "").to_string();
888
889 static RE_PK_CONSTRAINT: Lazy<Regex> = Lazy::new(|| {
891 Regex::new(r#"(?i),?\s*CONSTRAINT\s+"?\w+"?\s+PRIMARY\s+KEY\s+\([^)]+\)"#).unwrap()
892 });
893 result = RE_PK_CONSTRAINT.replace_all(&result, "").to_string();
894
895 static RE_FK_CONSTRAINT: Lazy<Regex> = Lazy::new(|| {
897 Regex::new(r#"(?i),?\s*CONSTRAINT\s+"?\w+"?\s+FOREIGN\s+KEY\s*\([^)]+\)\s*REFERENCES\s+[^\s(]+\s*\([^)]+\)"#).unwrap()
898 });
899 result = RE_FK_CONSTRAINT.replace_all(&result, "").to_string();
900
901 static RE_WITH: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)\s*WITH\s*\([^)]+\)").unwrap());
903 result = RE_WITH.replace_all(&result, "").to_string();
904
905 static RE_TEXTIMAGE: Lazy<Regex> =
907 Lazy::new(|| Regex::new(r#"(?i)\s*TEXTIMAGE_ON\s*"?\w+"?"#).unwrap());
908 result = RE_TEXTIMAGE.replace_all(&result, "").to_string();
909
910 static RE_GETDATE: Lazy<Regex> =
912 Lazy::new(|| Regex::new(r"(?i)\bGETDATE\s*\(\s*\)").unwrap());
913 result = RE_GETDATE
914 .replace_all(&result, "CURRENT_TIMESTAMP")
915 .to_string();
916
917 static RE_NEWID: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)\bNEWID\s*\(\s*\)").unwrap());
919 result = RE_NEWID
920 .replace_all(&result, "gen_random_uuid()")
921 .to_string();
922
923 result
924 }
925
926 fn strip_sqlite_syntax(stmt: &str) -> String {
928 let mut result = stmt.to_string();
929
930 result = result.replace(" AUTOINCREMENT", "");
933 result = result.replace(" autoincrement", "");
934
935 static RE_STRICT: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)\)\s*STRICT\s*;").unwrap());
940 result = RE_STRICT.replace_all(&result, ");").to_string();
941
942 static RE_WITHOUT_ROWID: Lazy<Regex> =
944 Lazy::new(|| Regex::new(r"(?i)\)\s*WITHOUT\s+ROWID\s*;").unwrap());
945 result = RE_WITHOUT_ROWID.replace_all(&result, ");").to_string();
946
947 result
948 }
949
950 fn count_insert_rows(sql: &str) -> u64 {
952 if let Some(values_pos) = sql.to_uppercase().find("VALUES") {
954 let after_values = &sql[values_pos + 6..];
955 let mut count = 0u64;
957 let mut depth = 0;
958 let mut in_string = false;
959 let mut prev_char = ' ';
960
961 for c in after_values.chars() {
962 if c == '\'' && prev_char != '\\' {
963 in_string = !in_string;
964 }
965 if !in_string {
966 if c == '(' {
967 if depth == 0 {
968 count += 1;
969 }
970 depth += 1;
971 } else if c == ')' {
972 depth -= 1;
973 }
974 }
975 prev_char = c;
976 }
977 count
978 } else {
979 1
980 }
981 }
982}
983
984struct StatementReader<R> {
986 reader: BufReader<R>,
987 dialect: SqlDialect,
988 buffer: String,
989 buffer_pos: usize,
991 eof: bool,
992 in_copy_mode: bool,
994}
995
996impl<R: Read> StatementReader<R> {
997 fn new(reader: BufReader<R>, dialect: SqlDialect) -> Self {
998 Self {
999 reader,
1000 dialect,
1001 buffer: String::new(),
1002 buffer_pos: 0,
1003 eof: false,
1004 in_copy_mode: false,
1005 }
1006 }
1007
1008 fn compact_buffer(&mut self) {
1011 if self.buffer_pos > 0 {
1012 self.buffer.drain(..self.buffer_pos);
1013 self.buffer_pos = 0;
1014 }
1015 }
1016
1017 fn remaining_buffer(&self) -> &str {
1019 &self.buffer[self.buffer_pos..]
1020 }
1021
1022 fn set_copy_mode(&mut self, enabled: bool) {
1024 self.in_copy_mode = enabled;
1025 }
1026
1027 fn strip_leading_sql_comments(s: &str) -> &str {
1029 let mut result = s.trim();
1030 loop {
1031 if result.starts_with("--") {
1033 if let Some(pos) = result.find('\n') {
1034 result = result[pos + 1..].trim();
1035 continue;
1036 } else {
1037 return ""; }
1039 }
1040 if result.starts_with("/*") {
1042 if let Some(pos) = result.find("*/") {
1043 result = result[pos + 2..].trim();
1044 continue;
1045 } else {
1046 return ""; }
1048 }
1049 break;
1050 }
1051 result
1052 }
1053
1054 fn next_statement(&mut self) -> Option<Result<String>> {
1055 if self.eof && self.remaining_buffer().is_empty() {
1056 return None;
1057 }
1058
1059 loop {
1060 if self.in_copy_mode {
1062 if let Some(line) = self.extract_copy_line() {
1063 return Some(Ok(line));
1064 }
1065 } else {
1066 if let Some(stmt) = self.extract_statement() {
1068 return Some(Ok(stmt));
1070 }
1071 }
1072
1073 if self.buffer_pos > 64 * 1024 {
1076 self.compact_buffer();
1077 }
1078
1079 let mut line = String::new();
1081 match self.reader.read_line(&mut line) {
1082 Ok(0) => {
1083 self.eof = true;
1084 self.in_copy_mode = false; let remaining = self.remaining_buffer().trim();
1086 if !remaining.is_empty() {
1087 let stmt = remaining.to_string();
1088 self.buffer.clear();
1089 self.buffer_pos = 0;
1090 return Some(Ok(stmt));
1091 }
1092 return None;
1093 }
1094 Ok(_) => {
1095 self.buffer.push_str(&line);
1096 }
1097 Err(e) => return Some(Err(e.into())),
1098 }
1099 }
1100 }
1101
1102 fn extract_copy_line(&mut self) -> Option<String> {
1104 let remaining = self.remaining_buffer();
1105 if let Some(newline_pos) = remaining.find('\n') {
1106 let line = remaining[..newline_pos].to_string();
1107 self.buffer_pos += newline_pos + 1;
1108 Some(line)
1110 } else {
1111 None
1112 }
1113 }
1114
1115 fn extract_statement(&mut self) -> Option<String> {
1116 let remaining = self.remaining_buffer();
1117 let mut in_string = false;
1118 let mut in_dollar_quote = false;
1119 let mut in_bracket = false;
1120 let mut in_line_comment = false;
1121 let mut in_block_comment = false;
1122 let mut escape_next = false;
1123 let mut chars = remaining.char_indices().peekable();
1124 let mut end_pos = None;
1125
1126 if self.dialect == SqlDialect::Mssql {
1128 if let Some(go_pos) = self.find_go_separator() {
1129 let stmt = remaining[..go_pos].to_string();
1130 let after_go = &remaining[go_pos..];
1132 if let Some(line_end) = after_go.find('\n') {
1133 self.buffer_pos += go_pos + line_end + 1;
1134 } else {
1135 self.buffer_pos = self.buffer.len();
1136 }
1137
1138 let trimmed = stmt.trim();
1139 if trimmed.is_empty()
1140 || trimmed.starts_with("--")
1141 || (trimmed.starts_with("/*") && !trimmed.contains("/*!"))
1142 {
1143 return self.extract_statement();
1144 }
1145 return Some(stmt);
1146 }
1147 }
1148
1149 while let Some((i, c)) = chars.next() {
1150 if escape_next {
1151 escape_next = false;
1152 continue;
1153 }
1154
1155 if in_line_comment {
1157 if c == '\n' {
1158 in_line_comment = false;
1159 }
1160 continue;
1161 }
1162
1163 if in_block_comment {
1165 if c == '*' && chars.peek().map(|(_, c)| *c == '/').unwrap_or(false) {
1166 chars.next();
1167 in_block_comment = false;
1168 }
1169 continue;
1170 }
1171
1172 match c {
1173 '\\' if self.dialect == SqlDialect::MySql && in_string => {
1174 escape_next = true;
1175 }
1176 '\'' if !in_dollar_quote && !in_bracket => {
1177 in_string = !in_string;
1178 }
1179 '[' if self.dialect == SqlDialect::Mssql && !in_string => {
1180 in_bracket = true;
1181 }
1182 ']' if self.dialect == SqlDialect::Mssql && !in_string => {
1183 if chars.peek().map(|(_, c)| *c == ']').unwrap_or(false) {
1185 chars.next();
1186 } else {
1187 in_bracket = false;
1188 }
1189 }
1190 '$' if self.dialect == SqlDialect::Postgres && !in_string => {
1191 if chars.peek().map(|(_, c)| *c == '$').unwrap_or(false) {
1193 in_dollar_quote = !in_dollar_quote;
1194 chars.next();
1195 }
1196 }
1197 '-' if !in_string && !in_dollar_quote && !in_bracket => {
1198 if chars.peek().map(|(_, c)| *c == '-').unwrap_or(false) {
1200 chars.next();
1201 in_line_comment = true;
1202 }
1203 }
1204 '/' if !in_string && !in_dollar_quote && !in_bracket => {
1205 if chars.peek().map(|(_, c)| *c == '*').unwrap_or(false) {
1207 chars.next();
1208 in_block_comment = true;
1209 }
1210 }
1211 ';' if !in_string && !in_dollar_quote && !in_bracket => {
1212 end_pos = Some(i + 1);
1213 break;
1214 }
1215 _ => {}
1216 }
1217 }
1218
1219 if let Some(pos) = end_pos {
1220 let stmt = remaining[..pos].to_string();
1221 let after_stmt = &remaining[pos..];
1223 let trimmed_len = after_stmt.len() - after_stmt.trim_start().len();
1224 self.buffer_pos += pos + trimmed_len;
1225
1226 let trimmed = stmt.trim();
1228
1229 let stripped = Self::strip_leading_sql_comments(trimmed);
1231 if stripped.is_empty() {
1232 return self.extract_statement();
1233 }
1234
1235 let trimmed = stripped;
1237
1238 if self.dialect == SqlDialect::Postgres {
1241 let upper = trimmed.to_uppercase();
1242 if upper.ends_with("FROM STDIN;") && upper.contains("COPY ") {
1243 self.in_copy_mode = true;
1244 }
1245 }
1246
1247 Some(stmt)
1248 } else {
1249 None
1250 }
1251 }
1252
1253 fn find_go_separator(&self) -> Option<usize> {
1255 let remaining = self.remaining_buffer();
1256 let mut in_string = false;
1257 let mut in_bracket = false;
1258 let mut line_start = 0;
1259
1260 for (i, c) in remaining.char_indices() {
1261 if c == '\'' && !in_bracket {
1262 in_string = !in_string;
1263 } else if c == '[' && !in_string {
1264 in_bracket = true;
1265 } else if c == ']' && !in_string {
1266 in_bracket = false;
1267 } else if c == '\n' {
1268 line_start = i + 1;
1269 } else if !in_string && !in_bracket && i == line_start {
1270 let rest = &remaining[i..];
1272 if rest.len() >= 2 {
1273 let word = &rest[..2.min(rest.len())];
1274 if word.eq_ignore_ascii_case("GO") {
1275 let after_go = if rest.len() > 2 {
1277 rest.chars().nth(2)
1278 } else {
1279 None
1280 };
1281 if after_go.is_none()
1282 || after_go == Some('\n')
1283 || after_go == Some('\r')
1284 || after_go == Some(' ')
1285 || after_go.unwrap().is_ascii_digit()
1286 {
1287 return Some(i);
1288 }
1289 }
1290 }
1291 }
1292 }
1293 None
1294 }
1295}
1296
1297#[cfg(test)]
1298mod tests {
1299 use super::*;
1300
1301 #[test]
1302 fn test_count_insert_rows() {
1303 assert_eq!(
1304 DumpLoader::count_insert_rows("INSERT INTO t VALUES (1, 'a')"),
1305 1
1306 );
1307 assert_eq!(
1308 DumpLoader::count_insert_rows("INSERT INTO t VALUES (1, 'a'), (2, 'b'), (3, 'c')"),
1309 3
1310 );
1311 assert_eq!(
1312 DumpLoader::count_insert_rows("INSERT INTO t VALUES (1, '(test)')"),
1313 1
1314 );
1315 }
1316
1317 #[test]
1318 fn test_strip_mysql_clauses() {
1319 let sql = "CREATE TABLE t (id INT) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4";
1320 let result = DumpLoader::strip_mysql_clauses(sql);
1321 assert!(!result.contains("ENGINE"));
1322 assert!(!result.contains("CHARSET"));
1323 }
1324
1325 #[test]
1326 fn test_convert_identifiers() {
1327 let sql = "INSERT INTO `users` (`id`, `name`) VALUES (1, 'test')";
1328 let result = DumpLoader::convert_identifiers(sql, SqlDialect::MySql);
1329 assert_eq!(
1330 result,
1331 "INSERT INTO \"users\" (\"id\", \"name\") VALUES (1, 'test')"
1332 );
1333 }
1334
1335 #[test]
1336 fn test_looks_like_copy_data() {
1337 assert!(DumpLoader::looks_like_copy_data("1\tAlice\t2024-01-01"));
1338 assert!(!DumpLoader::looks_like_copy_data("SELECT * FROM users"));
1339 assert!(!DumpLoader::looks_like_copy_data("INSERT INTO t VALUES"));
1340 }
1341
1342 #[test]
1343 fn test_strip_mssql_syntax() {
1344 let sql = r#"CREATE TABLE "users" (
1345 "id" INTEGER NOT NULL,
1346 "email" VARCHAR(255) NOT NULL
1347)"#;
1348 let result = DumpLoader::strip_mssql_syntax(sql);
1350 assert!(!result.contains("IDENTITY"), "IDENTITY should be stripped");
1351
1352 let sql_with_identity = r#"CREATE TABLE "users" (
1354 "id" INTEGER IDENTITY(1,1) NOT NULL,
1355 "email" VARCHAR(255) NOT NULL
1356)"#;
1357 let result2 = DumpLoader::strip_mssql_syntax(sql_with_identity);
1358 assert!(
1359 !result2.contains("IDENTITY"),
1360 "IDENTITY should be stripped: {}",
1361 result2
1362 );
1363 assert!(
1365 !result2.contains("IDENTITY(1,1) NOT NULL"),
1366 "Should strip full IDENTITY NOT NULL"
1367 );
1368 }
1369
1370 #[test]
1371 fn test_convert_mssql_identifiers() {
1372 let sql = "INSERT INTO [dbo].[users] ([id], [name]) VALUES (1, N'test')";
1373 let result = DumpLoader::convert_identifiers(sql, SqlDialect::Mssql);
1374 assert_eq!(
1375 result,
1376 "INSERT INTO \"dbo\".\"users\" (\"id\", \"name\") VALUES (1, 'test')"
1377 );
1378 }
1379}