Skip to main content

qail_core/
access.rs

1//! Native vertical access policy checks.
2//!
3//! Row-level security decides which rows a subject can see. This module covers
4//! the vertical layer: which operations and columns a subject may use before a
5//! driver sends the AST to a backend.
6
7use std::collections::{BTreeMap, BTreeSet};
8
9use crate::ast::{
10    CageKind, Condition, ConflictAction, Expr, MergeAction, MergeSource, Qail, Value,
11};
12
13mod columns;
14mod config;
15mod error;
16mod ident;
17mod model;
18mod operations;
19
20use columns::{
21    check_named_read_column, check_projection_rule, check_qualified_read_column, create_columns,
22    expr_projects_all_columns, projection_restricted_action, update_columns,
23};
24use ident::{normalize_column_name, normalize_table_ref, target_refs_for_command};
25
26pub use error::{AccessError, AccessErrorKind, AccessPolicyLoadError};
27pub use model::{AccessContext, AccessDecision, AccessOperation, ColumnRule, TableAccessPolicy};
28pub use operations::required_operations_for_command;
29
30/// Complete access policy set.
31#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
32pub struct AccessPolicy {
33    /// Default decision when no exact or wildcard table policy matches.
34    pub default_decision: AccessDecision,
35    /// Table policies by table name. `"*"` is a wildcard fallback.
36    #[serde(default)]
37    pub tables: BTreeMap<String, TableAccessPolicy>,
38}
39
40impl AccessPolicy {
41    /// Check whether a command is allowed for the supplied context.
42    pub fn check_command(&self, ctx: &AccessContext, cmd: &Qail) -> Result<(), AccessError> {
43        self.check_command_inner(ctx, cmd)
44    }
45
46    fn check_command_inner(&self, ctx: &AccessContext, cmd: &Qail) -> Result<(), AccessError> {
47        if ctx.bypasses_access() {
48            return Ok(());
49        }
50
51        for cte in &cmd.ctes {
52            self.check_command_inner(ctx, &cte.base_query)?;
53            if let Some(recursive_query) = &cte.recursive_query {
54                self.check_command_inner(ctx, recursive_query)?;
55            }
56        }
57        for (_, set_query) in &cmd.set_ops {
58            self.check_command_inner(ctx, set_query)?;
59        }
60        if let Some(source_query) = &cmd.source_query {
61            self.check_command_inner(ctx, source_query)?;
62        }
63        if let Some(merge) = &cmd.merge {
64            match &merge.source {
65                MergeSource::Query { query, .. } => self.check_command_inner(ctx, query)?,
66                MergeSource::Table { name, .. } => self.check_merge_table_source(ctx, name)?,
67            }
68        }
69
70        let table = normalize_table_ref(&cmd.table);
71        if table.is_empty() {
72            return Err(AccessError::new(
73                String::new(),
74                None,
75                AccessErrorKind::EmptyTable,
76            ));
77        }
78
79        self.check_embedded_queries(ctx, cmd)?;
80        self.check_condition_read_columns(&table, cmd)?;
81
82        let cte_names: BTreeSet<String> = cmd
83            .ctes
84            .iter()
85            .map(|cte| normalize_table_ref(&cte.name))
86            .collect();
87        self.check_join_read_access(ctx, cmd, &cte_names)?;
88        self.check_auxiliary_read_access(ctx, cmd, &cte_names)?;
89
90        let required_ops = required_operations_for_command(cmd).ok_or_else(|| {
91            AccessError::new(
92                table.clone(),
93                None,
94                AccessErrorKind::UnsupportedAction(cmd.action),
95            )
96        })?;
97        if cte_names.contains(&table) {
98            if required_ops.iter().all(|op| *op == AccessOperation::Read) {
99                return Ok(());
100            }
101            return Err(AccessError::new(
102                table,
103                None,
104                AccessErrorKind::CteMutationUnsupported,
105            ));
106        }
107
108        for operation in &required_ops {
109            self.check_table_operation(ctx, &table, *operation)?;
110        }
111
112        if required_ops.contains(&AccessOperation::Read) && projection_restricted_action(cmd.action)
113        {
114            self.check_read_columns(&table, AccessOperation::Read, &cmd.columns)?;
115        }
116
117        if required_ops.contains(&AccessOperation::Create) {
118            let columns = create_columns(cmd)?;
119            self.check_write_columns(&table, AccessOperation::Create, &columns)?;
120        }
121
122        if required_ops.contains(&AccessOperation::Update) {
123            let columns = update_columns(cmd)?;
124            self.check_write_columns(&table, AccessOperation::Update, &columns)?;
125        }
126
127        if let Some(returning) = &cmd.returning {
128            self.check_returning_columns(&table, returning)?;
129        }
130
131        Ok(())
132    }
133
134    fn check_merge_table_source(
135        &self,
136        ctx: &AccessContext,
137        source_table: &str,
138    ) -> Result<(), AccessError> {
139        let table = normalize_table_ref(source_table);
140        if table.is_empty() {
141            return Err(AccessError::new(
142                String::new(),
143                Some(AccessOperation::Read),
144                AccessErrorKind::EmptyTable,
145            ));
146        }
147
148        self.check_table_operation(ctx, &table, AccessOperation::Read)?;
149        if self
150            .table_policy(&table)
151            .is_some_and(|policy| policy.read_columns.is_restrictive())
152        {
153            return Err(AccessError::new(
154                table,
155                Some(AccessOperation::Read),
156                AccessErrorKind::SourceTableColumnPolicyUnsupported,
157            ));
158        }
159        Ok(())
160    }
161
162    fn check_table_operation(
163        &self,
164        ctx: &AccessContext,
165        table: &str,
166        operation: AccessOperation,
167    ) -> Result<(), AccessError> {
168        let Some(policy) = self.table_policy(table) else {
169            return match self.default_decision {
170                AccessDecision::Allow => Ok(()),
171                AccessDecision::Deny => Err(AccessError::new(
172                    table.to_string(),
173                    Some(operation),
174                    AccessErrorKind::NoPolicy,
175                )),
176            };
177        };
178
179        if !ctx.has_any_role(&policy.require_any_role) {
180            return Err(AccessError::new(
181                table.to_string(),
182                Some(operation),
183                AccessErrorKind::MissingRole {
184                    required: policy.require_any_role.clone(),
185                },
186            ));
187        }
188
189        if !ctx.has_all_scopes(&policy.require_scopes) {
190            return Err(AccessError::new(
191                table.to_string(),
192                Some(operation),
193                AccessErrorKind::MissingScope {
194                    required: policy.require_scopes.clone(),
195                },
196            ));
197        }
198
199        if !policy.allows_operation(operation) {
200            return Err(AccessError::new(
201                table.to_string(),
202                Some(operation),
203                AccessErrorKind::OperationDenied,
204            ));
205        }
206
207        Ok(())
208    }
209
210    fn check_join_read_access(
211        &self,
212        ctx: &AccessContext,
213        cmd: &Qail,
214        cte_names: &BTreeSet<String>,
215    ) -> Result<(), AccessError> {
216        for join in &cmd.joins {
217            let table = normalize_table_ref(&join.table);
218            if table.is_empty() || cte_names.contains(&table) {
219                continue;
220            }
221            self.check_table_operation(ctx, &table, AccessOperation::Read)?;
222            if self
223                .table_policy(&table)
224                .is_some_and(|policy| policy.read_columns.is_restrictive())
225            {
226                return Err(AccessError::new(
227                    table,
228                    Some(AccessOperation::Read),
229                    AccessErrorKind::JoinedTableColumnPolicyUnsupported,
230                ));
231            }
232        }
233        Ok(())
234    }
235
236    fn check_auxiliary_read_access(
237        &self,
238        ctx: &AccessContext,
239        cmd: &Qail,
240        cte_names: &BTreeSet<String>,
241    ) -> Result<(), AccessError> {
242        for table_ref in cmd.from_tables.iter().chain(&cmd.using_tables) {
243            let table = normalize_table_ref(table_ref);
244            if table.is_empty() || cte_names.contains(&table) {
245                continue;
246            }
247            self.check_table_operation(ctx, &table, AccessOperation::Read)?;
248            if self
249                .table_policy(&table)
250                .is_some_and(|policy| policy.read_columns.is_restrictive())
251            {
252                return Err(AccessError::new(
253                    table,
254                    Some(AccessOperation::Read),
255                    AccessErrorKind::AuxiliaryTableColumnPolicyUnsupported,
256                ));
257            }
258        }
259        Ok(())
260    }
261
262    fn check_condition_read_columns(&self, table: &str, cmd: &Qail) -> Result<(), AccessError> {
263        let rule = self
264            .table_policy(table)
265            .map(|policy| &policy.read_columns)
266            .unwrap_or(&ColumnRule::Any);
267        if !rule.is_restrictive() {
268            return Ok(());
269        }
270
271        let target_refs = target_refs_for_command(cmd, table);
272        self.check_distinct_on_columns(table, rule, &target_refs, cmd)?;
273        self.check_grouping_set_columns(table, rule, &target_refs, cmd)?;
274        for cage in &cmd.cages {
275            if matches!(cage.kind, CageKind::Payload) {
276                for condition in &cage.conditions {
277                    self.check_value_column_refs(
278                        table,
279                        rule,
280                        &target_refs,
281                        &condition.value,
282                        "write payload value",
283                    )?;
284                }
285                continue;
286            }
287            for condition in &cage.conditions {
288                self.check_condition_column_refs(
289                    table,
290                    rule,
291                    &target_refs,
292                    condition,
293                    "condition",
294                )?;
295            }
296        }
297        for condition in &cmd.having {
298            self.check_condition_column_refs(
299                table,
300                rule,
301                &target_refs,
302                condition,
303                "having condition",
304            )?;
305        }
306        if let Some(on_conflict) = &cmd.on_conflict
307            && let ConflictAction::DoUpdate { assignments } = &on_conflict.action
308        {
309            for (_, expr) in assignments {
310                self.check_expr_column_refs(
311                    table,
312                    rule,
313                    &target_refs,
314                    expr,
315                    "conflict update value",
316                )?;
317            }
318        }
319        for join in &cmd.joins {
320            if let Some(conditions) = &join.on {
321                for condition in conditions {
322                    self.check_condition_column_refs(
323                        table,
324                        rule,
325                        &target_refs,
326                        condition,
327                        "join condition",
328                    )?;
329                }
330            }
331        }
332        if let Some(merge) = &cmd.merge {
333            for condition in &merge.on {
334                self.check_condition_column_refs(
335                    table,
336                    rule,
337                    &target_refs,
338                    condition,
339                    "merge condition",
340                )?;
341            }
342            for clause in &merge.clauses {
343                for condition in &clause.condition {
344                    self.check_condition_column_refs(
345                        table,
346                        rule,
347                        &target_refs,
348                        condition,
349                        "merge condition",
350                    )?;
351                }
352                match &clause.action {
353                    MergeAction::Update { assignments } => {
354                        for (_, expr) in assignments {
355                            self.check_expr_column_refs(
356                                table,
357                                rule,
358                                &target_refs,
359                                expr,
360                                "merge update value",
361                            )?;
362                        }
363                    }
364                    MergeAction::Insert { values, .. } => {
365                        for expr in values {
366                            self.check_expr_column_refs(
367                                table,
368                                rule,
369                                &target_refs,
370                                expr,
371                                "merge insert value",
372                            )?;
373                        }
374                    }
375                    MergeAction::Delete | MergeAction::DoNothing => {}
376                }
377            }
378        }
379        Ok(())
380    }
381
382    fn check_distinct_on_columns(
383        &self,
384        table: &str,
385        rule: &ColumnRule,
386        target_refs: &BTreeSet<String>,
387        cmd: &Qail,
388    ) -> Result<(), AccessError> {
389        for expr in &cmd.distinct_on {
390            if expr_projects_all_columns(expr) {
391                return Err(AccessError::new(
392                    table.to_string(),
393                    Some(AccessOperation::Read),
394                    AccessErrorKind::WildcardProjectionDenied,
395                ));
396            }
397            self.check_expr_column_refs(table, rule, target_refs, expr, "distinct on")?;
398        }
399        Ok(())
400    }
401
402    fn check_grouping_set_columns(
403        &self,
404        table: &str,
405        rule: &ColumnRule,
406        target_refs: &BTreeSet<String>,
407        cmd: &Qail,
408    ) -> Result<(), AccessError> {
409        if let crate::ast::GroupByMode::GroupingSets(sets) = &cmd.group_by_mode {
410            for group in sets {
411                for column in group {
412                    check_named_read_column(table, rule, target_refs, column, "grouping sets")?;
413                }
414            }
415        }
416        Ok(())
417    }
418
419    fn check_condition_column_refs(
420        &self,
421        table: &str,
422        rule: &ColumnRule,
423        target_refs: &BTreeSet<String>,
424        condition: &Condition,
425        context: &'static str,
426    ) -> Result<(), AccessError> {
427        self.check_expr_column_refs(table, rule, target_refs, &condition.left, context)?;
428        self.check_value_column_refs(table, rule, target_refs, &condition.value, context)
429    }
430
431    fn check_expr_column_refs(
432        &self,
433        table: &str,
434        rule: &ColumnRule,
435        target_refs: &BTreeSet<String>,
436        expr: &Expr,
437        context: &'static str,
438    ) -> Result<(), AccessError> {
439        match expr {
440            Expr::Named(name)
441            | Expr::Aliased { name, .. }
442            | Expr::JsonAccess { column: name, .. } => {
443                check_named_read_column(table, rule, target_refs, name, context)
444            }
445            Expr::Aggregate { col, filter, .. } => {
446                if col != "*" {
447                    check_named_read_column(table, rule, target_refs, col, context)?;
448                }
449                if let Some(conditions) = filter {
450                    for condition in conditions {
451                        self.check_condition_column_refs(
452                            table,
453                            rule,
454                            target_refs,
455                            condition,
456                            context,
457                        )?;
458                    }
459                }
460                Ok(())
461            }
462            Expr::Cast { expr, .. }
463            | Expr::Mod { col: expr, .. }
464            | Expr::FieldAccess { expr, .. }
465            | Expr::Collate { expr, .. } => {
466                self.check_expr_column_refs(table, rule, target_refs, expr, context)
467            }
468            Expr::Subscript { expr, index, .. } => {
469                self.check_expr_column_refs(table, rule, target_refs, expr, context)?;
470                self.check_expr_column_refs(table, rule, target_refs, index, context)
471            }
472            Expr::FunctionCall { args, .. } => {
473                for arg in args {
474                    self.check_expr_column_refs(table, rule, target_refs, arg, context)?;
475                }
476                Ok(())
477            }
478            Expr::SpecialFunction { args, .. } => {
479                for (_, arg) in args {
480                    self.check_expr_column_refs(table, rule, target_refs, arg, context)?;
481                }
482                Ok(())
483            }
484            Expr::Binary { left, right, .. } => {
485                self.check_expr_column_refs(table, rule, target_refs, left, context)?;
486                self.check_expr_column_refs(table, rule, target_refs, right, context)
487            }
488            Expr::Literal(value) => {
489                self.check_value_column_refs(table, rule, target_refs, value, context)
490            }
491            Expr::ArrayConstructor { elements, .. } | Expr::RowConstructor { elements, .. } => {
492                for element in elements {
493                    self.check_expr_column_refs(table, rule, target_refs, element, context)?;
494                }
495                Ok(())
496            }
497            Expr::Case {
498                when_clauses,
499                else_value,
500                ..
501            } => {
502                for (condition, value) in when_clauses {
503                    self.check_condition_column_refs(table, rule, target_refs, condition, context)?;
504                    self.check_expr_column_refs(table, rule, target_refs, value, context)?;
505                }
506                if let Some(value) = else_value {
507                    self.check_expr_column_refs(table, rule, target_refs, value, context)?;
508                }
509                Ok(())
510            }
511            Expr::Window {
512                params,
513                partition,
514                order,
515                ..
516            } => {
517                for param in params {
518                    self.check_expr_column_refs(table, rule, target_refs, param, context)?;
519                }
520                for column in partition {
521                    check_named_read_column(table, rule, target_refs, column, context)?;
522                }
523                for cage in order {
524                    for condition in &cage.conditions {
525                        self.check_condition_column_refs(
526                            table,
527                            rule,
528                            target_refs,
529                            condition,
530                            context,
531                        )?;
532                    }
533                }
534                Ok(())
535            }
536            Expr::Subquery { query, .. } | Expr::Exists { query, .. } => {
537                self.check_outer_command_column_refs(table, rule, target_refs, query)
538            }
539            Expr::Star | Expr::Def { .. } => Ok(()),
540        }
541    }
542
543    fn check_value_column_refs(
544        &self,
545        table: &str,
546        rule: &ColumnRule,
547        target_refs: &BTreeSet<String>,
548        value: &Value,
549        context: &'static str,
550    ) -> Result<(), AccessError> {
551        match value {
552            Value::Column(name) => check_named_read_column(table, rule, target_refs, name, context),
553            Value::Expr(expr) => {
554                self.check_expr_column_refs(table, rule, target_refs, expr, context)
555            }
556            Value::Array(values) => {
557                for value in values {
558                    self.check_value_column_refs(table, rule, target_refs, value, context)?;
559                }
560                Ok(())
561            }
562            Value::Function(_) => Err(AccessError::new(
563                table.to_string(),
564                Some(AccessOperation::Read),
565                AccessErrorKind::UnsupportedColumnExpression { context },
566            )),
567            Value::Subquery(query) => {
568                self.check_outer_command_column_refs(table, rule, target_refs, query)
569            }
570            _ => Ok(()),
571        }
572    }
573
574    fn check_outer_command_column_refs(
575        &self,
576        table: &str,
577        rule: &ColumnRule,
578        target_refs: &BTreeSet<String>,
579        cmd: &Qail,
580    ) -> Result<(), AccessError> {
581        for expr in &cmd.columns {
582            self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
583        }
584        if let Some(returning) = &cmd.returning {
585            for expr in returning {
586                self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
587            }
588        }
589        for cage in &cmd.cages {
590            for condition in &cage.conditions {
591                self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
592            }
593        }
594        for condition in &cmd.having {
595            self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
596        }
597        for join in &cmd.joins {
598            if let Some(conditions) = &join.on {
599                for condition in conditions {
600                    self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
601                }
602            }
603        }
604        if let Some(on_conflict) = &cmd.on_conflict
605            && let ConflictAction::DoUpdate { assignments } = &on_conflict.action
606        {
607            for (_, expr) in assignments {
608                self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
609            }
610        }
611        if let Some(merge) = &cmd.merge {
612            if let MergeSource::Query { query, .. } = &merge.source {
613                self.check_outer_command_column_refs(table, rule, target_refs, query)?;
614            }
615            for condition in &merge.on {
616                self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
617            }
618            for clause in &merge.clauses {
619                for condition in &clause.condition {
620                    self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
621                }
622                match &clause.action {
623                    MergeAction::Update { assignments } => {
624                        for (_, expr) in assignments {
625                            self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
626                        }
627                    }
628                    MergeAction::Insert { values, .. } => {
629                        for expr in values {
630                            self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
631                        }
632                    }
633                    MergeAction::Delete | MergeAction::DoNothing => {}
634                }
635            }
636        }
637        for cte in &cmd.ctes {
638            self.check_outer_command_column_refs(table, rule, target_refs, &cte.base_query)?;
639            if let Some(recursive_query) = &cte.recursive_query {
640                self.check_outer_command_column_refs(table, rule, target_refs, recursive_query)?;
641            }
642        }
643        for (_, set_query) in &cmd.set_ops {
644            self.check_outer_command_column_refs(table, rule, target_refs, set_query)?;
645        }
646        if let Some(source_query) = &cmd.source_query {
647            self.check_outer_command_column_refs(table, rule, target_refs, source_query)?;
648        }
649        Ok(())
650    }
651
652    fn check_outer_condition_column_refs(
653        &self,
654        table: &str,
655        rule: &ColumnRule,
656        target_refs: &BTreeSet<String>,
657        condition: &Condition,
658    ) -> Result<(), AccessError> {
659        self.check_outer_expr_column_refs(table, rule, target_refs, &condition.left)?;
660        self.check_outer_value_column_refs(table, rule, target_refs, &condition.value)
661    }
662
663    fn check_outer_expr_column_refs(
664        &self,
665        table: &str,
666        rule: &ColumnRule,
667        target_refs: &BTreeSet<String>,
668        expr: &Expr,
669    ) -> Result<(), AccessError> {
670        match expr {
671            Expr::Named(name)
672            | Expr::Aliased { name, .. }
673            | Expr::JsonAccess { column: name, .. } => {
674                check_qualified_read_column(table, rule, target_refs, name)
675            }
676            Expr::Aggregate { col, filter, .. } => {
677                if col != "*" {
678                    check_qualified_read_column(table, rule, target_refs, col)?;
679                }
680                if let Some(conditions) = filter {
681                    for condition in conditions {
682                        self.check_outer_condition_column_refs(
683                            table,
684                            rule,
685                            target_refs,
686                            condition,
687                        )?;
688                    }
689                }
690                Ok(())
691            }
692            Expr::Cast { expr, .. }
693            | Expr::Mod { col: expr, .. }
694            | Expr::FieldAccess { expr, .. }
695            | Expr::Collate { expr, .. } => {
696                self.check_outer_expr_column_refs(table, rule, target_refs, expr)
697            }
698            Expr::Subscript { expr, index, .. } => {
699                self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
700                self.check_outer_expr_column_refs(table, rule, target_refs, index)
701            }
702            Expr::FunctionCall { args, .. } => {
703                for arg in args {
704                    self.check_outer_expr_column_refs(table, rule, target_refs, arg)?;
705                }
706                Ok(())
707            }
708            Expr::SpecialFunction { args, .. } => {
709                for (_, arg) in args {
710                    self.check_outer_expr_column_refs(table, rule, target_refs, arg)?;
711                }
712                Ok(())
713            }
714            Expr::Binary { left, right, .. } => {
715                self.check_outer_expr_column_refs(table, rule, target_refs, left)?;
716                self.check_outer_expr_column_refs(table, rule, target_refs, right)
717            }
718            Expr::Literal(value) => {
719                self.check_outer_value_column_refs(table, rule, target_refs, value)
720            }
721            Expr::ArrayConstructor { elements, .. } | Expr::RowConstructor { elements, .. } => {
722                for element in elements {
723                    self.check_outer_expr_column_refs(table, rule, target_refs, element)?;
724                }
725                Ok(())
726            }
727            Expr::Case {
728                when_clauses,
729                else_value,
730                ..
731            } => {
732                for (condition, value) in when_clauses {
733                    self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
734                    self.check_outer_expr_column_refs(table, rule, target_refs, value)?;
735                }
736                if let Some(value) = else_value {
737                    self.check_outer_expr_column_refs(table, rule, target_refs, value)?;
738                }
739                Ok(())
740            }
741            Expr::Window {
742                params,
743                partition,
744                order,
745                ..
746            } => {
747                for param in params {
748                    self.check_outer_expr_column_refs(table, rule, target_refs, param)?;
749                }
750                for column in partition {
751                    check_qualified_read_column(table, rule, target_refs, column)?;
752                }
753                for cage in order {
754                    for condition in &cage.conditions {
755                        self.check_outer_condition_column_refs(
756                            table,
757                            rule,
758                            target_refs,
759                            condition,
760                        )?;
761                    }
762                }
763                Ok(())
764            }
765            Expr::Subquery { query, .. } | Expr::Exists { query, .. } => {
766                self.check_outer_command_column_refs(table, rule, target_refs, query)
767            }
768            Expr::Star | Expr::Def { .. } => Ok(()),
769        }
770    }
771
772    fn check_outer_value_column_refs(
773        &self,
774        table: &str,
775        rule: &ColumnRule,
776        target_refs: &BTreeSet<String>,
777        value: &Value,
778    ) -> Result<(), AccessError> {
779        match value {
780            Value::Column(name) => check_qualified_read_column(table, rule, target_refs, name),
781            Value::Expr(expr) => self.check_outer_expr_column_refs(table, rule, target_refs, expr),
782            Value::Array(values) => {
783                for value in values {
784                    self.check_outer_value_column_refs(table, rule, target_refs, value)?;
785                }
786                Ok(())
787            }
788            Value::Subquery(query) => {
789                self.check_outer_command_column_refs(table, rule, target_refs, query)
790            }
791            _ => Ok(()),
792        }
793    }
794
795    fn check_read_columns(
796        &self,
797        table: &str,
798        operation: AccessOperation,
799        columns: &[Expr],
800    ) -> Result<(), AccessError> {
801        let rule = self
802            .table_policy(table)
803            .map(|policy| &policy.read_columns)
804            .unwrap_or(&ColumnRule::Any);
805        check_projection_rule(table, operation, rule, columns, "read projection")
806    }
807
808    fn check_write_columns(
809        &self,
810        table: &str,
811        operation: AccessOperation,
812        columns: &[String],
813    ) -> Result<(), AccessError> {
814        let rule = self
815            .table_policy(table)
816            .map(|policy| &policy.write_columns)
817            .unwrap_or(&ColumnRule::Any);
818        if !rule.is_restrictive() {
819            return Ok(());
820        }
821        if columns.is_empty() {
822            return Err(AccessError::new(
823                table.to_string(),
824                Some(operation),
825                AccessErrorKind::ExplicitWriteColumnsRequired,
826            ));
827        }
828        for column in columns {
829            if !rule.allows(column) {
830                return Err(AccessError::new(
831                    table.to_string(),
832                    Some(operation),
833                    AccessErrorKind::ColumnDenied {
834                        column: normalize_column_name(column),
835                    },
836                ));
837            }
838        }
839        Ok(())
840    }
841
842    fn check_returning_columns(&self, table: &str, columns: &[Expr]) -> Result<(), AccessError> {
843        let read_rule = self
844            .table_policy(table)
845            .map(|policy| &policy.read_columns)
846            .unwrap_or(&ColumnRule::Any);
847        let returning_rule = self
848            .table_policy(table)
849            .map(|policy| &policy.returning_columns)
850            .unwrap_or(&ColumnRule::Any);
851        check_projection_rule(
852            table,
853            AccessOperation::Read,
854            read_rule,
855            columns,
856            "returning projection",
857        )?;
858        check_projection_rule(
859            table,
860            AccessOperation::Read,
861            returning_rule,
862            columns,
863            "returning projection",
864        )
865    }
866
867    fn table_policy(&self, table: &str) -> Option<&TableAccessPolicy> {
868        self.tables.get(table).or_else(|| self.tables.get("*"))
869    }
870
871    fn check_embedded_queries(&self, ctx: &AccessContext, cmd: &Qail) -> Result<(), AccessError> {
872        for expr in &cmd.columns {
873            self.check_expr(ctx, expr)?;
874        }
875        if let Some(returning) = &cmd.returning {
876            for expr in returning {
877                self.check_expr(ctx, expr)?;
878            }
879        }
880        for cage in &cmd.cages {
881            for condition in &cage.conditions {
882                self.check_condition(ctx, condition)?;
883            }
884        }
885        for condition in &cmd.having {
886            self.check_condition(ctx, condition)?;
887        }
888        for join in &cmd.joins {
889            if let Some(conditions) = &join.on {
890                for condition in conditions {
891                    self.check_condition(ctx, condition)?;
892                }
893            }
894        }
895        if let Some(on_conflict) = &cmd.on_conflict
896            && let ConflictAction::DoUpdate { assignments } = &on_conflict.action
897        {
898            for (_, expr) in assignments {
899                self.check_expr(ctx, expr)?;
900            }
901        }
902        if let Some(merge) = &cmd.merge {
903            for condition in &merge.on {
904                self.check_condition(ctx, condition)?;
905            }
906            for clause in &merge.clauses {
907                for condition in &clause.condition {
908                    self.check_condition(ctx, condition)?;
909                }
910                match &clause.action {
911                    MergeAction::Update { assignments } => {
912                        for (_, expr) in assignments {
913                            self.check_expr(ctx, expr)?;
914                        }
915                    }
916                    MergeAction::Insert { values, .. } => {
917                        for expr in values {
918                            self.check_expr(ctx, expr)?;
919                        }
920                    }
921                    MergeAction::Delete | MergeAction::DoNothing => {}
922                }
923            }
924        }
925        Ok(())
926    }
927
928    fn check_condition(
929        &self,
930        ctx: &AccessContext,
931        condition: &Condition,
932    ) -> Result<(), AccessError> {
933        self.check_expr(ctx, &condition.left)?;
934        self.check_value(ctx, &condition.value)
935    }
936
937    fn check_expr(&self, ctx: &AccessContext, expr: &Expr) -> Result<(), AccessError> {
938        match expr {
939            Expr::Cast { expr, .. }
940            | Expr::Mod { col: expr, .. }
941            | Expr::FieldAccess { expr, .. }
942            | Expr::Collate { expr, .. } => self.check_expr(ctx, expr),
943            Expr::Subscript { expr, index, .. } => {
944                self.check_expr(ctx, expr)?;
945                self.check_expr(ctx, index)
946            }
947            Expr::FunctionCall { args, .. } => {
948                for arg in args {
949                    self.check_expr(ctx, arg)?;
950                }
951                Ok(())
952            }
953            Expr::SpecialFunction { args, .. } => {
954                for (_, arg) in args {
955                    self.check_expr(ctx, arg)?;
956                }
957                Ok(())
958            }
959            Expr::Binary { left, right, .. } => {
960                self.check_expr(ctx, left)?;
961                self.check_expr(ctx, right)
962            }
963            Expr::Literal(value) => self.check_value(ctx, value),
964            Expr::ArrayConstructor { elements, .. } | Expr::RowConstructor { elements, .. } => {
965                for element in elements {
966                    self.check_expr(ctx, element)?;
967                }
968                Ok(())
969            }
970            Expr::Case {
971                when_clauses,
972                else_value,
973                ..
974            } => {
975                for (condition, value) in when_clauses {
976                    self.check_condition(ctx, condition)?;
977                    self.check_expr(ctx, value)?;
978                }
979                if let Some(value) = else_value {
980                    self.check_expr(ctx, value)?;
981                }
982                Ok(())
983            }
984            Expr::Window { params, order, .. } => {
985                for param in params {
986                    self.check_expr(ctx, param)?;
987                }
988                for cage in order {
989                    for condition in &cage.conditions {
990                        self.check_condition(ctx, condition)?;
991                    }
992                }
993                Ok(())
994            }
995            Expr::Aggregate { filter, .. } => {
996                if let Some(conditions) = filter {
997                    for condition in conditions {
998                        self.check_condition(ctx, condition)?;
999                    }
1000                }
1001                Ok(())
1002            }
1003            Expr::Subquery { query, .. } | Expr::Exists { query, .. } => {
1004                self.check_command_inner(ctx, query)
1005            }
1006            Expr::Star
1007            | Expr::Named(_)
1008            | Expr::Aliased { .. }
1009            | Expr::Def { .. }
1010            | Expr::JsonAccess { .. } => Ok(()),
1011        }
1012    }
1013
1014    fn check_value(&self, ctx: &AccessContext, value: &Value) -> Result<(), AccessError> {
1015        match value {
1016            Value::Subquery(query) => self.check_command_inner(ctx, query),
1017            Value::Expr(expr) => self.check_expr(ctx, expr),
1018            Value::Array(values) => {
1019                for value in values {
1020                    self.check_value(ctx, value)?;
1021                }
1022                Ok(())
1023            }
1024            _ => Ok(()),
1025        }
1026    }
1027}
1028
1029#[cfg(test)]
1030mod tests;