use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use regex::Regex;
use once_cell::sync::Lazy;
#[derive(Debug, Clone)]
pub struct NormalizedQuery {
pub fingerprint: String,
pub hash: u64,
pub tables: Vec<String>,
pub parameters: Vec<String>,
}
impl NormalizedQuery {
pub fn fingerprint(&self) -> &str {
&self.fingerprint
}
pub fn hash(&self) -> u64 {
self.hash
}
pub fn tables(&self) -> &[String] {
&self.tables
}
}
#[derive(Debug, Clone)]
pub struct QueryNormalizer {
preserve_order: bool,
}
static STRING_LITERAL: Lazy<Regex> = Lazy::new(|| {
Regex::new(r#"'(?:[^'\\]|\\.)*'"#).unwrap()
});
static DOUBLE_QUOTED: Lazy<Regex> = Lazy::new(|| {
Regex::new(r#""(?:[^"\\]|\\.)*""#).unwrap()
});
static NUMBER_LITERAL: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"\b\d+(?:\.\d+)?(?:e[+-]?\d+)?\b").unwrap()
});
static WHITESPACE: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"\s+").unwrap()
});
static TABLE_PATTERN: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"(?i)(?:FROM|JOIN|INTO|UPDATE|TABLE)\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)").unwrap()
});
static HINT_PATTERN: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"/\*[^*]*\*/").unwrap()
});
static COMMENT_PATTERN: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"--[^\n]*").unwrap()
});
impl QueryNormalizer {
pub fn new() -> Self {
Self {
preserve_order: true,
}
}
pub fn unordered() -> Self {
Self {
preserve_order: false,
}
}
pub fn normalize(&self, sql: &str) -> NormalizedQuery {
let mut parameters = Vec::new();
let sql = HINT_PATTERN.replace_all(sql, "");
let sql = COMMENT_PATTERN.replace_all(&sql, "");
let tables = self.extract_tables(&sql);
let sql = STRING_LITERAL.replace_all(&sql, |caps: ®ex::Captures| {
let value = caps.get(0).unwrap().as_str();
let inner = &value[1..value.len()-1];
parameters.push(inner.to_string());
"?"
});
let sql = NUMBER_LITERAL.replace_all(&sql, |caps: ®ex::Captures| {
let value = caps.get(0).unwrap().as_str();
parameters.push(value.to_string());
"?"
});
let sql = WHITESPACE.replace_all(&sql, " ");
let fingerprint = sql.trim().to_uppercase();
let mut hasher = DefaultHasher::new();
fingerprint.hash(&mut hasher);
if self.preserve_order {
for param in ¶meters {
param.hash(&mut hasher);
}
}
let hash = hasher.finish();
NormalizedQuery {
fingerprint,
hash,
tables,
parameters,
}
}
fn extract_tables(&self, sql: &str) -> Vec<String> {
let mut tables = Vec::new();
for cap in TABLE_PATTERN.captures_iter(sql) {
if let Some(table_match) = cap.get(1) {
let table = table_match.as_str().to_lowercase();
let table_name = table.split('.').last().unwrap_or(&table);
if !tables.contains(&table_name.to_string()) {
tables.push(table_name.to_string());
}
}
}
tables
}
pub fn fingerprint(&self, sql: &str) -> String {
let sql = HINT_PATTERN.replace_all(sql, "");
let sql = COMMENT_PATTERN.replace_all(&sql, "");
let sql = STRING_LITERAL.replace_all(&sql, "?");
let sql = NUMBER_LITERAL.replace_all(&sql, "?");
let sql = WHITESPACE.replace_all(&sql, " ");
sql.trim().to_uppercase()
}
pub fn are_equivalent(&self, sql1: &str, sql2: &str) -> bool {
self.fingerprint(sql1) == self.fingerprint(sql2)
}
}
impl Default for QueryNormalizer {
fn default() -> Self {
Self::new()
}
}
pub fn quick_fingerprint(sql: &str) -> u64 {
let normalized = QueryNormalizer::new().fingerprint(sql);
let mut hasher = DefaultHasher::new();
normalized.hash(&mut hasher);
hasher.finish()
}
pub fn extract_tables(sql: &str) -> Vec<String> {
let normalizer = QueryNormalizer::new();
normalizer.extract_tables(sql)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_simple_query() {
let normalizer = QueryNormalizer::new();
let query = "SELECT * FROM users WHERE id = 123";
let normalized = normalizer.normalize(query);
assert_eq!(normalized.fingerprint, "SELECT * FROM USERS WHERE ID = ?");
assert_eq!(normalized.parameters, vec!["123"]);
assert_eq!(normalized.tables, vec!["users"]);
}
#[test]
fn test_normalize_string_literals() {
let normalizer = QueryNormalizer::new();
let query = "SELECT * FROM users WHERE name = 'John Doe'";
let normalized = normalizer.normalize(query);
assert_eq!(normalized.fingerprint, "SELECT * FROM USERS WHERE NAME = ?");
assert_eq!(normalized.parameters, vec!["John Doe"]);
}
#[test]
fn test_normalize_multiple_parameters() {
let normalizer = QueryNormalizer::new();
let query = "SELECT * FROM users WHERE age > 18 AND status = 'active' AND score < 100";
let normalized = normalizer.normalize(query);
assert_eq!(normalized.fingerprint, "SELECT * FROM USERS WHERE AGE > ? AND STATUS = ? AND SCORE < ?");
assert_eq!(normalized.parameters.len(), 3);
}
#[test]
fn test_extract_tables_join() {
let normalizer = QueryNormalizer::new();
let query = "SELECT u.*, o.* FROM users u JOIN orders o ON u.id = o.user_id";
let normalized = normalizer.normalize(query);
assert!(normalized.tables.contains(&"users".to_string()));
assert!(normalized.tables.contains(&"orders".to_string()));
}
#[test]
fn test_normalize_removes_comments() {
let normalizer = QueryNormalizer::new();
let query = "/* helios:cache_ttl=60 */ SELECT * FROM users -- inline comment\nWHERE id = 1";
let normalized = normalizer.normalize(query);
assert_eq!(normalized.fingerprint, "SELECT * FROM USERS WHERE ID = ?");
}
#[test]
fn test_normalize_whitespace() {
let normalizer = QueryNormalizer::new();
let query1 = "SELECT * FROM users WHERE id=1";
let query2 = "SELECT * FROM users WHERE id=1";
assert_eq!(
normalizer.fingerprint(query1),
normalizer.fingerprint(query2)
);
}
#[test]
fn test_equivalent_queries() {
let normalizer = QueryNormalizer::new();
let query1 = "SELECT * FROM users WHERE id = 123";
let query2 = "SELECT * FROM users WHERE id = 456";
assert!(normalizer.are_equivalent(query1, query2));
let query3 = "SELECT * FROM users WHERE name = 'test'";
assert!(!normalizer.are_equivalent(query1, query3));
}
#[test]
fn test_hash_consistency() {
let normalizer = QueryNormalizer::new();
let query1 = "SELECT * FROM users WHERE id = 1";
let query2 = "SELECT * FROM users WHERE id = 1";
let norm1 = normalizer.normalize(query1);
let norm2 = normalizer.normalize(query2);
assert_eq!(norm1.hash, norm2.hash);
}
#[test]
fn test_hash_different_params() {
let normalizer = QueryNormalizer::new();
let query1 = "SELECT * FROM users WHERE id = 1";
let query2 = "SELECT * FROM users WHERE id = 2";
let norm1 = normalizer.normalize(query1);
let norm2 = normalizer.normalize(query2);
assert_ne!(norm1.hash, norm2.hash);
}
#[test]
fn test_unordered_normalizer() {
let normalizer = QueryNormalizer::unordered();
let query1 = "SELECT * FROM users WHERE id = 1";
let query2 = "SELECT * FROM users WHERE id = 2";
let norm1 = normalizer.normalize(query1);
let norm2 = normalizer.normalize(query2);
assert_eq!(norm1.fingerprint, norm2.fingerprint);
}
#[test]
fn test_extract_tables_various() {
let normalizer = QueryNormalizer::new();
let queries = vec![
("INSERT INTO users VALUES (1)", vec!["users"]),
("UPDATE products SET price = 10", vec!["products"]),
("DELETE FROM orders WHERE id = 1", vec!["orders"]),
("SELECT * FROM schema.table", vec!["table"]),
("TABLE users", vec!["users"]),
];
for (sql, expected_tables) in queries {
let normalized = normalizer.normalize(sql);
for table in expected_tables {
assert!(
normalized.tables.contains(&table.to_string()),
"Query '{}' should contain table '{}'",
sql,
table
);
}
}
}
#[test]
fn test_decimal_numbers() {
let normalizer = QueryNormalizer::new();
let query = "SELECT * FROM products WHERE price < 99.99 AND rating > 4.5";
let normalized = normalizer.normalize(query);
assert!(normalized.parameters.contains(&"99.99".to_string()));
assert!(normalized.parameters.contains(&"4.5".to_string()));
}
#[test]
fn test_scientific_notation() {
let normalizer = QueryNormalizer::new();
let query = "SELECT * FROM data WHERE value = 1e10";
let normalized = normalizer.normalize(query);
assert!(normalized.fingerprint.contains("VALUE = ?"));
}
#[test]
fn test_quick_fingerprint() {
let hash1 = quick_fingerprint("SELECT * FROM users WHERE id = 1");
let hash2 = quick_fingerprint("SELECT * FROM users WHERE id = 2");
assert_eq!(hash1, hash2);
}
}