use crate::parser::mysql_insert::{InsertParser, ParsedValue};
use crate::parser::postgres_copy::{parse_copy_columns, CopyParser};
use crate::parser::SqlDialect;
use crate::redactor::strategy::{
ConstantStrategy, FakeStrategy, HashStrategy, MaskStrategy, NullStrategy, RedactValue,
Strategy, StrategyKind,
};
use crate::schema::TableSchema;
use rand::rngs::StdRng;
use rand::SeedableRng;
pub struct ValueRewriter {
rng: StdRng,
dialect: SqlDialect,
locale: String,
}
impl ValueRewriter {
pub fn new(seed: Option<u64>, dialect: SqlDialect, locale: String) -> Self {
let rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_os_rng(),
};
Self {
rng,
dialect,
locale,
}
}
pub fn rewrite_insert(
&mut self,
stmt: &[u8],
table_name: &str,
table: &TableSchema,
strategies: &[StrategyKind],
) -> anyhow::Result<(Vec<u8>, u64, u64)> {
let mut parser = InsertParser::new(stmt).with_schema(table);
let rows = parser.parse_rows()?;
if rows.is_empty() {
return Ok((stmt.to_vec(), 0, 0));
}
let stmt_str = String::from_utf8_lossy(stmt);
let column_list = self.extract_column_list(&stmt_str);
let mut result = self.build_insert_header(table_name, &column_list);
let mut rows_redacted = 0u64;
let mut columns_redacted = 0u64;
let num_strategies = strategies.len();
for (row_idx, row) in rows.iter().enumerate() {
if row_idx > 0 {
result.extend_from_slice(b",");
}
result.extend_from_slice(b"\n(");
let mut row_had_redaction = false;
for (col_idx, value) in row.values.iter().enumerate() {
if col_idx > 0 {
result.extend_from_slice(b", ");
}
let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
let (redacted_sql, was_redacted) =
self.redact_value(value, strategy, col_idx < num_strategies);
result.extend_from_slice(redacted_sql.as_bytes());
if was_redacted {
columns_redacted += 1;
row_had_redaction = true;
}
}
result.extend_from_slice(b")");
if row_had_redaction {
rows_redacted += 1;
}
}
result.extend_from_slice(b";\n");
Ok((result, rows_redacted, columns_redacted))
}
pub fn rewrite_copy(
&mut self,
stmt: &[u8],
_table_name: &str,
table: &TableSchema,
strategies: &[StrategyKind],
) -> anyhow::Result<(Vec<u8>, u64, u64)> {
let stmt_str = String::from_utf8_lossy(stmt);
let header_end = stmt_str
.find('\n')
.ok_or_else(|| anyhow::anyhow!("Invalid COPY statement: no newline"))?;
let header = &stmt_str[..header_end];
let data_block = &stmt[header_end + 1..];
let columns = parse_copy_columns(header);
let mut parser = CopyParser::new(data_block)
.with_schema(table)
.with_column_order(columns.clone());
let rows = parser.parse_rows()?;
if rows.is_empty() {
return Ok((stmt.to_vec(), 0, 0));
}
let mut result = Vec::with_capacity(stmt.len());
result.extend_from_slice(header.as_bytes());
result.push(b'\n');
let mut rows_redacted = 0u64;
let mut columns_redacted = 0u64;
for row in &rows {
let mut row_had_redaction = false;
let mut first = true;
let values = self.parse_copy_row_values(&row.raw);
for (col_idx, value) in values.iter().enumerate() {
if !first {
result.push(b'\t');
}
first = false;
let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
let (redacted, was_redacted) = self.redact_copy_value(value, strategy);
result.extend_from_slice(&redacted);
if was_redacted {
columns_redacted += 1;
row_had_redaction = true;
}
}
result.push(b'\n');
if row_had_redaction {
rows_redacted += 1;
}
}
result.extend_from_slice(b"\\.\n");
Ok((result, rows_redacted, columns_redacted))
}
pub fn rewrite_copy_data(
&mut self,
data_block: &[u8],
table: &TableSchema,
strategies: &[StrategyKind],
columns: &[String],
) -> anyhow::Result<(Vec<u8>, u64, u64)> {
let mut parser = CopyParser::new(data_block)
.with_schema(table)
.with_column_order(columns.to_vec());
let rows = parser.parse_rows()?;
if rows.is_empty() {
return Ok((data_block.to_vec(), 0, 0));
}
let mut result = Vec::with_capacity(data_block.len());
let mut rows_redacted = 0u64;
let mut columns_redacted = 0u64;
for row in &rows {
let mut row_had_redaction = false;
let mut first = true;
let values = self.parse_copy_row_values(&row.raw);
for (col_idx, value) in values.iter().enumerate() {
if !first {
result.push(b'\t');
}
first = false;
let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
let (redacted, was_redacted) = self.redact_copy_value(value, strategy);
result.extend_from_slice(&redacted);
if was_redacted {
columns_redacted += 1;
row_had_redaction = true;
}
}
result.push(b'\n');
if row_had_redaction {
rows_redacted += 1;
}
}
result.extend_from_slice(b"\\.\n");
Ok((result, rows_redacted, columns_redacted))
}
fn parse_copy_row_values(&self, raw: &[u8]) -> Vec<CopyValueRef> {
let mut values = Vec::new();
let mut start = 0;
for (i, &b) in raw.iter().enumerate() {
if b == b'\t' {
values.push(self.parse_single_copy_value(&raw[start..i]));
start = i + 1;
}
}
if start <= raw.len() {
values.push(self.parse_single_copy_value(&raw[start..]));
}
values
}
fn parse_single_copy_value(&self, raw: &[u8]) -> CopyValueRef {
if raw == b"\\N" {
CopyValueRef::Null
} else {
CopyValueRef::Text(raw.to_vec())
}
}
fn redact_copy_value(
&mut self,
value: &CopyValueRef,
strategy: &StrategyKind,
) -> (Vec<u8>, bool) {
if matches!(strategy, StrategyKind::Skip) {
let bytes = match value {
CopyValueRef::Null => b"\\N".to_vec(),
CopyValueRef::Text(t) => t.clone(),
};
return (bytes, false);
}
let redact_value = match value {
CopyValueRef::Null => RedactValue::Null,
CopyValueRef::Text(t) => {
let decoded = self.decode_copy_escapes(t);
RedactValue::String(String::from_utf8_lossy(&decoded).into_owned())
}
};
let result = self.apply_strategy(&redact_value, strategy);
let bytes = match result {
RedactValue::Null => b"\\N".to_vec(),
RedactValue::String(s) => self.encode_copy_escapes(&s),
RedactValue::Integer(i) => i.to_string().into_bytes(),
RedactValue::Bytes(b) => self.encode_copy_escapes(&String::from_utf8_lossy(&b)),
};
(bytes, true)
}
fn decode_copy_escapes(&self, value: &[u8]) -> Vec<u8> {
let mut result = Vec::with_capacity(value.len());
let mut i = 0;
while i < value.len() {
if value[i] == b'\\' && i + 1 < value.len() {
let next = value[i + 1];
let decoded = match next {
b'n' => b'\n',
b'r' => b'\r',
b't' => b'\t',
b'\\' => b'\\',
_ => {
result.push(b'\\');
result.push(next);
i += 2;
continue;
}
};
result.push(decoded);
i += 2;
} else {
result.push(value[i]);
i += 1;
}
}
result
}
fn encode_copy_escapes(&self, value: &str) -> Vec<u8> {
let mut result = Vec::with_capacity(value.len());
for b in value.bytes() {
match b {
b'\n' => result.extend_from_slice(b"\\n"),
b'\r' => result.extend_from_slice(b"\\r"),
b'\t' => result.extend_from_slice(b"\\t"),
b'\\' => result.extend_from_slice(b"\\\\"),
_ => result.push(b),
}
}
result
}
fn extract_column_list(&self, stmt: &str) -> Option<Vec<String>> {
let upper = stmt.to_uppercase();
let values_pos = upper.find("VALUES")?;
let before_values = &stmt[..values_pos];
let close_paren = before_values.rfind(')')?;
let open_paren = before_values[..close_paren].rfind('(')?;
let col_list = &before_values[open_paren + 1..close_paren];
let upper_cols = col_list.to_uppercase();
if col_list.trim().is_empty()
|| upper_cols.contains("SELECT")
|| upper_cols.contains("VALUES")
{
return None;
}
let columns: Vec<String> = col_list
.split(',')
.map(|c| {
c.trim()
.trim_matches('`')
.trim_matches('"')
.trim_matches('[')
.trim_matches(']')
.to_string()
})
.collect();
if columns.is_empty() {
None
} else {
Some(columns)
}
}
fn build_insert_header(&self, table_name: &str, columns: &Option<Vec<String>>) -> Vec<u8> {
let mut result = Vec::new();
result.extend_from_slice(b"INSERT INTO ");
result.extend_from_slice(self.quote_identifier(table_name).as_bytes());
if let Some(cols) = columns {
result.extend_from_slice(b" (");
for (i, col) in cols.iter().enumerate() {
if i > 0 {
result.extend_from_slice(b", ");
}
result.extend_from_slice(self.quote_identifier(col).as_bytes());
}
result.extend_from_slice(b")");
}
result.extend_from_slice(b" VALUES");
result
}
fn quote_identifier(&self, name: &str) -> String {
match self.dialect {
SqlDialect::MySql => format!("`{}`", name),
SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", name),
SqlDialect::Mssql => format!("[{}]", name),
}
}
fn redact_value(
&mut self,
value: &ParsedValue,
strategy: &StrategyKind,
has_strategy: bool,
) -> (String, bool) {
if !has_strategy || matches!(strategy, StrategyKind::Skip) {
return (self.format_value(value), false);
}
let redact_value = self.parsed_to_redact(value);
let result = self.apply_strategy(&redact_value, strategy);
(self.format_redact_value(&result), true)
}
fn parsed_to_redact(&self, value: &ParsedValue) -> RedactValue {
match value {
ParsedValue::Null => RedactValue::Null,
ParsedValue::Integer(n) => RedactValue::Integer(*n),
ParsedValue::BigInteger(n) => RedactValue::Integer(*n as i64), ParsedValue::String { value } => RedactValue::String(value.clone()),
ParsedValue::Hex(bytes) => RedactValue::Bytes(bytes.clone()),
ParsedValue::Other(bytes) => {
RedactValue::String(String::from_utf8_lossy(bytes).into_owned())
}
}
}
fn apply_strategy(&mut self, value: &RedactValue, strategy: &StrategyKind) -> RedactValue {
match strategy {
StrategyKind::Null => NullStrategy::new().apply(value, &mut self.rng),
StrategyKind::Constant { value: constant } => {
ConstantStrategy::new(constant.clone()).apply(value, &mut self.rng)
}
StrategyKind::Hash { preserve_domain } => {
HashStrategy::new(*preserve_domain).apply(value, &mut self.rng)
}
StrategyKind::Mask { pattern } => {
MaskStrategy::new(pattern.clone()).apply(value, &mut self.rng)
}
StrategyKind::Fake { generator } => {
FakeStrategy::new(generator.clone(), self.locale.clone())
.apply(value, &mut self.rng)
}
StrategyKind::Shuffle => {
value.clone()
}
StrategyKind::Skip => value.clone(),
}
}
fn format_value(&self, value: &ParsedValue) -> String {
match value {
ParsedValue::Null => "NULL".to_string(),
ParsedValue::Integer(n) => n.to_string(),
ParsedValue::BigInteger(n) => n.to_string(),
ParsedValue::String { value } => self.format_sql_string(value),
ParsedValue::Hex(bytes) => String::from_utf8_lossy(bytes).into_owned(),
ParsedValue::Other(bytes) => String::from_utf8_lossy(bytes).into_owned(),
}
}
fn format_redact_value(&self, value: &RedactValue) -> String {
match value {
RedactValue::Null => "NULL".to_string(),
RedactValue::Integer(n) => n.to_string(),
RedactValue::String(s) => self.format_sql_string(s),
RedactValue::Bytes(b) => {
format!("0x{}", hex::encode(b))
}
}
}
fn format_sql_string(&self, value: &str) -> String {
match self.dialect {
SqlDialect::MySql => {
let escaped = value
.replace('\\', "\\\\")
.replace('\'', "\\'")
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\t', "\\t")
.replace('\0', "\\0");
format!("'{}'", escaped)
}
SqlDialect::Postgres | SqlDialect::Sqlite => {
let escaped = value.replace('\'', "''");
format!("'{}'", escaped)
}
SqlDialect::Mssql => {
let escaped = value.replace('\'', "''");
if value.bytes().any(|b| b > 127) {
format!("N'{}'", escaped)
} else {
format!("'{}'", escaped)
}
}
}
}
}
enum CopyValueRef {
Null,
Text(Vec<u8>),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{Column, ColumnId, ColumnType, TableId, TableSchema};
fn create_test_schema() -> TableSchema {
TableSchema {
name: "users".to_string(),
id: TableId(0),
columns: vec![
Column {
name: "id".to_string(),
col_type: ColumnType::Int,
ordinal: ColumnId(0),
is_primary_key: true,
is_nullable: false,
},
Column {
name: "email".to_string(),
col_type: ColumnType::Text,
ordinal: ColumnId(1),
is_primary_key: false,
is_nullable: false,
},
Column {
name: "name".to_string(),
col_type: ColumnType::Text,
ordinal: ColumnId(2),
is_primary_key: false,
is_nullable: true,
},
],
primary_key: vec![ColumnId(0)],
foreign_keys: vec![],
indexes: vec![],
create_statement: None,
}
}
#[test]
fn test_rewrite_insert_mysql() {
let mut rewriter = ValueRewriter::new(Some(42), SqlDialect::MySql, "en".to_string());
let schema = create_test_schema();
let stmt = b"INSERT INTO `users` (`id`, `email`, `name`) VALUES (1, 'alice@example.com', 'Alice');";
let strategies = vec![
StrategyKind::Skip, StrategyKind::Hash {
preserve_domain: true,
}, StrategyKind::Fake {
generator: "name".to_string(),
}, ];
let (result, rows, cols) = rewriter
.rewrite_insert(stmt, "users", &schema, &strategies)
.unwrap();
let result_str = String::from_utf8_lossy(&result);
assert!(result_str.contains("INSERT INTO `users`"));
assert!(result_str.contains("VALUES"));
assert_eq!(rows, 1);
assert_eq!(cols, 2); }
#[test]
fn test_rewrite_insert_mssql() {
let mut rewriter = ValueRewriter::new(Some(42), SqlDialect::Mssql, "en".to_string());
let schema = create_test_schema();
let stmt = b"INSERT INTO [users] ([id], [email], [name]) VALUES (1, N'alice@example.com', N'Alice');";
let strategies = vec![
StrategyKind::Skip, StrategyKind::Null, StrategyKind::Skip, ];
let (result, rows, cols) = rewriter
.rewrite_insert(stmt, "users", &schema, &strategies)
.unwrap();
let result_str = String::from_utf8_lossy(&result);
assert!(result_str.contains("INSERT INTO [users]"));
assert!(result_str.contains("NULL")); assert_eq!(rows, 1);
assert_eq!(cols, 1);
}
#[test]
fn test_format_sql_string_mysql() {
let rewriter = ValueRewriter::new(Some(42), SqlDialect::MySql, "en".to_string());
assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
assert_eq!(rewriter.format_sql_string("it's"), "'it\\'s'");
assert_eq!(rewriter.format_sql_string("line\nbreak"), "'line\\nbreak'");
}
#[test]
fn test_format_sql_string_postgres() {
let rewriter = ValueRewriter::new(Some(42), SqlDialect::Postgres, "en".to_string());
assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
assert_eq!(rewriter.format_sql_string("it's"), "'it''s'");
}
#[test]
fn test_format_sql_string_mssql() {
let rewriter = ValueRewriter::new(Some(42), SqlDialect::Mssql, "en".to_string());
assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
assert_eq!(rewriter.format_sql_string("café"), "N'café'");
}
#[test]
fn test_quote_identifier() {
let mysql = ValueRewriter::new(None, SqlDialect::MySql, "en".to_string());
assert_eq!(mysql.quote_identifier("users"), "`users`");
let pg = ValueRewriter::new(None, SqlDialect::Postgres, "en".to_string());
assert_eq!(pg.quote_identifier("users"), "\"users\"");
let mssql = ValueRewriter::new(None, SqlDialect::Mssql, "en".to_string());
assert_eq!(mssql.quote_identifier("users"), "[users]");
}
}