1use qail_core::access::{AccessContext, AccessError, AccessPolicy};
4use qail_core::ast::Qail;
5
6use super::{
7 AstPipelineMode, AutoCountPlan, PgDriver, PgError, PgPool, PgResult, PgRow, PooledConnection,
8 PreparedAstQuery, QueryResult, ResultFormat,
9};
10
11fn access_denied_error(err: AccessError) -> PgError {
12 PgError::Query(format!("Access denied by policy: {}", err))
13}
14
15fn check_access(policy: &AccessPolicy, ctx: &AccessContext, cmd: &Qail) -> PgResult<()> {
16 policy.check_command(ctx, cmd).map_err(access_denied_error)
17}
18
19fn check_all_access(policy: &AccessPolicy, ctx: &AccessContext, cmds: &[Qail]) -> PgResult<()> {
20 for cmd in cmds {
21 check_access(policy, ctx, cmd)?;
22 }
23 Ok(())
24}
25
26fn copy_export_table_command(table: &str, columns: &[String]) -> Qail {
27 Qail::export(table).columns(columns.iter().map(String::as_str))
28}
29
30impl PgDriver {
31 pub fn check_access(
33 &self,
34 cmd: &Qail,
35 access_ctx: &AccessContext,
36 access_policy: &AccessPolicy,
37 ) -> PgResult<()> {
38 check_access(access_policy, access_ctx, cmd)
39 }
40
41 pub async fn fetch_all_checked(
43 &mut self,
44 cmd: &Qail,
45 access_ctx: &AccessContext,
46 access_policy: &AccessPolicy,
47 ) -> PgResult<Vec<PgRow>> {
48 check_access(access_policy, access_ctx, cmd)?;
49 self.fetch_all(cmd).await
50 }
51
52 pub async fn fetch_all_with_format_checked(
54 &mut self,
55 cmd: &Qail,
56 result_format: ResultFormat,
57 access_ctx: &AccessContext,
58 access_policy: &AccessPolicy,
59 ) -> PgResult<Vec<PgRow>> {
60 check_access(access_policy, access_ctx, cmd)?;
61 self.fetch_all_with_format(cmd, result_format).await
62 }
63
64 pub async fn fetch_all_uncached_checked(
66 &mut self,
67 cmd: &Qail,
68 access_ctx: &AccessContext,
69 access_policy: &AccessPolicy,
70 ) -> PgResult<Vec<PgRow>> {
71 check_access(access_policy, access_ctx, cmd)?;
72 self.fetch_all_uncached(cmd).await
73 }
74
75 pub async fn fetch_all_uncached_with_format_checked(
77 &mut self,
78 cmd: &Qail,
79 result_format: ResultFormat,
80 access_ctx: &AccessContext,
81 access_policy: &AccessPolicy,
82 ) -> PgResult<Vec<PgRow>> {
83 check_access(access_policy, access_ctx, cmd)?;
84 self.fetch_all_uncached_with_format(cmd, result_format)
85 .await
86 }
87
88 pub async fn fetch_all_fast_checked(
90 &mut self,
91 cmd: &Qail,
92 access_ctx: &AccessContext,
93 access_policy: &AccessPolicy,
94 ) -> PgResult<Vec<PgRow>> {
95 check_access(access_policy, access_ctx, cmd)?;
96 self.fetch_all_fast(cmd).await
97 }
98
99 pub async fn fetch_all_fast_with_format_checked(
101 &mut self,
102 cmd: &Qail,
103 result_format: ResultFormat,
104 access_ctx: &AccessContext,
105 access_policy: &AccessPolicy,
106 ) -> PgResult<Vec<PgRow>> {
107 check_access(access_policy, access_ctx, cmd)?;
108 self.fetch_all_fast_with_format(cmd, result_format).await
109 }
110
111 pub async fn fetch_one_checked(
113 &mut self,
114 cmd: &Qail,
115 access_ctx: &AccessContext,
116 access_policy: &AccessPolicy,
117 ) -> PgResult<PgRow> {
118 check_access(access_policy, access_ctx, cmd)?;
119 self.fetch_one(cmd).await
120 }
121
122 pub async fn prepare_ast_query_checked(
128 &mut self,
129 cmd: &Qail,
130 access_ctx: &AccessContext,
131 access_policy: &AccessPolicy,
132 ) -> PgResult<PreparedAstQuery> {
133 check_access(access_policy, access_ctx, cmd)?;
134 self.prepare_ast_query(cmd).await
135 }
136
137 pub async fn fetch_typed_checked<T: super::row::QailRow>(
139 &mut self,
140 cmd: &Qail,
141 access_ctx: &AccessContext,
142 access_policy: &AccessPolicy,
143 ) -> PgResult<Vec<T>> {
144 check_access(access_policy, access_ctx, cmd)?;
145 self.fetch_typed(cmd).await
146 }
147
148 pub async fn fetch_typed_with_format_checked<T: super::row::QailRow>(
150 &mut self,
151 cmd: &Qail,
152 result_format: ResultFormat,
153 access_ctx: &AccessContext,
154 access_policy: &AccessPolicy,
155 ) -> PgResult<Vec<T>> {
156 check_access(access_policy, access_ctx, cmd)?;
157 self.fetch_typed_with_format(cmd, result_format).await
158 }
159
160 pub async fn fetch_one_typed_checked<T: super::row::QailRow>(
162 &mut self,
163 cmd: &Qail,
164 access_ctx: &AccessContext,
165 access_policy: &AccessPolicy,
166 ) -> PgResult<Option<T>> {
167 check_access(access_policy, access_ctx, cmd)?;
168 self.fetch_one_typed(cmd).await
169 }
170
171 pub async fn fetch_one_typed_with_format_checked<T: super::row::QailRow>(
173 &mut self,
174 cmd: &Qail,
175 result_format: ResultFormat,
176 access_ctx: &AccessContext,
177 access_policy: &AccessPolicy,
178 ) -> PgResult<Option<T>> {
179 check_access(access_policy, access_ctx, cmd)?;
180 self.fetch_one_typed_with_format(cmd, result_format).await
181 }
182
183 pub async fn execute_checked(
185 &mut self,
186 cmd: &Qail,
187 access_ctx: &AccessContext,
188 access_policy: &AccessPolicy,
189 ) -> PgResult<u64> {
190 check_access(access_policy, access_ctx, cmd)?;
191 self.execute(cmd).await
192 }
193
194 pub async fn copy_bulk_checked(
196 &mut self,
197 cmd: &Qail,
198 rows: &[Vec<qail_core::ast::Value>],
199 access_ctx: &AccessContext,
200 access_policy: &AccessPolicy,
201 ) -> PgResult<u64> {
202 check_access(access_policy, access_ctx, cmd)?;
203 self.copy_bulk(cmd, rows).await
204 }
205
206 pub async fn copy_bulk_bytes_checked(
208 &mut self,
209 cmd: &Qail,
210 data: &[u8],
211 access_ctx: &AccessContext,
212 access_policy: &AccessPolicy,
213 ) -> PgResult<u64> {
214 check_access(access_policy, access_ctx, cmd)?;
215 self.copy_bulk_bytes(cmd, data).await
216 }
217
218 pub async fn copy_export_table_checked(
220 &mut self,
221 table: &str,
222 columns: &[String],
223 access_ctx: &AccessContext,
224 access_policy: &AccessPolicy,
225 ) -> PgResult<Vec<u8>> {
226 check_access(
227 access_policy,
228 access_ctx,
229 ©_export_table_command(table, columns),
230 )?;
231 self.copy_export_table(table, columns).await
232 }
233
234 pub async fn copy_export_table_stream_checked<F, Fut>(
236 &mut self,
237 table: &str,
238 columns: &[String],
239 on_chunk: F,
240 access_ctx: &AccessContext,
241 access_policy: &AccessPolicy,
242 ) -> PgResult<()>
243 where
244 F: FnMut(Vec<u8>) -> Fut,
245 Fut: std::future::Future<Output = PgResult<()>>,
246 {
247 check_access(
248 access_policy,
249 access_ctx,
250 ©_export_table_command(table, columns),
251 )?;
252 self.copy_export_table_stream(table, columns, on_chunk)
253 .await
254 }
255
256 pub async fn copy_export_cmd_stream_checked<F, Fut>(
258 &mut self,
259 cmd: &Qail,
260 on_chunk: F,
261 access_ctx: &AccessContext,
262 access_policy: &AccessPolicy,
263 ) -> PgResult<()>
264 where
265 F: FnMut(Vec<u8>) -> Fut,
266 Fut: std::future::Future<Output = PgResult<()>>,
267 {
268 check_access(access_policy, access_ctx, cmd)?;
269 self.copy_export_cmd_stream(cmd, on_chunk).await
270 }
271
272 pub async fn copy_export_cmd_stream_rows_checked<F>(
274 &mut self,
275 cmd: &Qail,
276 on_row: F,
277 access_ctx: &AccessContext,
278 access_policy: &AccessPolicy,
279 ) -> PgResult<()>
280 where
281 F: FnMut(Vec<String>) -> PgResult<()>,
282 {
283 check_access(access_policy, access_ctx, cmd)?;
284 self.copy_export_cmd_stream_rows(cmd, on_row).await
285 }
286
287 pub async fn stream_cmd_checked(
289 &mut self,
290 cmd: &Qail,
291 batch_size: usize,
292 access_ctx: &AccessContext,
293 access_policy: &AccessPolicy,
294 ) -> PgResult<Vec<Vec<PgRow>>> {
295 check_access(access_policy, access_ctx, cmd)?;
296 self.stream_cmd(cmd, batch_size).await
297 }
298
299 pub async fn query_ast_checked(
301 &mut self,
302 cmd: &Qail,
303 access_ctx: &AccessContext,
304 access_policy: &AccessPolicy,
305 ) -> PgResult<QueryResult> {
306 check_access(access_policy, access_ctx, cmd)?;
307 self.query_ast(cmd).await
308 }
309
310 pub async fn query_ast_with_format_checked(
312 &mut self,
313 cmd: &Qail,
314 result_format: ResultFormat,
315 access_ctx: &AccessContext,
316 access_policy: &AccessPolicy,
317 ) -> PgResult<QueryResult> {
318 check_access(access_policy, access_ctx, cmd)?;
319 self.query_ast_with_format(cmd, result_format).await
320 }
321
322 pub async fn execute_batch_checked(
327 &mut self,
328 cmds: &[Qail],
329 access_ctx: &AccessContext,
330 access_policy: &AccessPolicy,
331 ) -> PgResult<Vec<u64>> {
332 check_all_access(access_policy, access_ctx, cmds)?;
333 self.execute_batch(cmds).await
334 }
335
336 pub async fn execute_count_auto_with_plan_checked(
338 &mut self,
339 cmds: &[Qail],
340 access_ctx: &AccessContext,
341 access_policy: &AccessPolicy,
342 ) -> PgResult<(usize, AutoCountPlan)> {
343 check_all_access(access_policy, access_ctx, cmds)?;
344 self.execute_count_auto_with_plan(cmds).await
345 }
346
347 pub async fn execute_count_auto_checked(
349 &mut self,
350 cmds: &[Qail],
351 access_ctx: &AccessContext,
352 access_policy: &AccessPolicy,
353 ) -> PgResult<usize> {
354 check_all_access(access_policy, access_ctx, cmds)?;
355 self.execute_count_auto(cmds).await
356 }
357
358 pub async fn pipeline_execute_count_with_mode_checked(
360 &mut self,
361 cmds: &[Qail],
362 mode: AstPipelineMode,
363 access_ctx: &AccessContext,
364 access_policy: &AccessPolicy,
365 ) -> PgResult<usize> {
366 check_all_access(access_policy, access_ctx, cmds)?;
367 self.pipeline_execute_count_with_mode(cmds, mode).await
368 }
369
370 pub async fn pipeline_execute_count_checked(
372 &mut self,
373 cmds: &[Qail],
374 access_ctx: &AccessContext,
375 access_policy: &AccessPolicy,
376 ) -> PgResult<usize> {
377 check_all_access(access_policy, access_ctx, cmds)?;
378 self.pipeline_execute_count(cmds).await
379 }
380
381 pub async fn pipeline_execute_rows_checked(
383 &mut self,
384 cmds: &[Qail],
385 access_ctx: &AccessContext,
386 access_policy: &AccessPolicy,
387 ) -> PgResult<Vec<Vec<PgRow>>> {
388 check_all_access(access_policy, access_ctx, cmds)?;
389 self.pipeline_execute_rows(cmds).await
390 }
391}
392
393impl PooledConnection {
394 pub fn check_access(
396 &self,
397 cmd: &Qail,
398 access_ctx: &AccessContext,
399 access_policy: &AccessPolicy,
400 ) -> PgResult<()> {
401 check_access(access_policy, access_ctx, cmd)
402 }
403
404 pub async fn fetch_all_cached_checked(
406 &mut self,
407 cmd: &Qail,
408 access_ctx: &AccessContext,
409 access_policy: &AccessPolicy,
410 ) -> PgResult<Vec<PgRow>> {
411 check_access(access_policy, access_ctx, cmd)?;
412 self.fetch_all_cached(cmd).await
413 }
414
415 pub async fn fetch_all_cached_with_format_checked(
417 &mut self,
418 cmd: &Qail,
419 result_format: ResultFormat,
420 access_ctx: &AccessContext,
421 access_policy: &AccessPolicy,
422 ) -> PgResult<Vec<PgRow>> {
423 check_access(access_policy, access_ctx, cmd)?;
424 self.fetch_all_cached_with_format(cmd, result_format).await
425 }
426
427 pub async fn fetch_all_uncached_checked(
429 &mut self,
430 cmd: &Qail,
431 access_ctx: &AccessContext,
432 access_policy: &AccessPolicy,
433 ) -> PgResult<Vec<PgRow>> {
434 check_access(access_policy, access_ctx, cmd)?;
435 self.fetch_all_uncached(cmd).await
436 }
437
438 pub async fn fetch_all_uncached_with_format_checked(
440 &mut self,
441 cmd: &Qail,
442 result_format: ResultFormat,
443 access_ctx: &AccessContext,
444 access_policy: &AccessPolicy,
445 ) -> PgResult<Vec<PgRow>> {
446 check_access(access_policy, access_ctx, cmd)?;
447 self.fetch_all_uncached_with_format(cmd, result_format)
448 .await
449 }
450
451 pub async fn fetch_all_fast_checked(
453 &mut self,
454 cmd: &Qail,
455 access_ctx: &AccessContext,
456 access_policy: &AccessPolicy,
457 ) -> PgResult<Vec<PgRow>> {
458 check_access(access_policy, access_ctx, cmd)?;
459 self.fetch_all_fast(cmd).await
460 }
461
462 pub async fn fetch_all_fast_with_format_checked(
464 &mut self,
465 cmd: &Qail,
466 result_format: ResultFormat,
467 access_ctx: &AccessContext,
468 access_policy: &AccessPolicy,
469 ) -> PgResult<Vec<PgRow>> {
470 check_access(access_policy, access_ctx, cmd)?;
471 self.fetch_all_fast_with_format(cmd, result_format).await
472 }
473
474 pub async fn fetch_all_with_rls_checked(
476 &mut self,
477 cmd: &Qail,
478 rls_sql: &str,
479 access_ctx: &AccessContext,
480 access_policy: &AccessPolicy,
481 ) -> PgResult<Vec<PgRow>> {
482 check_access(access_policy, access_ctx, cmd)?;
483 self.fetch_all_with_rls(cmd, rls_sql).await
484 }
485
486 pub async fn fetch_all_with_rls_with_format_checked(
488 &mut self,
489 cmd: &Qail,
490 rls_sql: &str,
491 result_format: ResultFormat,
492 access_ctx: &AccessContext,
493 access_policy: &AccessPolicy,
494 ) -> PgResult<Vec<PgRow>> {
495 check_access(access_policy, access_ctx, cmd)?;
496 self.fetch_all_with_rls_with_format(cmd, rls_sql, result_format)
497 .await
498 }
499
500 pub async fn copy_export_checked(
502 &mut self,
503 cmd: &Qail,
504 access_ctx: &AccessContext,
505 access_policy: &AccessPolicy,
506 ) -> PgResult<Vec<Vec<String>>> {
507 check_access(access_policy, access_ctx, cmd)?;
508 self.copy_export(cmd).await
509 }
510
511 pub async fn copy_export_stream_raw_checked<F, Fut>(
513 &mut self,
514 cmd: &Qail,
515 on_chunk: F,
516 access_ctx: &AccessContext,
517 access_policy: &AccessPolicy,
518 ) -> PgResult<()>
519 where
520 F: FnMut(Vec<u8>) -> Fut,
521 Fut: std::future::Future<Output = PgResult<()>>,
522 {
523 check_access(access_policy, access_ctx, cmd)?;
524 self.copy_export_stream_raw(cmd, on_chunk).await
525 }
526
527 pub async fn copy_export_stream_rows_checked<F>(
529 &mut self,
530 cmd: &Qail,
531 on_row: F,
532 access_ctx: &AccessContext,
533 access_policy: &AccessPolicy,
534 ) -> PgResult<()>
535 where
536 F: FnMut(Vec<String>) -> PgResult<()>,
537 {
538 check_access(access_policy, access_ctx, cmd)?;
539 self.copy_export_stream_rows(cmd, on_row).await
540 }
541
542 pub async fn copy_export_table_checked(
544 &mut self,
545 table: &str,
546 columns: &[String],
547 access_ctx: &AccessContext,
548 access_policy: &AccessPolicy,
549 ) -> PgResult<Vec<u8>> {
550 check_access(
551 access_policy,
552 access_ctx,
553 ©_export_table_command(table, columns),
554 )?;
555 self.copy_export_table(table, columns).await
556 }
557
558 pub async fn copy_export_table_stream_checked<F, Fut>(
560 &mut self,
561 table: &str,
562 columns: &[String],
563 on_chunk: F,
564 access_ctx: &AccessContext,
565 access_policy: &AccessPolicy,
566 ) -> PgResult<()>
567 where
568 F: FnMut(Vec<u8>) -> Fut,
569 Fut: std::future::Future<Output = PgResult<()>>,
570 {
571 check_access(
572 access_policy,
573 access_ctx,
574 ©_export_table_command(table, columns),
575 )?;
576 self.copy_export_table_stream(table, columns, on_chunk)
577 .await
578 }
579
580 pub async fn fetch_typed_checked<T: super::row::QailRow>(
582 &mut self,
583 cmd: &Qail,
584 access_ctx: &AccessContext,
585 access_policy: &AccessPolicy,
586 ) -> PgResult<Vec<T>> {
587 check_access(access_policy, access_ctx, cmd)?;
588 self.fetch_typed(cmd).await
589 }
590
591 pub async fn fetch_typed_with_format_checked<T: super::row::QailRow>(
593 &mut self,
594 cmd: &Qail,
595 result_format: ResultFormat,
596 access_ctx: &AccessContext,
597 access_policy: &AccessPolicy,
598 ) -> PgResult<Vec<T>> {
599 check_access(access_policy, access_ctx, cmd)?;
600 self.fetch_typed_with_format(cmd, result_format).await
601 }
602
603 pub async fn fetch_one_typed_checked<T: super::row::QailRow>(
605 &mut self,
606 cmd: &Qail,
607 access_ctx: &AccessContext,
608 access_policy: &AccessPolicy,
609 ) -> PgResult<Option<T>> {
610 check_access(access_policy, access_ctx, cmd)?;
611 self.fetch_one_typed(cmd).await
612 }
613
614 pub async fn fetch_one_typed_with_format_checked<T: super::row::QailRow>(
616 &mut self,
617 cmd: &Qail,
618 result_format: ResultFormat,
619 access_ctx: &AccessContext,
620 access_policy: &AccessPolicy,
621 ) -> PgResult<Option<T>> {
622 check_access(access_policy, access_ctx, cmd)?;
623 self.fetch_one_typed_with_format(cmd, result_format).await
624 }
625
626 pub async fn pipeline_execute_rows_ast_checked(
628 &mut self,
629 cmds: &[Qail],
630 access_ctx: &AccessContext,
631 access_policy: &AccessPolicy,
632 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
633 check_all_access(access_policy, access_ctx, cmds)?;
634 self.pipeline_execute_rows_ast(cmds).await
635 }
636}
637
638impl PgPool {
639 pub async fn execute_count_auto_with_plan_checked(
641 &self,
642 cmds: &[Qail],
643 access_ctx: &AccessContext,
644 access_policy: &AccessPolicy,
645 ) -> PgResult<(usize, AutoCountPlan)> {
646 check_all_access(access_policy, access_ctx, cmds)?;
647 self.execute_count_auto_with_plan(cmds).await
648 }
649
650 pub async fn execute_count_auto_checked(
652 &self,
653 cmds: &[Qail],
654 access_ctx: &AccessContext,
655 access_policy: &AccessPolicy,
656 ) -> PgResult<usize> {
657 check_all_access(access_policy, access_ctx, cmds)?;
658 self.execute_count_auto(cmds).await
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use qail_core::access::{
665 AccessContext, AccessOperation, AccessPolicy, ColumnRule, TableAccessPolicy,
666 };
667 use qail_core::ast::{Expr, Qail};
668
669 use super::{check_access, check_all_access, copy_export_table_command};
670 use crate::driver::PgError;
671
672 #[test]
673 fn checked_pg_error_uses_existing_query_variant() {
674 let err = check_access(
675 &AccessPolicy::new(),
676 &AccessContext::anonymous(),
677 &Qail::get("orders"),
678 )
679 .expect_err("missing policy should fail closed");
680
681 match err {
682 PgError::Query(message) => {
683 assert!(message.contains("Access denied by policy"));
684 assert!(message.contains("orders"));
685 }
686 other => panic!("unexpected error variant: {other:?}"),
687 }
688 }
689
690 #[test]
691 fn checked_batch_rejects_denied_later_command_before_execution() {
692 let policy = AccessPolicy::new().with_table(
693 "orders",
694 TableAccessPolicy::new()
695 .allow_operations([AccessOperation::Read])
696 .read_columns(ColumnRule::only(["id"])),
697 );
698 let cmds = vec![
699 Qail::get("orders").columns(["id"]),
700 Qail::get("orders").columns(["id", "private_note"]),
701 ];
702
703 let err = check_all_access(&policy, &AccessContext::anonymous(), &cmds)
704 .expect_err("second command should deny before any wrapper executes");
705
706 assert!(matches!(err, PgError::Query(_)));
707 }
708
709 #[test]
710 fn checked_policy_recurses_into_subqueries() {
711 let policy = AccessPolicy::new().with_table(
712 "orders",
713 TableAccessPolicy::new().allow_operations([AccessOperation::Read]),
714 );
715 let cmd = Qail::get("orders").columns_expr([Expr::Subquery {
716 query: Box::new(Qail::get("users").columns(["id"])),
717 alias: None,
718 }]);
719
720 let err = check_access(&policy, &AccessContext::anonymous(), &cmd)
721 .expect_err("subquery table should require its own policy");
722
723 match err {
724 PgError::Query(message) => assert!(message.contains("users")),
725 other => panic!("unexpected error variant: {other:?}"),
726 }
727 }
728
729 #[test]
730 fn checked_copy_export_table_command_uses_read_column_policy() {
731 let policy = AccessPolicy::new().with_table(
732 "orders",
733 TableAccessPolicy::new()
734 .allow_operations([AccessOperation::Read])
735 .read_columns(ColumnRule::only(["id"])),
736 );
737 let columns = vec!["id".to_string(), "private_note".to_string()];
738 let cmd = copy_export_table_command("orders", &columns);
739
740 let err = check_access(&policy, &AccessContext::anonymous(), &cmd)
741 .expect_err("denied COPY export column should fail before execution");
742
743 match err {
744 PgError::Query(message) => assert!(message.contains("private_note")),
745 other => panic!("unexpected error variant: {other:?}"),
746 }
747 }
748}