1use sqlparser::ast::{
16 Delete, FromTable, Insert, ObjectName, ObjectNamePart, Query, Select, SelectItem, SetExpr,
17 Statement, TableFactor, TableObject, Update, UpdateTableFromKind,
18};
19use sqlparser::dialect::{
20 BigQueryDialect, Dialect, GenericDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect,
21 SQLiteDialect, SnowflakeDialect,
22};
23use sqlparser::parser::Parser;
24
25use crate::config::{SqlDialect, SqlOperation};
26
27#[derive(Clone, Debug, PartialEq, Eq)]
29pub struct SqlAnalysis {
30 pub operation: SqlOperation,
32 pub tables: Vec<String>,
36 pub projected_columns: Vec<(String, String)>,
46 pub has_where: bool,
49 pub where_canonical: String,
52}
53
54pub fn parse(query: &str, dialect: SqlDialect) -> Result<SqlAnalysis, String> {
59 let dialect_obj = dialect_for(dialect);
60 let statements = Parser::parse_sql(dialect_obj.as_ref(), query).map_err(|e| e.to_string())?;
61 if statements.len() > 1 {
69 return Err(format!(
70 "multi-statement SQL not supported by guard (found {} statements); split into separate evaluations",
71 statements.len()
72 ));
73 }
74 let Some(statement) = statements.into_iter().next() else {
75 return Err("empty statement".to_string());
76 };
77
78 Ok(analyze(&statement))
79}
80
81fn dialect_for(dialect: SqlDialect) -> Box<dyn Dialect + Send + Sync> {
82 match dialect {
83 SqlDialect::Generic => Box::new(GenericDialect {}),
84 SqlDialect::Postgres => Box::new(PostgreSqlDialect {}),
85 SqlDialect::MySql => Box::new(MySqlDialect {}),
86 SqlDialect::Sqlite => Box::new(SQLiteDialect {}),
87 SqlDialect::MsSql => Box::new(MsSqlDialect {}),
88 SqlDialect::Snowflake => Box::new(SnowflakeDialect {}),
89 SqlDialect::BigQuery => Box::new(BigQueryDialect {}),
90 }
91}
92
93fn analyze(stmt: &Statement) -> SqlAnalysis {
94 let mut analysis = SqlAnalysis {
95 operation: classify(stmt),
96 tables: Vec::new(),
97 projected_columns: Vec::new(),
98 has_where: false,
99 where_canonical: String::new(),
100 };
101
102 match stmt {
103 Statement::Query(query) => analyze_query(query, &mut analysis),
104 Statement::Insert(insert) => analyze_insert(insert, &mut analysis),
105 Statement::Update(update) => analyze_update(update, &mut analysis),
106 Statement::Delete(Delete {
107 from, selection, ..
108 }) => {
109 let twj_list = match from {
110 FromTable::WithFromKeyword(list) | FromTable::WithoutKeyword(list) => list,
111 };
112 for twj in twj_list {
113 collect_table_factor(&twj.relation, &mut analysis.tables, &mut Vec::new());
114 }
115 if let Some(expr) = selection {
116 analysis.has_where = true;
117 analysis.where_canonical = canonicalize(&expr_to_string(expr));
118 }
119 }
120 Statement::Truncate(truncate) => {
121 for truncate_target in &truncate.table_names {
122 analysis
123 .tables
124 .push(object_name_to_string(&truncate_target.name));
125 }
126 }
127 Statement::CreateTable(ct) => analysis.tables.push(object_name_to_string(&ct.name)),
128 Statement::Drop { names, .. } => {
129 for name in names {
130 analysis.tables.push(object_name_to_string(name));
131 }
132 }
133 Statement::AlterTable(alter) => analysis.tables.push(object_name_to_string(&alter.name)),
134 _ => {}
135 }
136
137 dedupe(&mut analysis.tables);
138 analysis
139}
140
141fn classify(stmt: &Statement) -> SqlOperation {
142 match stmt {
143 Statement::Query(_) => SqlOperation::Select,
144 Statement::Insert(_) => SqlOperation::Insert,
145 Statement::Update(_) => SqlOperation::Update,
146 Statement::Delete(_) | Statement::Truncate(_) => SqlOperation::Delete,
147 Statement::CreateTable(_)
148 | Statement::CreateView { .. }
149 | Statement::CreateIndex(_)
150 | Statement::CreateSchema { .. }
151 | Statement::CreateDatabase { .. }
152 | Statement::CreateFunction { .. }
153 | Statement::CreateProcedure { .. }
154 | Statement::CreateTrigger { .. }
155 | Statement::Drop { .. }
156 | Statement::AlterTable(_)
157 | Statement::AlterIndex { .. }
158 | Statement::AlterView { .. }
159 | Statement::RenameTable(_)
160 | Statement::Comment { .. } => SqlOperation::Ddl,
161 _ => SqlOperation::Other,
162 }
163}
164
165fn analyze_query(query: &Query, analysis: &mut SqlAnalysis) {
166 match query.body.as_ref() {
167 SetExpr::Select(select) => analyze_select(select, analysis),
168 SetExpr::Query(inner) => analyze_query(inner, analysis),
169 SetExpr::SetOperation { left, right, .. } => {
170 analyze_set_expr(left, analysis);
171 analyze_set_expr(right, analysis);
172 }
173 _ => {}
174 }
175 if let Some(with) = &query.with {
176 for cte in &with.cte_tables {
177 analyze_query(&cte.query, analysis);
178 }
179 }
180}
181
182fn analyze_set_expr(expr: &SetExpr, analysis: &mut SqlAnalysis) {
183 match expr {
184 SetExpr::Select(select) => analyze_select(select, analysis),
185 SetExpr::Query(inner) => analyze_query(inner, analysis),
186 SetExpr::SetOperation { left, right, .. } => {
187 analyze_set_expr(left, analysis);
188 analyze_set_expr(right, analysis);
189 }
190 _ => {}
191 }
192}
193
194fn analyze_select(select: &Select, analysis: &mut SqlAnalysis) {
195 if let Some(into) = &select.into {
196 analysis.operation = SqlOperation::Ddl;
197 analysis.tables.push(object_name_to_string(&into.name));
198 }
199
200 let mut aliases: Vec<(String, String)> = Vec::new();
204 for twj in &select.from {
205 collect_table_factor(&twj.relation, &mut analysis.tables, &mut aliases);
206 for join in &twj.joins {
207 collect_table_factor(&join.relation, &mut analysis.tables, &mut aliases);
208 }
209 }
210
211 let primary_table: String = if analysis.tables.len() == 1 {
214 analysis.tables[0].clone()
215 } else {
216 "?".to_string()
217 };
218
219 for item in &select.projection {
220 match item {
221 SelectItem::Wildcard(_) => {
222 if analysis.tables.is_empty() {
223 analysis.projected_columns.push(("?".into(), "*".into()));
224 } else {
225 for tbl in &analysis.tables {
226 analysis.projected_columns.push((tbl.clone(), "*".into()));
227 }
228 }
229 }
230 SelectItem::QualifiedWildcard(kind, _) => {
231 let object_name = match kind {
232 sqlparser::ast::SelectItemQualifiedWildcardKind::ObjectName(name) => name,
233 sqlparser::ast::SelectItemQualifiedWildcardKind::Expr(_) => {
234 analysis.projected_columns.push(("?".into(), "*".into()));
235 continue;
236 }
237 };
238 let qualifier = object_name_to_string(object_name);
239 let resolved = resolve_alias(&qualifier, &aliases).unwrap_or(qualifier);
240 analysis.projected_columns.push((resolved, "*".into()));
241 }
242 SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
243 let (table, column) = resolve_projected_expr(expr, &primary_table, &aliases);
244 analysis.projected_columns.push((table, column));
245 }
246 }
247 }
248
249 if let Some(expr) = &select.selection {
250 analysis.has_where = true;
251 analysis.where_canonical = canonicalize(&expr_to_string(expr));
252 }
253}
254
255fn expr_to_string(expr: &sqlparser::ast::Expr) -> String {
256 format!("{expr}")
257}
258
259fn collect_table_factor(
260 factor: &TableFactor,
261 tables: &mut Vec<String>,
262 aliases: &mut Vec<(String, String)>,
263) {
264 match factor {
265 TableFactor::Table { name, alias, .. } => {
266 let full = object_name_to_string(name);
267 tables.push(full.clone());
268 if let Some(a) = alias {
269 aliases.push((a.name.value.clone(), full));
270 }
271 }
272 TableFactor::Derived {
273 subquery, alias, ..
274 } => {
275 let mut nested = SqlAnalysis {
276 operation: SqlOperation::Select,
277 tables: Vec::new(),
278 projected_columns: Vec::new(),
279 has_where: false,
280 where_canonical: String::new(),
281 };
282 analyze_query(subquery, &mut nested);
283 for t in nested.tables {
284 tables.push(t.clone());
285 if let Some(a) = alias {
286 aliases.push((a.name.value.clone(), t));
287 }
288 }
289 }
290 TableFactor::NestedJoin {
291 table_with_joins, ..
292 } => {
293 collect_table_factor(&table_with_joins.relation, tables, aliases);
294 for join in &table_with_joins.joins {
295 collect_table_factor(&join.relation, tables, aliases);
296 }
297 }
298 _ => {}
299 }
300}
301
302fn resolve_projected_expr(
303 expr: &sqlparser::ast::Expr,
304 primary_table: &str,
305 aliases: &[(String, String)],
306) -> (String, String) {
307 use sqlparser::ast::Expr;
308 match expr {
309 Expr::Identifier(ident) => (primary_table.to_string(), ident.value.clone()),
310 Expr::CompoundIdentifier(parts) => {
311 if parts.len() >= 2 {
312 let qualifier = parts[parts.len() - 2].value.clone();
313 let column = parts[parts.len() - 1].value.clone();
314 let resolved = resolve_alias(&qualifier, aliases).unwrap_or(qualifier);
315 (resolved, column)
316 } else if let Some(single) = parts.first() {
317 (primary_table.to_string(), single.value.clone())
318 } else {
319 ("?".into(), "?".into())
320 }
321 }
322 _ => (primary_table.to_string(), "?".to_string()),
327 }
328}
329
330fn resolve_alias(qualifier: &str, aliases: &[(String, String)]) -> Option<String> {
331 let lower = qualifier.to_ascii_lowercase();
332 aliases
333 .iter()
334 .find(|(a, _)| a.to_ascii_lowercase() == lower)
335 .map(|(_, t)| t.clone())
336}
337
338fn analyze_insert(insert: &Insert, analysis: &mut SqlAnalysis) {
339 match &insert.table {
340 TableObject::TableName(name) => analysis.tables.push(object_name_to_string(name)),
341 TableObject::TableFunction(_) => {}
342 }
343 if let Some(source) = &insert.source {
344 analyze_query(source, analysis);
345 }
346}
347
348fn analyze_update(update: &Update, analysis: &mut SqlAnalysis) {
349 collect_table_factor(
350 &update.table.relation,
351 &mut analysis.tables,
352 &mut Vec::new(),
353 );
354 for join in &update.table.joins {
355 collect_table_factor(&join.relation, &mut analysis.tables, &mut Vec::new());
356 }
357 if let Some(UpdateTableFromKind::BeforeSet(from_list))
358 | Some(UpdateTableFromKind::AfterSet(from_list)) = &update.from
359 {
360 for twj in from_list {
361 collect_table_factor(&twj.relation, &mut analysis.tables, &mut Vec::new());
362 }
363 }
364 if let Some(expr) = &update.selection {
365 analysis.has_where = true;
366 analysis.where_canonical = canonicalize(&expr_to_string(expr));
367 }
368}
369
370fn object_name_to_string(name: &ObjectName) -> String {
371 name.0
372 .iter()
373 .map(|part| match part {
374 ObjectNamePart::Identifier(i) => i.value.clone(),
375 ObjectNamePart::Function(f) => f.name.value.clone(),
376 })
377 .collect::<Vec<_>>()
378 .join(".")
379}
380
381fn canonicalize(raw: &str) -> String {
382 let mut out = String::with_capacity(raw.len());
383 let mut prev_ws = false;
384 for ch in raw.chars() {
385 if ch.is_whitespace() {
386 if !prev_ws {
387 out.push(' ');
388 prev_ws = true;
389 }
390 } else {
391 out.push(ch.to_ascii_lowercase());
392 prev_ws = false;
393 }
394 }
395 out.trim().to_string()
396}
397
398fn dedupe(items: &mut Vec<String>) {
399 let mut seen: Vec<String> = Vec::new();
400 items.retain(|item| {
401 let lower = item.to_ascii_lowercase();
402 if seen.contains(&lower) {
403 false
404 } else {
405 seen.push(lower);
406 true
407 }
408 });
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn parses_simple_select() {
417 let a = parse("SELECT id, name FROM orders", SqlDialect::Generic).expect("parse");
418 assert_eq!(a.operation, SqlOperation::Select);
419 assert_eq!(a.tables, vec!["orders".to_string()]);
420 assert_eq!(
421 a.projected_columns,
422 vec![
423 ("orders".to_string(), "id".to_string()),
424 ("orders".to_string(), "name".to_string()),
425 ]
426 );
427 assert!(!a.has_where);
428 }
429
430 #[test]
431 fn parses_select_star() {
432 let a = parse("SELECT * FROM users", SqlDialect::Generic).expect("parse");
433 assert_eq!(a.operation, SqlOperation::Select);
434 assert_eq!(a.tables, vec!["users".to_string()]);
435 assert_eq!(
436 a.projected_columns,
437 vec![("users".to_string(), "*".to_string())]
438 );
439 }
440
441 #[test]
442 fn classifies_drop_as_ddl() {
443 let a = parse("DROP TABLE orders", SqlDialect::Generic).expect("parse");
444 assert_eq!(a.operation, SqlOperation::Ddl);
445 assert_eq!(a.tables, vec!["orders".to_string()]);
446 }
447
448 #[test]
449 fn classifies_update_with_where() {
450 let a = parse(
451 "UPDATE orders SET total = 0 WHERE id = 1",
452 SqlDialect::Generic,
453 )
454 .expect("parse");
455 assert_eq!(a.operation, SqlOperation::Update);
456 assert!(a.has_where);
457 assert!(a.where_canonical.contains("id = 1"));
458 }
459
460 #[test]
461 fn classifies_delete_without_where() {
462 let a = parse("DELETE FROM orders", SqlDialect::Generic).expect("parse");
463 assert_eq!(a.operation, SqlOperation::Delete);
464 assert!(!a.has_where);
465 }
466
467 #[test]
468 fn resolves_alias_in_projection() {
469 let a = parse(
470 "SELECT o.id FROM orders o JOIN users u ON o.user_id = u.id",
471 SqlDialect::Generic,
472 )
473 .expect("parse");
474 assert_eq!(a.operation, SqlOperation::Select);
475 assert!(a
477 .projected_columns
478 .iter()
479 .any(|(t, c)| t == "orders" && c == "id"));
480 }
481
482 #[test]
483 fn parses_postgres_dialect() {
484 let a = parse(
485 "SELECT id FROM orders WHERE created_at > NOW() - INTERVAL '1 day'",
486 SqlDialect::Postgres,
487 )
488 .expect("parse");
489 assert_eq!(a.operation, SqlOperation::Select);
490 }
491
492 #[test]
493 fn parses_mysql_dialect() {
494 let a = parse(
495 "SELECT `id` FROM `orders` WHERE `name` = 'x'",
496 SqlDialect::MySql,
497 )
498 .expect("parse");
499 assert_eq!(a.operation, SqlOperation::Select);
500 assert_eq!(a.tables, vec!["orders".to_string()]);
501 }
502
503 #[test]
504 fn parse_error_is_surfaced() {
505 let err = parse("SELEKT * FRUM", SqlDialect::Generic).expect_err("should fail");
506 assert!(!err.is_empty());
507 }
508
509 #[test]
510 fn canonicalize_normalizes_whitespace_and_case() {
511 assert_eq!(canonicalize(" ID = 1 "), "id = 1");
512 assert_eq!(canonicalize("A\n\tOR\n1=1"), "a or 1=1");
513 }
514
515 #[test]
516 fn truncate_is_delete() {
517 let a = parse("TRUNCATE TABLE orders", SqlDialect::Generic).expect("parse");
518 assert_eq!(a.operation, SqlOperation::Delete);
519 }
520
521 #[test]
522 fn select_into_is_treated_as_write_ddl() {
523 let a = parse("SELECT id INTO archive FROM orders", SqlDialect::MsSql).expect("parse");
524 assert_eq!(a.operation, SqlOperation::Ddl);
525 assert!(a.tables.contains(&"archive".to_string()));
526 assert!(a.tables.contains(&"orders".to_string()));
527 }
528}