use once_cell::sync::Lazy;
use regex::Regex;
const MAX_ROWS_PER_INSERT: usize = 1000;
#[derive(Debug, Clone)]
pub struct CopyHeader {
pub schema: Option<String>,
pub table: String,
pub columns: Vec<String>,
}
pub fn parse_copy_header(stmt: &str) -> Option<CopyHeader> {
let stmt = strip_leading_comments(stmt);
static RE_COPY: Lazy<Regex> = Lazy::new(|| {
Regex::new(
r#"(?i)^\s*COPY\s+(?:ONLY\s+)?(?:"?(\w+)"?\.)?["]?(\w+)["]?\s*(?:\(([^)]+)\))?\s+FROM\s+stdin"#
).unwrap()
});
let caps = RE_COPY.captures(&stmt)?;
let schema = caps.get(1).map(|m| m.as_str().to_string());
let table = caps.get(2)?.as_str().to_string();
let columns = caps
.get(3)
.map(|m| {
m.as_str()
.split(',')
.map(|c| c.trim().trim_matches('"').trim_matches('`').to_string())
.collect()
})
.unwrap_or_default();
Some(CopyHeader {
schema,
table,
columns,
})
}
fn strip_leading_comments(stmt: &str) -> String {
let mut result = stmt.trim();
loop {
if result.starts_with("--") {
if let Some(pos) = result.find('\n') {
result = result[pos + 1..].trim();
continue;
} else {
return String::new();
}
}
if result.starts_with("/*") {
if let Some(pos) = result.find("*/") {
result = result[pos + 2..].trim();
continue;
} else {
return String::new();
}
}
break;
}
result.to_string()
}
pub fn copy_to_inserts(
header: &CopyHeader,
data: &[u8],
target_dialect: crate::parser::SqlDialect,
) -> Vec<Vec<u8>> {
let mut inserts = Vec::new();
let rows = parse_copy_data(data);
if rows.is_empty() {
return inserts;
}
let quote_char = match target_dialect {
crate::parser::SqlDialect::MySql => '`',
_ => '"',
};
let table_ref = if let Some(ref schema) = header.schema {
if target_dialect == crate::parser::SqlDialect::MySql {
format!("{}{}{}", quote_char, header.table, quote_char)
} else if schema == "public" || schema == "pg_catalog" {
format!("{}{}{}", quote_char, header.table, quote_char)
} else {
format!(
"{}{}{}.{}{}{}",
quote_char, schema, quote_char, quote_char, header.table, quote_char
)
}
} else {
format!("{}{}{}", quote_char, header.table, quote_char)
};
let columns_str = if header.columns.is_empty() {
String::new()
} else {
let cols: Vec<String> = header
.columns
.iter()
.map(|c| format!("{}{}{}", quote_char, c, quote_char))
.collect();
format!(" ({})", cols.join(", "))
};
for chunk in rows.chunks(MAX_ROWS_PER_INSERT) {
let mut insert = format!("INSERT INTO {}{} VALUES\n", table_ref, columns_str);
for (i, row) in chunk.iter().enumerate() {
if i > 0 {
insert.push_str(",\n");
}
insert.push('(');
for (j, value) in row.iter().enumerate() {
if j > 0 {
insert.push_str(", ");
}
insert.push_str(&format_value(value, target_dialect));
}
insert.push(')');
}
insert.push(';');
inserts.push(insert.into_bytes());
}
inserts
}
#[derive(Debug, Clone)]
pub enum CopyValue {
Null,
Text(String),
}
pub fn parse_copy_data(data: &[u8]) -> Vec<Vec<CopyValue>> {
let mut rows = Vec::new();
let mut pos = 0;
while pos < data.len() {
let line_end = data[pos..]
.iter()
.position(|&b| b == b'\n')
.map(|p| pos + p)
.unwrap_or(data.len());
let line = &data[pos..line_end];
if line == b"\\." || line.is_empty() {
pos = line_end + 1;
continue;
}
let row = parse_row(line);
if !row.is_empty() {
rows.push(row);
}
pos = line_end + 1;
}
rows
}
fn parse_row(line: &[u8]) -> Vec<CopyValue> {
let mut values = Vec::new();
let mut start = 0;
for (i, &b) in line.iter().enumerate() {
if b == b'\t' {
values.push(parse_value(&line[start..i]));
start = i + 1;
}
}
if start <= line.len() {
values.push(parse_value(&line[start..]));
}
values
}
fn parse_value(value: &[u8]) -> CopyValue {
if value == b"\\N" {
return CopyValue::Null;
}
let decoded = decode_escapes(value);
CopyValue::Text(decoded)
}
fn decode_escapes(value: &[u8]) -> String {
let mut result = String::with_capacity(value.len());
let mut i = 0;
while i < value.len() {
if value[i] == b'\\' && i + 1 < value.len() {
let next = value[i + 1];
let decoded = match next {
b'n' => '\n',
b'r' => '\r',
b't' => '\t',
b'\\' => '\\',
b'b' => '\x08', b'f' => '\x0C', b'v' => '\x0B', _ => {
if next.is_ascii_digit() {
let mut octal_val = 0u8;
let mut consumed = 0;
for j in 0..3 {
if i + 1 + j < value.len() {
let d = value[i + 1 + j];
if (b'0'..=b'7').contains(&d) {
octal_val = octal_val * 8 + (d - b'0');
consumed += 1;
} else {
break;
}
}
}
if consumed > 0 {
result.push(octal_val as char);
i += 1 + consumed;
continue;
}
}
result.push('\\');
result.push(next as char);
i += 2;
continue;
}
};
result.push(decoded);
i += 2;
} else {
if value[i] < 128 {
result.push(value[i] as char);
i += 1;
} else {
let remaining = &value[i..];
if let Ok(s) = std::str::from_utf8(remaining) {
if let Some(c) = s.chars().next() {
result.push(c);
i += c.len_utf8();
} else {
i += 1;
}
} else {
result.push('\u{FFFD}');
i += 1;
}
}
}
}
result
}
fn format_value(value: &CopyValue, dialect: crate::parser::SqlDialect) -> String {
match value {
CopyValue::Null => "NULL".to_string(),
CopyValue::Text(s) => {
let escaped = match dialect {
crate::parser::SqlDialect::MySql => {
s.replace('\\', "\\\\")
.replace('\'', "\\'")
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\t', "\\t")
.replace('\0', "\\0")
}
_ => {
s.replace('\'', "''")
}
};
format!("'{}'", escaped)
}
}
}