use crate::storage::query::lexer::{Lexer, Token};
use crate::storage::schema::Value;
pub fn normalize_cache_key(sql: &str) -> String {
let mut out = String::with_capacity(sql.len());
let bytes = sql.as_bytes();
let mut i = 0;
let mut last_was_space = true; let mut preserve_numeric_literal = false;
while i < bytes.len() {
let b = bytes[i];
if b.is_ascii_whitespace() {
if !last_was_space {
out.push(' ');
last_was_space = true;
}
i += 1;
continue;
}
if b == b'\'' {
i += 1;
while i < bytes.len() {
if bytes[i] == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2;
continue;
}
i += 1;
break;
}
i += 1;
}
out.push('?');
last_was_space = false;
continue;
}
if b == b'"' {
let start = i;
i += 1;
while i < bytes.len() && bytes[i] != b'"' {
i += 1;
}
if i < bytes.len() {
i += 1;
}
out.push_str(&sql[start..i]);
last_was_space = false;
continue;
}
if b.is_ascii_digit() {
let start = i;
while i < bytes.len()
&& (bytes[i].is_ascii_digit()
|| bytes[i] == b'.'
|| bytes[i] == b'e'
|| bytes[i] == b'E'
|| bytes[i] == b'+'
|| bytes[i] == b'-')
{
if bytes[i] == b'+' || bytes[i] == b'-' {
let prev = if i > 0 { bytes[i - 1] } else { 0 };
if prev != b'e' && prev != b'E' {
break;
}
}
i += 1;
}
if preserve_numeric_literal {
out.push_str(&sql[start..i]);
preserve_numeric_literal = false;
} else {
out.push('?');
}
last_was_space = false;
continue;
}
if b.is_ascii_alphabetic() || b == b'_' {
let start = i;
while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
i += 1;
}
let word = &sql[start..i];
if word.eq_ignore_ascii_case("true")
|| word.eq_ignore_ascii_case("false")
|| word.eq_ignore_ascii_case("null")
{
out.push('?');
preserve_numeric_literal = false;
} else {
for c in word.chars() {
out.push(c.to_ascii_uppercase());
}
preserve_numeric_literal =
word.eq_ignore_ascii_case("limit") || word.eq_ignore_ascii_case("offset");
}
last_was_space = false;
continue;
}
out.push(b as char);
preserve_numeric_literal = false;
last_was_space = false;
i += 1;
}
if out.ends_with(' ') {
out.pop();
}
out
}
pub fn same_cache_key(a: &str, b: &str) -> bool {
normalize_cache_key(a) == normalize_cache_key(b)
}
pub fn normalize_and_extract(sql: &str) -> (String, Vec<Value>) {
let mut out = String::with_capacity(sql.len());
let mut binds: Vec<Value> = Vec::new();
let bytes = sql.as_bytes();
let mut i = 0;
let mut last_was_space = true;
let mut preserve_numeric_literal = false;
while i < bytes.len() {
let b = bytes[i];
if b.is_ascii_whitespace() {
if !last_was_space {
out.push(' ');
last_was_space = true;
}
i += 1;
continue;
}
if b == b'\'' {
i += 1;
let body_start = i;
let mut literal: Option<String> = None;
while i < bytes.len() {
if bytes[i] == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
let acc = literal.get_or_insert_with(|| sql[body_start..i].to_string());
acc.push('\'');
i += 2;
continue;
}
break;
}
if let Some(ref mut acc) = literal {
acc.push(bytes[i] as char);
}
i += 1;
}
let value = match literal {
Some(s) => s,
None => sql[body_start..i].to_string(),
};
if i < bytes.len() && bytes[i] == b'\'' {
i += 1;
}
binds.push(Value::text(value));
out.push('?');
last_was_space = false;
continue;
}
if b == b'"' {
let start = i;
i += 1;
while i < bytes.len() && bytes[i] != b'"' {
i += 1;
}
if i < bytes.len() {
i += 1;
}
out.push_str(&sql[start..i]);
last_was_space = false;
continue;
}
if b.is_ascii_digit() {
let start = i;
while i < bytes.len()
&& (bytes[i].is_ascii_digit()
|| bytes[i] == b'.'
|| bytes[i] == b'e'
|| bytes[i] == b'E'
|| bytes[i] == b'+'
|| bytes[i] == b'-')
{
if bytes[i] == b'+' || bytes[i] == b'-' {
let prev = if i > 0 { bytes[i - 1] } else { 0 };
if prev != b'e' && prev != b'E' {
break;
}
}
i += 1;
}
let lit = &sql[start..i];
if preserve_numeric_literal {
out.push_str(lit);
preserve_numeric_literal = false;
} else {
out.push('?');
if lit.contains('.') || lit.contains('e') || lit.contains('E') {
if let Ok(v) = lit.parse::<f64>() {
binds.push(Value::Float(v));
}
} else if let Ok(v) = lit.parse::<i64>() {
binds.push(Value::Integer(v));
} else if let Ok(v) = lit.parse::<u64>() {
binds.push(Value::UnsignedInteger(v));
}
}
last_was_space = false;
continue;
}
if b.is_ascii_alphabetic() || b == b'_' {
let start = i;
while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
i += 1;
}
let word = &sql[start..i];
if word.eq_ignore_ascii_case("true") {
out.push('?');
binds.push(Value::Boolean(true));
preserve_numeric_literal = false;
} else if word.eq_ignore_ascii_case("false") {
out.push('?');
binds.push(Value::Boolean(false));
preserve_numeric_literal = false;
} else if word.eq_ignore_ascii_case("null") {
out.push('?');
binds.push(Value::Null);
preserve_numeric_literal = false;
} else {
for c in word.chars() {
out.push(c.to_ascii_uppercase());
}
preserve_numeric_literal =
word.eq_ignore_ascii_case("limit") || word.eq_ignore_ascii_case("offset");
}
last_was_space = false;
continue;
}
out.push(b as char);
preserve_numeric_literal = false;
last_was_space = false;
i += 1;
}
if out.ends_with(' ') {
out.pop();
}
(out, binds)
}
pub fn extract_literal_bindings(sql: &str) -> Result<Vec<Value>, String> {
let mut lexer = Lexer::new(sql);
let mut binds = Vec::new();
let mut skip_next_numeric = false;
loop {
let spanned = lexer.next_token().map_err(|err| err.to_string())?;
match spanned.token {
Token::Eof => break,
Token::Limit | Token::Offset => {
skip_next_numeric = true;
}
Token::Integer(n) => {
if !skip_next_numeric {
binds.push(Value::Integer(n));
}
skip_next_numeric = false;
}
Token::Float(n) => {
if !skip_next_numeric {
binds.push(Value::Float(n));
}
skip_next_numeric = false;
}
Token::String(s) => {
binds.push(Value::text(s));
skip_next_numeric = false;
}
Token::True => {
binds.push(Value::Boolean(true));
skip_next_numeric = false;
}
Token::False => {
binds.push(Value::Boolean(false));
skip_next_numeric = false;
}
Token::Null => {
binds.push(Value::Null);
skip_next_numeric = false;
}
_ => {
skip_next_numeric = false;
}
}
}
Ok(binds)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn integer_literals_collapse() {
assert_eq!(
normalize_cache_key("SELECT * FROM t WHERE id = 1"),
normalize_cache_key("SELECT * FROM t WHERE id = 2"),
);
}
#[test]
fn string_literals_collapse() {
assert_eq!(
normalize_cache_key("SELECT * FROM t WHERE name = 'alice'"),
normalize_cache_key("SELECT * FROM t WHERE name = 'bob'"),
);
}
#[test]
fn case_insensitive_keywords() {
assert_eq!(
normalize_cache_key("select * from t"),
normalize_cache_key("SELECT * FROM t"),
);
}
#[test]
fn whitespace_collapses() {
assert_eq!(
normalize_cache_key("SELECT * FROM t"),
normalize_cache_key("SELECT * FROM t"),
);
}
#[test]
fn different_shape_different_key() {
assert_ne!(
normalize_cache_key("SELECT * FROM a WHERE x = 1"),
normalize_cache_key("SELECT * FROM b WHERE x = 1"),
);
}
#[test]
fn float_and_scientific_collapse() {
assert_eq!(
normalize_cache_key("SELECT 1.5e10"),
normalize_cache_key("SELECT 3.14"),
);
}
#[test]
fn null_and_boolean_are_literals() {
assert_eq!(
normalize_cache_key("WHERE x IS NULL"),
normalize_cache_key("WHERE x IS TRUE"),
);
}
#[test]
fn quoted_identifiers_preserved() {
assert_ne!(
normalize_cache_key(r#"SELECT "col" FROM t"#),
normalize_cache_key(r#"SELECT "other" FROM t"#),
);
}
#[test]
fn limit_and_offset_literals_remain_in_shape() {
assert_ne!(
normalize_cache_key("SELECT * FROM t WHERE id = 1 LIMIT 10"),
normalize_cache_key("SELECT * FROM t WHERE id = 2 LIMIT 20"),
);
assert_ne!(
normalize_cache_key("SELECT * FROM t WHERE id = 1 OFFSET 10"),
normalize_cache_key("SELECT * FROM t WHERE id = 2 OFFSET 20"),
);
}
#[test]
fn normalize_and_extract_agrees_with_separate_paths() {
let queries = [
"SELECT * FROM users WHERE id = 42",
"UPDATE users SET score = 99.5 WHERE city = 'NYC' AND age > 30",
"DELETE FROM t WHERE name = 'al''ice' AND active = TRUE",
"SELECT 1, 'x', 2.5, NULL, FALSE FROM t",
"SELECT * FROM t LIMIT 10 OFFSET 5",
];
for q in queries {
let (fk, fb) = normalize_and_extract(q);
assert_eq!(fk, normalize_cache_key(q), "cache_key mismatch for: {q}");
let sep = extract_literal_bindings(q).unwrap();
assert_eq!(
fb.len(),
sep.len(),
"bind count mismatch for {q}: fused={:?} sep={:?}",
fb,
sep
);
for (a, b) in fb.iter().zip(sep.iter()) {
assert_eq!(format!("{a:?}"), format!("{b:?}"), "bind mismatch for {q}");
}
}
}
#[test]
fn extract_literal_bindings_skips_limit_and_offset() {
let binds =
extract_literal_bindings("SELECT * FROM t WHERE age = 18 AND active = true LIMIT 10")
.unwrap();
assert_eq!(binds, vec![Value::Integer(18), Value::Boolean(true)]);
}
}