Skip to main content

pmcp_code_mode/
sql.rs

1//! SQL validation for Code Mode.
2//!
3//! Parses SQL statements with [`sqlparser`], classifies the statement type
4//! (`SELECT`/`INSERT`/`UPDATE`/`DELETE`/DDL), and extracts the tables, columns,
5//! and structural metadata that the Cedar policy evaluator needs.
6//!
7//! Gated behind the `sql-code-mode` feature.
8
9use crate::types::{
10    CodeType, Complexity, SecurityAnalysis, SecurityIssue, SecurityIssueType, ValidationError,
11};
12use sqlparser::ast::{
13    AssignmentTarget, Expr, FromTable, GroupByExpr, Join, LimitClause, ObjectName, Query, Select,
14    SelectItem, SetExpr, Statement, TableFactor, TableObject, TableWithJoins,
15};
16use sqlparser::dialect::{Dialect, GenericDialect};
17use sqlparser::parser::Parser;
18use std::collections::HashSet;
19
20/// High-level category of a SQL statement.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SqlStatementType {
23    /// `SELECT`, `SHOW`, `EXPLAIN`, `DESCRIBE`
24    Select,
25    /// `INSERT`
26    Insert,
27    /// `UPDATE`, `MERGE`
28    Update,
29    /// `DELETE`, `TRUNCATE`
30    Delete,
31    /// `CREATE`/`ALTER`/`DROP`/`GRANT`/`REVOKE` (DDL/admin)
32    Ddl,
33    /// Unrecognized or unsupported statement
34    Other,
35}
36
37impl SqlStatementType {
38    /// The canonical uppercase string ("SELECT", "INSERT", etc.) used by
39    /// the Cedar schema and [`UnifiedAction::from_sql`](crate::UnifiedAction::from_sql).
40    pub fn as_str(&self) -> &'static str {
41        match self {
42            Self::Select => "SELECT",
43            Self::Insert => "INSERT",
44            Self::Update => "UPDATE",
45            Self::Delete => "DELETE",
46            Self::Ddl => "DDL",
47            Self::Other => "OTHER",
48        }
49    }
50
51    /// Whether this statement is read-only.
52    pub fn is_read_only(&self) -> bool {
53        matches!(self, Self::Select)
54    }
55
56    /// Whether this statement writes data (INSERT/UPDATE).
57    pub fn is_write(&self) -> bool {
58        matches!(self, Self::Insert | Self::Update)
59    }
60
61    /// Whether this statement deletes data (DELETE/TRUNCATE).
62    pub fn is_delete(&self) -> bool {
63        matches!(self, Self::Delete)
64    }
65
66    /// Whether this statement changes schema or permissions.
67    pub fn is_admin(&self) -> bool {
68        matches!(self, Self::Ddl)
69    }
70}
71
72/// Structural information extracted from a parsed SQL statement.
73#[derive(Debug, Clone)]
74pub struct SqlStatementInfo {
75    /// High-level statement category.
76    pub statement_type: SqlStatementType,
77
78    /// Raw uppercase verb ("SELECT", "INSERT", "CREATE TABLE", etc.) — used
79    /// for explanations. For Cedar entity building use [`Self::statement_type`].
80    pub verb: String,
81
82    /// All tables referenced by name (final path segment if qualified).
83    pub tables: HashSet<String>,
84
85    /// All columns referenced (where determinable). `*` recorded for wildcards.
86    pub columns: HashSet<String>,
87
88    /// Whether the statement has a `WHERE` clause.
89    pub has_where: bool,
90
91    /// Whether the statement has a `LIMIT` clause.
92    pub has_limit: bool,
93
94    /// Whether the statement has an `ORDER BY` clause.
95    pub has_order_by: bool,
96
97    /// Whether the statement includes `GROUP BY` or aggregate functions.
98    pub has_aggregation: bool,
99
100    /// Number of `JOIN` clauses across all FROM items.
101    pub join_count: u32,
102
103    /// Number of subqueries (naive count of nested SELECTs).
104    pub subquery_count: u32,
105
106    /// Row-count estimate: `LIMIT n` when present, otherwise a configurable default.
107    pub estimated_rows: u64,
108
109    /// Raw length of the SQL string (characters).
110    pub sql_length: usize,
111}
112
113/// SQL validator that parses and analyzes SQL statements.
114#[derive(Debug, Clone)]
115pub struct SqlValidator {
116    dialect: DialectBox,
117    default_row_estimate: u64,
118}
119
120impl Default for SqlValidator {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl SqlValidator {
127    /// Create a new SQL validator with the generic ANSI dialect.
128    pub fn new() -> Self {
129        Self {
130            dialect: DialectBox::Generic,
131            default_row_estimate: 1000,
132        }
133    }
134
135    /// Parse SQL and extract statement info.
136    ///
137    /// Returns an error if the SQL fails to parse, is empty, or contains
138    /// multiple statements (SQL Code Mode validates one statement at a time).
139    pub fn validate(&self, sql: &str) -> Result<SqlStatementInfo, ValidationError> {
140        let trimmed = sql.trim();
141        if trimmed.is_empty() {
142            return Err(ValidationError::ParseError {
143                message: "SQL statement is empty".to_string(),
144                line: 1,
145                column: 1,
146            });
147        }
148
149        let statements = Parser::parse_sql(self.dialect.as_dialect(), trimmed).map_err(|e| {
150            ValidationError::ParseError {
151                message: format!("SQL parse error: {}", e),
152                line: 1,
153                column: 1,
154            }
155        })?;
156
157        match statements.len() {
158            0 => Err(ValidationError::ParseError {
159                message: "SQL contains no statements".to_string(),
160                line: 1,
161                column: 1,
162            }),
163            1 => Ok(self.analyze_statement(&statements[0], trimmed)),
164            n => Err(ValidationError::ParseError {
165                message: format!("SQL Code Mode validates one statement at a time; got {}", n),
166                line: 1,
167                column: 1,
168            }),
169        }
170    }
171
172    /// Produce a security analysis for the given statement info.
173    ///
174    /// The issues produced here are warnings only — config-level and
175    /// policy-level authorization are enforced separately in
176    /// [`ValidationPipeline::validate_sql_query`](crate::ValidationPipeline::validate_sql_query).
177    pub fn analyze_security(&self, info: &SqlStatementInfo) -> SecurityAnalysis {
178        let mut issues: Vec<SecurityIssue> = Vec::new();
179
180        // UPDATE/DELETE without WHERE affects every row — classify as UnboundedQuery.
181        if (info.statement_type.is_write() || info.statement_type.is_delete()) && !info.has_where {
182            issues.push(SecurityIssue::new(
183                SecurityIssueType::UnboundedQuery,
184                format!(
185                    "{} statement has no WHERE clause — affects all rows in the table",
186                    info.verb
187                ),
188            ));
189        }
190
191        // Pure SELECT without LIMIT is also unbounded.
192        if info.statement_type.is_read_only() && !info.has_limit {
193            issues.push(SecurityIssue::new(
194                SecurityIssueType::UnboundedQuery,
195                format!(
196                    "{} statement has no LIMIT — result set may be large",
197                    info.verb
198                ),
199            ));
200        }
201
202        // Excessive joins or subqueries — complexity signal.
203        if info.join_count > 5 {
204            issues.push(SecurityIssue::new(
205                SecurityIssueType::HighComplexity,
206                format!(
207                    "Query has {} JOINs, which may be expensive to execute",
208                    info.join_count
209                ),
210            ));
211        }
212        if info.subquery_count > 3 {
213            issues.push(SecurityIssue::new(
214                SecurityIssueType::DeepNesting,
215                format!("Query has {} nested subqueries", info.subquery_count),
216            ));
217        }
218
219        let complexity = estimate_complexity(info);
220
221        SecurityAnalysis {
222            is_read_only: info.statement_type.is_read_only(),
223            tables_accessed: info.tables.clone(),
224            fields_accessed: info.columns.clone(),
225            has_aggregation: info.has_aggregation,
226            has_subqueries: info.subquery_count > 0,
227            estimated_complexity: complexity,
228            potential_issues: issues,
229            estimated_rows: Some(info.estimated_rows),
230        }
231    }
232
233    /// Map parsed statement info to [`CodeType`].
234    pub fn to_code_type(&self, info: &SqlStatementInfo) -> CodeType {
235        if info.statement_type.is_read_only() {
236            CodeType::SqlQuery
237        } else {
238            CodeType::SqlMutation
239        }
240    }
241
242    fn analyze_statement(&self, stmt: &Statement, sql: &str) -> SqlStatementInfo {
243        let mut info = SqlStatementInfo {
244            statement_type: SqlStatementType::Other,
245            verb: verb_for(stmt),
246            tables: HashSet::new(),
247            columns: HashSet::new(),
248            has_where: false,
249            has_limit: false,
250            has_order_by: false,
251            has_aggregation: false,
252            join_count: 0,
253            subquery_count: 0,
254            estimated_rows: self.default_row_estimate,
255            sql_length: sql.len(),
256        };
257
258        match stmt {
259            Statement::Query(query) => {
260                info.statement_type = SqlStatementType::Select;
261                self.analyze_query(query, &mut info);
262            },
263            Statement::Insert(insert) => {
264                info.statement_type = SqlStatementType::Insert;
265                if let TableObject::TableName(name) = &insert.table {
266                    add_object_name(&mut info.tables, name);
267                }
268                for col in &insert.columns {
269                    // sqlparser 0.62 changed insert column items: use Display
270                    // instead of the removed `.value` field on ObjectName/Ident.
271                    info.columns.insert(col.to_string());
272                }
273                if let Some(source) = &insert.source {
274                    self.analyze_query(source, &mut info);
275                }
276            },
277            Statement::Update(update) => {
278                info.statement_type = SqlStatementType::Update;
279                self.analyze_table_with_joins(&update.table, &mut info);
280                for assignment in &update.assignments {
281                    match &assignment.target {
282                        AssignmentTarget::ColumnName(name) => {
283                            add_object_name(&mut info.columns, name);
284                        },
285                        AssignmentTarget::Tuple(names) => {
286                            for n in names {
287                                add_object_name(&mut info.columns, n);
288                            }
289                        },
290                    }
291                    self.analyze_expr(&assignment.value, &mut info);
292                }
293                if let Some(expr) = &update.selection {
294                    info.has_where = true;
295                    self.analyze_expr(expr, &mut info);
296                }
297            },
298            Statement::Delete(delete) => {
299                info.statement_type = SqlStatementType::Delete;
300                match &delete.from {
301                    FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => {
302                        for t in tables {
303                            self.analyze_table_with_joins(t, &mut info);
304                        }
305                    },
306                }
307                // Multi-table delete names
308                for t in &delete.tables {
309                    add_object_name(&mut info.tables, t);
310                }
311                if let Some(expr) = &delete.selection {
312                    info.has_where = true;
313                    self.analyze_expr(expr, &mut info);
314                }
315            },
316            Statement::Truncate(truncate) => {
317                info.statement_type = SqlStatementType::Delete;
318                for tn in &truncate.table_names {
319                    add_object_name(&mut info.tables, &tn.name);
320                }
321            },
322            Statement::CreateTable(create) => {
323                info.statement_type = SqlStatementType::Ddl;
324                add_object_name(&mut info.tables, &create.name);
325            },
326            Statement::AlterTable(alter) => {
327                info.statement_type = SqlStatementType::Ddl;
328                add_object_name(&mut info.tables, &alter.name);
329            },
330            Statement::Drop { .. }
331            | Statement::CreateIndex(_)
332            | Statement::CreateView { .. }
333            | Statement::Grant { .. }
334            | Statement::Revoke { .. } => {
335                info.statement_type = SqlStatementType::Ddl;
336            },
337            _ => {
338                // Unknown statement — leave as Other.
339            },
340        }
341
342        info
343    }
344
345    fn analyze_query(&self, query: &Query, info: &mut SqlStatementInfo) {
346        if query.order_by.is_some() {
347            info.has_order_by = true;
348        }
349        if let Some(limit_clause) = &query.limit_clause {
350            info.has_limit = true;
351            let limit_expr = match limit_clause {
352                LimitClause::LimitOffset { limit, .. } => limit.as_ref(),
353                LimitClause::OffsetCommaLimit { limit, .. } => Some(limit),
354            };
355            if let Some(Expr::Value(v)) = limit_expr {
356                if let sqlparser::ast::Value::Number(n, _) = &v.value {
357                    if let Ok(parsed) = n.parse::<u64>() {
358                        info.estimated_rows = parsed;
359                    }
360                }
361            }
362        }
363
364        self.analyze_set_expr(&query.body, info);
365    }
366
367    fn analyze_set_expr(&self, set_expr: &SetExpr, info: &mut SqlStatementInfo) {
368        match set_expr {
369            SetExpr::Select(select) => self.analyze_select(select, info),
370            SetExpr::Query(inner) => {
371                info.subquery_count += 1;
372                self.analyze_query(inner, info);
373            },
374            SetExpr::SetOperation { left, right, .. } => {
375                self.analyze_set_expr(left, info);
376                self.analyze_set_expr(right, info);
377            },
378            _ => {},
379        }
380    }
381
382    fn analyze_select(&self, select: &Select, info: &mut SqlStatementInfo) {
383        // Projection columns
384        for item in &select.projection {
385            match item {
386                SelectItem::UnnamedExpr(expr) => self.analyze_expr(expr, info),
387                SelectItem::ExprWithAlias { expr, .. } => self.analyze_expr(expr, info),
388                // sqlparser 0.62 added the plural variant for SQL Server's
389                // multi-alias projection syntax — treat like ExprWithAlias.
390                SelectItem::ExprWithAliases { expr, .. } => self.analyze_expr(expr, info),
391                SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
392                    info.columns.insert("*".to_string());
393                },
394            }
395        }
396
397        // FROM tables + joins
398        for table in &select.from {
399            self.analyze_table_with_joins(table, info);
400        }
401
402        // WHERE
403        if let Some(expr) = &select.selection {
404            info.has_where = true;
405            self.analyze_expr(expr, info);
406        }
407
408        // GROUP BY / aggregation
409        if !group_by_is_empty(&select.group_by) {
410            info.has_aggregation = true;
411        }
412    }
413
414    fn analyze_table_with_joins(&self, item: &TableWithJoins, info: &mut SqlStatementInfo) {
415        self.analyze_table_factor(&item.relation, info);
416        for join in &item.joins {
417            info.join_count += 1;
418            self.analyze_join(join, info);
419        }
420    }
421
422    fn analyze_join(&self, join: &Join, info: &mut SqlStatementInfo) {
423        self.analyze_table_factor(&join.relation, info);
424    }
425
426    fn analyze_table_factor(&self, factor: &TableFactor, info: &mut SqlStatementInfo) {
427        match factor {
428            TableFactor::Table { name, .. } => add_object_name(&mut info.tables, name),
429            TableFactor::Derived { subquery, .. } => {
430                info.subquery_count += 1;
431                self.analyze_query(subquery, info);
432            },
433            TableFactor::NestedJoin {
434                table_with_joins, ..
435            } => self.analyze_table_with_joins(table_with_joins, info),
436            _ => {},
437        }
438    }
439
440    fn analyze_expr(&self, expr: &Expr, info: &mut SqlStatementInfo) {
441        match expr {
442            Expr::Identifier(id) => {
443                info.columns.insert(id.value.clone());
444            },
445            Expr::CompoundIdentifier(ids) => {
446                if let Some(last) = ids.last() {
447                    info.columns.insert(last.value.clone());
448                }
449            },
450            Expr::Subquery(q)
451            | Expr::Exists { subquery: q, .. }
452            | Expr::InSubquery { subquery: q, .. } => {
453                info.subquery_count += 1;
454                self.analyze_query(q, info);
455            },
456            Expr::Function(f) => {
457                let name = f.name.to_string().to_uppercase();
458                if matches!(
459                    name.as_str(),
460                    "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "ARRAY_AGG" | "STRING_AGG"
461                ) {
462                    info.has_aggregation = true;
463                }
464            },
465            _ => {},
466        }
467    }
468}
469
470fn estimate_complexity(info: &SqlStatementInfo) -> Complexity {
471    let joins = info.join_count;
472    let subs = info.subquery_count;
473    if joins >= 5 || subs >= 3 {
474        Complexity::High
475    } else if joins >= 2 || subs >= 1 || info.has_aggregation {
476        Complexity::Medium
477    } else {
478        Complexity::Low
479    }
480}
481
482fn group_by_is_empty(group_by: &GroupByExpr) -> bool {
483    match group_by {
484        GroupByExpr::All(_) => true,
485        GroupByExpr::Expressions(exprs, _) => exprs.is_empty(),
486    }
487}
488
489fn add_object_name(out: &mut HashSet<String>, name: &ObjectName) {
490    if let Some(last) = name.0.last() {
491        out.insert(last.to_string());
492    } else {
493        out.insert(name.to_string());
494    }
495}
496
497fn verb_for(stmt: &Statement) -> String {
498    match stmt {
499        Statement::Query(_) => "SELECT".to_string(),
500        Statement::Insert(_) => "INSERT".to_string(),
501        Statement::Update { .. } => "UPDATE".to_string(),
502        Statement::Delete(_) => "DELETE".to_string(),
503        Statement::Truncate { .. } => "TRUNCATE".to_string(),
504        Statement::CreateTable(_) => "CREATE TABLE".to_string(),
505        Statement::AlterTable { .. } => "ALTER TABLE".to_string(),
506        Statement::Drop { .. } => "DROP".to_string(),
507        Statement::CreateIndex(_) => "CREATE INDEX".to_string(),
508        Statement::CreateView { .. } => "CREATE VIEW".to_string(),
509        Statement::Grant { .. } => "GRANT".to_string(),
510        Statement::Revoke { .. } => "REVOKE".to_string(),
511        other => format!("{:?}", other)
512            .split('(')
513            .next()
514            .unwrap_or("OTHER")
515            .to_uppercase(),
516    }
517}
518
519/// Enum wrapper around concrete dialects so `SqlValidator` stays `Clone` and
520/// avoids trait-object gymnastics.
521#[derive(Debug, Clone)]
522enum DialectBox {
523    Generic,
524}
525
526impl DialectBox {
527    fn as_dialect(&self) -> &dyn Dialect {
528        match self {
529            Self::Generic => &GenericDialect {},
530        }
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    #[test]
539    fn select_simple() {
540        let v = SqlValidator::new();
541        let info = v.validate("SELECT id, name FROM users").unwrap();
542        assert_eq!(info.statement_type, SqlStatementType::Select);
543        assert!(info.tables.contains("users"));
544        assert!(info.columns.contains("id"));
545        assert!(info.columns.contains("name"));
546        assert!(!info.has_where);
547        assert!(!info.has_limit);
548    }
549
550    #[test]
551    fn select_with_where_limit_order() {
552        let v = SqlValidator::new();
553        let info = v
554            .validate("SELECT id FROM users WHERE active = 1 ORDER BY id LIMIT 10")
555            .unwrap();
556        assert!(info.has_where);
557        assert!(info.has_limit);
558        assert!(info.has_order_by);
559        assert_eq!(info.estimated_rows, 10);
560    }
561
562    #[test]
563    fn select_star() {
564        let v = SqlValidator::new();
565        let info = v.validate("SELECT * FROM users").unwrap();
566        assert!(info.columns.contains("*"));
567    }
568
569    #[test]
570    fn select_join_and_subquery() {
571        let v = SqlValidator::new();
572        let info = v
573            .validate(
574                "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id \
575                 WHERE u.id IN (SELECT id FROM admins)",
576            )
577            .unwrap();
578        assert_eq!(info.join_count, 1);
579        assert!(info.subquery_count >= 1);
580        assert!(info.tables.contains("users"));
581        assert!(info.tables.contains("orders"));
582        assert!(info.tables.contains("admins"));
583    }
584
585    #[test]
586    fn insert_extracts_table_and_columns() {
587        let v = SqlValidator::new();
588        let info = v
589            .validate("INSERT INTO users (id, name) VALUES (1, 'Alice')")
590            .unwrap();
591        assert_eq!(info.statement_type, SqlStatementType::Insert);
592        assert!(info.tables.contains("users"));
593        assert!(info.columns.contains("id"));
594        assert!(info.columns.contains("name"));
595    }
596
597    #[test]
598    fn update_without_where_flagged() {
599        let v = SqlValidator::new();
600        let info = v.validate("UPDATE users SET active = 0").unwrap();
601        assert_eq!(info.statement_type, SqlStatementType::Update);
602        assert!(!info.has_where);
603        let sa = v.analyze_security(&info);
604        assert!(sa
605            .potential_issues
606            .iter()
607            .any(|i| i.issue_type == SecurityIssueType::UnboundedQuery));
608    }
609
610    #[test]
611    fn update_with_where() {
612        let v = SqlValidator::new();
613        let info = v
614            .validate("UPDATE users SET active = 0 WHERE id = 1")
615            .unwrap();
616        assert_eq!(info.statement_type, SqlStatementType::Update);
617        assert!(info.has_where);
618        assert!(info.columns.contains("active"));
619    }
620
621    #[test]
622    fn delete_with_where() {
623        let v = SqlValidator::new();
624        let info = v.validate("DELETE FROM users WHERE id = 1").unwrap();
625        assert_eq!(info.statement_type, SqlStatementType::Delete);
626        assert!(info.has_where);
627    }
628
629    #[test]
630    fn ddl_is_admin() {
631        let v = SqlValidator::new();
632        let info = v.validate("CREATE TABLE foo (id INT)").unwrap();
633        assert_eq!(info.statement_type, SqlStatementType::Ddl);
634        assert!(info.statement_type.is_admin());
635    }
636
637    #[test]
638    fn empty_sql_rejected() {
639        let v = SqlValidator::new();
640        assert!(matches!(
641            v.validate(""),
642            Err(ValidationError::ParseError { .. })
643        ));
644        assert!(matches!(
645            v.validate("   "),
646            Err(ValidationError::ParseError { .. })
647        ));
648    }
649
650    #[test]
651    fn syntax_error_rejected() {
652        let v = SqlValidator::new();
653        assert!(matches!(
654            v.validate("SELEC id FRM users"),
655            Err(ValidationError::ParseError { .. })
656        ));
657    }
658
659    #[test]
660    fn multiple_statements_rejected() {
661        let v = SqlValidator::new();
662        assert!(matches!(
663            v.validate("SELECT 1; SELECT 2"),
664            Err(ValidationError::ParseError { .. })
665        ));
666    }
667
668    #[test]
669    fn aggregation_detected() {
670        let v = SqlValidator::new();
671        let info = v.validate("SELECT COUNT(*) FROM users").unwrap();
672        assert!(info.has_aggregation);
673    }
674
675    #[test]
676    fn group_by_detected() {
677        let v = SqlValidator::new();
678        let info = v
679            .validate("SELECT role, COUNT(*) FROM users GROUP BY role")
680            .unwrap();
681        assert!(info.has_aggregation);
682    }
683}