1use std::collections::{BTreeMap, BTreeSet};
8
9use crate::ast::{
10 CageKind, Condition, ConflictAction, Expr, MergeAction, MergeSource, Qail, Value,
11};
12
13mod columns;
14mod config;
15mod error;
16mod ident;
17mod model;
18mod operations;
19
20use columns::{
21 check_named_read_column, check_projection_rule, check_qualified_read_column, create_columns,
22 expr_projects_all_columns, projection_restricted_action, update_columns,
23};
24use ident::{normalize_column_name, normalize_table_ref, target_refs_for_command};
25
26pub use error::{AccessError, AccessErrorKind, AccessPolicyLoadError};
27pub use model::{AccessContext, AccessDecision, AccessOperation, ColumnRule, TableAccessPolicy};
28pub use operations::required_operations_for_command;
29
30#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
32pub struct AccessPolicy {
33 pub default_decision: AccessDecision,
35 #[serde(default)]
37 pub tables: BTreeMap<String, TableAccessPolicy>,
38}
39
40impl AccessPolicy {
41 pub fn check_command(&self, ctx: &AccessContext, cmd: &Qail) -> Result<(), AccessError> {
43 self.check_command_inner(ctx, cmd)
44 }
45
46 fn check_command_inner(&self, ctx: &AccessContext, cmd: &Qail) -> Result<(), AccessError> {
47 if ctx.bypasses_access() {
48 return Ok(());
49 }
50
51 for cte in &cmd.ctes {
52 self.check_command_inner(ctx, &cte.base_query)?;
53 if let Some(recursive_query) = &cte.recursive_query {
54 self.check_command_inner(ctx, recursive_query)?;
55 }
56 }
57 for (_, set_query) in &cmd.set_ops {
58 self.check_command_inner(ctx, set_query)?;
59 }
60 if let Some(source_query) = &cmd.source_query {
61 self.check_command_inner(ctx, source_query)?;
62 }
63 if let Some(merge) = &cmd.merge {
64 match &merge.source {
65 MergeSource::Query { query, .. } => self.check_command_inner(ctx, query)?,
66 MergeSource::Table { name, .. } => self.check_merge_table_source(ctx, name)?,
67 }
68 }
69
70 let table = normalize_table_ref(&cmd.table);
71 if table.is_empty() {
72 return Err(AccessError::new(
73 String::new(),
74 None,
75 AccessErrorKind::EmptyTable,
76 ));
77 }
78
79 self.check_embedded_queries(ctx, cmd)?;
80 self.check_condition_read_columns(&table, cmd)?;
81
82 let cte_names: BTreeSet<String> = cmd
83 .ctes
84 .iter()
85 .map(|cte| normalize_table_ref(&cte.name))
86 .collect();
87 self.check_join_read_access(ctx, cmd, &cte_names)?;
88 self.check_auxiliary_read_access(ctx, cmd, &cte_names)?;
89
90 let required_ops = required_operations_for_command(cmd).ok_or_else(|| {
91 AccessError::new(
92 table.clone(),
93 None,
94 AccessErrorKind::UnsupportedAction(cmd.action),
95 )
96 })?;
97 if cte_names.contains(&table) {
98 if required_ops.iter().all(|op| *op == AccessOperation::Read) {
99 return Ok(());
100 }
101 return Err(AccessError::new(
102 table,
103 None,
104 AccessErrorKind::CteMutationUnsupported,
105 ));
106 }
107
108 for operation in &required_ops {
109 self.check_table_operation(ctx, &table, *operation)?;
110 }
111
112 if required_ops.contains(&AccessOperation::Read) && projection_restricted_action(cmd.action)
113 {
114 self.check_read_columns(&table, AccessOperation::Read, &cmd.columns)?;
115 }
116
117 if required_ops.contains(&AccessOperation::Create) {
118 let columns = create_columns(cmd)?;
119 self.check_write_columns(&table, AccessOperation::Create, &columns)?;
120 }
121
122 if required_ops.contains(&AccessOperation::Update) {
123 let columns = update_columns(cmd)?;
124 self.check_write_columns(&table, AccessOperation::Update, &columns)?;
125 }
126
127 if let Some(returning) = &cmd.returning {
128 self.check_returning_columns(&table, returning)?;
129 }
130
131 Ok(())
132 }
133
134 fn check_merge_table_source(
135 &self,
136 ctx: &AccessContext,
137 source_table: &str,
138 ) -> Result<(), AccessError> {
139 let table = normalize_table_ref(source_table);
140 if table.is_empty() {
141 return Err(AccessError::new(
142 String::new(),
143 Some(AccessOperation::Read),
144 AccessErrorKind::EmptyTable,
145 ));
146 }
147
148 self.check_table_operation(ctx, &table, AccessOperation::Read)?;
149 if self
150 .table_policy(&table)
151 .is_some_and(|policy| policy.read_columns.is_restrictive())
152 {
153 return Err(AccessError::new(
154 table,
155 Some(AccessOperation::Read),
156 AccessErrorKind::SourceTableColumnPolicyUnsupported,
157 ));
158 }
159 Ok(())
160 }
161
162 fn check_table_operation(
163 &self,
164 ctx: &AccessContext,
165 table: &str,
166 operation: AccessOperation,
167 ) -> Result<(), AccessError> {
168 let Some(policy) = self.table_policy(table) else {
169 return match self.default_decision {
170 AccessDecision::Allow => Ok(()),
171 AccessDecision::Deny => Err(AccessError::new(
172 table.to_string(),
173 Some(operation),
174 AccessErrorKind::NoPolicy,
175 )),
176 };
177 };
178
179 if !ctx.has_any_role(&policy.require_any_role) {
180 return Err(AccessError::new(
181 table.to_string(),
182 Some(operation),
183 AccessErrorKind::MissingRole {
184 required: policy.require_any_role.clone(),
185 },
186 ));
187 }
188
189 if !ctx.has_all_scopes(&policy.require_scopes) {
190 return Err(AccessError::new(
191 table.to_string(),
192 Some(operation),
193 AccessErrorKind::MissingScope {
194 required: policy.require_scopes.clone(),
195 },
196 ));
197 }
198
199 if !policy.allows_operation(operation) {
200 return Err(AccessError::new(
201 table.to_string(),
202 Some(operation),
203 AccessErrorKind::OperationDenied,
204 ));
205 }
206
207 Ok(())
208 }
209
210 fn check_join_read_access(
211 &self,
212 ctx: &AccessContext,
213 cmd: &Qail,
214 cte_names: &BTreeSet<String>,
215 ) -> Result<(), AccessError> {
216 for join in &cmd.joins {
217 let table = normalize_table_ref(&join.table);
218 if table.is_empty() || cte_names.contains(&table) {
219 continue;
220 }
221 self.check_table_operation(ctx, &table, AccessOperation::Read)?;
222 if self
223 .table_policy(&table)
224 .is_some_and(|policy| policy.read_columns.is_restrictive())
225 {
226 return Err(AccessError::new(
227 table,
228 Some(AccessOperation::Read),
229 AccessErrorKind::JoinedTableColumnPolicyUnsupported,
230 ));
231 }
232 }
233 Ok(())
234 }
235
236 fn check_auxiliary_read_access(
237 &self,
238 ctx: &AccessContext,
239 cmd: &Qail,
240 cte_names: &BTreeSet<String>,
241 ) -> Result<(), AccessError> {
242 for table_ref in cmd.from_tables.iter().chain(&cmd.using_tables) {
243 let table = normalize_table_ref(table_ref);
244 if table.is_empty() || cte_names.contains(&table) {
245 continue;
246 }
247 self.check_table_operation(ctx, &table, AccessOperation::Read)?;
248 if self
249 .table_policy(&table)
250 .is_some_and(|policy| policy.read_columns.is_restrictive())
251 {
252 return Err(AccessError::new(
253 table,
254 Some(AccessOperation::Read),
255 AccessErrorKind::AuxiliaryTableColumnPolicyUnsupported,
256 ));
257 }
258 }
259 Ok(())
260 }
261
262 fn check_condition_read_columns(&self, table: &str, cmd: &Qail) -> Result<(), AccessError> {
263 let rule = self
264 .table_policy(table)
265 .map(|policy| &policy.read_columns)
266 .unwrap_or(&ColumnRule::Any);
267 if !rule.is_restrictive() {
268 return Ok(());
269 }
270
271 let target_refs = target_refs_for_command(cmd, table);
272 self.check_distinct_on_columns(table, rule, &target_refs, cmd)?;
273 self.check_grouping_set_columns(table, rule, &target_refs, cmd)?;
274 for cage in &cmd.cages {
275 if matches!(cage.kind, CageKind::Payload) {
276 for condition in &cage.conditions {
277 self.check_value_column_refs(
278 table,
279 rule,
280 &target_refs,
281 &condition.value,
282 "write payload value",
283 )?;
284 }
285 continue;
286 }
287 for condition in &cage.conditions {
288 self.check_condition_column_refs(
289 table,
290 rule,
291 &target_refs,
292 condition,
293 "condition",
294 )?;
295 }
296 }
297 for condition in &cmd.having {
298 self.check_condition_column_refs(
299 table,
300 rule,
301 &target_refs,
302 condition,
303 "having condition",
304 )?;
305 }
306 if let Some(on_conflict) = &cmd.on_conflict
307 && let ConflictAction::DoUpdate { assignments } = &on_conflict.action
308 {
309 for (_, expr) in assignments {
310 self.check_expr_column_refs(
311 table,
312 rule,
313 &target_refs,
314 expr,
315 "conflict update value",
316 )?;
317 }
318 }
319 for join in &cmd.joins {
320 if let Some(conditions) = &join.on {
321 for condition in conditions {
322 self.check_condition_column_refs(
323 table,
324 rule,
325 &target_refs,
326 condition,
327 "join condition",
328 )?;
329 }
330 }
331 }
332 if let Some(merge) = &cmd.merge {
333 for condition in &merge.on {
334 self.check_condition_column_refs(
335 table,
336 rule,
337 &target_refs,
338 condition,
339 "merge condition",
340 )?;
341 }
342 for clause in &merge.clauses {
343 for condition in &clause.condition {
344 self.check_condition_column_refs(
345 table,
346 rule,
347 &target_refs,
348 condition,
349 "merge condition",
350 )?;
351 }
352 match &clause.action {
353 MergeAction::Update { assignments } => {
354 for (_, expr) in assignments {
355 self.check_expr_column_refs(
356 table,
357 rule,
358 &target_refs,
359 expr,
360 "merge update value",
361 )?;
362 }
363 }
364 MergeAction::Insert { values, .. } => {
365 for expr in values {
366 self.check_expr_column_refs(
367 table,
368 rule,
369 &target_refs,
370 expr,
371 "merge insert value",
372 )?;
373 }
374 }
375 MergeAction::Delete | MergeAction::DoNothing => {}
376 }
377 }
378 }
379 Ok(())
380 }
381
382 fn check_distinct_on_columns(
383 &self,
384 table: &str,
385 rule: &ColumnRule,
386 target_refs: &BTreeSet<String>,
387 cmd: &Qail,
388 ) -> Result<(), AccessError> {
389 for expr in &cmd.distinct_on {
390 if expr_projects_all_columns(expr) {
391 return Err(AccessError::new(
392 table.to_string(),
393 Some(AccessOperation::Read),
394 AccessErrorKind::WildcardProjectionDenied,
395 ));
396 }
397 self.check_expr_column_refs(table, rule, target_refs, expr, "distinct on")?;
398 }
399 Ok(())
400 }
401
402 fn check_grouping_set_columns(
403 &self,
404 table: &str,
405 rule: &ColumnRule,
406 target_refs: &BTreeSet<String>,
407 cmd: &Qail,
408 ) -> Result<(), AccessError> {
409 if let crate::ast::GroupByMode::GroupingSets(sets) = &cmd.group_by_mode {
410 for group in sets {
411 for column in group {
412 check_named_read_column(table, rule, target_refs, column, "grouping sets")?;
413 }
414 }
415 }
416 Ok(())
417 }
418
419 fn check_condition_column_refs(
420 &self,
421 table: &str,
422 rule: &ColumnRule,
423 target_refs: &BTreeSet<String>,
424 condition: &Condition,
425 context: &'static str,
426 ) -> Result<(), AccessError> {
427 self.check_expr_column_refs(table, rule, target_refs, &condition.left, context)?;
428 self.check_value_column_refs(table, rule, target_refs, &condition.value, context)
429 }
430
431 fn check_expr_column_refs(
432 &self,
433 table: &str,
434 rule: &ColumnRule,
435 target_refs: &BTreeSet<String>,
436 expr: &Expr,
437 context: &'static str,
438 ) -> Result<(), AccessError> {
439 match expr {
440 Expr::Named(name)
441 | Expr::Aliased { name, .. }
442 | Expr::JsonAccess { column: name, .. } => {
443 check_named_read_column(table, rule, target_refs, name, context)
444 }
445 Expr::Aggregate { col, filter, .. } => {
446 if col != "*" {
447 check_named_read_column(table, rule, target_refs, col, context)?;
448 }
449 if let Some(conditions) = filter {
450 for condition in conditions {
451 self.check_condition_column_refs(
452 table,
453 rule,
454 target_refs,
455 condition,
456 context,
457 )?;
458 }
459 }
460 Ok(())
461 }
462 Expr::Cast { expr, .. }
463 | Expr::Mod { col: expr, .. }
464 | Expr::FieldAccess { expr, .. }
465 | Expr::Collate { expr, .. } => {
466 self.check_expr_column_refs(table, rule, target_refs, expr, context)
467 }
468 Expr::Subscript { expr, index, .. } => {
469 self.check_expr_column_refs(table, rule, target_refs, expr, context)?;
470 self.check_expr_column_refs(table, rule, target_refs, index, context)
471 }
472 Expr::FunctionCall { args, .. } => {
473 for arg in args {
474 self.check_expr_column_refs(table, rule, target_refs, arg, context)?;
475 }
476 Ok(())
477 }
478 Expr::SpecialFunction { args, .. } => {
479 for (_, arg) in args {
480 self.check_expr_column_refs(table, rule, target_refs, arg, context)?;
481 }
482 Ok(())
483 }
484 Expr::Binary { left, right, .. } => {
485 self.check_expr_column_refs(table, rule, target_refs, left, context)?;
486 self.check_expr_column_refs(table, rule, target_refs, right, context)
487 }
488 Expr::Literal(value) => {
489 self.check_value_column_refs(table, rule, target_refs, value, context)
490 }
491 Expr::ArrayConstructor { elements, .. } | Expr::RowConstructor { elements, .. } => {
492 for element in elements {
493 self.check_expr_column_refs(table, rule, target_refs, element, context)?;
494 }
495 Ok(())
496 }
497 Expr::Case {
498 when_clauses,
499 else_value,
500 ..
501 } => {
502 for (condition, value) in when_clauses {
503 self.check_condition_column_refs(table, rule, target_refs, condition, context)?;
504 self.check_expr_column_refs(table, rule, target_refs, value, context)?;
505 }
506 if let Some(value) = else_value {
507 self.check_expr_column_refs(table, rule, target_refs, value, context)?;
508 }
509 Ok(())
510 }
511 Expr::Window {
512 params,
513 partition,
514 order,
515 ..
516 } => {
517 for param in params {
518 self.check_expr_column_refs(table, rule, target_refs, param, context)?;
519 }
520 for column in partition {
521 check_named_read_column(table, rule, target_refs, column, context)?;
522 }
523 for cage in order {
524 for condition in &cage.conditions {
525 self.check_condition_column_refs(
526 table,
527 rule,
528 target_refs,
529 condition,
530 context,
531 )?;
532 }
533 }
534 Ok(())
535 }
536 Expr::Subquery { query, .. } | Expr::Exists { query, .. } => {
537 self.check_outer_command_column_refs(table, rule, target_refs, query)
538 }
539 Expr::Star | Expr::Def { .. } => Ok(()),
540 }
541 }
542
543 fn check_value_column_refs(
544 &self,
545 table: &str,
546 rule: &ColumnRule,
547 target_refs: &BTreeSet<String>,
548 value: &Value,
549 context: &'static str,
550 ) -> Result<(), AccessError> {
551 match value {
552 Value::Column(name) => check_named_read_column(table, rule, target_refs, name, context),
553 Value::Expr(expr) => {
554 self.check_expr_column_refs(table, rule, target_refs, expr, context)
555 }
556 Value::Array(values) => {
557 for value in values {
558 self.check_value_column_refs(table, rule, target_refs, value, context)?;
559 }
560 Ok(())
561 }
562 Value::Function(_) => Err(AccessError::new(
563 table.to_string(),
564 Some(AccessOperation::Read),
565 AccessErrorKind::UnsupportedColumnExpression { context },
566 )),
567 Value::Subquery(query) => {
568 self.check_outer_command_column_refs(table, rule, target_refs, query)
569 }
570 _ => Ok(()),
571 }
572 }
573
574 fn check_outer_command_column_refs(
575 &self,
576 table: &str,
577 rule: &ColumnRule,
578 target_refs: &BTreeSet<String>,
579 cmd: &Qail,
580 ) -> Result<(), AccessError> {
581 for expr in &cmd.columns {
582 self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
583 }
584 if let Some(returning) = &cmd.returning {
585 for expr in returning {
586 self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
587 }
588 }
589 for cage in &cmd.cages {
590 for condition in &cage.conditions {
591 self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
592 }
593 }
594 for condition in &cmd.having {
595 self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
596 }
597 for join in &cmd.joins {
598 if let Some(conditions) = &join.on {
599 for condition in conditions {
600 self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
601 }
602 }
603 }
604 if let Some(on_conflict) = &cmd.on_conflict
605 && let ConflictAction::DoUpdate { assignments } = &on_conflict.action
606 {
607 for (_, expr) in assignments {
608 self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
609 }
610 }
611 if let Some(merge) = &cmd.merge {
612 if let MergeSource::Query { query, .. } = &merge.source {
613 self.check_outer_command_column_refs(table, rule, target_refs, query)?;
614 }
615 for condition in &merge.on {
616 self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
617 }
618 for clause in &merge.clauses {
619 for condition in &clause.condition {
620 self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
621 }
622 match &clause.action {
623 MergeAction::Update { assignments } => {
624 for (_, expr) in assignments {
625 self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
626 }
627 }
628 MergeAction::Insert { values, .. } => {
629 for expr in values {
630 self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
631 }
632 }
633 MergeAction::Delete | MergeAction::DoNothing => {}
634 }
635 }
636 }
637 for cte in &cmd.ctes {
638 self.check_outer_command_column_refs(table, rule, target_refs, &cte.base_query)?;
639 if let Some(recursive_query) = &cte.recursive_query {
640 self.check_outer_command_column_refs(table, rule, target_refs, recursive_query)?;
641 }
642 }
643 for (_, set_query) in &cmd.set_ops {
644 self.check_outer_command_column_refs(table, rule, target_refs, set_query)?;
645 }
646 if let Some(source_query) = &cmd.source_query {
647 self.check_outer_command_column_refs(table, rule, target_refs, source_query)?;
648 }
649 Ok(())
650 }
651
652 fn check_outer_condition_column_refs(
653 &self,
654 table: &str,
655 rule: &ColumnRule,
656 target_refs: &BTreeSet<String>,
657 condition: &Condition,
658 ) -> Result<(), AccessError> {
659 self.check_outer_expr_column_refs(table, rule, target_refs, &condition.left)?;
660 self.check_outer_value_column_refs(table, rule, target_refs, &condition.value)
661 }
662
663 fn check_outer_expr_column_refs(
664 &self,
665 table: &str,
666 rule: &ColumnRule,
667 target_refs: &BTreeSet<String>,
668 expr: &Expr,
669 ) -> Result<(), AccessError> {
670 match expr {
671 Expr::Named(name)
672 | Expr::Aliased { name, .. }
673 | Expr::JsonAccess { column: name, .. } => {
674 check_qualified_read_column(table, rule, target_refs, name)
675 }
676 Expr::Aggregate { col, filter, .. } => {
677 if col != "*" {
678 check_qualified_read_column(table, rule, target_refs, col)?;
679 }
680 if let Some(conditions) = filter {
681 for condition in conditions {
682 self.check_outer_condition_column_refs(
683 table,
684 rule,
685 target_refs,
686 condition,
687 )?;
688 }
689 }
690 Ok(())
691 }
692 Expr::Cast { expr, .. }
693 | Expr::Mod { col: expr, .. }
694 | Expr::FieldAccess { expr, .. }
695 | Expr::Collate { expr, .. } => {
696 self.check_outer_expr_column_refs(table, rule, target_refs, expr)
697 }
698 Expr::Subscript { expr, index, .. } => {
699 self.check_outer_expr_column_refs(table, rule, target_refs, expr)?;
700 self.check_outer_expr_column_refs(table, rule, target_refs, index)
701 }
702 Expr::FunctionCall { args, .. } => {
703 for arg in args {
704 self.check_outer_expr_column_refs(table, rule, target_refs, arg)?;
705 }
706 Ok(())
707 }
708 Expr::SpecialFunction { args, .. } => {
709 for (_, arg) in args {
710 self.check_outer_expr_column_refs(table, rule, target_refs, arg)?;
711 }
712 Ok(())
713 }
714 Expr::Binary { left, right, .. } => {
715 self.check_outer_expr_column_refs(table, rule, target_refs, left)?;
716 self.check_outer_expr_column_refs(table, rule, target_refs, right)
717 }
718 Expr::Literal(value) => {
719 self.check_outer_value_column_refs(table, rule, target_refs, value)
720 }
721 Expr::ArrayConstructor { elements, .. } | Expr::RowConstructor { elements, .. } => {
722 for element in elements {
723 self.check_outer_expr_column_refs(table, rule, target_refs, element)?;
724 }
725 Ok(())
726 }
727 Expr::Case {
728 when_clauses,
729 else_value,
730 ..
731 } => {
732 for (condition, value) in when_clauses {
733 self.check_outer_condition_column_refs(table, rule, target_refs, condition)?;
734 self.check_outer_expr_column_refs(table, rule, target_refs, value)?;
735 }
736 if let Some(value) = else_value {
737 self.check_outer_expr_column_refs(table, rule, target_refs, value)?;
738 }
739 Ok(())
740 }
741 Expr::Window {
742 params,
743 partition,
744 order,
745 ..
746 } => {
747 for param in params {
748 self.check_outer_expr_column_refs(table, rule, target_refs, param)?;
749 }
750 for column in partition {
751 check_qualified_read_column(table, rule, target_refs, column)?;
752 }
753 for cage in order {
754 for condition in &cage.conditions {
755 self.check_outer_condition_column_refs(
756 table,
757 rule,
758 target_refs,
759 condition,
760 )?;
761 }
762 }
763 Ok(())
764 }
765 Expr::Subquery { query, .. } | Expr::Exists { query, .. } => {
766 self.check_outer_command_column_refs(table, rule, target_refs, query)
767 }
768 Expr::Star | Expr::Def { .. } => Ok(()),
769 }
770 }
771
772 fn check_outer_value_column_refs(
773 &self,
774 table: &str,
775 rule: &ColumnRule,
776 target_refs: &BTreeSet<String>,
777 value: &Value,
778 ) -> Result<(), AccessError> {
779 match value {
780 Value::Column(name) => check_qualified_read_column(table, rule, target_refs, name),
781 Value::Expr(expr) => self.check_outer_expr_column_refs(table, rule, target_refs, expr),
782 Value::Array(values) => {
783 for value in values {
784 self.check_outer_value_column_refs(table, rule, target_refs, value)?;
785 }
786 Ok(())
787 }
788 Value::Subquery(query) => {
789 self.check_outer_command_column_refs(table, rule, target_refs, query)
790 }
791 _ => Ok(()),
792 }
793 }
794
795 fn check_read_columns(
796 &self,
797 table: &str,
798 operation: AccessOperation,
799 columns: &[Expr],
800 ) -> Result<(), AccessError> {
801 let rule = self
802 .table_policy(table)
803 .map(|policy| &policy.read_columns)
804 .unwrap_or(&ColumnRule::Any);
805 check_projection_rule(table, operation, rule, columns, "read projection")
806 }
807
808 fn check_write_columns(
809 &self,
810 table: &str,
811 operation: AccessOperation,
812 columns: &[String],
813 ) -> Result<(), AccessError> {
814 let rule = self
815 .table_policy(table)
816 .map(|policy| &policy.write_columns)
817 .unwrap_or(&ColumnRule::Any);
818 if !rule.is_restrictive() {
819 return Ok(());
820 }
821 if columns.is_empty() {
822 return Err(AccessError::new(
823 table.to_string(),
824 Some(operation),
825 AccessErrorKind::ExplicitWriteColumnsRequired,
826 ));
827 }
828 for column in columns {
829 if !rule.allows(column) {
830 return Err(AccessError::new(
831 table.to_string(),
832 Some(operation),
833 AccessErrorKind::ColumnDenied {
834 column: normalize_column_name(column),
835 },
836 ));
837 }
838 }
839 Ok(())
840 }
841
842 fn check_returning_columns(&self, table: &str, columns: &[Expr]) -> Result<(), AccessError> {
843 let read_rule = self
844 .table_policy(table)
845 .map(|policy| &policy.read_columns)
846 .unwrap_or(&ColumnRule::Any);
847 let returning_rule = self
848 .table_policy(table)
849 .map(|policy| &policy.returning_columns)
850 .unwrap_or(&ColumnRule::Any);
851 check_projection_rule(
852 table,
853 AccessOperation::Read,
854 read_rule,
855 columns,
856 "returning projection",
857 )?;
858 check_projection_rule(
859 table,
860 AccessOperation::Read,
861 returning_rule,
862 columns,
863 "returning projection",
864 )
865 }
866
867 fn table_policy(&self, table: &str) -> Option<&TableAccessPolicy> {
868 self.tables.get(table).or_else(|| self.tables.get("*"))
869 }
870
871 fn check_embedded_queries(&self, ctx: &AccessContext, cmd: &Qail) -> Result<(), AccessError> {
872 for expr in &cmd.columns {
873 self.check_expr(ctx, expr)?;
874 }
875 if let Some(returning) = &cmd.returning {
876 for expr in returning {
877 self.check_expr(ctx, expr)?;
878 }
879 }
880 for cage in &cmd.cages {
881 for condition in &cage.conditions {
882 self.check_condition(ctx, condition)?;
883 }
884 }
885 for condition in &cmd.having {
886 self.check_condition(ctx, condition)?;
887 }
888 for join in &cmd.joins {
889 if let Some(conditions) = &join.on {
890 for condition in conditions {
891 self.check_condition(ctx, condition)?;
892 }
893 }
894 }
895 if let Some(on_conflict) = &cmd.on_conflict
896 && let ConflictAction::DoUpdate { assignments } = &on_conflict.action
897 {
898 for (_, expr) in assignments {
899 self.check_expr(ctx, expr)?;
900 }
901 }
902 if let Some(merge) = &cmd.merge {
903 for condition in &merge.on {
904 self.check_condition(ctx, condition)?;
905 }
906 for clause in &merge.clauses {
907 for condition in &clause.condition {
908 self.check_condition(ctx, condition)?;
909 }
910 match &clause.action {
911 MergeAction::Update { assignments } => {
912 for (_, expr) in assignments {
913 self.check_expr(ctx, expr)?;
914 }
915 }
916 MergeAction::Insert { values, .. } => {
917 for expr in values {
918 self.check_expr(ctx, expr)?;
919 }
920 }
921 MergeAction::Delete | MergeAction::DoNothing => {}
922 }
923 }
924 }
925 Ok(())
926 }
927
928 fn check_condition(
929 &self,
930 ctx: &AccessContext,
931 condition: &Condition,
932 ) -> Result<(), AccessError> {
933 self.check_expr(ctx, &condition.left)?;
934 self.check_value(ctx, &condition.value)
935 }
936
937 fn check_expr(&self, ctx: &AccessContext, expr: &Expr) -> Result<(), AccessError> {
938 match expr {
939 Expr::Cast { expr, .. }
940 | Expr::Mod { col: expr, .. }
941 | Expr::FieldAccess { expr, .. }
942 | Expr::Collate { expr, .. } => self.check_expr(ctx, expr),
943 Expr::Subscript { expr, index, .. } => {
944 self.check_expr(ctx, expr)?;
945 self.check_expr(ctx, index)
946 }
947 Expr::FunctionCall { args, .. } => {
948 for arg in args {
949 self.check_expr(ctx, arg)?;
950 }
951 Ok(())
952 }
953 Expr::SpecialFunction { args, .. } => {
954 for (_, arg) in args {
955 self.check_expr(ctx, arg)?;
956 }
957 Ok(())
958 }
959 Expr::Binary { left, right, .. } => {
960 self.check_expr(ctx, left)?;
961 self.check_expr(ctx, right)
962 }
963 Expr::Literal(value) => self.check_value(ctx, value),
964 Expr::ArrayConstructor { elements, .. } | Expr::RowConstructor { elements, .. } => {
965 for element in elements {
966 self.check_expr(ctx, element)?;
967 }
968 Ok(())
969 }
970 Expr::Case {
971 when_clauses,
972 else_value,
973 ..
974 } => {
975 for (condition, value) in when_clauses {
976 self.check_condition(ctx, condition)?;
977 self.check_expr(ctx, value)?;
978 }
979 if let Some(value) = else_value {
980 self.check_expr(ctx, value)?;
981 }
982 Ok(())
983 }
984 Expr::Window { params, order, .. } => {
985 for param in params {
986 self.check_expr(ctx, param)?;
987 }
988 for cage in order {
989 for condition in &cage.conditions {
990 self.check_condition(ctx, condition)?;
991 }
992 }
993 Ok(())
994 }
995 Expr::Aggregate { filter, .. } => {
996 if let Some(conditions) = filter {
997 for condition in conditions {
998 self.check_condition(ctx, condition)?;
999 }
1000 }
1001 Ok(())
1002 }
1003 Expr::Subquery { query, .. } | Expr::Exists { query, .. } => {
1004 self.check_command_inner(ctx, query)
1005 }
1006 Expr::Star
1007 | Expr::Named(_)
1008 | Expr::Aliased { .. }
1009 | Expr::Def { .. }
1010 | Expr::JsonAccess { .. } => Ok(()),
1011 }
1012 }
1013
1014 fn check_value(&self, ctx: &AccessContext, value: &Value) -> Result<(), AccessError> {
1015 match value {
1016 Value::Subquery(query) => self.check_command_inner(ctx, query),
1017 Value::Expr(expr) => self.check_expr(ctx, expr),
1018 Value::Array(values) => {
1019 for value in values {
1020 self.check_value(ctx, value)?;
1021 }
1022 Ok(())
1023 }
1024 _ => Ok(()),
1025 }
1026 }
1027}
1028
1029#[cfg(test)]
1030mod tests;