use regex::Regex;
use std::borrow::Cow;
use std::sync::LazyLock;
static IN_LIST_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?i)IN\s*\(\s*\?(?:\s*,\s*\?)*\s*\)").expect("static regex"));
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SqlNormalized {
pub template: String,
pub params: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum State {
Normal,
InString,
InNumber,
InDoubleQuote,
InDollarQuote,
}
struct Tokenizer<'a> {
query: &'a str,
bytes: &'a [u8],
template: String,
params: Vec<String>,
i: usize,
state: State,
current_value: String,
seen_dot: bool,
has_in_list: bool,
normal_start: usize,
value_start: usize,
dollar_tag: Vec<u8>,
}
const MAX_QUERY_LEN: usize = 65_536;
#[must_use]
pub fn normalize_sql(query: &str) -> SqlNormalized {
let query = if query.len() > MAX_QUERY_LEN {
&query[..query.floor_char_boundary(MAX_QUERY_LEN)]
} else {
query
};
let mut t = Tokenizer {
query,
bytes: query.as_bytes(),
template: String::with_capacity(query.len()),
params: Vec::with_capacity(4),
i: 0,
state: State::Normal,
current_value: String::new(),
seen_dot: false,
has_in_list: false,
normal_start: 0,
value_start: 0,
dollar_tag: Vec::new(),
};
while t.i < t.bytes.len() {
match t.state {
State::Normal => step_normal(&mut t),
State::InString => step_in_string(&mut t),
State::InNumber => step_in_number(&mut t),
State::InDoubleQuote => step_in_double_quote(&mut t),
State::InDollarQuote => step_in_dollar_quote(&mut t),
}
}
flush_pending(&mut t);
collapse_in_lists(t.template, t.has_in_list, t.params)
}
fn step_normal(t: &mut Tokenizer<'_>) {
let b = t.bytes[t.i];
if b == b'\'' {
flush_normal_run(t);
t.state = State::InString;
t.current_value.clear();
t.value_start = t.i + 1; } else if b == b'"' {
t.state = State::InDoubleQuote;
t.i += 1;
return;
} else if b == b'$' && is_dollar_quote_start(t.i, t.bytes) {
let tag = extract_dollar_tag(t.i, t.bytes);
flush_normal_run(t);
let tag_len = tag.len();
t.dollar_tag = tag;
t.state = State::InDollarQuote;
t.current_value.clear();
t.i += tag_len;
t.value_start = t.i; return;
} else if b.is_ascii_digit() && !is_identifier_byte_before(t.i, t.bytes) {
flush_normal_run(t);
t.state = State::InNumber;
t.seen_dot = false;
t.current_value.clear();
t.current_value.push(b as char);
} else if !t.has_in_list {
t.has_in_list = is_in_keyword(t.i, t.bytes);
}
t.i += 1;
}
fn step_in_string(t: &mut Tokenizer<'_>) {
let b = t.bytes[t.i];
if b == b'\'' {
if t.i + 1 < t.bytes.len() && t.bytes[t.i + 1] == b'\'' {
t.current_value.push_str(&t.query[t.value_start..t.i]);
t.current_value.push('\'');
t.i += 2;
t.value_start = t.i;
} else {
t.current_value.push_str(&t.query[t.value_start..t.i]);
t.params.push(std::mem::take(&mut t.current_value));
t.template.push('?');
t.state = State::Normal;
t.i += 1;
t.normal_start = t.i;
}
} else {
t.i += 1;
}
}
fn step_in_number(t: &mut Tokenizer<'_>) {
let b = t.bytes[t.i];
if b.is_ascii_digit() {
t.current_value.push(b as char);
t.i += 1;
} else if b == b'.' && !t.seen_dot {
t.seen_dot = true;
t.current_value.push('.');
t.i += 1;
} else {
t.params.push(std::mem::take(&mut t.current_value));
t.template.push('?');
t.state = State::Normal;
t.normal_start = t.i;
}
}
fn step_in_double_quote(t: &mut Tokenizer<'_>) {
if t.bytes[t.i] == b'"' {
t.state = State::Normal;
}
t.i += 1;
}
fn step_in_dollar_quote(t: &mut Tokenizer<'_>) {
let remaining = &t.bytes[t.i..];
if remaining.starts_with(&t.dollar_tag) {
t.current_value.push_str(&t.query[t.value_start..t.i]);
t.params.push(std::mem::take(&mut t.current_value));
t.template.push('?');
t.i += t.dollar_tag.len();
t.state = State::Normal;
t.normal_start = t.i;
} else {
t.i += 1;
}
}
fn is_dollar_quote_start(i: usize, bytes: &[u8]) -> bool {
if i >= bytes.len() || bytes[i] != b'$' {
return false;
}
if i + 1 < bytes.len() && bytes[i + 1] == b'$' {
return true;
}
let mut j = i + 1;
while j < bytes.len() && (bytes[j].is_ascii_alphanumeric() || bytes[j] == b'_') {
j += 1;
}
j > i + 1 && j < bytes.len() && bytes[j] == b'$'
}
fn extract_dollar_tag(i: usize, bytes: &[u8]) -> Vec<u8> {
if i + 1 < bytes.len() && bytes[i + 1] == b'$' {
return vec![b'$', b'$'];
}
let mut j = i + 1;
while j < bytes.len() && (bytes[j].is_ascii_alphanumeric() || bytes[j] == b'_') {
j += 1;
}
bytes[i..=j].to_vec()
}
fn flush_normal_run(t: &mut Tokenizer<'_>) {
if t.i > t.normal_start {
t.template.push_str(&t.query[t.normal_start..t.i]);
}
}
fn flush_pending(t: &mut Tokenizer<'_>) {
match t.state {
State::InString | State::InDollarQuote => {
t.current_value
.push_str(&t.query[t.value_start..t.bytes.len()]);
t.params.push(std::mem::take(&mut t.current_value));
t.template.push('?');
}
State::InNumber => {
t.params.push(std::mem::take(&mut t.current_value));
t.template.push('?');
}
State::Normal | State::InDoubleQuote => {
let len = t.bytes.len();
if len > t.normal_start {
t.template.push_str(&t.query[t.normal_start..len]);
}
}
}
}
fn is_in_keyword(i: usize, bytes: &[u8]) -> bool {
let b = bytes[i];
(b == b'I' || b == b'i')
&& i + 1 < bytes.len()
&& (bytes[i + 1] == b'N' || bytes[i + 1] == b'n')
&& (i == 0 || bytes[i - 1].is_ascii_whitespace())
&& (i + 2 >= bytes.len() || !bytes[i + 2].is_ascii_alphanumeric())
}
fn collapse_in_lists(template: String, has_in_list: bool, params: Vec<String>) -> SqlNormalized {
let template = if has_in_list {
match IN_LIST_RE.replace_all(&template, "IN (?)") {
Cow::Borrowed(_) => template,
Cow::Owned(s) => s,
}
} else {
template
};
SqlNormalized { template, params }
}
fn is_identifier_byte_before(i: usize, bytes: &[u8]) -> bool {
if i == 0 {
return false;
}
let prev = bytes[i - 1];
prev.is_ascii_alphanumeric() || prev == b'_'
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn numeric_literal() {
let r = normalize_sql("SELECT * FROM order_item WHERE order_id = 42");
assert_eq!(r.template, "SELECT * FROM order_item WHERE order_id = ?");
assert_eq!(r.params, vec!["42"]);
}
#[test]
fn float_literal() {
let r = normalize_sql("SELECT * FROM t WHERE price > 3.14");
assert_eq!(r.template, "SELECT * FROM t WHERE price > ?");
assert_eq!(r.params, vec!["3.14"]);
}
#[test]
fn string_literal() {
let r = normalize_sql("SELECT * FROM users WHERE name = 'Alice'");
assert_eq!(r.template, "SELECT * FROM users WHERE name = ?");
assert_eq!(r.params, vec!["Alice"]);
}
#[test]
fn uuid_in_string() {
let r = normalize_sql("SELECT * FROM t WHERE id = 'a1b2c3d4-e5f6-7890-abcd-ef1234567890'");
assert_eq!(r.template, "SELECT * FROM t WHERE id = ?");
assert_eq!(r.params, vec!["a1b2c3d4-e5f6-7890-abcd-ef1234567890"]);
}
#[test]
fn in_list_collapsed() {
let r = normalize_sql("SELECT * FROM t WHERE id IN (1, 2, 3)");
assert_eq!(r.template, "SELECT * FROM t WHERE id IN (?)");
assert_eq!(r.params, vec!["1", "2", "3"]);
}
#[test]
fn in_list_strings_collapsed() {
let r = normalize_sql("SELECT * FROM t WHERE name IN ('a', 'b', 'c')");
assert_eq!(r.template, "SELECT * FROM t WHERE name IN (?)");
assert_eq!(r.params, vec!["a", "b", "c"]);
}
#[test]
fn escaped_quotes() {
let r = normalize_sql("SELECT * FROM t WHERE name = 'O''Brien'");
assert_eq!(r.template, "SELECT * FROM t WHERE name = ?");
assert_eq!(r.params, vec!["O'Brien"]);
}
#[test]
fn table_names_with_digits_preserved() {
let r = normalize_sql("SELECT * FROM order_item2 WHERE id = 1");
assert_eq!(r.template, "SELECT * FROM order_item2 WHERE id = ?");
assert_eq!(r.params, vec!["1"]);
}
#[test]
fn join_query() {
let r = normalize_sql(
"SELECT p.name FROM order_item p JOIN orders g ON p.order_id = g.id WHERE g.id = 42",
);
assert_eq!(
r.template,
"SELECT p.name FROM order_item p JOIN orders g ON p.order_id = g.id WHERE g.id = ?"
);
assert_eq!(r.params, vec!["42"]);
}
#[test]
fn multiple_params() {
let r = normalize_sql("UPDATE t SET a = 1, b = 'foo' WHERE id = 99");
assert_eq!(r.template, "UPDATE t SET a = ?, b = ? WHERE id = ?");
assert_eq!(r.params, vec!["1", "foo", "99"]);
}
#[test]
fn no_literals() {
let r = normalize_sql("SELECT count(*) FROM users");
assert_eq!(r.template, "SELECT count(*) FROM users");
assert!(r.params.is_empty());
}
#[test]
fn multi_dot_number_rejected() {
let r = normalize_sql("SELECT * FROM t WHERE x = 1.2.3");
assert_eq!(r.params[0], "1.2");
}
#[test]
fn unterminated_string_flushed() {
let r = normalize_sql("SELECT * FROM t WHERE name = 'unterminated");
assert_eq!(r.template, "SELECT * FROM t WHERE name = ?");
assert_eq!(r.params, vec!["unterminated"]);
}
#[test]
fn empty_query() {
let r = normalize_sql("");
assert_eq!(r.template, "");
assert!(r.params.is_empty());
}
#[test]
fn number_at_start_of_query() {
let r = normalize_sql("42");
assert_eq!(r.template, "?");
assert_eq!(r.params, vec!["42"]);
}
#[test]
fn multi_dot_full_template() {
let r = normalize_sql("SELECT * FROM t WHERE x = 1.2.3");
assert_eq!(r.template, "SELECT * FROM t WHERE x = ?.?");
assert_eq!(r.params, vec!["1.2", "3"]);
}
#[test]
fn empty_string_literal() {
let r = normalize_sql("SELECT * FROM t WHERE name = ''");
assert_eq!(r.template, "SELECT * FROM t WHERE name = ?");
assert_eq!(r.params, vec![""]);
}
#[test]
fn digit_in_string_literal() {
let r = normalize_sql("SELECT * FROM t WHERE code = '42'");
assert_eq!(r.template, "SELECT * FROM t WHERE code = ?");
assert_eq!(r.params, vec!["42"]);
}
#[test]
fn underscore_before_digit_preserved() {
let r = normalize_sql("SELECT col_1 FROM t");
assert_eq!(r.template, "SELECT col_1 FROM t");
assert!(r.params.is_empty());
}
#[test]
fn number_only_query_at_eof() {
let r = normalize_sql("SELECT * FROM t LIMIT 100");
assert_eq!(r.template, "SELECT * FROM t LIMIT ?");
assert_eq!(r.params, vec!["100"]);
}
#[test]
fn cow_borrowed_path_no_in_list() {
let r = normalize_sql("SELECT 1");
assert_eq!(r.template, "SELECT ?");
assert_eq!(r.params, vec!["1"]);
}
#[test]
fn negative_number_not_collapsed() {
let r = normalize_sql("SELECT * FROM t WHERE x = -5");
assert_eq!(r.template, "SELECT * FROM t WHERE x = -?");
assert_eq!(r.params, vec!["5"]);
}
#[test]
fn cte_basic() {
let r = normalize_sql(
"WITH active AS (SELECT id FROM users WHERE status = 'active') \
SELECT * FROM orders WHERE user_id IN (SELECT id FROM active) AND total > 100",
);
assert_eq!(
r.template,
"WITH active AS (SELECT id FROM users WHERE status = ?) \
SELECT * FROM orders WHERE user_id IN (SELECT id FROM active) AND total > ?"
);
assert_eq!(r.params, vec!["active", "100"]);
}
#[test]
fn cte_nested() {
let r = normalize_sql(
"WITH a AS (SELECT 1), b AS (SELECT * FROM a WHERE x = 'test') \
SELECT * FROM b WHERE id = 42",
);
assert_eq!(
r.template,
"WITH a AS (SELECT ?), b AS (SELECT * FROM a WHERE x = ?) \
SELECT * FROM b WHERE id = ?"
);
assert_eq!(r.params, vec!["1", "test", "42"]);
}
#[test]
fn double_quoted_identifier_preserved() {
let r = normalize_sql(r#"SELECT * FROM "MyTable" WHERE "Column" = 42"#);
assert_eq!(r.template, r#"SELECT * FROM "MyTable" WHERE "Column" = ?"#);
assert_eq!(r.params, vec!["42"]);
}
#[test]
fn double_quoted_with_digits_preserved() {
let r = normalize_sql(r#"SELECT * FROM "table_2" WHERE "col_3" = 'value'"#);
assert_eq!(r.template, r#"SELECT * FROM "table_2" WHERE "col_3" = ?"#);
assert_eq!(r.params, vec!["value"]);
}
#[test]
fn dollar_quote_basic() {
let r = normalize_sql("SELECT $$hello world$$ AS greeting");
assert_eq!(r.template, "SELECT ? AS greeting");
assert_eq!(r.params, vec!["hello world"]);
}
#[test]
fn dollar_quote_tagged() {
let r = normalize_sql("SELECT $tag$some body$tag$ AS body");
assert_eq!(r.template, "SELECT ? AS body");
assert_eq!(r.params, vec!["some body"]);
}
#[test]
fn dollar_quote_in_function() {
let r = normalize_sql(
"CREATE FUNCTION foo() RETURNS void AS $$ BEGIN RAISE NOTICE 'hi'; END; $$ LANGUAGE plpgsql",
);
assert_eq!(
r.template,
"CREATE FUNCTION foo() RETURNS void AS ? LANGUAGE plpgsql"
);
}
#[test]
fn call_with_params() {
let r = normalize_sql("CALL process_order(42, 'rush', NOW())");
assert_eq!(r.template, "CALL process_order(?, ?, NOW())");
assert_eq!(r.params, vec!["42", "rush"]);
}
#[test]
fn call_with_interval() {
let r = normalize_sql("CALL schedule_task(1, INTERVAL '2 days')");
assert_eq!(r.template, "CALL schedule_task(?, INTERVAL ?)");
assert_eq!(r.params, vec!["1", "2 days"]);
}
#[test]
fn utf8_in_string_literal() {
let r = normalize_sql("SELECT * FROM t WHERE name = 'caf\u{00e9}'");
assert_eq!(r.template, "SELECT * FROM t WHERE name = ?");
assert_eq!(r.params, vec!["caf\u{00e9}"]);
}
#[test]
fn utf8_emoji_in_string_literal() {
let r = normalize_sql("INSERT INTO t (msg) VALUES ('\u{1F600} hello')");
assert_eq!(r.template, "INSERT INTO t (msg) VALUES (?)");
assert_eq!(r.params, vec!["\u{1F600} hello"]);
}
#[test]
fn utf8_cjk_in_string_literal() {
let r = normalize_sql("SELECT * FROM t WHERE name = '\u{4F60}\u{597D}'");
assert_eq!(r.template, "SELECT * FROM t WHERE name = ?");
assert_eq!(r.params, vec!["\u{4F60}\u{597D}"]);
}
#[test]
fn utf8_in_dollar_quote() {
let r = normalize_sql("SELECT $$caf\u{00e9} au lait$$ AS drink");
assert_eq!(r.template, "SELECT ? AS drink");
assert_eq!(r.params, vec!["caf\u{00e9} au lait"]);
}
#[test]
fn utf8_with_escaped_quotes() {
let r = normalize_sql("SELECT * FROM t WHERE name = 'caf\u{00e9} d''or'");
assert_eq!(r.template, "SELECT * FROM t WHERE name = ?");
assert_eq!(r.params, vec!["caf\u{00e9} d'or"]);
}
#[test]
fn line_comment_passes_through() {
let r = normalize_sql("SELECT 1 -- this is a comment");
assert_eq!(r.template, "SELECT ? -- this is a comment");
assert_eq!(r.params, vec!["1"]);
}
#[test]
fn block_comment_passes_through() {
let r = normalize_sql("SELECT /* comment */ 1 FROM t");
assert_eq!(r.template, "SELECT /* comment */ ? FROM t");
}
#[test]
fn comment_inside_string_not_treated_as_comment() {
let r = normalize_sql("SELECT * FROM t WHERE name = 'value -- not a comment'");
assert_eq!(r.template, "SELECT * FROM t WHERE name = ?");
assert_eq!(r.params, vec!["value -- not a comment"]);
}
#[test]
fn unterminated_dollar_quote_flushed() {
let r = normalize_sql("SELECT $$incomplete");
assert_eq!(r.template, "SELECT ?");
assert_eq!(r.params, vec!["incomplete"]);
}
#[test]
fn unterminated_double_quote_flushed() {
let r = normalize_sql("SELECT \"unterminated");
assert_eq!(r.template, "SELECT \"unterminated");
assert!(r.params.is_empty());
}
#[test]
fn double_quoted_identifier_with_digits() {
let r = normalize_sql(r#"SELECT "col123" FROM "table456" WHERE id = 1"#);
assert_eq!(
r.template,
r#"SELECT "col123" FROM "table456" WHERE id = ?"#
);
assert_eq!(r.params, vec!["1"]);
}
#[test]
fn empty_dollar_quoted_string() {
let r = normalize_sql("SELECT $$$$ AS empty");
assert_eq!(r.template, "SELECT ? AS empty");
assert_eq!(r.params, vec![""]);
}
#[test]
fn long_query_truncated_at_max() {
let long_query = format!("SELECT * FROM t WHERE name = '{}'", "a".repeat(70_000));
let r = normalize_sql(&long_query);
assert!(r.template.len() <= 65_536 + 10); }
#[test]
fn whitespace_only_string_literal() {
let r = normalize_sql("SELECT * FROM t WHERE name = ' '");
assert_eq!(r.template, "SELECT * FROM t WHERE name = ?");
assert_eq!(r.params, vec![" "]);
}
#[test]
fn four_consecutive_quotes() {
let r = normalize_sql("SELECT * FROM t WHERE name = ''''");
assert_eq!(r.template, "SELECT * FROM t WHERE name = ?");
assert_eq!(r.params, vec!["'"]);
}
}