pub mod mysql_insert;
pub mod postgres_copy;
use once_cell::sync::Lazy;
use regex::bytes::Regex;
use std::io::{BufRead, BufReader, Read};
pub const SMALL_BUFFER_SIZE: usize = 64 * 1024;
pub const MEDIUM_BUFFER_SIZE: usize = 256 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SqlDialect {
#[default]
MySql,
Postgres,
Sqlite,
}
impl std::str::FromStr for SqlDialect {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"mysql" | "mariadb" => Ok(SqlDialect::MySql),
"postgres" | "postgresql" | "pg" => Ok(SqlDialect::Postgres),
"sqlite" | "sqlite3" => Ok(SqlDialect::Sqlite),
_ => Err(format!(
"Unknown dialect: {}. Valid options: mysql, postgres, sqlite",
s
)),
}
}
}
impl std::fmt::Display for SqlDialect {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SqlDialect::MySql => write!(f, "mysql"),
SqlDialect::Postgres => write!(f, "postgres"),
SqlDialect::Sqlite => write!(f, "sqlite"),
}
}
}
#[derive(Debug, Clone)]
pub struct DialectDetectionResult {
pub dialect: SqlDialect,
pub confidence: DialectConfidence,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DialectConfidence {
High,
Medium,
Low,
}
#[derive(Default)]
struct DialectScore {
mysql: u32,
postgres: u32,
sqlite: u32,
}
pub fn detect_dialect(header: &[u8]) -> DialectDetectionResult {
let mut score = DialectScore::default();
if contains_bytes(header, b"pg_dump") {
score.postgres += 10;
}
if contains_bytes(header, b"PostgreSQL database dump") {
score.postgres += 10;
}
if contains_bytes(header, b"MySQL dump") {
score.mysql += 10;
}
if contains_bytes(header, b"MariaDB dump") {
score.mysql += 10;
}
if contains_bytes(header, b"SQLite") {
score.sqlite += 10;
}
if contains_bytes(header, b"COPY ") && contains_bytes(header, b"FROM stdin") {
score.postgres += 5;
}
if contains_bytes(header, b"search_path") {
score.postgres += 5;
}
if contains_bytes(header, b"/*!40") || contains_bytes(header, b"/*!50") {
score.mysql += 5;
}
if contains_bytes(header, b"LOCK TABLES") {
score.mysql += 5;
}
if contains_bytes(header, b"PRAGMA") {
score.sqlite += 5;
}
if contains_bytes(header, b"$$") {
score.postgres += 2;
}
if contains_bytes(header, b"CREATE EXTENSION") {
score.postgres += 2;
}
if contains_bytes(header, b"BEGIN TRANSACTION") {
score.sqlite += 2;
}
if header.contains(&b'`') {
score.mysql += 2;
}
let max_score = score.mysql.max(score.postgres).max(score.sqlite);
if max_score == 0 {
return DialectDetectionResult {
dialect: SqlDialect::MySql,
confidence: DialectConfidence::Low,
};
}
let (dialect, confidence) = if score.postgres > score.mysql && score.postgres > score.sqlite {
let conf = if score.postgres >= 10 {
DialectConfidence::High
} else if score.postgres >= 5 {
DialectConfidence::Medium
} else {
DialectConfidence::Low
};
(SqlDialect::Postgres, conf)
} else if score.sqlite > score.mysql {
let conf = if score.sqlite >= 10 {
DialectConfidence::High
} else if score.sqlite >= 5 {
DialectConfidence::Medium
} else {
DialectConfidence::Low
};
(SqlDialect::Sqlite, conf)
} else {
let conf = if score.mysql >= 10 {
DialectConfidence::High
} else if score.mysql >= 5 {
DialectConfidence::Medium
} else {
DialectConfidence::Low
};
(SqlDialect::MySql, conf)
};
DialectDetectionResult {
dialect,
confidence,
}
}
pub fn detect_dialect_from_file(path: &std::path::Path) -> std::io::Result<DialectDetectionResult> {
use std::fs::File;
use std::io::Read;
let mut file = File::open(path)?;
let mut buf = [0u8; 8192];
let n = file.read(&mut buf)?;
Ok(detect_dialect(&buf[..n]))
}
#[inline]
fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
haystack
.windows(needle.len())
.any(|window| window == needle)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StatementType {
Unknown,
CreateTable,
Insert,
CreateIndex,
AlterTable,
DropTable,
Copy,
}
impl StatementType {
pub fn is_schema(&self) -> bool {
matches!(
self,
StatementType::CreateTable
| StatementType::CreateIndex
| StatementType::AlterTable
| StatementType::DropTable
)
}
pub fn is_data(&self) -> bool {
matches!(self, StatementType::Insert | StatementType::Copy)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ContentFilter {
#[default]
All,
SchemaOnly,
DataOnly,
}
static CREATE_TABLE_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)^\s*CREATE\s+TABLE\s+`?([^\s`(]+)`?").unwrap());
static INSERT_INTO_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)^\s*INSERT\s+INTO\s+`?([^\s`(]+)`?").unwrap());
static CREATE_INDEX_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)ON\s+`?([^\s`(;]+)`?").unwrap());
static ALTER_TABLE_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+`?([^\s`;]+)`?").unwrap());
static DROP_TABLE_RE: Lazy<Regex> = Lazy::new(|| {
Regex::new(r#"(?i)DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?[`"]?([^\s`"`;]+)[`"]?"#).unwrap()
});
static COPY_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r#"(?i)^\s*COPY\s+(?:ONLY\s+)?[`"]?([^\s`"(]+)[`"]?"#).unwrap());
static CREATE_TABLE_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
Regex::new(r#"(?i)^\s*CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#).unwrap()
});
static INSERT_FLEXIBLE_RE: Lazy<Regex> = Lazy::new(|| {
Regex::new(
r#"(?i)^\s*INSERT\s+INTO\s+(?:ONLY\s+)?(?:[`"]?[\w]+[`"]?\s*\.\s*)?[`"]?([\w]+)[`"]?"#,
)
.unwrap()
});
pub struct Parser<R: Read> {
reader: BufReader<R>,
stmt_buffer: Vec<u8>,
dialect: SqlDialect,
in_copy_data: bool,
}
impl<R: Read> Parser<R> {
#[allow(dead_code)]
pub fn new(reader: R, buffer_size: usize) -> Self {
Self::with_dialect(reader, buffer_size, SqlDialect::default())
}
pub fn with_dialect(reader: R, buffer_size: usize, dialect: SqlDialect) -> Self {
Self {
reader: BufReader::with_capacity(buffer_size, reader),
stmt_buffer: Vec::with_capacity(32 * 1024),
dialect,
in_copy_data: false,
}
}
pub fn read_statement(&mut self) -> std::io::Result<Option<Vec<u8>>> {
if self.in_copy_data {
return self.read_copy_data();
}
self.stmt_buffer.clear();
let mut inside_single_quote = false;
let mut inside_double_quote = false;
let mut escaped = false;
let mut in_line_comment = false;
let mut in_dollar_quote = false;
let mut dollar_tag: Vec<u8> = Vec::new();
loop {
let buf = self.reader.fill_buf()?;
if buf.is_empty() {
if self.stmt_buffer.is_empty() {
return Ok(None);
}
let result = std::mem::take(&mut self.stmt_buffer);
return Ok(Some(result));
}
let mut consumed = 0;
let mut found_terminator = false;
for (i, &b) in buf.iter().enumerate() {
let inside_string = inside_single_quote || inside_double_quote || in_dollar_quote;
if in_line_comment {
if b == b'\n' {
in_line_comment = false;
}
continue;
}
if escaped {
escaped = false;
continue;
}
if b == b'\\' && inside_string && self.dialect == SqlDialect::MySql {
escaped = true;
continue;
}
if b == b'-' && !inside_string && i + 1 < buf.len() && buf[i + 1] == b'-' {
in_line_comment = true;
continue;
}
if self.dialect == SqlDialect::Postgres
&& !inside_single_quote
&& !inside_double_quote
{
if b == b'$' && !in_dollar_quote {
if let Some(end) = buf[i + 1..].iter().position(|&c| c == b'$') {
let tag_bytes = &buf[i + 1..i + 1 + end];
let is_valid_tag = if tag_bytes.is_empty() {
true
} else {
let mut iter = tag_bytes.iter();
match iter.next() {
Some(&first)
if first.is_ascii_alphabetic() || first == b'_' =>
{
iter.all(|&c| c.is_ascii_alphanumeric() || c == b'_')
}
_ => false,
}
};
if is_valid_tag {
dollar_tag = tag_bytes.to_vec();
in_dollar_quote = true;
continue;
}
}
} else if b == b'$' && in_dollar_quote {
let tag_len = dollar_tag.len();
if i + 1 + tag_len < buf.len()
&& buf[i + 1..i + 1 + tag_len] == dollar_tag[..]
&& buf.get(i + 1 + tag_len) == Some(&b'$')
{
in_dollar_quote = false;
dollar_tag.clear();
continue;
}
}
}
if b == b'\'' && !inside_double_quote && !in_dollar_quote {
inside_single_quote = !inside_single_quote;
} else if b == b'"' && !inside_single_quote && !in_dollar_quote {
inside_double_quote = !inside_double_quote;
} else if b == b';' && !inside_string {
self.stmt_buffer.extend_from_slice(&buf[..=i]);
consumed = i + 1;
found_terminator = true;
break;
}
}
if found_terminator {
self.reader.consume(consumed);
let result = std::mem::take(&mut self.stmt_buffer);
if self.dialect == SqlDialect::Postgres && self.is_copy_from_stdin(&result) {
self.in_copy_data = true;
}
return Ok(Some(result));
}
self.stmt_buffer.extend_from_slice(buf);
let len = buf.len();
self.reader.consume(len);
}
}
fn is_copy_from_stdin(&self, stmt: &[u8]) -> bool {
let stmt = strip_leading_comments_and_whitespace(stmt);
if stmt.len() < 4 {
return false;
}
let upper: Vec<u8> = stmt
.iter()
.take(500)
.map(|b| b.to_ascii_uppercase())
.collect();
upper.starts_with(b"COPY ")
&& (upper.windows(10).any(|w| w == b"FROM STDIN")
|| upper.windows(11).any(|w| w == b"FROM STDIN;"))
}
fn read_copy_data(&mut self) -> std::io::Result<Option<Vec<u8>>> {
self.stmt_buffer.clear();
loop {
let buf = self.reader.fill_buf()?;
if buf.is_empty() {
self.in_copy_data = false;
if self.stmt_buffer.is_empty() {
return Ok(None);
}
return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
}
let newline_pos = buf.iter().position(|&b| b == b'\n');
if let Some(i) = newline_pos {
self.stmt_buffer.extend_from_slice(&buf[..=i]);
self.reader.consume(i + 1);
if self.ends_with_copy_terminator() {
self.in_copy_data = false;
return Ok(Some(std::mem::take(&mut self.stmt_buffer)));
}
} else {
let len = buf.len();
self.stmt_buffer.extend_from_slice(buf);
self.reader.consume(len);
}
}
}
fn ends_with_copy_terminator(&self) -> bool {
let data = &self.stmt_buffer;
if data.len() < 2 {
return false;
}
let last_newline = data[..data.len() - 1]
.iter()
.rposition(|&b| b == b'\n')
.map(|i| i + 1)
.unwrap_or(0);
let last_line = &data[last_newline..];
last_line == b"\\.\n" || last_line == b"\\.\r\n"
}
#[allow(dead_code)]
pub fn parse_statement(stmt: &[u8]) -> (StatementType, String) {
Self::parse_statement_with_dialect(stmt, SqlDialect::MySql)
}
pub fn parse_statement_with_dialect(
stmt: &[u8],
dialect: SqlDialect,
) -> (StatementType, String) {
let stmt = strip_leading_comments_and_whitespace(stmt);
if stmt.len() < 4 {
return (StatementType::Unknown, String::new());
}
let upper_prefix: Vec<u8> = stmt
.iter()
.take(25)
.map(|b| b.to_ascii_uppercase())
.collect();
if upper_prefix.starts_with(b"COPY ") {
if let Some(caps) = COPY_RE.captures(stmt) {
if let Some(m) = caps.get(1) {
let name = String::from_utf8_lossy(m.as_bytes()).into_owned();
let table_name = name.split('.').next_back().unwrap_or(&name).to_string();
return (StatementType::Copy, table_name);
}
}
}
if upper_prefix.starts_with(b"CREATE TABLE") {
if let Some(name) = extract_table_name_flexible(stmt, 12, dialect) {
return (StatementType::CreateTable, name);
}
if let Some(caps) = CREATE_TABLE_FLEXIBLE_RE.captures(stmt) {
if let Some(m) = caps.get(1) {
return (
StatementType::CreateTable,
String::from_utf8_lossy(m.as_bytes()).into_owned(),
);
}
}
if let Some(caps) = CREATE_TABLE_RE.captures(stmt) {
if let Some(m) = caps.get(1) {
return (
StatementType::CreateTable,
String::from_utf8_lossy(m.as_bytes()).into_owned(),
);
}
}
}
if upper_prefix.starts_with(b"INSERT INTO") || upper_prefix.starts_with(b"INSERT ONLY") {
if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
return (StatementType::Insert, name);
}
if let Some(caps) = INSERT_FLEXIBLE_RE.captures(stmt) {
if let Some(m) = caps.get(1) {
return (
StatementType::Insert,
String::from_utf8_lossy(m.as_bytes()).into_owned(),
);
}
}
if let Some(caps) = INSERT_INTO_RE.captures(stmt) {
if let Some(m) = caps.get(1) {
return (
StatementType::Insert,
String::from_utf8_lossy(m.as_bytes()).into_owned(),
);
}
}
}
if upper_prefix.starts_with(b"CREATE INDEX") {
if let Some(caps) = CREATE_INDEX_RE.captures(stmt) {
if let Some(m) = caps.get(1) {
return (
StatementType::CreateIndex,
String::from_utf8_lossy(m.as_bytes()).into_owned(),
);
}
}
}
if upper_prefix.starts_with(b"ALTER TABLE") {
if let Some(name) = extract_table_name_flexible(stmt, 11, dialect) {
return (StatementType::AlterTable, name);
}
if let Some(caps) = ALTER_TABLE_RE.captures(stmt) {
if let Some(m) = caps.get(1) {
return (
StatementType::AlterTable,
String::from_utf8_lossy(m.as_bytes()).into_owned(),
);
}
}
}
if upper_prefix.starts_with(b"DROP TABLE") {
if let Some(name) = extract_table_name_flexible(stmt, 10, dialect) {
return (StatementType::DropTable, name);
}
if let Some(caps) = DROP_TABLE_RE.captures(stmt) {
if let Some(m) = caps.get(1) {
return (
StatementType::DropTable,
String::from_utf8_lossy(m.as_bytes()).into_owned(),
);
}
}
}
(StatementType::Unknown, String::new())
}
}
#[inline]
fn trim_ascii_start(data: &[u8]) -> &[u8] {
let start = data
.iter()
.position(|&b| !matches!(b, b' ' | b'\t' | b'\n' | b'\r'))
.unwrap_or(data.len());
&data[start..]
}
fn strip_leading_comments_and_whitespace(mut data: &[u8]) -> &[u8] {
loop {
data = trim_ascii_start(data);
if data.len() >= 2 && data[0] == b'-' && data[1] == b'-' {
if let Some(pos) = data.iter().position(|&b| b == b'\n') {
data = &data[pos + 1..];
continue;
} else {
return &[];
}
}
if data.len() >= 2 && data[0] == b'/' && data[1] == b'*' {
let mut i = 2;
let mut depth = 1;
while i < data.len() - 1 && depth > 0 {
if data[i] == b'*' && data[i + 1] == b'/' {
depth -= 1;
i += 2;
} else if data[i] == b'/' && data[i + 1] == b'*' {
depth += 1;
i += 2;
} else {
i += 1;
}
}
if depth == 0 {
data = &data[i..];
continue;
} else {
return &[];
}
}
if !data.is_empty() && data[0] == b'#' {
if let Some(pos) = data.iter().position(|&b| b == b'\n') {
data = &data[pos + 1..];
continue;
} else {
return &[];
}
}
break;
}
data
}
#[inline]
fn extract_table_name_flexible(stmt: &[u8], offset: usize, dialect: SqlDialect) -> Option<String> {
let mut i = offset;
while i < stmt.len() && is_whitespace(stmt[i]) {
i += 1;
}
if i >= stmt.len() {
return None;
}
let upper_check: Vec<u8> = stmt[i..]
.iter()
.take(20)
.map(|b| b.to_ascii_uppercase())
.collect();
if upper_check.starts_with(b"IF NOT EXISTS") {
i += 13; while i < stmt.len() && is_whitespace(stmt[i]) {
i += 1;
}
} else if upper_check.starts_with(b"IF EXISTS") {
i += 9; while i < stmt.len() && is_whitespace(stmt[i]) {
i += 1;
}
}
let upper_check: Vec<u8> = stmt[i..]
.iter()
.take(10)
.map(|b| b.to_ascii_uppercase())
.collect();
if upper_check.starts_with(b"ONLY ") || upper_check.starts_with(b"ONLY\t") {
i += 4;
while i < stmt.len() && is_whitespace(stmt[i]) {
i += 1;
}
}
if i >= stmt.len() {
return None;
}
let mut parts: Vec<String> = Vec::new();
loop {
let quote_char = match stmt.get(i) {
Some(b'`') if dialect == SqlDialect::MySql => {
i += 1;
Some(b'`')
}
Some(b'"') if dialect != SqlDialect::MySql => {
i += 1;
Some(b'"')
}
Some(b'"') => {
i += 1;
Some(b'"')
}
_ => None,
};
let start = i;
while i < stmt.len() {
let b = stmt[i];
if let Some(q) = quote_char {
if b == q {
let name = &stmt[start..i];
parts.push(String::from_utf8_lossy(name).into_owned());
i += 1; break;
}
} else if is_whitespace(b) || b == b'(' || b == b';' || b == b',' || b == b'.' {
if i > start {
let name = &stmt[start..i];
parts.push(String::from_utf8_lossy(name).into_owned());
}
break;
}
i += 1;
}
if quote_char.is_some() && i <= start {
break;
}
while i < stmt.len() && is_whitespace(stmt[i]) {
i += 1;
}
if i < stmt.len() && stmt[i] == b'.' {
i += 1; while i < stmt.len() && is_whitespace(stmt[i]) {
i += 1;
}
} else {
break;
}
}
parts.pop()
}
#[inline]
fn is_whitespace(b: u8) -> bool {
matches!(b, b' ' | b'\t' | b'\n' | b'\r')
}
pub fn determine_buffer_size(file_size: u64) -> usize {
if file_size > 1024 * 1024 * 1024 {
MEDIUM_BUFFER_SIZE
} else {
SMALL_BUFFER_SIZE
}
}