Skip to main content

pg_logstats/sql/
query.rs

1use serde::{Deserialize, Serialize};
2use sqlparser::{
3    ast::{Expr, Value, VisitMut, VisitorMut},
4    dialect::PostgreSqlDialect,
5    parser::Parser,
6};
7
8use crate::PgLogstatsError;
9
10/// Query type classification
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum QueryType {
13    /// SELECT queries
14    Select,
15    /// INSERT queries
16    Insert,
17    /// UPDATE queries
18    Update,
19    /// DELETE queries
20    Delete,
21    /// Data Definition Language (CREATE, DROP, ALTER, etc.)
22    DDL,
23    /// Other queries (BEGIN, COMMIT, ROLLBACK, etc.)
24    Other,
25}
26
27impl std::fmt::Display for QueryType {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            QueryType::Select => write!(f, "SELECT"),
31            QueryType::Insert => write!(f, "INSERT"),
32            QueryType::Update => write!(f, "UPDATE"),
33            QueryType::Delete => write!(f, "DELETE"),
34            QueryType::DDL => write!(f, "DDL"),
35            QueryType::Other => write!(f, "OTHER"),
36        }
37    }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Query {
42    pub sql: String,
43    pub query_type: QueryType,
44    pub normalized_query: String,
45}
46
47impl Query {
48    /// Parse SQL and return a vector of Query, one for each statement
49    pub fn from_sql(sql: &str) -> Result<Vec<Query>, PgLogstatsError> {
50        let dialect = PostgreSqlDialect {};
51        let ast = Parser::parse_sql(&dialect, sql).map_err(|e| PgLogstatsError::Parse {
52            message: format!("Failed to parse SQL: {}", e),
53            line_number: None,
54            line_content: Some(sql.to_string()),
55        })?;
56
57        let mut queries = Vec::new();
58        for stmt in &ast {
59            let query_type = Query::query_type_from_statement(stmt);
60            let normalized_query = Query::normalize_query(std::slice::from_ref(stmt))
61                .unwrap_or_else(|_| stmt.to_string());
62            queries.push(Query {
63                sql: stmt.to_string(),
64                query_type,
65                normalized_query,
66            });
67        }
68        Ok(queries)
69    }
70
71    fn query_type_from_statement(stmt: &sqlparser::ast::Statement) -> QueryType {
72        use sqlparser::ast::Statement::*;
73        match stmt {
74            Query(_) => QueryType::Select,
75            Insert { .. } => QueryType::Insert,
76            Update { .. } => QueryType::Update,
77            Delete { .. } => QueryType::Delete,
78            CreateTable { .. }
79            | CreateView { .. }
80            | CreateIndex { .. }
81            | CreateSchema { .. }
82            | CreateDatabase { .. }
83            | Drop { .. }
84            | AlterTable { .. }
85            | Truncate { .. } => QueryType::DDL,
86            _ => QueryType::Other,
87        }
88    }
89
90    /// Normalize SQL query using an existing AST
91    fn normalize_query(ast: &[sqlparser::ast::Statement]) -> Result<String, PgLogstatsError> {
92        if ast.is_empty() {
93            return Ok("".to_string());
94        }
95
96        // Clone AST to mutate
97        let mut ast = ast.to_owned();
98
99        let mut normalizer = LiteralNormalizer;
100        for stmt in &mut ast {
101            let _ = stmt.visit(&mut normalizer);
102        }
103
104        let normalized_sql = ast
105            .iter()
106            .map(|stmt| stmt.to_string())
107            .collect::<Vec<_>>()
108            .join("; ");
109
110        Ok(normalized_sql)
111    }
112}
113
114/// Visitor that replaces literal values with placeholders
115struct LiteralNormalizer;
116
117impl VisitorMut for LiteralNormalizer {
118    type Break = ();
119
120    fn pre_visit_expr(&mut self, _expr: &mut Expr) -> std::ops::ControlFlow<Self::Break> {
121        // Always continue traversal to visit nested expressions
122        std::ops::ControlFlow::Continue(())
123    }
124
125    fn post_visit_expr(&mut self, expr: &mut Expr) -> std::ops::ControlFlow<Self::Break> {
126        match expr {
127            // Replace literal constants with placeholders
128            Expr::Value(Value::Number(_, _))
129            | Expr::Value(Value::SingleQuotedString(_))
130            | Expr::Value(Value::DoubleQuotedString(_))
131            | Expr::Value(Value::Boolean(_))
132            | Expr::Value(Value::Null) => {
133                *expr = Expr::Value(Value::Placeholder("?".to_string()));
134            }
135
136            // Normalize existing parameters to standard format
137            Expr::Value(Value::Placeholder(_)) => {
138                *expr = Expr::Value(Value::Placeholder("?".to_string()));
139            }
140
141            // Continue traversing for all other expressions
142            _ => {}
143        }
144
145        std::ops::ControlFlow::Continue(())
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    fn run_normalization_test(original: &str, expected: &str) {
154        let result = Query::from_sql(original);
155        assert!(result.is_ok(), "Parsing failed for: {}", original);
156        let queries = result.unwrap();
157        assert_eq!(queries.len(), 1, "Expected one query for: {}", original);
158        let query = &queries[0];
159        assert_eq!(
160            query.normalized_query, expected,
161            "Normalization failed for: {}\nGot: {}\nExpected: {}",
162            original, query.normalized_query, expected
163        );
164    }
165
166    #[test]
167    fn test_parameterized_normalization() {
168        let cases = vec![
169            (
170                "SELECT * FROM users WHERE id = 1",
171                "SELECT * FROM users WHERE id = ?",
172            ),
173            (
174                "SELECT * FROM users WHERE name = 'John' AND city = 'New York'",
175                "SELECT * FROM users WHERE name = ? AND city = ?",
176            ),
177            (
178                "UPDATE users SET name = $1, email = $2 WHERE id = $3",
179                "UPDATE users SET name = ?, email = ? WHERE id = ?",
180            ),
181            (
182                "SELECT   *   FROM    users   WHERE   id=1",
183                "SELECT * FROM users WHERE id = ?",
184            ),
185            (
186                "SELECT * FROM users WHERE (age > 25 AND name = 'John') OR id IN (1, 2, 3)",
187                "SELECT * FROM users WHERE (age > ? AND name = ?) OR id IN (?, ?, ?)",
188            ),
189            (
190                "INSERT INTO users (name, age) VALUES ('Alice', 30)",
191                "INSERT INTO users (name, age) VALUES (?, ?)",
192            ),
193            (
194                "DELETE FROM users WHERE active = true",
195                "DELETE FROM users WHERE active = ?",
196            ),
197            (
198                "SELECT * FROM orders WHERE price > 100.5",
199                "SELECT * FROM orders WHERE price > ?",
200            ),
201            (
202                "SELECT * FROM logs WHERE message IS NULL",
203                "SELECT * FROM logs WHERE message IS NULL",
204            ),
205            (
206                "SELECT * FROM products WHERE id IN ($1, $2, $3)",
207                "SELECT * FROM products WHERE id IN (?, ?, ?)",
208            ),
209            (
210                "SELECT   *   FROM    users   WHERE   id=1",
211                "SELECT * FROM users WHERE id = ?",
212            ),
213            (
214                "SELECT * FROM users WHERE name = 'John' AND city = 'New York'",
215                "SELECT * FROM users WHERE name = ? AND city = ?",
216            ),
217            (
218                "SELECT * FROM users WHERE age > 25 AND score < 100.5",
219                "SELECT * FROM users WHERE age > ? AND score < ?",
220            ),
221            (
222                "SELECT * FROM users WHERE active = true",
223                "SELECT * FROM users WHERE active = ?",
224            ),
225        ];
226
227        for (original, expected) in cases {
228            run_normalization_test(original, expected);
229        }
230    }
231}