Skip to main content

chio_data_guards/
sql_parser.rs

1//! Thin wrapper over the `sqlparser` crate that produces a normalized
2//! [`SqlAnalysis`] for the guard to evaluate.
3//!
4//! Goals:
5//!
6//! - Keep [`sqlparser`] types out of the guard surface.  Everything the
7//!   guard consumes is a plain `String`, `Vec<String>`, or an
8//!   [`SqlOperation`].
9//! - Extract the four things the guard cares about: the operation class,
10//!   the referenced tables, the projected columns per table (for `SELECT`
11//!   only), and whether a `WHERE` clause is present.
12//! - Fail-closed on parse errors: returning an [`Err`] causes the guard
13//!   to deny.
14
15use sqlparser::ast::{
16    Delete, FromTable, Insert, ObjectName, ObjectNamePart, Query, Select, SelectItem, SetExpr,
17    Statement, TableFactor, TableObject, Update, UpdateTableFromKind,
18};
19use sqlparser::dialect::{
20    BigQueryDialect, Dialect, GenericDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect,
21    SQLiteDialect, SnowflakeDialect,
22};
23use sqlparser::parser::Parser;
24
25use crate::config::{SqlDialect, SqlOperation};
26
27/// Normalized view of a parsed SQL statement.
28#[derive(Clone, Debug, PartialEq, Eq)]
29pub struct SqlAnalysis {
30    /// Operation class.
31    pub operation: SqlOperation,
32    /// All tables referenced anywhere in the statement.  Names are left as
33    /// the parser produced them (case preserved); case-insensitive compare
34    /// happens in the config layer.
35    pub tables: Vec<String>,
36    /// Projected columns per source table, for `SELECT` queries only.
37    ///
38    /// Each entry is `(table, column)`.  `column == "*"` means the query
39    /// uses a wildcard projection.  The table is the source table as
40    /// resolved from the `FROM` list (aliases are resolved back to the
41    /// underlying table).  When the projection cannot be resolved to a
42    /// specific table, the special sentinel `"?"` is used so the guard
43    /// can conservatively apply column checks across every referenced
44    /// table.
45    pub projected_columns: Vec<(String, String)>,
46    /// Whether the statement contains a `WHERE` clause.  Applies to
47    /// `SELECT`, `UPDATE`, `DELETE`.  `INSERT` always reports `false`.
48    pub has_where: bool,
49    /// Canonicalized WHERE text, lower-cased and whitespace-collapsed, or
50    /// an empty string when absent.  Used against the predicate denylist.
51    pub where_canonical: String,
52}
53
54/// Parse `query` and return a normalized analysis.  Parse errors are
55/// returned as [`Err(String)`] so the guard can build a
56/// [`SqlGuardDenyReason::ParseError`](crate::error::SqlGuardDenyReason::ParseError)
57/// from them.
58pub fn parse(query: &str, dialect: SqlDialect) -> Result<SqlAnalysis, String> {
59    let dialect_obj = dialect_for(dialect);
60    let statements = Parser::parse_sql(dialect_obj.as_ref(), query).map_err(|e| e.to_string())?;
61    // Reject multi-statement queries fail-closed. Analyzing only the first
62    // statement would let a payload like `SELECT ...; DROP TABLE ...` sail
63    // past scope checks because the guard classifies the SELECT while the
64    // destructive DROP hides behind it. Drivers like mysql-connector or
65    // postgres that support `multi_statements` would then execute both.
66    // Operators who legitimately need a batch can split and evaluate each
67    // statement independently.
68    if statements.len() > 1 {
69        return Err(format!(
70            "multi-statement SQL not supported by guard (found {} statements); split into separate evaluations",
71            statements.len()
72        ));
73    }
74    let Some(statement) = statements.into_iter().next() else {
75        return Err("empty statement".to_string());
76    };
77
78    Ok(analyze(&statement))
79}
80
81fn dialect_for(dialect: SqlDialect) -> Box<dyn Dialect + Send + Sync> {
82    match dialect {
83        SqlDialect::Generic => Box::new(GenericDialect {}),
84        SqlDialect::Postgres => Box::new(PostgreSqlDialect {}),
85        SqlDialect::MySql => Box::new(MySqlDialect {}),
86        SqlDialect::Sqlite => Box::new(SQLiteDialect {}),
87        SqlDialect::MsSql => Box::new(MsSqlDialect {}),
88        SqlDialect::Snowflake => Box::new(SnowflakeDialect {}),
89        SqlDialect::BigQuery => Box::new(BigQueryDialect {}),
90    }
91}
92
93fn analyze(stmt: &Statement) -> SqlAnalysis {
94    let mut analysis = SqlAnalysis {
95        operation: classify(stmt),
96        tables: Vec::new(),
97        projected_columns: Vec::new(),
98        has_where: false,
99        where_canonical: String::new(),
100    };
101
102    match stmt {
103        Statement::Query(query) => analyze_query(query, &mut analysis),
104        Statement::Insert(insert) => analyze_insert(insert, &mut analysis),
105        Statement::Update(update) => analyze_update(update, &mut analysis),
106        Statement::Delete(Delete {
107            from, selection, ..
108        }) => {
109            let twj_list = match from {
110                FromTable::WithFromKeyword(list) | FromTable::WithoutKeyword(list) => list,
111            };
112            for twj in twj_list {
113                collect_table_factor(&twj.relation, &mut analysis.tables, &mut Vec::new());
114            }
115            if let Some(expr) = selection {
116                analysis.has_where = true;
117                analysis.where_canonical = canonicalize(&expr_to_string(expr));
118            }
119        }
120        Statement::Truncate(truncate) => {
121            for truncate_target in &truncate.table_names {
122                analysis
123                    .tables
124                    .push(object_name_to_string(&truncate_target.name));
125            }
126        }
127        Statement::CreateTable(ct) => analysis.tables.push(object_name_to_string(&ct.name)),
128        Statement::Drop { names, .. } => {
129            for name in names {
130                analysis.tables.push(object_name_to_string(name));
131            }
132        }
133        Statement::AlterTable(alter) => analysis.tables.push(object_name_to_string(&alter.name)),
134        _ => {}
135    }
136
137    dedupe(&mut analysis.tables);
138    analysis
139}
140
141fn classify(stmt: &Statement) -> SqlOperation {
142    match stmt {
143        Statement::Query(_) => SqlOperation::Select,
144        Statement::Insert(_) => SqlOperation::Insert,
145        Statement::Update(_) => SqlOperation::Update,
146        Statement::Delete(_) | Statement::Truncate(_) => SqlOperation::Delete,
147        Statement::CreateTable(_)
148        | Statement::CreateView { .. }
149        | Statement::CreateIndex(_)
150        | Statement::CreateSchema { .. }
151        | Statement::CreateDatabase { .. }
152        | Statement::CreateFunction { .. }
153        | Statement::CreateProcedure { .. }
154        | Statement::CreateTrigger { .. }
155        | Statement::Drop { .. }
156        | Statement::AlterTable(_)
157        | Statement::AlterIndex { .. }
158        | Statement::AlterView { .. }
159        | Statement::RenameTable(_)
160        | Statement::Comment { .. } => SqlOperation::Ddl,
161        _ => SqlOperation::Other,
162    }
163}
164
165fn analyze_query(query: &Query, analysis: &mut SqlAnalysis) {
166    match query.body.as_ref() {
167        SetExpr::Select(select) => analyze_select(select, analysis),
168        SetExpr::Query(inner) => analyze_query(inner, analysis),
169        SetExpr::SetOperation { left, right, .. } => {
170            analyze_set_expr(left, analysis);
171            analyze_set_expr(right, analysis);
172        }
173        _ => {}
174    }
175    if let Some(with) = &query.with {
176        for cte in &with.cte_tables {
177            analyze_query(&cte.query, analysis);
178        }
179    }
180}
181
182fn analyze_set_expr(expr: &SetExpr, analysis: &mut SqlAnalysis) {
183    match expr {
184        SetExpr::Select(select) => analyze_select(select, analysis),
185        SetExpr::Query(inner) => analyze_query(inner, analysis),
186        SetExpr::SetOperation { left, right, .. } => {
187            analyze_set_expr(left, analysis);
188            analyze_set_expr(right, analysis);
189        }
190        _ => {}
191    }
192}
193
194fn analyze_select(select: &Select, analysis: &mut SqlAnalysis) {
195    if let Some(into) = &select.into {
196        analysis.operation = SqlOperation::Ddl;
197        analysis.tables.push(object_name_to_string(&into.name));
198    }
199
200    // Resolve FROM/JOIN table list and build an alias -> table map so
201    // qualified projections (`u.id`) can be attributed to their source
202    // table.
203    let mut aliases: Vec<(String, String)> = Vec::new();
204    for twj in &select.from {
205        collect_table_factor(&twj.relation, &mut analysis.tables, &mut aliases);
206        for join in &twj.joins {
207            collect_table_factor(&join.relation, &mut analysis.tables, &mut aliases);
208        }
209    }
210
211    // Determine the "primary" source table for unqualified projections.
212    // If there is exactly one source table, use it; otherwise mark "?".
213    let primary_table: String = if analysis.tables.len() == 1 {
214        analysis.tables[0].clone()
215    } else {
216        "?".to_string()
217    };
218
219    for item in &select.projection {
220        match item {
221            SelectItem::Wildcard(_) => {
222                if analysis.tables.is_empty() {
223                    analysis.projected_columns.push(("?".into(), "*".into()));
224                } else {
225                    for tbl in &analysis.tables {
226                        analysis.projected_columns.push((tbl.clone(), "*".into()));
227                    }
228                }
229            }
230            SelectItem::QualifiedWildcard(kind, _) => {
231                let object_name = match kind {
232                    sqlparser::ast::SelectItemQualifiedWildcardKind::ObjectName(name) => name,
233                    sqlparser::ast::SelectItemQualifiedWildcardKind::Expr(_) => {
234                        analysis.projected_columns.push(("?".into(), "*".into()));
235                        continue;
236                    }
237                };
238                let qualifier = object_name_to_string(object_name);
239                let resolved = resolve_alias(&qualifier, &aliases).unwrap_or(qualifier);
240                analysis.projected_columns.push((resolved, "*".into()));
241            }
242            SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
243                let (table, column) = resolve_projected_expr(expr, &primary_table, &aliases);
244                analysis.projected_columns.push((table, column));
245            }
246        }
247    }
248
249    if let Some(expr) = &select.selection {
250        analysis.has_where = true;
251        analysis.where_canonical = canonicalize(&expr_to_string(expr));
252    }
253}
254
255fn expr_to_string(expr: &sqlparser::ast::Expr) -> String {
256    format!("{expr}")
257}
258
259fn collect_table_factor(
260    factor: &TableFactor,
261    tables: &mut Vec<String>,
262    aliases: &mut Vec<(String, String)>,
263) {
264    match factor {
265        TableFactor::Table { name, alias, .. } => {
266            let full = object_name_to_string(name);
267            tables.push(full.clone());
268            if let Some(a) = alias {
269                aliases.push((a.name.value.clone(), full));
270            }
271        }
272        TableFactor::Derived {
273            subquery, alias, ..
274        } => {
275            let mut nested = SqlAnalysis {
276                operation: SqlOperation::Select,
277                tables: Vec::new(),
278                projected_columns: Vec::new(),
279                has_where: false,
280                where_canonical: String::new(),
281            };
282            analyze_query(subquery, &mut nested);
283            for t in nested.tables {
284                tables.push(t.clone());
285                if let Some(a) = alias {
286                    aliases.push((a.name.value.clone(), t));
287                }
288            }
289        }
290        TableFactor::NestedJoin {
291            table_with_joins, ..
292        } => {
293            collect_table_factor(&table_with_joins.relation, tables, aliases);
294            for join in &table_with_joins.joins {
295                collect_table_factor(&join.relation, tables, aliases);
296            }
297        }
298        _ => {}
299    }
300}
301
302fn resolve_projected_expr(
303    expr: &sqlparser::ast::Expr,
304    primary_table: &str,
305    aliases: &[(String, String)],
306) -> (String, String) {
307    use sqlparser::ast::Expr;
308    match expr {
309        Expr::Identifier(ident) => (primary_table.to_string(), ident.value.clone()),
310        Expr::CompoundIdentifier(parts) => {
311            if parts.len() >= 2 {
312                let qualifier = parts[parts.len() - 2].value.clone();
313                let column = parts[parts.len() - 1].value.clone();
314                let resolved = resolve_alias(&qualifier, aliases).unwrap_or(qualifier);
315                (resolved, column)
316            } else if let Some(single) = parts.first() {
317                (primary_table.to_string(), single.value.clone())
318            } else {
319                ("?".into(), "?".into())
320            }
321        }
322        // Any other expression (function call, literal, arithmetic) does
323        // not project a single identified column; we mark it with "?" so
324        // the guard will neither allow nor deny on column grounds.  The
325        // guard falls back to table-allowlist enforcement for these.
326        _ => (primary_table.to_string(), "?".to_string()),
327    }
328}
329
330fn resolve_alias(qualifier: &str, aliases: &[(String, String)]) -> Option<String> {
331    let lower = qualifier.to_ascii_lowercase();
332    aliases
333        .iter()
334        .find(|(a, _)| a.to_ascii_lowercase() == lower)
335        .map(|(_, t)| t.clone())
336}
337
338fn analyze_insert(insert: &Insert, analysis: &mut SqlAnalysis) {
339    match &insert.table {
340        TableObject::TableName(name) => analysis.tables.push(object_name_to_string(name)),
341        TableObject::TableFunction(_) => {}
342    }
343    if let Some(source) = &insert.source {
344        analyze_query(source, analysis);
345    }
346}
347
348fn analyze_update(update: &Update, analysis: &mut SqlAnalysis) {
349    collect_table_factor(
350        &update.table.relation,
351        &mut analysis.tables,
352        &mut Vec::new(),
353    );
354    for join in &update.table.joins {
355        collect_table_factor(&join.relation, &mut analysis.tables, &mut Vec::new());
356    }
357    if let Some(UpdateTableFromKind::BeforeSet(from_list))
358    | Some(UpdateTableFromKind::AfterSet(from_list)) = &update.from
359    {
360        for twj in from_list {
361            collect_table_factor(&twj.relation, &mut analysis.tables, &mut Vec::new());
362        }
363    }
364    if let Some(expr) = &update.selection {
365        analysis.has_where = true;
366        analysis.where_canonical = canonicalize(&expr_to_string(expr));
367    }
368}
369
370fn object_name_to_string(name: &ObjectName) -> String {
371    name.0
372        .iter()
373        .map(|part| match part {
374            ObjectNamePart::Identifier(i) => i.value.clone(),
375            ObjectNamePart::Function(f) => f.name.value.clone(),
376        })
377        .collect::<Vec<_>>()
378        .join(".")
379}
380
381fn canonicalize(raw: &str) -> String {
382    let mut out = String::with_capacity(raw.len());
383    let mut prev_ws = false;
384    for ch in raw.chars() {
385        if ch.is_whitespace() {
386            if !prev_ws {
387                out.push(' ');
388                prev_ws = true;
389            }
390        } else {
391            out.push(ch.to_ascii_lowercase());
392            prev_ws = false;
393        }
394    }
395    out.trim().to_string()
396}
397
398fn dedupe(items: &mut Vec<String>) {
399    let mut seen: Vec<String> = Vec::new();
400    items.retain(|item| {
401        let lower = item.to_ascii_lowercase();
402        if seen.contains(&lower) {
403            false
404        } else {
405            seen.push(lower);
406            true
407        }
408    });
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn parses_simple_select() {
417        let a = parse("SELECT id, name FROM orders", SqlDialect::Generic).expect("parse");
418        assert_eq!(a.operation, SqlOperation::Select);
419        assert_eq!(a.tables, vec!["orders".to_string()]);
420        assert_eq!(
421            a.projected_columns,
422            vec![
423                ("orders".to_string(), "id".to_string()),
424                ("orders".to_string(), "name".to_string()),
425            ]
426        );
427        assert!(!a.has_where);
428    }
429
430    #[test]
431    fn parses_select_star() {
432        let a = parse("SELECT * FROM users", SqlDialect::Generic).expect("parse");
433        assert_eq!(a.operation, SqlOperation::Select);
434        assert_eq!(a.tables, vec!["users".to_string()]);
435        assert_eq!(
436            a.projected_columns,
437            vec![("users".to_string(), "*".to_string())]
438        );
439    }
440
441    #[test]
442    fn classifies_drop_as_ddl() {
443        let a = parse("DROP TABLE orders", SqlDialect::Generic).expect("parse");
444        assert_eq!(a.operation, SqlOperation::Ddl);
445        assert_eq!(a.tables, vec!["orders".to_string()]);
446    }
447
448    #[test]
449    fn classifies_update_with_where() {
450        let a = parse(
451            "UPDATE orders SET total = 0 WHERE id = 1",
452            SqlDialect::Generic,
453        )
454        .expect("parse");
455        assert_eq!(a.operation, SqlOperation::Update);
456        assert!(a.has_where);
457        assert!(a.where_canonical.contains("id = 1"));
458    }
459
460    #[test]
461    fn classifies_delete_without_where() {
462        let a = parse("DELETE FROM orders", SqlDialect::Generic).expect("parse");
463        assert_eq!(a.operation, SqlOperation::Delete);
464        assert!(!a.has_where);
465    }
466
467    #[test]
468    fn resolves_alias_in_projection() {
469        let a = parse(
470            "SELECT o.id FROM orders o JOIN users u ON o.user_id = u.id",
471            SqlDialect::Generic,
472        )
473        .expect("parse");
474        assert_eq!(a.operation, SqlOperation::Select);
475        // orders should be resolved through alias "o"
476        assert!(a
477            .projected_columns
478            .iter()
479            .any(|(t, c)| t == "orders" && c == "id"));
480    }
481
482    #[test]
483    fn parses_postgres_dialect() {
484        let a = parse(
485            "SELECT id FROM orders WHERE created_at > NOW() - INTERVAL '1 day'",
486            SqlDialect::Postgres,
487        )
488        .expect("parse");
489        assert_eq!(a.operation, SqlOperation::Select);
490    }
491
492    #[test]
493    fn parses_mysql_dialect() {
494        let a = parse(
495            "SELECT `id` FROM `orders` WHERE `name` = 'x'",
496            SqlDialect::MySql,
497        )
498        .expect("parse");
499        assert_eq!(a.operation, SqlOperation::Select);
500        assert_eq!(a.tables, vec!["orders".to_string()]);
501    }
502
503    #[test]
504    fn parse_error_is_surfaced() {
505        let err = parse("SELEKT * FRUM", SqlDialect::Generic).expect_err("should fail");
506        assert!(!err.is_empty());
507    }
508
509    #[test]
510    fn canonicalize_normalizes_whitespace_and_case() {
511        assert_eq!(canonicalize("  ID  =  1  "), "id = 1");
512        assert_eq!(canonicalize("A\n\tOR\n1=1"), "a or 1=1");
513    }
514
515    #[test]
516    fn truncate_is_delete() {
517        let a = parse("TRUNCATE TABLE orders", SqlDialect::Generic).expect("parse");
518        assert_eq!(a.operation, SqlOperation::Delete);
519    }
520
521    #[test]
522    fn select_into_is_treated_as_write_ddl() {
523        let a = parse("SELECT id INTO archive FROM orders", SqlDialect::MsSql).expect("parse");
524        assert_eq!(a.operation, SqlOperation::Ddl);
525        assert!(a.tables.contains(&"archive".to_string()));
526        assert!(a.tables.contains(&"orders".to_string()));
527    }
528}