kimberlite_query/rbac_filter.rs
1//! RBAC query filtering and rewriting.
2//!
3//! This module provides query rewriting to enforce RBAC policies:
4//! - **Column filtering**: Remove unauthorized columns from SELECT
5//! - **Row-level security**: Inject WHERE clauses
6//!
7//! ## Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────────┐
11//! │ Original Query │
12//! │ SELECT name, ssn FROM patients │
13//! └───────────────┬─────────────────────┘
14//! │
15//! ▼
16//! ┌─────────────────────────────────────┐
17//! │ RBAC Filter │
18//! │ - Check stream access │
19//! │ - Filter columns (remove "ssn") │
20//! │ - Inject WHERE clause │
21//! └───────────────┬─────────────────────┘
22//! │
23//! ▼
24//! ┌─────────────────────────────────────┐
25//! │ Rewritten Query │
26//! │ SELECT name FROM patients │
27//! │ WHERE tenant_id = 42 │
28//! └─────────────────────────────────────┘
29//! ```
30
31use crate::error::QueryError;
32use kimberlite_rbac::{AccessPolicy, enforcement::PolicyEnforcer};
33use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
34use thiserror::Error;
35use tracing::{debug, info, warn};
36
37/// Error type for RBAC filtering.
38#[derive(Debug, Error)]
39pub enum RbacError {
40 /// Access denied by policy.
41 #[error("Access denied: {0}")]
42 AccessDenied(String),
43
44 /// No authorized columns in query.
45 #[error("No authorized columns in query")]
46 NoAuthorizedColumns,
47
48 /// Unsupported query type for RBAC.
49 #[error("Unsupported query type: {0}")]
50 UnsupportedQuery(String),
51
52 /// Policy enforcement failed.
53 #[error("Policy enforcement failed: {0}")]
54 EnforcementFailed(String),
55}
56
57impl From<kimberlite_rbac::enforcement::EnforcementError> for RbacError {
58 fn from(err: kimberlite_rbac::enforcement::EnforcementError) -> Self {
59 match err {
60 kimberlite_rbac::enforcement::EnforcementError::AccessDenied { reason } => {
61 RbacError::AccessDenied(reason)
62 }
63 _ => RbacError::EnforcementFailed(err.to_string()),
64 }
65 }
66}
67
68impl From<RbacError> for QueryError {
69 fn from(err: RbacError) -> Self {
70 QueryError::UnsupportedFeature(err.to_string())
71 }
72}
73
74/// Result type for RBAC operations.
75pub type Result<T> = std::result::Result<T, RbacError>;
76
77/// Output of [`RbacFilter::rewrite_statement`].
78///
79/// Carries the rewritten statement alongside the alias mapping derived
80/// from the original projection. Downstream code (e.g. the masking pass
81/// in `kimberlite`) uses the mapping to resolve output column names
82/// back to their source columns so masks are applied to the underlying
83/// sensitive attribute, not the user-visible alias.
84#[derive(Debug)]
85pub struct RewriteOutput {
86 /// The rewritten SQL statement.
87 pub statement: Statement,
88 /// Pairs of `(output_column_name, source_column_name)` for each
89 /// projection item that survived RBAC filtering.
90 ///
91 /// Bare identifiers produce pairs where both entries are equal.
92 /// Aliased identifiers (`SELECT ssn AS id`) produce distinct
93 /// output/source entries — the masking pass must key its lookup
94 /// on the source entry (AUDIT-2026-04 M-7).
95 pub column_aliases: Vec<(String, String)>,
96}
97
98/// RBAC query filter.
99///
100/// Rewrites SQL queries to enforce access control policies.
101pub struct RbacFilter {
102 enforcer: PolicyEnforcer,
103}
104
105impl RbacFilter {
106 /// Creates a new RBAC filter with the given policy.
107 pub fn new(policy: AccessPolicy) -> Self {
108 Self {
109 enforcer: PolicyEnforcer::new(policy),
110 }
111 }
112
113 /// Rewrites a SQL statement to enforce RBAC policy.
114 ///
115 /// **Transformations:**
116 /// 1. Check stream access (deny if unauthorized)
117 /// 2. Filter SELECT columns (remove unauthorized columns)
118 /// 3. Inject WHERE clause for row-level security
119 ///
120 /// # Arguments
121 ///
122 /// * `stmt` - SQL statement to rewrite
123 ///
124 /// # Returns
125 ///
126 /// Rewritten statement plus a map of `(output_column_name,
127 /// source_column_name)` pairs — one entry per projection item that
128 /// survived RBAC filtering. The masking pass uses this map to look
129 /// up column masks by source column rather than by the
130 /// potentially-aliased output name (AUDIT-2026-04 M-7).
131 ///
132 /// # Errors
133 ///
134 /// - `AccessDenied` if stream access is denied
135 /// - `NoAuthorizedColumns` if all columns are unauthorized
136 /// - `UnsupportedQuery` if query type is not supported
137 pub fn rewrite_statement(&self, mut stmt: Statement) -> Result<RewriteOutput> {
138 match &mut stmt {
139 Statement::Query(query) => {
140 let column_aliases = self.rewrite_query(query)?;
141 Ok(RewriteOutput {
142 statement: stmt,
143 column_aliases,
144 })
145 }
146 _ => Err(RbacError::UnsupportedQuery(
147 "Only SELECT queries are currently supported".to_string(),
148 )),
149 }
150 }
151
152 /// Rewrites a query to enforce RBAC.
153 ///
154 /// **AUDIT-2026-04 M-7 — recursive traversal.** Prior to this
155 /// change, only the top-level `SetExpr::Select` was rewritten,
156 /// so a predicate like
157 /// `SELECT id FROM t WHERE x IN (SELECT ssn FROM patients)`
158 /// would bypass column filtering on `ssn`. The recursive walk
159 /// below ensures every nested `Query` (CTE / UNION / subquery
160 /// in FROM / subquery in WHERE) is rewritten under the same
161 /// policy before the outer select is processed.
162 fn rewrite_query(&self, query: &mut Query) -> Result<Vec<(String, String)>> {
163 // 1. Rewrite CTEs first — later referenced by name in the
164 // main set-expression, so their filtering must land
165 // before the outer select reads them.
166 if let Some(with) = query.with.as_mut() {
167 for cte in with.cte_tables.iter_mut() {
168 // CTEs themselves cannot leak if the outer select
169 // never references the denied columns — but we
170 // rewrite defensively so that any CTE reference
171 // through `SELECT * FROM cte_name` (once wildcard
172 // support lands) does not expose masked sources.
173 let _ = self.rewrite_query(cte.query.as_mut())?;
174 }
175 }
176
177 // 2. Dispatch on set-expression shape.
178 self.rewrite_set_expr(query.body.as_mut())
179 }
180
181 /// Recursively rewrites a `SetExpr`, returning the column
182 /// lineage for the *representative* select (the left-most
183 /// branch of a UNION, or the inner select of a parenthesised
184 /// query).
185 ///
186 /// UNION branches must all satisfy the policy independently —
187 /// if any branch references a denied column, the whole query
188 /// is rejected.
189 fn rewrite_set_expr(&self, set_expr: &mut SetExpr) -> Result<Vec<(String, String)>> {
190 match set_expr {
191 SetExpr::Select(select) => self.rewrite_select(select),
192 // Parenthesised query — recurse.
193 SetExpr::Query(inner) => self.rewrite_query(inner.as_mut()),
194 // UNION / INTERSECT / EXCEPT — every branch must pass
195 // RBAC independently. The outer lineage comes from the
196 // left branch (all branches must have compatible
197 // column counts, so either branch's lineage is a valid
198 // descriptor; we use left for determinism).
199 SetExpr::SetOperation { left, right, .. } => {
200 let left_lineage = self.rewrite_set_expr(left.as_mut())?;
201 let _right_lineage = self.rewrite_set_expr(right.as_mut())?;
202 Ok(left_lineage)
203 }
204 _ => Err(RbacError::UnsupportedQuery(format!(
205 "unsupported set-expression: {set_expr:?}"
206 ))),
207 }
208 }
209
210 /// Rewrites a SELECT statement. Returns the `(output, source)`
211 /// column pairs for the surviving projection items.
212 fn rewrite_select(&self, select: &mut Select) -> Result<Vec<(String, String)>> {
213 // AUDIT-2026-04 M-7 — subquery / nested-SELECT recursion.
214 //
215 // Step 0a: rewrite any `TableFactor::Derived { subquery }`
216 // in the FROM clause. A predicate that reads
217 // `SELECT outer.x FROM (SELECT ssn AS x FROM patients) outer`
218 // was previously accepted because `extract_stream_name` only
219 // saw the outer derived-table reference — the inner SELECT
220 // was never filtered against the `patients.ssn` deny policy.
221 // Now the inner SELECT is rewritten first; if it references
222 // a denied column it errors out here, before any outer
223 // lineage is reported.
224 for table_with_joins in select.from.iter_mut() {
225 self.rewrite_table_factor(&mut table_with_joins.relation)?;
226 for join in table_with_joins.joins.iter_mut() {
227 self.rewrite_table_factor(&mut join.relation)?;
228 }
229 }
230
231 // Step 0b: rewrite subqueries inside the WHERE clause.
232 // Handles `IN (SELECT ...)`, `EXISTS (SELECT ...)`, and
233 // scalar-subquery forms. The traversal is read-mutable
234 // because the inner rewrite replaces column projections.
235 if let Some(ref mut selection) = select.selection {
236 self.rewrite_expr_subqueries(selection)?;
237 }
238
239 // 1. Extract stream name from FROM clause
240 let stream_name = Self::extract_stream_name(select)?;
241
242 debug!(stream = %stream_name, "Extracting columns from SELECT");
243
244 // 2. Extract requested columns (source names) and aliases
245 let aliases = Self::extract_column_aliases(select)?;
246 let requested_columns: Vec<String> =
247 aliases.iter().map(|(_, src)| src.clone()).collect();
248
249 info!(
250 stream = %stream_name,
251 columns = ?requested_columns,
252 "Enforcing RBAC policy"
253 );
254
255 // 3. Enforce policy (checks stream access + filters columns)
256 let (allowed_columns, where_clause_sql) = self
257 .enforcer
258 .enforce_query(&stream_name, &requested_columns)?;
259
260 if allowed_columns.is_empty() {
261 warn!(stream = %stream_name, "No authorized columns");
262 return Err(RbacError::NoAuthorizedColumns);
263 }
264
265 // 4. Rewrite SELECT projection (filter columns)
266 Self::rewrite_projection(select, &allowed_columns);
267
268 // 5. Inject WHERE clause for row-level security
269 if !where_clause_sql.is_empty() {
270 Self::inject_where_clause(select, &where_clause_sql)?;
271 }
272
273 info!(
274 stream = %stream_name,
275 allowed_columns = ?allowed_columns,
276 where_clause = %where_clause_sql,
277 "Query rewritten successfully"
278 );
279
280 // 6. Trim the alias map to the surviving projection.
281 let allowed: std::collections::HashSet<&str> =
282 allowed_columns.iter().map(String::as_str).collect();
283 let surviving_aliases = aliases
284 .into_iter()
285 .filter(|(_, src)| allowed.contains(src.as_str()))
286 .collect();
287
288 Ok(surviving_aliases)
289 }
290
291 /// AUDIT-2026-04 M-7 helper — recurse into nested queries
292 /// carried by a `TableFactor`.
293 ///
294 /// `TableFactor::Derived { subquery }` is the AST node for
295 /// `FROM (SELECT ...)`. `TableFactor::NestedJoin` wraps a
296 /// `TableWithJoins` that may itself contain derived tables.
297 /// Anything else is a terminal table reference handled by
298 /// `extract_stream_name` downstream.
299 fn rewrite_table_factor(&self, factor: &mut TableFactor) -> Result<()> {
300 match factor {
301 TableFactor::Derived { subquery, .. } => {
302 self.rewrite_query(subquery.as_mut())?;
303 Ok(())
304 }
305 TableFactor::NestedJoin {
306 table_with_joins, ..
307 } => {
308 self.rewrite_table_factor(&mut table_with_joins.relation)?;
309 for join in table_with_joins.joins.iter_mut() {
310 self.rewrite_table_factor(&mut join.relation)?;
311 }
312 Ok(())
313 }
314 _ => Ok(()),
315 }
316 }
317
318 /// AUDIT-2026-04 M-7 helper — recurse into subqueries embedded
319 /// in a WHERE-clause `Expr`.
320 ///
321 /// Walks `Expr::Subquery`, `Expr::InSubquery`, `Expr::Exists`,
322 /// and combinators (`BinaryOp`, `UnaryOp`, `Nested`) that can
323 /// transport a subquery in their children. Non-subquery leaves
324 /// (identifiers, literals) are terminal.
325 ///
326 /// A bounded-depth guard would belong here if the recursive
327 /// kernel principle forbade it; the query parser already
328 /// rejects SQL with unbounded expression depth before reaching
329 /// this point, so we rely on the sqlparser limit.
330 fn rewrite_expr_subqueries(&self, expr: &mut Expr) -> Result<()> {
331 match expr {
332 Expr::Subquery(q) | Expr::Exists { subquery: q, .. } => {
333 self.rewrite_query(q.as_mut())?;
334 Ok(())
335 }
336 Expr::InSubquery { subquery, expr: inner, .. } => {
337 self.rewrite_expr_subqueries(inner.as_mut())?;
338 self.rewrite_query(subquery.as_mut())?;
339 Ok(())
340 }
341 Expr::BinaryOp { left, right, .. } => {
342 self.rewrite_expr_subqueries(left.as_mut())?;
343 self.rewrite_expr_subqueries(right.as_mut())
344 }
345 Expr::UnaryOp { expr: inner, .. } => self.rewrite_expr_subqueries(inner.as_mut()),
346 Expr::Nested(inner) => self.rewrite_expr_subqueries(inner.as_mut()),
347 Expr::InList { expr: inner, list, .. } => {
348 self.rewrite_expr_subqueries(inner.as_mut())?;
349 for item in list.iter_mut() {
350 self.rewrite_expr_subqueries(item)?;
351 }
352 Ok(())
353 }
354 Expr::Between {
355 expr: inner,
356 low,
357 high,
358 ..
359 } => {
360 self.rewrite_expr_subqueries(inner.as_mut())?;
361 self.rewrite_expr_subqueries(low.as_mut())?;
362 self.rewrite_expr_subqueries(high.as_mut())
363 }
364 Expr::Case {
365 conditions,
366 results,
367 else_result,
368 ..
369 } => {
370 for c in conditions.iter_mut() {
371 self.rewrite_expr_subqueries(c)?;
372 }
373 for r in results.iter_mut() {
374 self.rewrite_expr_subqueries(r)?;
375 }
376 if let Some(else_r) = else_result.as_mut() {
377 self.rewrite_expr_subqueries(else_r.as_mut())?;
378 }
379 Ok(())
380 }
381 // Identifiers, literals, function calls without subquery
382 // arguments, etc. — nothing to rewrite.
383 _ => Ok(()),
384 }
385 }
386
387 /// Extracts the stream name from the FROM clause.
388 fn extract_stream_name(select: &Select) -> Result<String> {
389 if select.from.is_empty() {
390 return Err(RbacError::UnsupportedQuery(
391 "SELECT without FROM clause".to_string(),
392 ));
393 }
394
395 let table = &select.from[0];
396 match &table.relation {
397 TableFactor::Table { name, .. } => {
398 let stream_name = name
399 .0
400 .iter()
401 .map(|i| i.value.as_str())
402 .collect::<Vec<_>>()
403 .join(".");
404 Ok(stream_name)
405 }
406 _ => Err(RbacError::UnsupportedQuery(
407 "Only simple table references are supported".to_string(),
408 )),
409 }
410 }
411
412 /// Extracts `(output_column_name, source_column_name)` pairs for
413 /// each item in the SELECT projection. See [`column_aliases`] for
414 /// the free-function entry point used by the SQL-level mask pass.
415 fn extract_column_aliases(select: &Select) -> Result<Vec<(String, String)>> {
416 column_aliases_from_select(select)
417 }
418}
419
420/// Extracts `(output_column_name, source_column_name)` pairs for each
421/// item in the SELECT projection of `stmt`.
422///
423/// Returns an empty vector for non-`SELECT` statements or for set-expr
424/// bodies that are not a plain `SELECT` (e.g. `UNION`) — the masking
425/// pass treats an empty map as "no aliases known" and falls back to
426/// output-name keying, matching pre-M-7 semantics for those shapes.
427///
428/// Semantics:
429/// - `SELECT col` → `("col", "col")`
430/// - `SELECT col AS alias` → `("alias", "col")`
431/// - `SELECT UPPER(col) AS alias` → `("alias", "alias")` (non-identifier
432/// expressions cannot be resolved to a source column — mask lookup
433/// keys on the alias, mirroring the pre-M-7 behaviour).
434///
435/// AUDIT-2026-04 M-7: the masking pass uses the source half of the
436/// pair to look up column masks. Without this, `SELECT ssn AS id FROM
437/// patients` passed RBAC (source `ssn` is permitted) but
438/// `mask_for_column("id")` returned `None`, leaking the masked
439/// attribute under a rename.
440pub fn column_aliases(stmt: &Statement) -> Vec<(String, String)> {
441 let Statement::Query(query) = stmt else {
442 return Vec::new();
443 };
444 let SetExpr::Select(select) = query.body.as_ref() else {
445 return Vec::new();
446 };
447 column_aliases_from_select(select).unwrap_or_default()
448}
449
450fn column_aliases_from_select(select: &Select) -> Result<Vec<(String, String)>> {
451 let mut pairs = Vec::new();
452
453 for item in &select.projection {
454 match item {
455 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
456 pairs.push((ident.value.clone(), ident.value.clone()));
457 }
458 SelectItem::ExprWithAlias { expr, alias } => {
459 if let Expr::Identifier(ident) = expr {
460 pairs.push((alias.value.clone(), ident.value.clone()));
461 } else {
462 pairs.push((alias.value.clone(), alias.value.clone()));
463 }
464 }
465 SelectItem::Wildcard(_) => {
466 return Err(RbacError::UnsupportedQuery(
467 "SELECT * is not yet supported with RBAC".to_string(),
468 ));
469 }
470 _ => {
471 return Err(RbacError::UnsupportedQuery(format!(
472 "Unsupported SELECT item: {item:?}"
473 )));
474 }
475 }
476 }
477
478 Ok(pairs)
479}
480
481impl RbacFilter {
482
483 /// Rewrites the SELECT projection to include only allowed columns.
484 fn rewrite_projection(select: &mut Select, allowed_columns: &[String]) {
485 let allowed_set: std::collections::HashSet<_> = allowed_columns.iter().collect();
486
487 select.projection.retain(|item| match item {
488 SelectItem::UnnamedExpr(Expr::Identifier(ident))
489 | SelectItem::ExprWithAlias {
490 expr: Expr::Identifier(ident),
491 ..
492 } => allowed_set.contains(&ident.value),
493 _ => false,
494 });
495 }
496
497 /// Injects a WHERE clause for row-level security.
498 fn inject_where_clause(select: &mut Select, where_clause_sql: &str) -> Result<()> {
499 // Parse the WHERE clause SQL into an Expr
500 let where_expr = Self::parse_where_clause(where_clause_sql)?;
501
502 // Combine with existing WHERE clause (if any)
503 select.selection = match select.selection.take() {
504 Some(existing) => Some(Expr::BinaryOp {
505 left: Box::new(existing),
506 op: sqlparser::ast::BinaryOperator::And,
507 right: Box::new(where_expr),
508 }),
509 None => Some(where_expr),
510 };
511
512 Ok(())
513 }
514
515 /// Parses a WHERE clause SQL string into an Expr.
516 ///
517 /// # Security boundary
518 ///
519 /// This function is **only ever called with trusted `RowFilter` values** generated
520 /// internally by the RBAC policy engine (see [`PolicyEnforcer::row_filter`]).
521 /// It is **not** called with user-supplied SQL strings and is therefore not a
522 /// SQL-injection vector. If you ever call this with data derived from user input,
523 /// you MUST validate/sanitize the input first.
524 ///
525 /// The parser handles `column = value` predicates joined by `AND`. It produces
526 /// AST nodes directly (not concatenated SQL), so the result is safe to pass to
527 /// the query planner without further escaping.
528 ///
529 /// More complex predicates may require the full SQL parser.
530 fn parse_where_clause(where_clause_sql: &str) -> Result<Expr> {
531 // Simple parser for "column = value" and "column1 = value1 AND column2 = value2".
532 // SAFETY: Only called with trusted, internally-generated RowFilter strings.
533 let parts: Vec<&str> = where_clause_sql.split(" AND ").collect();
534
535 let mut exprs = Vec::new();
536
537 for part in parts {
538 // Parse "column = value"
539 let tokens: Vec<&str> = part.trim().split('=').collect();
540 if tokens.len() != 2 {
541 return Err(RbacError::UnsupportedQuery(format!(
542 "Invalid WHERE clause: {part}"
543 )));
544 }
545
546 let column = tokens[0].trim();
547 let value = tokens[1].trim();
548
549 let expr = Expr::BinaryOp {
550 left: Box::new(Expr::Identifier(sqlparser::ast::Ident::new(column))),
551 op: sqlparser::ast::BinaryOperator::Eq,
552 right: Box::new(Expr::Value(sqlparser::ast::Value::Number(
553 value.to_string(),
554 false,
555 ))),
556 };
557
558 exprs.push(expr);
559 }
560
561 // Combine with AND
562 let mut iter = exprs.into_iter();
563 let mut result = iter
564 .next()
565 .ok_or_else(|| RbacError::UnsupportedQuery("Empty WHERE clause".to_string()))?;
566
567 for expr in iter {
568 result = Expr::BinaryOp {
569 left: Box::new(result),
570 op: sqlparser::ast::BinaryOperator::And,
571 right: Box::new(expr),
572 };
573 }
574
575 Ok(result)
576 }
577
578 /// Returns the underlying policy enforcer.
579 pub fn enforcer(&self) -> &PolicyEnforcer {
580 &self.enforcer
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use kimberlite_rbac::policy::StandardPolicies;
588 use kimberlite_types::TenantId;
589 use sqlparser::dialect::GenericDialect;
590 use sqlparser::parser::Parser;
591
592 fn parse_sql(sql: &str) -> Statement {
593 let dialect = GenericDialect {};
594 let statements = Parser::parse_sql(&dialect, sql).expect("Failed to parse SQL");
595 statements.into_iter().next().expect("No statement parsed")
596 }
597
598 #[test]
599 fn test_rewrite_admin_policy() {
600 let policy = StandardPolicies::admin();
601 let filter = RbacFilter::new(policy);
602
603 let sql = "SELECT name, email FROM users";
604 let stmt = parse_sql(sql);
605
606 let result = filter.rewrite_statement(stmt);
607 assert!(result.is_ok());
608 }
609
610 #[test]
611 fn test_rewrite_user_policy_column_filter() {
612 let policy = kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
613 .allow_stream("users")
614 .allow_column("name")
615 .deny_column("ssn");
616
617 let filter = RbacFilter::new(policy);
618
619 let sql = "SELECT name, ssn FROM users";
620 let stmt = parse_sql(sql);
621
622 let result = filter.rewrite_statement(stmt);
623 assert!(result.is_ok());
624
625 // Check that ssn was filtered out
626 if let Statement::Query(query) = result.unwrap().statement {
627 if let SetExpr::Select(select) = query.body.as_ref() {
628 assert_eq!(select.projection.len(), 1);
629 // Should only have "name" column
630 }
631 }
632 }
633
634 #[test]
635 fn test_rewrite_with_row_filter() {
636 let tenant_id = TenantId::new(42);
637 let policy = StandardPolicies::user(tenant_id);
638 let filter = RbacFilter::new(policy);
639
640 let sql = "SELECT name, email FROM users";
641 let stmt = parse_sql(sql);
642
643 let result = filter.rewrite_statement(stmt);
644 assert!(result.is_ok());
645
646 // Check that WHERE clause was injected
647 if let Statement::Query(query) = result.unwrap().statement {
648 if let SetExpr::Select(select) = query.body.as_ref() {
649 assert!(select.selection.is_some());
650 // Should have WHERE tenant_id = 42
651 }
652 }
653 }
654
655 #[test]
656 fn test_rewrite_access_denied() {
657 let policy = StandardPolicies::auditor();
658 let filter = RbacFilter::new(policy);
659
660 let sql = "SELECT name FROM users"; // Auditor cannot access users table
661 let stmt = parse_sql(sql);
662
663 let result = filter.rewrite_statement(stmt);
664 assert!(result.is_err());
665 assert!(matches!(result.unwrap_err(), RbacError::AccessDenied(_)));
666 }
667
668 #[test]
669 fn test_rewrite_no_authorized_columns() {
670 let policy = kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
671 .allow_stream("users")
672 .deny_column("*"); // Deny all columns
673
674 let filter = RbacFilter::new(policy);
675
676 let sql = "SELECT name, email FROM users";
677 let stmt = parse_sql(sql);
678
679 let result = filter.rewrite_statement(stmt);
680 assert!(result.is_err());
681 let err = result.unwrap_err();
682 assert!(
683 matches!(err, RbacError::AccessDenied(ref msg) if msg.contains("No authorized columns"))
684 );
685 }
686
687 // -----------------------------------------------------------------
688 // AUDIT-2026-04 M-7 — subquery / nested-SELECT RBAC enforcement.
689 //
690 // Before this fix, `rewrite_statement` only processed the
691 // top-level SELECT. A predicate like
692 // SELECT id FROM orders WHERE customer IN (SELECT ssn FROM users)
693 // passed through untouched because the inner SELECT was never
694 // visited; `ssn` was exposed despite the user's `deny_column`.
695 //
696 // These tests pin that every nested Query (WHERE IN, EXISTS,
697 // derived table in FROM, UNION branch) is rewritten under the
698 // same policy.
699 // -----------------------------------------------------------------
700
701 fn user_denies_ssn_policy() -> kimberlite_rbac::policy::AccessPolicy {
702 // `users` stream is fully accessible on the allow-list, but
703 // `ssn` is explicitly denied. Any nested reference to
704 // `ssn` must be rejected by the recursive walk.
705 kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
706 .allow_stream("users")
707 .allow_stream("orders")
708 .allow_column("name")
709 .allow_column("email")
710 .allow_column("customer")
711 .allow_column("id")
712 .deny_column("ssn")
713 }
714
715 #[test]
716 fn subquery_rbac_in_where_clause_enforces_inner_grants() {
717 // AUDIT-2026-04 M-7 regression test. Prior to the fix, this
718 // returned `Ok(_)` — `ssn` was never seen by the enforcer.
719 // After the fix, the inner SELECT is rewritten, and since
720 // `ssn` is denied + the inner projection has no other
721 // allowed columns, the whole query is rejected.
722 let filter = RbacFilter::new(user_denies_ssn_policy());
723 let sql =
724 "SELECT id FROM orders WHERE customer IN (SELECT ssn FROM users)";
725 let stmt = parse_sql(sql);
726 let result = filter.rewrite_statement(stmt);
727 assert!(
728 result.is_err(),
729 "nested subquery referencing denied column must be rejected"
730 );
731 }
732
733 #[test]
734 fn subquery_rbac_exists_clause_recurses() {
735 // EXISTS subqueries are rewritten too.
736 let filter = RbacFilter::new(user_denies_ssn_policy());
737 let sql =
738 "SELECT id FROM orders WHERE EXISTS (SELECT ssn FROM users)";
739 let stmt = parse_sql(sql);
740 let result = filter.rewrite_statement(stmt);
741 assert!(
742 result.is_err(),
743 "EXISTS-subquery referencing denied column must be rejected"
744 );
745 }
746
747 #[test]
748 fn subquery_rbac_derived_table_in_from_recurses() {
749 // Derived-table subquery in FROM clause.
750 let filter = RbacFilter::new(user_denies_ssn_policy());
751 let sql =
752 "SELECT t.email FROM (SELECT ssn FROM users) t";
753 let stmt = parse_sql(sql);
754 let result = filter.rewrite_statement(stmt);
755 assert!(
756 result.is_err(),
757 "derived-table SELECT referencing denied column must be rejected"
758 );
759 }
760
761 #[test]
762 fn subquery_rbac_union_both_branches_checked() {
763 // UNION — both branches must pass RBAC. The left branch
764 // asks for `ssn` (denied) → whole query rejected.
765 let filter = RbacFilter::new(user_denies_ssn_policy());
766 let sql =
767 "SELECT ssn FROM users UNION SELECT name FROM users";
768 let stmt = parse_sql(sql);
769 let result = filter.rewrite_statement(stmt);
770 assert!(
771 result.is_err(),
772 "UNION branch referencing denied column must be rejected"
773 );
774 }
775
776 #[test]
777 fn subquery_rbac_allowed_subquery_still_succeeds() {
778 // Sanity check: a subquery that references only allowed
779 // columns is unaffected — the M-7 fix must not introduce
780 // false-positive rejections.
781 let filter = RbacFilter::new(user_denies_ssn_policy());
782 let sql =
783 "SELECT id FROM orders WHERE customer IN (SELECT name FROM users)";
784 let stmt = parse_sql(sql);
785 let result = filter.rewrite_statement(stmt);
786 assert!(
787 result.is_ok(),
788 "all-allowed subquery must pass, got: {:?}",
789 result.err()
790 );
791 }
792
793 #[test]
794 fn subquery_rbac_cte_with_denied_column_rejected() {
795 // CTEs are rewritten before the outer select reads them.
796 let filter = RbacFilter::new(user_denies_ssn_policy());
797 let sql = "WITH u AS (SELECT ssn FROM users) SELECT id FROM orders";
798 let stmt = parse_sql(sql);
799 let result = filter.rewrite_statement(stmt);
800 assert!(
801 result.is_err(),
802 "CTE referencing denied column must be rejected"
803 );
804 }
805
806 #[test]
807 fn subquery_rbac_deeply_nested_three_levels() {
808 // Three levels of nesting — inner-most references denied
809 // column. Recursive walk must reach it.
810 let filter = RbacFilter::new(user_denies_ssn_policy());
811 let sql = "SELECT id FROM orders \
812 WHERE customer IN ( \
813 SELECT name FROM users \
814 WHERE email IN (SELECT ssn FROM users) \
815 )";
816 let stmt = parse_sql(sql);
817 let result = filter.rewrite_statement(stmt);
818 assert!(
819 result.is_err(),
820 "deeply nested subquery referencing denied column must be rejected"
821 );
822 }
823
824 #[test]
825 fn subquery_rbac_in_list_does_not_recurse_into_values() {
826 // `IN (literal_list)` is NOT a subquery — no recursion
827 // needed. The fix must not trip on regular in-list
828 // predicates.
829 let filter = RbacFilter::new(user_denies_ssn_policy());
830 let sql =
831 "SELECT id FROM orders WHERE customer IN ('alice', 'bob')";
832 let stmt = parse_sql(sql);
833 let result = filter.rewrite_statement(stmt);
834 assert!(
835 result.is_ok(),
836 "in-list with literal values must pass: {:?}",
837 result.err()
838 );
839 }
840}