use crate::schema::{ColumnId, ColumnType, TableSchema};
use ahash::AHashSet;
use smallvec::SmallVec;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PkValue {
Int(i64),
BigInt(i128),
Text(Box<str>),
Null,
}
impl PkValue {
pub fn is_null(&self) -> bool {
matches!(self, PkValue::Null)
}
}
pub type PkTuple = SmallVec<[PkValue; 2]>;
pub type PkSet = AHashSet<PkTuple>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FkRef {
pub table_id: u32,
pub fk_index: u16,
}
#[derive(Debug, Clone)]
pub struct ParsedRow {
pub raw: Vec<u8>,
pub pk: Option<PkTuple>,
pub fk_values: Vec<(FkRef, PkTuple)>,
}
pub struct InsertParser<'a> {
stmt: &'a [u8],
pos: usize,
table_schema: Option<&'a TableSchema>,
column_order: Vec<Option<ColumnId>>,
}
impl<'a> InsertParser<'a> {
pub fn new(stmt: &'a [u8]) -> Self {
Self {
stmt,
pos: 0,
table_schema: None,
column_order: Vec::new(),
}
}
pub fn with_schema(mut self, schema: &'a TableSchema) -> Self {
self.table_schema = Some(schema);
self
}
pub fn parse_rows(&mut self) -> anyhow::Result<Vec<ParsedRow>> {
let values_pos = self.find_values_keyword()?;
self.pos = values_pos;
self.parse_column_list();
let mut rows = Vec::new();
while self.pos < self.stmt.len() {
self.skip_whitespace();
if self.pos >= self.stmt.len() {
break;
}
if self.stmt[self.pos] == b'(' {
if let Some(row) = self.parse_row()? {
rows.push(row);
}
} else if self.stmt[self.pos] == b',' {
self.pos += 1;
} else if self.stmt[self.pos] == b';' {
break;
} else {
self.pos += 1;
}
}
Ok(rows)
}
fn find_values_keyword(&self) -> anyhow::Result<usize> {
let stmt_str = String::from_utf8_lossy(self.stmt);
let upper = stmt_str.to_uppercase();
if let Some(pos) = upper.find("VALUES") {
Ok(pos + 6) } else {
anyhow::bail!("INSERT statement missing VALUES keyword")
}
}
fn parse_column_list(&mut self) {
if self.table_schema.is_none() {
return;
}
let schema = self.table_schema.unwrap();
let before_values = &self.stmt[..self.pos.saturating_sub(6)];
let stmt_str = String::from_utf8_lossy(before_values);
if let Some(close_paren) = stmt_str.rfind(')') {
if let Some(open_paren) = stmt_str[..close_paren].rfind('(') {
let col_list = &stmt_str[open_paren + 1..close_paren];
if !col_list.to_uppercase().contains("SELECT") {
let cols: Vec<&str> = col_list.split(',').collect();
self.column_order = cols
.iter()
.map(|c| {
let name = c.trim().trim_matches('`').trim_matches('"');
schema.get_column_id(name)
})
.collect();
return;
}
}
}
self.column_order = schema.columns.iter().map(|c| Some(c.ordinal)).collect();
}
fn parse_row(&mut self) -> anyhow::Result<Option<ParsedRow>> {
self.skip_whitespace();
if self.pos >= self.stmt.len() || self.stmt[self.pos] != b'(' {
return Ok(None);
}
let start = self.pos;
self.pos += 1;
let mut values: Vec<ParsedValue> = Vec::new();
let mut depth = 1;
while self.pos < self.stmt.len() && depth > 0 {
self.skip_whitespace();
if self.pos >= self.stmt.len() {
break;
}
match self.stmt[self.pos] {
b'(' => {
depth += 1;
self.pos += 1;
}
b')' => {
depth -= 1;
self.pos += 1;
}
b',' if depth == 1 => {
self.pos += 1;
}
_ if depth == 1 => {
values.push(self.parse_value()?);
}
_ => {
self.pos += 1;
}
}
}
let end = self.pos;
let raw = self.stmt[start..end].to_vec();
let (pk, fk_values) = if let Some(schema) = self.table_schema {
self.extract_pk_fk(&values, schema)
} else {
(None, Vec::new())
};
Ok(Some(ParsedRow { raw, pk, fk_values }))
}
fn parse_value(&mut self) -> anyhow::Result<ParsedValue> {
self.skip_whitespace();
if self.pos >= self.stmt.len() {
return Ok(ParsedValue::Null);
}
let b = self.stmt[self.pos];
if self.pos + 4 <= self.stmt.len() {
let word = &self.stmt[self.pos..self.pos + 4];
if word.eq_ignore_ascii_case(b"NULL") {
self.pos += 4;
return Ok(ParsedValue::Null);
}
}
if b == b'\'' {
return self.parse_string_value();
}
if b == b'0' && self.pos + 1 < self.stmt.len() {
let next = self.stmt[self.pos + 1];
if next == b'x' || next == b'X' {
return self.parse_hex_value();
}
}
self.parse_number_value()
}
fn parse_string_value(&mut self) -> anyhow::Result<ParsedValue> {
self.pos += 1;
let mut value = Vec::new();
let mut escape_next = false;
while self.pos < self.stmt.len() {
let b = self.stmt[self.pos];
if escape_next {
let escaped = match b {
b'n' => b'\n',
b'r' => b'\r',
b't' => b'\t',
b'0' => 0,
_ => b, };
value.push(escaped);
escape_next = false;
self.pos += 1;
} else if b == b'\\' {
escape_next = true;
self.pos += 1;
} else if b == b'\'' {
if self.pos + 1 < self.stmt.len() && self.stmt[self.pos + 1] == b'\'' {
value.push(b'\'');
self.pos += 2;
} else {
self.pos += 1; break;
}
} else {
value.push(b);
self.pos += 1;
}
}
let text = String::from_utf8_lossy(&value).into_owned();
Ok(ParsedValue::String { value: text })
}
fn parse_hex_value(&mut self) -> anyhow::Result<ParsedValue> {
let start = self.pos;
self.pos += 2;
while self.pos < self.stmt.len() {
let b = self.stmt[self.pos];
if b.is_ascii_hexdigit() {
self.pos += 1;
} else {
break;
}
}
let raw = self.stmt[start..self.pos].to_vec();
Ok(ParsedValue::Hex(raw))
}
fn parse_number_value(&mut self) -> anyhow::Result<ParsedValue> {
let start = self.pos;
let mut has_dot = false;
if self.pos < self.stmt.len() && self.stmt[self.pos] == b'-' {
self.pos += 1;
}
while self.pos < self.stmt.len() {
let b = self.stmt[self.pos];
if b.is_ascii_digit() {
self.pos += 1;
} else if b == b'.' && !has_dot {
has_dot = true;
self.pos += 1;
} else if b == b'e' || b == b'E' {
self.pos += 1;
if self.pos < self.stmt.len()
&& (self.stmt[self.pos] == b'+' || self.stmt[self.pos] == b'-')
{
self.pos += 1;
}
} else if b == b',' || b == b')' || b.is_ascii_whitespace() {
break;
} else {
while self.pos < self.stmt.len() {
let c = self.stmt[self.pos];
if c == b',' || c == b')' {
break;
}
self.pos += 1;
}
break;
}
}
let raw = self.stmt[start..self.pos].to_vec();
let value_str = String::from_utf8_lossy(&raw);
if !has_dot {
if let Ok(n) = value_str.parse::<i64>() {
return Ok(ParsedValue::Integer(n));
}
if let Ok(n) = value_str.parse::<i128>() {
return Ok(ParsedValue::BigInteger(n));
}
}
Ok(ParsedValue::Other(raw))
}
fn skip_whitespace(&mut self) {
while self.pos < self.stmt.len() {
let b = self.stmt[self.pos];
if b.is_ascii_whitespace() {
self.pos += 1;
} else {
break;
}
}
}
fn extract_pk_fk(
&self,
values: &[ParsedValue],
schema: &TableSchema,
) -> (Option<PkTuple>, Vec<(FkRef, PkTuple)>) {
let mut pk_values = PkTuple::new();
let mut fk_values = Vec::new();
for (idx, col_id_opt) in self.column_order.iter().enumerate() {
if let Some(col_id) = col_id_opt {
if schema.is_pk_column(*col_id) {
if let Some(value) = values.get(idx) {
let pk_val = self.value_to_pk(value, schema.column(*col_id));
pk_values.push(pk_val);
}
}
}
}
for (fk_idx, fk) in schema.foreign_keys.iter().enumerate() {
if fk.referenced_table_id.is_none() {
continue;
}
let mut fk_tuple = PkTuple::new();
let mut all_non_null = true;
for &col_id in &fk.columns {
if let Some(idx) = self.column_order.iter().position(|&c| c == Some(col_id)) {
if let Some(value) = values.get(idx) {
let pk_val = self.value_to_pk(value, schema.column(col_id));
if pk_val.is_null() {
all_non_null = false;
break;
}
fk_tuple.push(pk_val);
}
}
}
if all_non_null && !fk_tuple.is_empty() {
fk_values.push((
FkRef {
table_id: schema.id.0,
fk_index: fk_idx as u16,
},
fk_tuple,
));
}
}
let pk = if pk_values.is_empty() || pk_values.iter().any(|v| v.is_null()) {
None
} else {
Some(pk_values)
};
(pk, fk_values)
}
fn value_to_pk(&self, value: &ParsedValue, col: Option<&crate::schema::Column>) -> PkValue {
match value {
ParsedValue::Null => PkValue::Null,
ParsedValue::Integer(n) => PkValue::Int(*n),
ParsedValue::BigInteger(n) => PkValue::BigInt(*n),
ParsedValue::String { value } => {
if let Some(col) = col {
match col.col_type {
ColumnType::Int => {
if let Ok(n) = value.parse::<i64>() {
return PkValue::Int(n);
}
}
ColumnType::BigInt => {
if let Ok(n) = value.parse::<i128>() {
return PkValue::BigInt(n);
}
}
_ => {}
}
}
PkValue::Text(value.clone().into_boxed_str())
}
ParsedValue::Hex(raw) => {
PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
}
ParsedValue::Other(raw) => {
PkValue::Text(String::from_utf8_lossy(raw).into_owned().into_boxed_str())
}
}
}
}
#[derive(Debug, Clone)]
enum ParsedValue {
Null,
Integer(i64),
BigInteger(i128),
String { value: String },
Hex(Vec<u8>),
Other(Vec<u8>),
}
pub fn parse_mysql_insert_rows(
stmt: &[u8],
schema: &TableSchema,
) -> anyhow::Result<Vec<ParsedRow>> {
let mut parser = InsertParser::new(stmt).with_schema(schema);
parser.parse_rows()
}
pub fn parse_mysql_insert_rows_raw(stmt: &[u8]) -> anyhow::Result<Vec<ParsedRow>> {
let mut parser = InsertParser::new(stmt);
parser.parse_rows()
}