Skip to main content

dm_database_sqllog2db/stats/
normalize.rs

1//! SQL 字面量标准化:将 SQL 文本中的字符串字面量(单引号包裹,含 `''` 转义)
2//! 和数字字面量(整数、浮点)替换为 `?` 占位符。
3//!
4//! 标识符中的数字(如 `col1`、`table2`)保持原样,不被替换。
5//! 用于把参数不同但模板相同的 SQL 调用归并为同一字符串键(Phase 51/52 统计聚合)。
6
7/// 将 SQL 文本中的字符串字面量(单引号包裹,含 `''` 转义)和数字字面量
8/// (整数、浮点)替换为 `?` 占位符。标识符中的数字(如 `col1`)保持原样。
9///
10/// 用于把参数不同但模板相同的 SQL 调用归并为同一字符串键。
11///
12/// # Panics
13///
14/// 不会在实践中 panic:输出字节要么来自 UTF-8 输入的原样复制,要么是 ASCII
15/// 字节 `b'?'`(单字节 ASCII 不会破坏多字节 UTF-8 序列)。`expect` 是内部
16/// 一致性断言,正常情况下不会触发。
17#[must_use]
18pub fn normalize_sql(sql: &str) -> String {
19    let bytes = sql.as_bytes();
20    let len = bytes.len();
21    let mut output = Vec::with_capacity(len);
22    let mut cursor = 0usize;
23    let mut prev_was_ident_char = false;
24
25    while cursor < len {
26        let byte = bytes[cursor];
27        match byte {
28            b'\'' => {
29                cursor = skip_string_literal(bytes, cursor + 1, len);
30                output.push(b'?');
31                prev_was_ident_char = false;
32            }
33            byte_val if byte_val.is_ascii_digit() && !prev_was_ident_char => {
34                cursor = skip_number_literal(bytes, cursor, len);
35                output.push(b'?');
36                prev_was_ident_char = false;
37            }
38            b'-' | b'+'
39                if !prev_was_ident_char
40                    && cursor + 1 < len
41                    && bytes[cursor + 1].is_ascii_digit() =>
42            {
43                // 负号或正号紧跟数字时,整体视为一个带符号的数字字面量,用单个 `?` 替换
44                cursor = skip_number_literal(bytes, cursor + 1, len);
45                output.push(b'?');
46                prev_was_ident_char = false;
47            }
48            _ => {
49                output.push(byte);
50                prev_was_ident_char = byte.is_ascii_alphanumeric() || byte == b'_' || byte == b'$';
51                cursor += 1;
52            }
53        }
54    }
55
56    String::from_utf8(output).expect("normalize_sql produced invalid UTF-8")
57}
58
59/// 跳过单引号字符串字面量,处理 `''` 转义引号。
60///
61/// `start` 是第一个开始引号之后的位置(即字符串内容起始处)。
62/// 返回字符串结束后的下一个游标位置(即结束引号 `'` 之后)。
63/// 若字符串未闭合,则返回 `len`(字节末尾)。
64fn skip_string_literal(bytes: &[u8], start: usize, len: usize) -> usize {
65    let mut cursor = start;
66    loop {
67        let Some(relative_pos) = memchr::memchr(b'\'', &bytes[cursor..]) else {
68            // 未闭合字符串——跳到末尾
69            return len;
70        };
71        cursor += relative_pos + 1;
72        if cursor < len && bytes[cursor] == b'\'' {
73            // `''` 转义引号——继续扫描字符串内容
74            cursor += 1;
75        } else {
76            // 字符串正常结束
77            return cursor;
78        }
79    }
80}
81
82/// 跳过数字字面量(整数或浮点数)。
83///
84/// `start` 是数字字面量的第一个数字位置。
85/// 返回数字字面量结束后的下一个游标位置。
86/// 支持浮点格式:整数部分后跟 `.` 再跟至少一个数字(如 `3.14`)。
87fn skip_number_literal(bytes: &[u8], start: usize, len: usize) -> usize {
88    let mut cursor = start;
89    // 跳过整数部分
90    while cursor < len && bytes[cursor].is_ascii_digit() {
91        cursor += 1;
92    }
93    // 处理浮点小数部分:`.` 后必须跟数字才算浮点
94    if cursor + 1 < len && bytes[cursor] == b'.' && bytes[cursor + 1].is_ascii_digit() {
95        cursor += 1; // 跳过 `.`
96        while cursor < len && bytes[cursor].is_ascii_digit() {
97            cursor += 1;
98        }
99    }
100    cursor
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_basic_where_number_and_string() {
109        assert_eq!(
110            normalize_sql("SELECT * FROM t WHERE id = 42 AND name = 'alice'"),
111            "SELECT * FROM t WHERE id = ? AND name = ?"
112        );
113    }
114
115    #[test]
116    fn test_multiple_numeric_literals() {
117        assert_eq!(
118            normalize_sql("INSERT INTO t VALUES (1, 2, 3)"),
119            "INSERT INTO t VALUES (?, ?, ?)"
120        );
121    }
122
123    #[test]
124    fn test_escaped_quote_in_string() {
125        assert_eq!(normalize_sql("WHERE name = 'O''Brien'"), "WHERE name = ?");
126    }
127
128    #[test]
129    fn test_no_literals_unchanged() {
130        let sql_with_placeholder = "SELECT col FROM t WHERE id = ?";
131        assert_eq!(normalize_sql(sql_with_placeholder), sql_with_placeholder);
132
133        let sql_plain = "SELECT col FROM t";
134        assert_eq!(normalize_sql(sql_plain), sql_plain);
135    }
136
137    #[test]
138    fn test_insert_multiple_columns_with_float() {
139        assert_eq!(
140            normalize_sql("INSERT INTO orders (id, name, amount) VALUES (100, 'test', 3.14)"),
141            "INSERT INTO orders (id, name, amount) VALUES (?, ?, ?)"
142        );
143    }
144
145    #[test]
146    fn test_identifier_with_digits_not_replaced() {
147        assert_eq!(
148            normalize_sql("SELECT col1, table2 FROM t WHERE id = 1"),
149            "SELECT col1, table2 FROM t WHERE id = ?"
150        );
151    }
152
153    #[test]
154    fn test_unclosed_string_does_not_panic() {
155        // 未闭合字符串:进入字符串状态后到达末尾,应产生单个 `?` 且不 panic
156        let result = normalize_sql("SELECT 'unclosed");
157        assert_eq!(result, "SELECT ?");
158    }
159}