1#[cfg(test)]
2mod edge_case_tests;
3
4use once_cell::sync::Lazy;
5use regex::bytes::Regex;
6use std::io::{BufRead, BufReader, Read};
7
8pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
9pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum SqlDialect {
14 #[default]
16 MySql,
17 Postgres,
19 Sqlite,
21}
22
23impl std::str::FromStr for SqlDialect {
24 type Err = String;
25
26 fn from_str(s: &str) -> Result<Self, Self::Err> {
27 match s.to_lowercase().as_str() {
28 "mysql" | "mariadb" => Ok(SqlDialect::MySql),
29 "postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
30 "sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
31 _ => Err(format!(
32 "Unknown dialect: {}. Valid options: mysql, postgres, sqlite",
33 s
34 )),
35 }
36 }
37}
38
39impl std::fmt::Display for SqlDialect {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 SqlDialect::MySql => write!(f, "mysql"),
43 SqlDialect::Postgres => write!(f, "postgres"),
44 SqlDialect::Sqlite => write!(f, "sqlite"),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum StatementType {
51 Unknown,
52 CreateTable,
53 Insert,
54 CreateIndex,
55 AlterTable,
56 DropTable,
57 Copy,
59}
60
61static CREATE_TABLE_RE: Lazy<Regex> =
62 Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
63
64static INSERT_INTO_RE: Lazy<Regex> =
65 Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
66
67static CREATE_INDEX_RE: Lazy<Regex> =
68 Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
69
70static ALTER_TABLE_RE: Lazy<Regex> =
71 Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
72
73static DROP_TABLE_RE: Lazy<Regex> =
74 Lazy::new(|| Regex::new(r"(?i)DROP\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
75
76static COPY_RE: Lazy<Regex> =
78 Lazy::new(|| Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap());
79
80static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
86 Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
87});
88
89static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
90 Regex::new(
91 r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#,
92 )
93 .unwrap()
94});
95
96pub struct Parser<R: Read> {
97 reader: BufReader<R>,
98 stmt_buffer: Vec<u8>,
99 dialect: SqlDialect,
100 in_copy_data: bool,
102}
103
104impl<R: Read> Parser<R> {
105 pub fn new(reader: R, buffer_size: usize) -> Self {
106 Self::with_dialect(reader, buffer_size, SqlDialect::default())
107 }
108
109 pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
110 Self {
111 reader: BufReader::with_capacity(buffer_size, reader),
112 stmt_buffer: Vec::with_capacity(32 * 1024),
113 dialect,
114 in_copy_data: false,
115 }
116 }
117
118 pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
119 if self.in_copy_data {
121 return self.read_copy_data();
122 }
123
124 self.stmt_buffer.clear();
125
126 let mut inside_single_quote = false;
127 let mut inside_double_quote = false;
128 let mut escaped = false;
129 let mut in_line_comment = false;
130 let mut in_dollar_quote = false;
132 let mut dollar_tag: Vec<u8> = Vec::new();
133
134 loop {
135 let buf = self.reader.fill_buf()?;
136 if buf.is_empty() {
137 if self.stmt_buffer.is_empty() {
138 return Ok(None);
139 }
140 let result = std::mem::take(&mut self.stmt_buffer);
141 return Ok(Some(result));
142 }
143
144 let mut consumed = 0;
145 let mut found_terminator = false;
146
147 for (i, &b) in buf.iter().enumerate() {
148 let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
149
150 if in_line_comment {
152 if b == b'\n' {
153 in_line_comment = false;
154 }
155 continue;
156 }
157
158 if escaped {
159 escaped = false;
160 continue;
161 }
162
163 if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
165 escaped = true;
166 continue;
167 }
168
169 if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
171 in_line_comment = true;
172 continue;
173 }
174
175 if self.dialect == SqlDialect::Postgres
177 && !inside_single_quote
178 && !inside_double_quote
179 {
180 if b == b'$' && !in_dollar_quote {
181 if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
183 dollar_tag = buf[i + 1..i + 1 + end].to_vec();
184 in_dollar_quote = true;
185 continue;
186 }
187 } else if b == b'$' && in_dollar_quote {
188 let tag_len = dollar_tag.len();
190 if i + 1 + tag_len < buf.len()
191 && buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
192 && buf.get(i + 1 + tag_len) == Some(&b'$')
193 {
194 in_dollar_quote = false;
195 dollar_tag.clear();
196 continue;
197 }
198 }
199 }
200
201 if b == b'\'' && !inside_double_quote && !in_dollar_quote {
202 inside_single_quote = !inside_single_quote;
203 } else if b == b'"' && !inside_single_quote && !in_dollar_quote {
204 inside_double_quote = !inside_double_quote;
205 } else if b == b';' && !inside_string {
206 self.stmt_buffer.extend_from_slice(&buf[..=i]);
207 consumed = i + 1;
208 found_terminator = true;
209 break;
210 }
211 }
212
213 if found_terminator {
214 self.reader.consume(consumed);
215 let result = std::mem::take(&mut self.stmt_buffer);
216
217 if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
219 self.in_copy_data = true;
220 }
221
222 return Ok(Some(result));
223 }
224
225 self.stmt_buffer.extend_from_slice(buf);
226 let len = buf.len();
227 self.reader.consume(len);
228 }
229 }
230
231 fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
233 let stmt = strip_leading_comments_and_whitespace(stmt);
235 if stmt.len() < 4 {
236 return false;
237 }
238
239 let upper: Vec<u8> = stmt
241 .iter()
242 .take(500)
243 .map(|b| b.to_ascii_uppercase())
244 .collect();
245 upper.starts_with(b"COPY ")
246 && (upper.windows(10).any(|w| w == b"FROM STDIN")
247 || upper.windows(11).any(|w| w == b"FROM STDIN;"))
248 }
249
250 fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
252 self.stmt_buffer.clear();
253
254 loop {
255 let buf = self.reader.fill_buf()?;
257 if buf.is_empty() {
258 self.in_copy_data = false;
259 if self.stmt_buffer.is_empty() {
260 return Ok(None);
261 }
262 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
263 }
264
265 let newline_pos = buf.iter().position(|&b| b == b'\n');
267
268 if let Some(i) = newline_pos {
269 self.stmt_buffer.extend_from_slice(&buf[..=i]);
271 self.reader.consume(i + 1);
272
273 if self.ends_with_copy_terminator() {
276 self.in_copy_data = false;
277 return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
278 }
279 } else {
281 let len = buf.len();
283 self.stmt_buffer.extend_from_slice(buf);
284 self.reader.consume(len);
285 }
286 }
287 }
288
289 fn ends_with_copy_terminator(&self) -> bool {
291 let data = &self.stmt_buffer;
292 if data.len() < 2 {
293 return false;
294 }
295
296 let last_newline = data[..data.len() - 1]
299 .iter()
300 .rposition(|&b| b == b'\n')
301 .map(|i| i + 1)
302 .unwrap_or(0);
303
304 let last_line = &data[last_newline..];
305
306 last_line == b"\\.\n" || last_line == b"\\.\r\n"
308 }
309
310 pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
311 Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
312 }
313
314 pub fn parse_statement_with_dialect(
316 stmt: &[u8],
317 dialect: SqlDialect,
318 ) -> (StatementType, String) {
319 let stmt = strip_leading_comments_and_whitespace(stmt);
321
322 if stmt.len() < 4 {
323 return (StatementType::Unknown, String::new());
324 }
325
326 let upper_prefix: Vec<u8> = stmt
327 .iter()
328 .take(25)
329 .map(|b| b.to_ascii_uppercase())
330 .collect();
331
332 if upper_prefix.starts_with(b"COPY ") {
334 if let Some(caps) = COPY_RE.captures(stmt) {
335 if let Some(m) = caps.get(1) {
336 let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
337 let table_name = name.split('.').last().unwrap_or(&name).to_string();
339 return (StatementType::Copy, table_name);
340 }
341 }
342 }
343
344 if upper_prefix.starts_with(b"CREATE TABLE") {
345 if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
347 return (StatementType::CreateTable, name);
348 }
349 if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
351 if let Some(m) = caps.get(1) {
352 return (
353 StatementType::CreateTable,
354 String::from_utf8_lossy(m.as_bytes()).into_owned(),
355 );
356 }
357 }
358 if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
360 if let Some(m) = caps.get(1) {
361 return (
362 StatementType::CreateTable,
363 String::from_utf8_lossy(m.as_bytes()).into_owned(),
364 );
365 }
366 }
367 }
368
369 if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
370 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
371 return (StatementType::Insert, name);
372 }
373 if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
374 if let Some(m) = caps.get(1) {
375 return (
376 StatementType::Insert,
377 String::from_utf8_lossy(m.as_bytes()).into_owned(),
378 );
379 }
380 }
381 if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
382 if let Some(m) = caps.get(1) {
383 return (
384 StatementType::Insert,
385 String::from_utf8_lossy(m.as_bytes()).into_owned(),
386 );
387 }
388 }
389 }
390
391 if upper_prefix.starts_with(b"CREATE INDEX") {
392 if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
393 if let Some(m) = caps.get(1) {
394 return (
395 StatementType::CreateIndex,
396 String::from_utf8_lossy(m.as_bytes()).into_owned(),
397 );
398 }
399 }
400 }
401
402 if upper_prefix.starts_with(b"ALTER TABLE") {
403 if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
404 return (StatementType::AlterTable, name);
405 }
406 if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
407 if let Some(m) = caps.get(1) {
408 return (
409 StatementType::AlterTable,
410 String::from_utf8_lossy(m.as_bytes()).into_owned(),
411 );
412 }
413 }
414 }
415
416 if upper_prefix.starts_with(b"DROP TABLE") {
417 if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
418 return (StatementType::DropTable, name);
419 }
420 if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
421 if let Some(m) = caps.get(1) {
422 return (
423 StatementType::DropTable,
424 String::from_utf8_lossy(m.as_bytes()).into_owned(),
425 );
426 }
427 }
428 }
429
430 (StatementType::Unknown, String::new())
431 }
432}
433
434#[inline]
435fn trim_ascii_start(data: &[u8]) -> &[u8] {
436 let start = data
437 .iter()
438 .position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
439 .unwrap_or(data.len());
440 &data[start..]
441}
442
443fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
446 loop {
447 data = trim_ascii_start(data);
449
450 if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
451 if let Some(pos) = data.iter().position(|&b| b == b'\n') {
453 data = &data[pos + 1..];
454 continue;
455 } else {
456 return &[];
458 }
459 }
460
461 break;
462 }
463
464 data
465}
466
467#[inline]
473fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
474 let mut i = offset;
475
476 while i < stmt.len() && is_whitespace(stmt[i]) {
478 i += 1;
479 }
480
481 if i >= stmt.len() {
482 return None;
483 }
484
485 let upper_check: Vec<u8> = stmt[i..]
487 .iter()
488 .take(20)
489 .map(|b| b.to_ascii_uppercase())
490 .collect();
491 if upper_check.starts_with(b"IF NOT EXISTS") {
492 i += 13; while i < stmt.len() && is_whitespace(stmt[i]) {
494 i += 1;
495 }
496 }
497
498 let upper_check: Vec<u8> = stmt[i..]
500 .iter()
501 .take(10)
502 .map(|b| b.to_ascii_uppercase())
503 .collect();
504 if upper_check.starts_with(b"ONLY ") || upper_check.starts_with(b"ONLY\t") {
505 i += 4;
506 while i < stmt.len() && is_whitespace(stmt[i]) {
507 i += 1;
508 }
509 }
510
511 if i >= stmt.len() {
512 return None;
513 }
514
515 let mut parts: Vec<String> = Vec::new();
517
518 loop {
519 let quote_char = match stmt.get(i) {
521 Some(b'`') if dialect == SqlDialect::MySql => {
522 i += 1;
523 Some(b'`')
524 }
525 Some(b'"') if dialect != SqlDialect::MySql => {
526 i += 1;
527 Some(b'"')
528 }
529 Some(b'"') => {
530 i += 1;
532 Some(b'"')
533 }
534 _ => None,
535 };
536
537 let start = i;
538
539 while i < stmt.len() {
540 let b = stmt[i];
541 if let Some(q) = quote_char {
542 if b == q {
543 let name = &stmt[start..i];
544 parts.push(String::from_utf8_lossy(name).into_owned());
545 i += 1; break;
547 }
548 } else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
549 if i > start {
550 let name = &stmt[start..i];
551 parts.push(String::from_utf8_lossy(name).into_owned());
552 }
553 break;
554 }
555 i += 1;
556 }
557
558 if quote_char.is_some() && i <= start {
560 break;
561 }
562
563 while i < stmt.len() && is_whitespace(stmt[i]) {
565 i += 1;
566 }
567
568 if i < stmt.len() && stmt[i] == b'.' {
569 i += 1; while i < stmt.len() && is_whitespace(stmt[i]) {
571 i += 1;
572 }
573 } else {
575 break;
576 }
577 }
578
579 parts.pop()
581}
582
583#[inline]
584fn is_whitespace(b: u8) -> bool {
585 matches!(b, b' ' | b'\t' | b'\n' | b'\r')
586}
587
588pub fn determine_buffer_size(file_size: u64) -> usize {
589 if file_size > 1024 * 1024 * 1024 {
590 MEDIUM_BUFFER_SIZE
591 } else {
592 SMALL_BUFFER_SIZE
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599
600 #[test]
601 fn test_parse_create_table() {
602 let stmt = b"CREATE TABLE users (id INT);";
603 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
604 assert_eq!(typ, StatementType::CreateTable);
605 assert_eq!(name, "users");
606 }
607
608 #[test]
609 fn test_parse_create_table_backticks() {
610 let stmt = b"CREATE TABLE `my_table` (id INT);";
611 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
612 assert_eq!(typ, StatementType::CreateTable);
613 assert_eq!(name, "my_table");
614 }
615
616 #[test]
617 fn test_parse_insert() {
618 let stmt = b"INSERT INTO posts VALUES (1, 'test');";
619 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
620 assert_eq!(typ, StatementType::Insert);
621 assert_eq!(name, "posts");
622 }
623
624 #[test]
625 fn test_parse_insert_backticks() {
626 let stmt = b"INSERT INTO `comments` VALUES (1);";
627 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
628 assert_eq!(typ, StatementType::Insert);
629 assert_eq!(name, "comments");
630 }
631
632 #[test]
633 fn test_parse_alter_table() {
634 let stmt = b"ALTER TABLE orders ADD COLUMN status INT;";
635 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
636 assert_eq!(typ, StatementType::AlterTable);
637 assert_eq!(name, "orders");
638 }
639
640 #[test]
641 fn test_parse_drop_table() {
642 let stmt = b"DROP TABLE temp_data;";
643 let (typ, name) = Parser::<&[u8]>::parse_statement(stmt);
644 assert_eq!(typ, StatementType::DropTable);
645 assert_eq!(name, "temp_data");
646 }
647
648 #[test]
649 fn test_read_statement_basic() {
650 let sql = b"CREATE TABLE t1 (id INT); INSERT INTO t1 VALUES (1);";
651 let mut parser = Parser::new(&sql[..], 1024);
652
653 let stmt1 = parser.read_statement().unwrap().unwrap();
654 assert_eq!(stmt1, b"CREATE TABLE t1 (id INT);");
655
656 let stmt2 = parser.read_statement().unwrap().unwrap();
657 assert_eq!(stmt2, b" INSERT INTO t1 VALUES (1);");
658
659 let stmt3 = parser.read_statement().unwrap();
660 assert!(stmt3.is_none());
661 }
662
663 #[test]
664 fn test_read_statement_with_strings() {
665 let sql = b"INSERT INTO t1 VALUES ('hello; world');";
666 let mut parser = Parser::new(&sql[..], 1024);
667
668 let stmt = parser.read_statement().unwrap().unwrap();
669 assert_eq!(stmt, b"INSERT INTO t1 VALUES ('hello; world');");
670 }
671
672 #[test]
673 fn test_read_statement_with_escaped_quotes() {
674 let sql = b"INSERT INTO t1 VALUES ('it\\'s a test');";
675 let mut parser = Parser::new(&sql[..], 1024);
676
677 let stmt = parser.read_statement().unwrap().unwrap();
678 assert_eq!(stmt, b"INSERT INTO t1 VALUES ('it\\'s a test');");
679 }
680}
681
682#[cfg(test)]
683mod copy_tests {
684 use super::*;
685 use std::io::Cursor;
686
687 #[test]
688 fn test_copy_from_stdin_detection() {
689 let data = b"COPY public.table_001 (id, col_int, col_varchar, col_text, col_decimal, created_at) FROM stdin;\n1\t6892\tvalue_1\tLorem ipsum\n\\.\n";
690 let reader = Cursor::new(&data[..]);
691 let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
692
693 let stmt1 = parser.read_statement().unwrap().unwrap();
695 let s1 = String::from_utf8_lossy(&stmt1);
696 assert!(s1.starts_with("COPY"), "First statement should be COPY");
697 assert!(s1.contains("FROM stdin"), "Should contain FROM stdin");
698
699 let stmt2 = parser.read_statement().unwrap().unwrap();
701 let s2 = String::from_utf8_lossy(&stmt2);
702 assert!(
703 s2.contains("1\t6892"),
704 "Data block should contain first row"
705 );
706 assert!(
707 s2.ends_with("\\.\n"),
708 "Data block should end with terminator"
709 );
710 }
711
712 #[test]
713 fn test_copy_with_leading_comments() {
714 let data = b"--\n-- Data for Name: table_001\n--\n\nCOPY public.table_001 (id, name) FROM stdin;\n1\tfoo\n\\.\n";
716 let reader = Cursor::new(&data[..]);
717 let mut parser = Parser::with_dialect(reader, 1024, SqlDialect::Postgres);
718
719 let stmt1 = parser.read_statement().unwrap().unwrap();
721 let (stmt_type, table_name) =
722 Parser::<&[u8]>::parse_statement_with_dialect(&stmt1, SqlDialect::Postgres);
723 assert_eq!(stmt_type, StatementType::Copy);
724 assert_eq!(table_name, "table_001");
725
726 let stmt2 = parser.read_statement().unwrap().unwrap();
728 let s2 = String::from_utf8_lossy(&stmt2);
729 assert!(
730 s2.ends_with("\\.\n"),
731 "Data block should end with terminator"
732 );
733 }
734}