Skip to main content

pmcp_code_mode/
sql.rs

1//! SQL validation for Code Mode.
2//!
3//! Parses SQL statements with [`sqlparser`], classifies the statement type
4//! (`SELECT`/`INSERT`/`UPDATE`/`DELETE`/DDL), and extracts the tables, columns,
5//! and structural metadata that the Cedar policy evaluator needs.
6//!
7//! Gated behind the `sql-code-mode` feature.
8
9use crate::types::{
10    CodeType, Complexity, SecurityAnalysis, SecurityIssue, SecurityIssueType, ValidationError,
11};
12use sqlparser::ast::{
13    AssignmentTarget, Expr, FromTable, GroupByExpr, Join, LimitClause, ObjectName, Query, Select,
14    SelectItem, SetExpr, Statement, TableFactor, TableObject, TableWithJoins,
15};
16use sqlparser::dialect::{Dialect, GenericDialect};
17use sqlparser::parser::Parser;
18use std::collections::HashSet;
19
20/// High-level category of a SQL statement.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SqlStatementType {
23    /// `SELECT`, `SHOW`, `EXPLAIN`, `DESCRIBE`
24    Select,
25    /// `INSERT`
26    Insert,
27    /// `UPDATE`, `MERGE`
28    Update,
29    /// `DELETE`, `TRUNCATE`
30    Delete,
31    /// `CREATE`/`ALTER`/`DROP`/`GRANT`/`REVOKE` (DDL/admin)
32    Ddl,
33    /// Unrecognized or unsupported statement
34    Other,
35}
36
37impl SqlStatementType {
38    /// The canonical uppercase string ("SELECT", "INSERT", etc.) used by
39    /// the Cedar schema and [`UnifiedAction::from_sql`](crate::UnifiedAction::from_sql).
40    pub fn as_str(&self) -> &'static str {
41        match self {
42            Self::Select => "SELECT",
43            Self::Insert => "INSERT",
44            Self::Update => "UPDATE",
45            Self::Delete => "DELETE",
46            Self::Ddl => "DDL",
47            Self::Other => "OTHER",
48        }
49    }
50
51    /// Whether this statement is read-only.
52    pub fn is_read_only(&self) -> bool {
53        matches!(self, Self::Select)
54    }
55
56    /// Whether this statement writes data (INSERT/UPDATE).
57    pub fn is_write(&self) -> bool {
58        matches!(self, Self::Insert | Self::Update)
59    }
60
61    /// Whether this statement deletes data (DELETE/TRUNCATE).
62    pub fn is_delete(&self) -> bool {
63        matches!(self, Self::Delete)
64    }
65
66    /// Whether this statement changes schema or permissions.
67    pub fn is_admin(&self) -> bool {
68        matches!(self, Self::Ddl)
69    }
70}
71
72/// Structural information extracted from a parsed SQL statement.
73#[derive(Debug, Clone)]
74pub struct SqlStatementInfo {
75    /// High-level statement category.
76    pub statement_type: SqlStatementType,
77
78    /// Raw uppercase verb ("SELECT", "INSERT", "CREATE TABLE", etc.) — used
79    /// for explanations. For Cedar entity building use [`Self::statement_type`].
80    pub verb: String,
81
82    /// All tables referenced by name (final path segment if qualified).
83    pub tables: HashSet<String>,
84
85    /// All columns referenced (where determinable). `*` recorded for wildcards.
86    pub columns: HashSet<String>,
87
88    /// Whether the statement has a `WHERE` clause.
89    pub has_where: bool,
90
91    /// Whether the statement has a `LIMIT` clause.
92    pub has_limit: bool,
93
94    /// Whether the statement has an `ORDER BY` clause.
95    pub has_order_by: bool,
96
97    /// Whether the statement includes `GROUP BY` or aggregate functions.
98    pub has_aggregation: bool,
99
100    /// Number of `JOIN` clauses across all FROM items.
101    pub join_count: u32,
102
103    /// Number of subqueries (naive count of nested SELECTs).
104    pub subquery_count: u32,
105
106    /// Row-count estimate: `LIMIT n` when present, otherwise a configurable default.
107    pub estimated_rows: u64,
108
109    /// Raw length of the SQL string (characters).
110    pub sql_length: usize,
111}
112
113/// SQL validator that parses and analyzes SQL statements.
114#[derive(Debug, Clone)]
115pub struct SqlValidator {
116    dialect: DialectBox,
117    default_row_estimate: u64,
118}
119
120impl Default for SqlValidator {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl SqlValidator {
127    /// Create a new SQL validator with the generic ANSI dialect.
128    pub fn new() -> Self {
129        Self {
130            dialect: DialectBox::Generic,
131            default_row_estimate: 1000,
132        }
133    }
134
135    /// Parse SQL and extract statement info.
136    ///
137    /// Returns an error if the SQL fails to parse, is empty, or contains
138    /// multiple statements (SQL Code Mode validates one statement at a time).
139    pub fn validate(&self, sql: &str) -> Result<SqlStatementInfo, ValidationError> {
140        let trimmed = sql.trim();
141        if trimmed.is_empty() {
142            return Err(ValidationError::ParseError {
143                message: "SQL statement is empty".to_string(),
144                line: 1,
145                column: 1,
146            });
147        }
148
149        let statements = Parser::parse_sql(self.dialect.as_dialect(), trimmed).map_err(|e| {
150            ValidationError::ParseError {
151                message: format!("SQL parse error: {}", e),
152                line: 1,
153                column: 1,
154            }
155        })?;
156
157        match statements.len() {
158            0 => Err(ValidationError::ParseError {
159                message: "SQL contains no statements".to_string(),
160                line: 1,
161                column: 1,
162            }),
163            1 => Ok(self.analyze_statement(&statements[0], trimmed)),
164            n => Err(ValidationError::ParseError {
165                message: format!("SQL Code Mode validates one statement at a time; got {}", n),
166                line: 1,
167                column: 1,
168            }),
169        }
170    }
171
172    /// Produce a security analysis for the given statement info.
173    ///
174    /// The issues produced here are warnings only — config-level and
175    /// policy-level authorization are enforced separately in
176    /// [`ValidationPipeline::validate_sql_query`](crate::ValidationPipeline::validate_sql_query).
177    pub fn analyze_security(&self, info: &SqlStatementInfo) -> SecurityAnalysis {
178        let mut issues: Vec<SecurityIssue> = Vec::new();
179
180        // UPDATE/DELETE without WHERE affects every row — classify as UnboundedQuery.
181        if (info.statement_type.is_write() || info.statement_type.is_delete()) && !info.has_where {
182            issues.push(SecurityIssue::new(
183                SecurityIssueType::UnboundedQuery,
184                format!(
185                    "{} statement has no WHERE clause — affects all rows in the table",
186                    info.verb
187                ),
188            ));
189        }
190
191        // Pure SELECT without LIMIT is also unbounded.
192        if info.statement_type.is_read_only() && !info.has_limit {
193            issues.push(SecurityIssue::new(
194                SecurityIssueType::UnboundedQuery,
195                format!(
196                    "{} statement has no LIMIT — result set may be large",
197                    info.verb
198                ),
199            ));
200        }
201
202        // Excessive joins or subqueries — complexity signal.
203        if info.join_count > 5 {
204            issues.push(SecurityIssue::new(
205                SecurityIssueType::HighComplexity,
206                format!(
207                    "Query has {} JOINs, which may be expensive to execute",
208                    info.join_count
209                ),
210            ));
211        }
212        if info.subquery_count > 3 {
213            issues.push(SecurityIssue::new(
214                SecurityIssueType::DeepNesting,
215                format!("Query has {} nested subqueries", info.subquery_count),
216            ));
217        }
218
219        let complexity = estimate_complexity(info);
220
221        SecurityAnalysis {
222            is_read_only: info.statement_type.is_read_only(),
223            tables_accessed: info.tables.clone(),
224            fields_accessed: info.columns.clone(),
225            has_aggregation: info.has_aggregation,
226            has_subqueries: info.subquery_count > 0,
227            estimated_complexity: complexity,
228            potential_issues: issues,
229            estimated_rows: Some(info.estimated_rows),
230        }
231    }
232
233    /// Map parsed statement info to [`CodeType`].
234    pub fn to_code_type(&self, info: &SqlStatementInfo) -> CodeType {
235        if info.statement_type.is_read_only() {
236            CodeType::SqlQuery
237        } else {
238            CodeType::SqlMutation
239        }
240    }
241
242    fn analyze_statement(&self, stmt: &Statement, sql: &str) -> SqlStatementInfo {
243        let mut info = SqlStatementInfo {
244            statement_type: SqlStatementType::Other,
245            verb: verb_for(stmt),
246            tables: HashSet::new(),
247            columns: HashSet::new(),
248            has_where: false,
249            has_limit: false,
250            has_order_by: false,
251            has_aggregation: false,
252            join_count: 0,
253            subquery_count: 0,
254            estimated_rows: self.default_row_estimate,
255            sql_length: sql.len(),
256        };
257
258        match stmt {
259            Statement::Query(query) => {
260                info.statement_type = SqlStatementType::Select;
261                self.analyze_query(query, &mut info);
262            },
263            Statement::Insert(insert) => {
264                info.statement_type = SqlStatementType::Insert;
265                if let TableObject::TableName(name) = &insert.table {
266                    add_object_name(&mut info.tables, name);
267                }
268                for col in &insert.columns {
269                    info.columns.insert(col.value.clone());
270                }
271                if let Some(source) = &insert.source {
272                    self.analyze_query(source, &mut info);
273                }
274            },
275            Statement::Update(update) => {
276                info.statement_type = SqlStatementType::Update;
277                self.analyze_table_with_joins(&update.table, &mut info);
278                for assignment in &update.assignments {
279                    match &assignment.target {
280                        AssignmentTarget::ColumnName(name) => {
281                            add_object_name(&mut info.columns, name);
282                        },
283                        AssignmentTarget::Tuple(names) => {
284                            for n in names {
285                                add_object_name(&mut info.columns, n);
286                            }
287                        },
288                    }
289                    self.analyze_expr(&assignment.value, &mut info);
290                }
291                if let Some(expr) = &update.selection {
292                    info.has_where = true;
293                    self.analyze_expr(expr, &mut info);
294                }
295            },
296            Statement::Delete(delete) => {
297                info.statement_type = SqlStatementType::Delete;
298                match &delete.from {
299                    FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => {
300                        for t in tables {
301                            self.analyze_table_with_joins(t, &mut info);
302                        }
303                    },
304                }
305                // Multi-table delete names
306                for t in &delete.tables {
307                    add_object_name(&mut info.tables, t);
308                }
309                if let Some(expr) = &delete.selection {
310                    info.has_where = true;
311                    self.analyze_expr(expr, &mut info);
312                }
313            },
314            Statement::Truncate(truncate) => {
315                info.statement_type = SqlStatementType::Delete;
316                for tn in &truncate.table_names {
317                    add_object_name(&mut info.tables, &tn.name);
318                }
319            },
320            Statement::CreateTable(create) => {
321                info.statement_type = SqlStatementType::Ddl;
322                add_object_name(&mut info.tables, &create.name);
323            },
324            Statement::AlterTable(alter) => {
325                info.statement_type = SqlStatementType::Ddl;
326                add_object_name(&mut info.tables, &alter.name);
327            },
328            Statement::Drop { .. }
329            | Statement::CreateIndex(_)
330            | Statement::CreateView { .. }
331            | Statement::Grant { .. }
332            | Statement::Revoke { .. } => {
333                info.statement_type = SqlStatementType::Ddl;
334            },
335            _ => {
336                // Unknown statement — leave as Other.
337            },
338        }
339
340        info
341    }
342
343    fn analyze_query(&self, query: &Query, info: &mut SqlStatementInfo) {
344        if query.order_by.is_some() {
345            info.has_order_by = true;
346        }
347        if let Some(limit_clause) = &query.limit_clause {
348            info.has_limit = true;
349            let limit_expr = match limit_clause {
350                LimitClause::LimitOffset { limit, .. } => limit.as_ref(),
351                LimitClause::OffsetCommaLimit { limit, .. } => Some(limit),
352            };
353            if let Some(Expr::Value(v)) = limit_expr {
354                if let sqlparser::ast::Value::Number(n, _) = &v.value {
355                    if let Ok(parsed) = n.parse::<u64>() {
356                        info.estimated_rows = parsed;
357                    }
358                }
359            }
360        }
361
362        self.analyze_set_expr(&query.body, info);
363    }
364
365    fn analyze_set_expr(&self, set_expr: &SetExpr, info: &mut SqlStatementInfo) {
366        match set_expr {
367            SetExpr::Select(select) => self.analyze_select(select, info),
368            SetExpr::Query(inner) => {
369                info.subquery_count += 1;
370                self.analyze_query(inner, info);
371            },
372            SetExpr::SetOperation { left, right, .. } => {
373                self.analyze_set_expr(left, info);
374                self.analyze_set_expr(right, info);
375            },
376            _ => {},
377        }
378    }
379
380    fn analyze_select(&self, select: &Select, info: &mut SqlStatementInfo) {
381        // Projection columns
382        for item in &select.projection {
383            match item {
384                SelectItem::UnnamedExpr(expr) => self.analyze_expr(expr, info),
385                SelectItem::ExprWithAlias { expr, .. } => self.analyze_expr(expr, info),
386                SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
387                    info.columns.insert("*".to_string());
388                },
389            }
390        }
391
392        // FROM tables + joins
393        for table in &select.from {
394            self.analyze_table_with_joins(table, info);
395        }
396
397        // WHERE
398        if let Some(expr) = &select.selection {
399            info.has_where = true;
400            self.analyze_expr(expr, info);
401        }
402
403        // GROUP BY / aggregation
404        if !group_by_is_empty(&select.group_by) {
405            info.has_aggregation = true;
406        }
407    }
408
409    fn analyze_table_with_joins(&self, item: &TableWithJoins, info: &mut SqlStatementInfo) {
410        self.analyze_table_factor(&item.relation, info);
411        for join in &item.joins {
412            info.join_count += 1;
413            self.analyze_join(join, info);
414        }
415    }
416
417    fn analyze_join(&self, join: &Join, info: &mut SqlStatementInfo) {
418        self.analyze_table_factor(&join.relation, info);
419    }
420
421    fn analyze_table_factor(&self, factor: &TableFactor, info: &mut SqlStatementInfo) {
422        match factor {
423            TableFactor::Table { name, .. } => add_object_name(&mut info.tables, name),
424            TableFactor::Derived { subquery, .. } => {
425                info.subquery_count += 1;
426                self.analyze_query(subquery, info);
427            },
428            TableFactor::NestedJoin {
429                table_with_joins, ..
430            } => self.analyze_table_with_joins(table_with_joins, info),
431            _ => {},
432        }
433    }
434
435    fn analyze_expr(&self, expr: &Expr, info: &mut SqlStatementInfo) {
436        match expr {
437            Expr::Identifier(id) => {
438                info.columns.insert(id.value.clone());
439            },
440            Expr::CompoundIdentifier(ids) => {
441                if let Some(last) = ids.last() {
442                    info.columns.insert(last.value.clone());
443                }
444            },
445            Expr::Subquery(q)
446            | Expr::Exists { subquery: q, .. }
447            | Expr::InSubquery { subquery: q, .. } => {
448                info.subquery_count += 1;
449                self.analyze_query(q, info);
450            },
451            Expr::Function(f) => {
452                let name = f.name.to_string().to_uppercase();
453                if matches!(
454                    name.as_str(),
455                    "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "ARRAY_AGG" | "STRING_AGG"
456                ) {
457                    info.has_aggregation = true;
458                }
459            },
460            _ => {},
461        }
462    }
463}
464
465fn estimate_complexity(info: &SqlStatementInfo) -> Complexity {
466    let joins = info.join_count;
467    let subs = info.subquery_count;
468    if joins >= 5 || subs >= 3 {
469        Complexity::High
470    } else if joins >= 2 || subs >= 1 || info.has_aggregation {
471        Complexity::Medium
472    } else {
473        Complexity::Low
474    }
475}
476
477fn group_by_is_empty(group_by: &GroupByExpr) -> bool {
478    match group_by {
479        GroupByExpr::All(_) => true,
480        GroupByExpr::Expressions(exprs, _) => exprs.is_empty(),
481    }
482}
483
484fn add_object_name(out: &mut HashSet<String>, name: &ObjectName) {
485    if let Some(last) = name.0.last() {
486        out.insert(last.to_string());
487    } else {
488        out.insert(name.to_string());
489    }
490}
491
492fn verb_for(stmt: &Statement) -> String {
493    match stmt {
494        Statement::Query(_) => "SELECT".to_string(),
495        Statement::Insert(_) => "INSERT".to_string(),
496        Statement::Update { .. } => "UPDATE".to_string(),
497        Statement::Delete(_) => "DELETE".to_string(),
498        Statement::Truncate { .. } => "TRUNCATE".to_string(),
499        Statement::CreateTable(_) => "CREATE TABLE".to_string(),
500        Statement::AlterTable { .. } => "ALTER TABLE".to_string(),
501        Statement::Drop { .. } => "DROP".to_string(),
502        Statement::CreateIndex(_) => "CREATE INDEX".to_string(),
503        Statement::CreateView { .. } => "CREATE VIEW".to_string(),
504        Statement::Grant { .. } => "GRANT".to_string(),
505        Statement::Revoke { .. } => "REVOKE".to_string(),
506        other => format!("{:?}", other)
507            .split('(')
508            .next()
509            .unwrap_or("OTHER")
510            .to_uppercase(),
511    }
512}
513
514/// Enum wrapper around concrete dialects so `SqlValidator` stays `Clone` and
515/// avoids trait-object gymnastics.
516#[derive(Debug, Clone)]
517enum DialectBox {
518    Generic,
519}
520
521impl DialectBox {
522    fn as_dialect(&self) -> &dyn Dialect {
523        match self {
524            Self::Generic => &GenericDialect {},
525        }
526    }
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532
533    #[test]
534    fn select_simple() {
535        let v = SqlValidator::new();
536        let info = v.validate("SELECT id, name FROM users").unwrap();
537        assert_eq!(info.statement_type, SqlStatementType::Select);
538        assert!(info.tables.contains("users"));
539        assert!(info.columns.contains("id"));
540        assert!(info.columns.contains("name"));
541        assert!(!info.has_where);
542        assert!(!info.has_limit);
543    }
544
545    #[test]
546    fn select_with_where_limit_order() {
547        let v = SqlValidator::new();
548        let info = v
549            .validate("SELECT id FROM users WHERE active = 1 ORDER BY id LIMIT 10")
550            .unwrap();
551        assert!(info.has_where);
552        assert!(info.has_limit);
553        assert!(info.has_order_by);
554        assert_eq!(info.estimated_rows, 10);
555    }
556
557    #[test]
558    fn select_star() {
559        let v = SqlValidator::new();
560        let info = v.validate("SELECT * FROM users").unwrap();
561        assert!(info.columns.contains("*"));
562    }
563
564    #[test]
565    fn select_join_and_subquery() {
566        let v = SqlValidator::new();
567        let info = v
568            .validate(
569                "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id \
570                 WHERE u.id IN (SELECT id FROM admins)",
571            )
572            .unwrap();
573        assert_eq!(info.join_count, 1);
574        assert!(info.subquery_count >= 1);
575        assert!(info.tables.contains("users"));
576        assert!(info.tables.contains("orders"));
577        assert!(info.tables.contains("admins"));
578    }
579
580    #[test]
581    fn insert_extracts_table_and_columns() {
582        let v = SqlValidator::new();
583        let info = v
584            .validate("INSERT INTO users (id, name) VALUES (1, 'Alice')")
585            .unwrap();
586        assert_eq!(info.statement_type, SqlStatementType::Insert);
587        assert!(info.tables.contains("users"));
588        assert!(info.columns.contains("id"));
589        assert!(info.columns.contains("name"));
590    }
591
592    #[test]
593    fn update_without_where_flagged() {
594        let v = SqlValidator::new();
595        let info = v.validate("UPDATE users SET active = 0").unwrap();
596        assert_eq!(info.statement_type, SqlStatementType::Update);
597        assert!(!info.has_where);
598        let sa = v.analyze_security(&info);
599        assert!(sa
600            .potential_issues
601            .iter()
602            .any(|i| i.issue_type == SecurityIssueType::UnboundedQuery));
603    }
604
605    #[test]
606    fn update_with_where() {
607        let v = SqlValidator::new();
608        let info = v
609            .validate("UPDATE users SET active = 0 WHERE id = 1")
610            .unwrap();
611        assert_eq!(info.statement_type, SqlStatementType::Update);
612        assert!(info.has_where);
613        assert!(info.columns.contains("active"));
614    }
615
616    #[test]
617    fn delete_with_where() {
618        let v = SqlValidator::new();
619        let info = v.validate("DELETE FROM users WHERE id = 1").unwrap();
620        assert_eq!(info.statement_type, SqlStatementType::Delete);
621        assert!(info.has_where);
622    }
623
624    #[test]
625    fn ddl_is_admin() {
626        let v = SqlValidator::new();
627        let info = v.validate("CREATE TABLE foo (id INT)").unwrap();
628        assert_eq!(info.statement_type, SqlStatementType::Ddl);
629        assert!(info.statement_type.is_admin());
630    }
631
632    #[test]
633    fn empty_sql_rejected() {
634        let v = SqlValidator::new();
635        assert!(matches!(
636            v.validate(""),
637            Err(ValidationError::ParseError { .. })
638        ));
639        assert!(matches!(
640            v.validate("   "),
641            Err(ValidationError::ParseError { .. })
642        ));
643    }
644
645    #[test]
646    fn syntax_error_rejected() {
647        let v = SqlValidator::new();
648        assert!(matches!(
649            v.validate("SELEC id FRM users"),
650            Err(ValidationError::ParseError { .. })
651        ));
652    }
653
654    #[test]
655    fn multiple_statements_rejected() {
656        let v = SqlValidator::new();
657        assert!(matches!(
658            v.validate("SELECT 1; SELECT 2"),
659            Err(ValidationError::ParseError { .. })
660        ));
661    }
662
663    #[test]
664    fn aggregation_detected() {
665        let v = SqlValidator::new();
666        let info = v.validate("SELECT COUNT(*) FROM users").unwrap();
667        assert!(info.has_aggregation);
668    }
669
670    #[test]
671    fn group_by_detected() {
672        let v = SqlValidator::new();
673        let info = v
674            .validate("SELECT role, COUNT(*) FROM users GROUP BY role")
675            .unwrap();
676        assert!(info.has_aggregation);
677    }
678}