1use serde::{Deserialize, Serialize};
2use sqlparser::{
3 ast::{Expr, Value, VisitMut, VisitorMut},
4 dialect::PostgreSqlDialect,
5 parser::Parser,
6};
7
8use crate::PgLogstatsError;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum QueryType {
13 Select,
15 Insert,
17 Update,
19 Delete,
21 DDL,
23 Other,
25}
26
27impl std::fmt::Display for QueryType {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 QueryType::Select => write!(f, "SELECT"),
31 QueryType::Insert => write!(f, "INSERT"),
32 QueryType::Update => write!(f, "UPDATE"),
33 QueryType::Delete => write!(f, "DELETE"),
34 QueryType::DDL => write!(f, "DDL"),
35 QueryType::Other => write!(f, "OTHER"),
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Query {
42 pub sql: String,
43 pub query_type: QueryType,
44 pub normalized_query: String,
45}
46
47impl Query {
48 pub fn from_sql(sql: &str) -> Result<Vec<Query>, PgLogstatsError> {
50 let dialect = PostgreSqlDialect {};
51 let ast = Parser::parse_sql(&dialect, sql).map_err(|e| PgLogstatsError::Parse {
52 message: format!("Failed to parse SQL: {}", e),
53 line_number: None,
54 line_content: Some(sql.to_string()),
55 })?;
56
57 let mut queries = Vec::new();
58 for stmt in &ast {
59 let query_type = Query::query_type_from_statement(stmt);
60 let normalized_query = Query::normalize_query(std::slice::from_ref(stmt))
61 .unwrap_or_else(|_| stmt.to_string());
62 queries.push(Query {
63 sql: stmt.to_string(),
64 query_type,
65 normalized_query,
66 });
67 }
68 Ok(queries)
69 }
70
71 fn query_type_from_statement(stmt: &sqlparser::ast::Statement) -> QueryType {
72 use sqlparser::ast::Statement::*;
73 match stmt {
74 Query(_) => QueryType::Select,
75 Insert { .. } => QueryType::Insert,
76 Update { .. } => QueryType::Update,
77 Delete { .. } => QueryType::Delete,
78 CreateTable { .. }
79 | CreateView { .. }
80 | CreateIndex { .. }
81 | CreateSchema { .. }
82 | CreateDatabase { .. }
83 | Drop { .. }
84 | AlterTable { .. }
85 | Truncate { .. } => QueryType::DDL,
86 _ => QueryType::Other,
87 }
88 }
89
90 fn normalize_query(ast: &[sqlparser::ast::Statement]) -> Result<String, PgLogstatsError> {
92 if ast.is_empty() {
93 return Ok("".to_string());
94 }
95
96 let mut ast = ast.to_owned();
98
99 let mut normalizer = LiteralNormalizer;
100 for stmt in &mut ast {
101 let _ = stmt.visit(&mut normalizer);
102 }
103
104 let normalized_sql = ast
105 .iter()
106 .map(|stmt| stmt.to_string())
107 .collect::<Vec<_>>()
108 .join("; ");
109
110 Ok(normalized_sql)
111 }
112}
113
114struct LiteralNormalizer;
116
117impl VisitorMut for LiteralNormalizer {
118 type Break = ();
119
120 fn pre_visit_expr(&mut self, _expr: &mut Expr) -> std::ops::ControlFlow<Self::Break> {
121 std::ops::ControlFlow::Continue(())
123 }
124
125 fn post_visit_expr(&mut self, expr: &mut Expr) -> std::ops::ControlFlow<Self::Break> {
126 match expr {
127 Expr::Value(Value::Number(_, _))
129 | Expr::Value(Value::SingleQuotedString(_))
130 | Expr::Value(Value::DoubleQuotedString(_))
131 | Expr::Value(Value::Boolean(_))
132 | Expr::Value(Value::Null) => {
133 *expr = Expr::Value(Value::Placeholder("?".to_string()));
134 }
135
136 Expr::Value(Value::Placeholder(_)) => {
138 *expr = Expr::Value(Value::Placeholder("?".to_string()));
139 }
140
141 _ => {}
143 }
144
145 std::ops::ControlFlow::Continue(())
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 fn run_normalization_test(original: &str, expected: &str) {
154 let result = Query::from_sql(original);
155 assert!(result.is_ok(), "Parsing failed for: {}", original);
156 let queries = result.unwrap();
157 assert_eq!(queries.len(), 1, "Expected one query for: {}", original);
158 let query = &queries[0];
159 assert_eq!(
160 query.normalized_query, expected,
161 "Normalization failed for: {}\nGot: {}\nExpected: {}",
162 original, query.normalized_query, expected
163 );
164 }
165
166 #[test]
167 fn test_parameterized_normalization() {
168 let cases = vec![
169 (
170 "SELECT * FROM users WHERE id = 1",
171 "SELECT * FROM users WHERE id = ?",
172 ),
173 (
174 "SELECT * FROM users WHERE name = 'John' AND city = 'New York'",
175 "SELECT * FROM users WHERE name = ? AND city = ?",
176 ),
177 (
178 "UPDATE users SET name = $1, email = $2 WHERE id = $3",
179 "UPDATE users SET name = ?, email = ? WHERE id = ?",
180 ),
181 (
182 "SELECT * FROM users WHERE id=1",
183 "SELECT * FROM users WHERE id = ?",
184 ),
185 (
186 "SELECT * FROM users WHERE (age > 25 AND name = 'John') OR id IN (1, 2, 3)",
187 "SELECT * FROM users WHERE (age > ? AND name = ?) OR id IN (?, ?, ?)",
188 ),
189 (
190 "INSERT INTO users (name, age) VALUES ('Alice', 30)",
191 "INSERT INTO users (name, age) VALUES (?, ?)",
192 ),
193 (
194 "DELETE FROM users WHERE active = true",
195 "DELETE FROM users WHERE active = ?",
196 ),
197 (
198 "SELECT * FROM orders WHERE price > 100.5",
199 "SELECT * FROM orders WHERE price > ?",
200 ),
201 (
202 "SELECT * FROM logs WHERE message IS NULL",
203 "SELECT * FROM logs WHERE message IS NULL",
204 ),
205 (
206 "SELECT * FROM products WHERE id IN ($1, $2, $3)",
207 "SELECT * FROM products WHERE id IN (?, ?, ?)",
208 ),
209 (
210 "SELECT * FROM users WHERE id=1",
211 "SELECT * FROM users WHERE id = ?",
212 ),
213 (
214 "SELECT * FROM users WHERE name = 'John' AND city = 'New York'",
215 "SELECT * FROM users WHERE name = ? AND city = ?",
216 ),
217 (
218 "SELECT * FROM users WHERE age > 25 AND score < 100.5",
219 "SELECT * FROM users WHERE age > ? AND score < ?",
220 ),
221 (
222 "SELECT * FROM users WHERE active = true",
223 "SELECT * FROM users WHERE active = ?",
224 ),
225 ];
226
227 for (original, expected) in cases {
228 run_normalization_test(original, expected);
229 }
230 }
231}