use crate::schema::{ColumnId, ColumnType, TableSchema};
use ahash::AHashSet;
use smallvec::SmallVec;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PkValue {
Int(i64),
BigInt(i128),
Text(Box<str>),
Null,
}
impl PkValue {
pub fn is_null(&self) -> bool {
matches!(self, PkValue::Null)
}
}
pub type PkTuple = SmallVec<[PkValue; 2]>;
pub type PkSet = AHashSet<PkTuple>;
pub type PkHashSet = AHashSet<u64>;
pub fn hash_pk_tuple(pk: &PkTuple) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = ahash::AHasher::default();
(pk.len() as u8).hash(&mut hasher);
for v in pk {
match v {
PkValue::Int(i) => {
0u8.hash(&mut hasher);
i.hash(&mut hasher);
}
PkValue::BigInt(i) => {
1u8.hash(&mut hasher);
i.hash(&mut hasher);
}
PkValue::Text(s) => {
2u8.hash(&mut hasher);
s.hash(&mut hasher);
}
PkValue::Null => {
3u8.hash(&mut hasher);
}
}
}
hasher.finish()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FkRef {
pub table_id: u32,
pub fk_index: u16,
}
#[derive(Debug, Clone)]
pub struct ParsedRow {
pub raw: Vec<u8>,
pub values: Vec<ParsedValue>,
pub pk: Option<PkTuple>,
pub fk_values: Vec<(FkRef, PkTuple)>,
pub all_values: Vec<PkValue>,
pub column_map: Vec<Option<usize>>,
}
impl ParsedRow {
pub fn get_column_value(&self, schema_col_index: usize) -> Option<&PkValue> {
self.column_map
.get(schema_col_index)
.and_then(|v| *v)
.and_then(|val_idx| self.all_values.get(val_idx))
}
}
pub struct InsertParser<'a> {
stmt: &'a [u8],
pos: usize,
table_schema: Option<&'a TableSchema>,
column_order: Vec<Option<ColumnId>>,
}
impl<'a> InsertParser<'a> {
pub fn new(stmt: &'a [u8]) -> Self {
Self {
stmt,
pos: 0,
table_schema: None,
column_order: Vec::new(),
}
}
pub fn with_schema(mut self, schema: &'a TableSchema) -> Self {
self.table_schema = Some(schema);
self
}
pub fn parse_rows(&mut self) -> anyhow::Result<Vec<ParsedRow>> {
let values_pos = self.find_values_keyword()?;
self.pos = values_pos;
self.parse_column_list();
let mut rows = Vec::new();
while self.pos < self.stmt.len() {
self.skip_whitespace();
if self.pos >= self.stmt.len() {
break;
}
if self.stmt[self.pos] == b'(' {
if let Some(row) = self.parse_row()? {
rows.push(row);
}
} else if self.stmt[self.pos] == b',' {
self.pos += 1;
} else if self.stmt[self.pos] == b';' {
break;
} else {
self.pos += 1;
}
}
Ok(rows)
}
fn find_values_keyword(&self) -> anyhow::Result<usize> {
let stmt_str = String::from_utf8_lossy(self.stmt);
let upper = stmt_str.to_uppercase();
if let Some(pos) = upper.find("VALUES") {
Ok(pos + 6) } else {
anyhow::bail!("INSERT statement missing VALUES keyword")
}
}
fn parse_column_list(&mut self) {
if self.table_schema.is_none() {
return;
}
let schema = self.table_schema.unwrap();
let before_values = &self.stmt[..self.pos.saturating_sub(6)];
let stmt_str = String::from_utf8_lossy(before_values);
if let Some(close_paren) = stmt_str.rfind(')') {
if let Some(open_paren) = stmt_str[..close_paren].rfind('(') {
let col_list = &stmt_str[open_paren + 1..close_paren];
if !col_list.to_uppercase().contains("SELECT") {
let cols: Vec<&str> = col_list.split(',').collect();
self.column_order = cols
.iter()
.map(|c| {
let name = c
.trim()
.trim_matches('`')
.trim_matches('"')
.trim_matches('[')
.trim_matches(']');
schema.get_column_id(name)
})
.collect();
return;
}
}
}
self.column_order = schema.columns.iter().map(|c| Some(c.ordinal)).collect();
}
fn parse_row(&mut self) -> anyhow::Result<Option<ParsedRow>> {
self.skip_whitespace();
if self.pos >= self.stmt.len() || self.stmt[self.pos] != b'(' {
return Ok(None);
}
let start = self.pos;
self.pos += 1;
let mut values: Vec<ParsedValue> = Vec::new();
let mut depth = 1;
while self.pos < self.stmt.len() && depth > 0 {
self.skip_whitespace();
if self.pos >= self.stmt.len() {
break;
}
match self.stmt[self.pos] {
b'(' => {
depth += 1;
self.pos += 1;
}
b')' => {
depth -= 1;
self.pos += 1;
}
b',' if depth == 1 => {
self.pos += 1;
}
_ if depth == 1 => {
values.push(self.parse_value()?);
}
_ => {
self.pos += 1;
}
}
}
let end = self.pos;
let raw = self.stmt[start..end].to_vec();
let (pk, fk_values, all_values, column_map) = if let Some(schema) = self.table_schema {
let (pk, fk_values, all_values) = self.extract_pk_fk(&values, schema);
let column_map = self.build_column_map(schema);
(pk, fk_values, all_values, column_map)
} else {
(None, Vec::new(), Vec::new(), Vec::new())
};
Ok(Some(ParsedRow {
raw,
values,
pk,
fk_values,
all_values,
column_map,
}))
}
fn parse_value(&mut self) -> anyhow::Result<ParsedValue> {
self.skip_whitespace();
if self.pos >= self.stmt.len() {
return Ok(ParsedValue::Null);
}
let b = self.stmt[self.pos];
if self.pos + 4 <= self.stmt.len() {
let word = &self.stmt[self.pos..self.pos + 4];
if word.eq_ignore_ascii_case(b"NULL") {
self.pos += 4;
return Ok(ParsedValue::Null);
}
}
if b == b'\'' {
return self.parse_string_value();
}
if (b == b'N' || b == b'n')
&& self.pos + 1 < self.stmt.len()
&& self.stmt[self.pos + 1] == b'\''
{
self.pos += 1; return self.parse_string_value();
}
if b == b'0' && self.pos + 1 < self.stmt.len() {
let next = self.stmt[self.pos + 1];
if next == b'x' || next == b'X' {
return self.parse_hex_value();
}
}
self.parse_number_value()
}
fn parse_string_value(&mut self) -> anyhow::Result<ParsedValue> {
self.pos += 1;
let mut value = Vec::new();
let mut escape_next = false;
while self.pos < self.stmt.len() {
let b = self.stmt[self.pos];
if escape_next {
let escaped = match b {
b'n' => b'\n',
b'r' => b'\r',
b't' => b'\t',
b'0' => 0,
_ => b, };
value.push(escaped);
escape_next = false;
self.pos += 1;
} else if b == b'\\' {
escape_next = true;
self.pos += 1;
} else if b == b'\'' {
if self.pos + 1 < self.stmt.len() && self.stmt[self.pos + 1] == b'\'' {
value.push(b'\'');
self.pos += 2;
} else {
self.pos += 1; break;
}
} else {
value.push(b);
self.pos += 1;
}
}
let text = String::from_utf8_lossy(&value).into_owned();
Ok(ParsedValue::String { value: text })
}
fn parse_hex_value(&mut self) -> anyhow::Result<ParsedValue> {
let start = self.pos;
self.pos += 2;
while self.pos < self.stmt.len() {
let b = self.stmt[self.pos];
if b.is_ascii_hexdigit() {
self.pos += 1;
} else {
break;
}
}
let raw = self.stmt[start..self.pos].to_vec();
Ok(ParsedValue::Hex(raw))
}
fn parse_number_value(&mut self) -> anyhow::Result<ParsedValue> {
let start = self.pos;
let mut has_dot = false;
if self.pos < self.stmt.len() && self.stmt[self.pos] == b'-' {
self.pos += 1;
}
while self.pos < self.stmt.len() {
let b = self.stmt[self.pos];
if b.is_ascii_digit() {
self.pos += 1;
} else if b == b'.' && !has_dot {
has_dot = true;
self.pos += 1;
} else if b == b'e' || b == b'E' {
self.pos += 1;
if self.pos < self.stmt.len()
&& (self.stmt[self.pos] == b'+' || self.stmt[self.pos] == b'-')
{
self.pos += 1;
}
} else if b == b',' || b == b')' || b.is_ascii_whitespace() {
break;
} else {
while self.pos < self.stmt.len() {
let c = self.stmt[self.pos];
if c == b',' || c == b')' {
break;
}
self.pos += 1;
}
break;
}
}
let raw = self.stmt[start..self.pos].to_vec();
let value_str = String::from_utf8_lossy(&raw);
if !has_dot {
if let Ok(n) = value_str.parse::<i64>() {
return Ok(ParsedValue::Integer(n));
}
if let Ok(n) = value_str.parse::<i128>() {
return Ok(ParsedValue::BigInteger(n));
}
}
Ok(ParsedValue::Other(raw))
}
fn skip_whitespace(&mut self) {
while self.pos < self.stmt.len() {
let b = self.stmt[self.pos];
if b.is_ascii_whitespace() {
self.pos += 1;
} else {
break;
}
}
}
fn extract_pk_fk(
&self,
values: &[ParsedValue],
schema: &TableSchema,
) -> (Option<PkTuple>, Vec<(FkRef, PkTuple)>, Vec<PkValue>) {
let mut pk_values = PkTuple::new();
let mut fk_values = Vec::new();
let all_values: Vec<PkValue> = values
.iter()
.enumerate()
.map(|(idx, v)| {
let col = self
.column_order
.get(idx)
.and_then(|c| *c)
.and_then(|id| schema.column(id));
self.value_to_pk(v, col)
})
.collect();
for (idx, col_id_opt) in self.column_order.iter().enumerate() {
if let Some(col_id) = col_id_opt {
if schema.is_pk_column(*col_id) {
if let Some(value) = values.get(idx) {
let pk_val = self.value_to_pk(value, schema.column(*col_id));
pk_values.push(pk_val);
}
}
}
}
for (fk_idx, fk) in schema.foreign_keys.iter().enumerate() {
if fk.referenced_table_id.is_none() {
continue;
}
let mut fk_tuple = PkTuple::new();
let mut all_non_null = true;
for &col_id in &fk.columns {
if let Some(idx) = self.column_order.iter().position(|&c| c == Some(col_id)) {
if let Some(value) = values.get(idx) {
let pk_val = self.value_to_pk(value, schema.column(col_id));
if pk_val.is_null() {
all_non_null = false;
break;
}
fk_tuple.push(pk_val);
}
}
}
if all_non_null && !fk_tuple.is_empty() {
fk_values.push((
FkRef {
table_id: schema.id.0,
fk_index: fk_idx as u16,
},
fk_tuple,
));
}
}
let pk = if pk_values.is_empty() || pk_values.iter().any(|v| v.is_null()) {
None
} else {
Some(pk_values)
};
(pk, fk_values, all_values)
}
fn build_column_map(&self, schema: &TableSchema) -> Vec<Option<usize>> {
let mut map = vec![None; schema.columns.len()];
for (val_idx, col_id_opt) in self.column_order.iter().enumerate() {
if let Some(col_id) = col_id_opt {
let ordinal = col_id.0 as usize;
if ordinal < map.len() {
map[ordinal] = Some(val_idx);
}
}
}
map
}
fn value_to_pk(&self, value: &ParsedValue, col: Option<&crate::schema::Column>) -> PkValue {
match value {
ParsedValue::Null => PkValue::Null,
ParsedValue::Integer(n) => PkValue::Int(*n),
ParsedValue::BigInteger(n) => PkValue::BigInt(*n),
ParsedValue::String { value } => {
if let Some(col) = col {
match col.col_type {
ColumnType::Int => {
if let Ok(n) = value.parse::<i64>() {
return PkValue::Int(n);
}
}
ColumnType::BigInt => {
if let Ok(n) = value.parse::<i128>() {
return PkValue::BigInt(n);
}
}
_ => {}
}
}
PkValue::Text(value.clone().into_boxed_str())
}
ParsedValue::Hex(raw) => {
PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
}
ParsedValue::Other(raw) => {
PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
}
}
}
}
#[derive(Debug, Clone)]
pub enum ParsedValue {
Null,
Integer(i64),
BigInteger(i128),
String { value: String },
Hex(Vec<u8>),
Other(Vec<u8>),
}
pub fn parse_mysql_insert_rows(
stmt: &[u8],
schema: &TableSchema,
) -> anyhow::Result<Vec<ParsedRow>> {
let mut parser = InsertParser::new(stmt).with_schema(schema);
parser.parse_rows()
}
pub fn parse_mysql_insert_rows_raw(stmt: &[u8]) -> anyhow::Result<Vec<ParsedRow>> {
let mut parser = InsertParser::new(stmt);
parser.parse_rows()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_insert_for_bulk_simple() {
let sql = b"INSERT INTO users VALUES (1, 'Alice')";
let result = parse_insert_for_bulk(sql).unwrap();
assert_eq!(result.table, "users");
assert!(result.columns.is_none());
assert_eq!(result.rows.len(), 1);
}
#[test]
fn test_parse_insert_for_bulk_with_columns() {
let sql = b"INSERT INTO users (name, id) VALUES ('Alice', 1)";
let result = parse_insert_for_bulk(sql).unwrap();
assert_eq!(result.table, "users");
assert_eq!(
result.columns,
Some(vec!["name".to_string(), "id".to_string()])
);
assert_eq!(result.rows.len(), 1);
}
#[test]
fn test_parse_insert_for_bulk_mssql() {
let sql =
b"INSERT INTO [dbo].[users] ([email], [name]) VALUES (N'alice@example.com', N'Alice')";
let result = parse_insert_for_bulk(sql).unwrap();
assert_eq!(result.table, "users");
assert_eq!(
result.columns,
Some(vec!["email".to_string(), "name".to_string()])
);
assert_eq!(result.rows.len(), 1);
}
#[test]
fn test_parse_insert_for_bulk_mysql() {
let sql = b"INSERT INTO `users` (`id`, `name`) VALUES (1, 'Bob')";
let result = parse_insert_for_bulk(sql).unwrap();
assert_eq!(result.table, "users");
assert_eq!(
result.columns,
Some(vec!["id".to_string(), "name".to_string()])
);
assert_eq!(result.rows.len(), 1);
}
}
#[derive(Debug, Clone)]
pub struct InsertValues {
pub table: String,
pub columns: Option<Vec<String>>,
pub rows: Vec<Vec<ParsedValue>>,
}
pub fn parse_insert_for_bulk(stmt: &[u8]) -> anyhow::Result<InsertValues> {
let stmt_str = String::from_utf8_lossy(stmt);
let upper = stmt_str.to_uppercase();
let table = extract_insert_table_name(&stmt_str, &upper)?;
let columns = extract_column_list(&stmt_str, &upper);
let mut parser = InsertParser::new(stmt);
let parsed_rows = parser.parse_rows()?;
let rows = parsed_rows.into_iter().map(|r| r.values).collect();
Ok(InsertValues {
table,
columns,
rows,
})
}
fn extract_insert_table_name(stmt: &str, upper: &str) -> anyhow::Result<String> {
let start_pos = if let Some(pos) = upper.find("INSERT INTO") {
pos + 11 } else if let Some(pos) = upper.find("INSERT") {
pos + 6 } else {
anyhow::bail!("Not an INSERT statement");
};
let remaining = stmt[start_pos..].trim_start();
let table_ref = extract_table_reference(remaining)?;
if let Some(dot_pos) = table_ref.rfind('.') {
let table_part = &table_ref[dot_pos + 1..];
Ok(strip_identifier_quotes(table_part))
} else {
Ok(strip_identifier_quotes(&table_ref))
}
}
fn extract_table_reference(s: &str) -> anyhow::Result<String> {
let s = s.trim();
if s.is_empty() {
anyhow::bail!("Empty table reference");
}
let mut result = String::new();
let mut chars = s.chars().peekable();
while let Some(&c) = chars.peek() {
match c {
'[' => {
chars.next();
result.push('[');
while let Some(&inner) = chars.peek() {
chars.next();
result.push(inner);
if inner == ']' {
break;
}
}
}
'`' => {
chars.next();
result.push('`');
while let Some(&inner) = chars.peek() {
chars.next();
result.push(inner);
if inner == '`' {
break;
}
}
}
'"' => {
chars.next();
result.push('"');
while let Some(&inner) = chars.peek() {
chars.next();
result.push(inner);
if inner == '"' {
break;
}
}
}
'.' => {
chars.next();
result.push('.');
}
c if c.is_whitespace() || c == '(' || c == ',' => {
break;
}
_ => {
chars.next();
result.push(c);
}
}
}
if result.is_empty() {
anyhow::bail!("Empty table reference");
}
Ok(result)
}
fn strip_identifier_quotes(s: &str) -> String {
s.trim_matches('`')
.trim_matches('"')
.trim_matches('[')
.trim_matches(']')
.to_string()
}
fn extract_column_list(stmt: &str, upper: &str) -> Option<Vec<String>> {
let values_pos = upper.find("VALUES")?;
let before_values = &stmt[..values_pos];
let close_paren = before_values.rfind(')')?;
let open_paren = before_values[..close_paren].rfind('(')?;
let col_list = &before_values[open_paren + 1..close_paren];
let upper_cols = col_list.to_uppercase();
if col_list.trim().is_empty() || upper_cols.contains("SELECT") || upper_cols.contains("VALUES")
{
return None;
}
let columns: Vec<String> = col_list
.split(',')
.map(|c| {
c.trim()
.trim_matches('`')
.trim_matches('"')
.trim_matches('[')
.trim_matches(']')
.to_string()
})
.collect();
if columns.is_empty() {
None
} else {
Some(columns)
}
}