kimberlite_query/rbac_filter.rs
1//! RBAC query filtering and rewriting.
2//!
3//! This module provides query rewriting to enforce RBAC policies:
4//! - **Column filtering**: Remove unauthorized columns from SELECT
5//! - **Row-level security**: Inject WHERE clauses
6//!
7//! ## Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────────┐
11//! │ Original Query │
12//! │ SELECT name, ssn FROM patients │
13//! └───────────────┬─────────────────────┘
14//! │
15//! ▼
16//! ┌─────────────────────────────────────┐
17//! │ RBAC Filter │
18//! │ - Check stream access │
19//! │ - Filter columns (remove "ssn") │
20//! │ - Inject WHERE clause │
21//! └───────────────┬─────────────────────┘
22//! │
23//! ▼
24//! ┌─────────────────────────────────────┐
25//! │ Rewritten Query │
26//! │ SELECT name FROM patients │
27//! │ WHERE tenant_id = 42 │
28//! └─────────────────────────────────────┘
29//! ```
30
31use crate::error::QueryError;
32use kimberlite_rbac::{AccessPolicy, enforcement::PolicyEnforcer};
33use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
34use thiserror::Error;
35use tracing::{debug, info, warn};
36
37/// Error type for RBAC filtering.
38#[derive(Debug, Error)]
39pub enum RbacError {
40 /// Access denied by policy.
41 #[error("Access denied: {0}")]
42 AccessDenied(String),
43
44 /// No authorized columns in query.
45 #[error("No authorized columns in query")]
46 NoAuthorizedColumns,
47
48 /// Unsupported query type for RBAC.
49 #[error("Unsupported query type: {0}")]
50 UnsupportedQuery(String),
51
52 /// Policy enforcement failed.
53 #[error("Policy enforcement failed: {0}")]
54 EnforcementFailed(String),
55}
56
57impl From<kimberlite_rbac::enforcement::EnforcementError> for RbacError {
58 fn from(err: kimberlite_rbac::enforcement::EnforcementError) -> Self {
59 match err {
60 kimberlite_rbac::enforcement::EnforcementError::AccessDenied { reason } => {
61 RbacError::AccessDenied(reason)
62 }
63 _ => RbacError::EnforcementFailed(err.to_string()),
64 }
65 }
66}
67
68impl From<RbacError> for QueryError {
69 fn from(err: RbacError) -> Self {
70 QueryError::UnsupportedFeature(err.to_string())
71 }
72}
73
74/// Result type for RBAC operations.
75pub type Result<T> = std::result::Result<T, RbacError>;
76
77/// Output of [`RbacFilter::rewrite_statement`].
78///
79/// Carries the rewritten statement alongside the alias mapping derived
80/// from the original projection. Downstream code (e.g. the masking pass
81/// in `kimberlite`) uses the mapping to resolve output column names
82/// back to their source columns so masks are applied to the underlying
83/// sensitive attribute, not the user-visible alias.
84#[derive(Debug)]
85pub struct RewriteOutput {
86 /// The rewritten SQL statement.
87 pub statement: Statement,
88 /// Pairs of `(output_column_name, source_column_name)` for each
89 /// projection item that survived RBAC filtering.
90 ///
91 /// Bare identifiers produce pairs where both entries are equal.
92 /// Aliased identifiers (`SELECT ssn AS id`) produce distinct
93 /// output/source entries — the masking pass must key its lookup
94 /// on the source entry (AUDIT-2026-04 M-7).
95 pub column_aliases: Vec<(String, String)>,
96}
97
98/// RBAC query filter.
99///
100/// Rewrites SQL queries to enforce access control policies.
101pub struct RbacFilter {
102 enforcer: PolicyEnforcer,
103}
104
105impl RbacFilter {
106 /// Creates a new RBAC filter with the given policy.
107 pub fn new(policy: AccessPolicy) -> Self {
108 Self {
109 enforcer: PolicyEnforcer::new(policy),
110 }
111 }
112
113 /// Rewrites a SQL statement to enforce RBAC policy.
114 ///
115 /// **Transformations:**
116 /// 1. Check stream access (deny if unauthorized)
117 /// 2. Filter SELECT columns (remove unauthorized columns)
118 /// 3. Inject WHERE clause for row-level security
119 ///
120 /// # Arguments
121 ///
122 /// * `stmt` - SQL statement to rewrite
123 ///
124 /// # Returns
125 ///
126 /// Rewritten statement plus a map of `(output_column_name,
127 /// source_column_name)` pairs — one entry per projection item that
128 /// survived RBAC filtering. The masking pass uses this map to look
129 /// up column masks by source column rather than by the
130 /// potentially-aliased output name (AUDIT-2026-04 M-7).
131 ///
132 /// # Errors
133 ///
134 /// - `AccessDenied` if stream access is denied
135 /// - `NoAuthorizedColumns` if all columns are unauthorized
136 /// - `UnsupportedQuery` if query type is not supported
137 pub fn rewrite_statement(&self, mut stmt: Statement) -> Result<RewriteOutput> {
138 match &mut stmt {
139 Statement::Query(query) => {
140 let column_aliases = self.rewrite_query(query)?;
141 Ok(RewriteOutput {
142 statement: stmt,
143 column_aliases,
144 })
145 }
146 _ => Err(RbacError::UnsupportedQuery(
147 "Only SELECT queries are currently supported".to_string(),
148 )),
149 }
150 }
151
152 /// Rewrites a query to enforce RBAC.
153 ///
154 /// **AUDIT-2026-04 M-7 — recursive traversal.** Prior to this
155 /// change, only the top-level `SetExpr::Select` was rewritten,
156 /// so a predicate like
157 /// `SELECT id FROM t WHERE x IN (SELECT ssn FROM patients)`
158 /// would bypass column filtering on `ssn`. The recursive walk
159 /// below ensures every nested `Query` (CTE / UNION / subquery
160 /// in FROM / subquery in WHERE) is rewritten under the same
161 /// policy before the outer select is processed.
162 fn rewrite_query(&self, query: &mut Query) -> Result<Vec<(String, String)>> {
163 // 1. Rewrite CTEs first — later referenced by name in the
164 // main set-expression, so their filtering must land
165 // before the outer select reads them.
166 if let Some(with) = query.with.as_mut() {
167 for cte in &mut with.cte_tables {
168 // CTEs themselves cannot leak if the outer select
169 // never references the denied columns — but we
170 // rewrite defensively so that any CTE reference
171 // through `SELECT * FROM cte_name` (once wildcard
172 // support lands) does not expose masked sources.
173 let _ = self.rewrite_query(cte.query.as_mut())?;
174 }
175 }
176
177 // 2. Dispatch on set-expression shape.
178 self.rewrite_set_expr(query.body.as_mut())
179 }
180
181 /// Recursively rewrites a `SetExpr`, returning the column
182 /// lineage for the *representative* select (the left-most
183 /// branch of a UNION, or the inner select of a parenthesised
184 /// query).
185 ///
186 /// UNION branches must all satisfy the policy independently —
187 /// if any branch references a denied column, the whole query
188 /// is rejected.
189 fn rewrite_set_expr(&self, set_expr: &mut SetExpr) -> Result<Vec<(String, String)>> {
190 match set_expr {
191 SetExpr::Select(select) => self.rewrite_select(select),
192 // Parenthesised query — recurse.
193 SetExpr::Query(inner) => self.rewrite_query(inner.as_mut()),
194 // UNION / INTERSECT / EXCEPT — every branch must pass
195 // RBAC independently. The outer lineage comes from the
196 // left branch (all branches must have compatible
197 // column counts, so either branch's lineage is a valid
198 // descriptor; we use left for determinism).
199 SetExpr::SetOperation { left, right, .. } => {
200 let left_lineage = self.rewrite_set_expr(left.as_mut())?;
201 let _right_lineage = self.rewrite_set_expr(right.as_mut())?;
202 Ok(left_lineage)
203 }
204 _ => Err(RbacError::UnsupportedQuery(format!(
205 "unsupported set-expression: {set_expr:?}"
206 ))),
207 }
208 }
209
210 /// Rewrites a SELECT statement. Returns the `(output, source)`
211 /// column pairs for the surviving projection items.
212 fn rewrite_select(&self, select: &mut Select) -> Result<Vec<(String, String)>> {
213 // AUDIT-2026-04 M-7 — subquery / nested-SELECT recursion.
214 //
215 // Step 0a: rewrite any `TableFactor::Derived { subquery }`
216 // in the FROM clause. A predicate that reads
217 // `SELECT outer.x FROM (SELECT ssn AS x FROM patients) outer`
218 // was previously accepted because `extract_stream_name` only
219 // saw the outer derived-table reference — the inner SELECT
220 // was never filtered against the `patients.ssn` deny policy.
221 // Now the inner SELECT is rewritten first; if it references
222 // a denied column it errors out here, before any outer
223 // lineage is reported.
224 for table_with_joins in &mut select.from {
225 self.rewrite_table_factor(&mut table_with_joins.relation)?;
226 for join in &mut table_with_joins.joins {
227 self.rewrite_table_factor(&mut join.relation)?;
228 }
229 }
230
231 // Step 0b: rewrite subqueries inside the WHERE clause.
232 // Handles `IN (SELECT ...)`, `EXISTS (SELECT ...)`, and
233 // scalar-subquery forms. The traversal is read-mutable
234 // because the inner rewrite replaces column projections.
235 if let Some(ref mut selection) = select.selection {
236 self.rewrite_expr_subqueries(selection)?;
237 }
238
239 // 1. Extract stream name from FROM clause
240 let stream_name = Self::extract_stream_name(select)?;
241
242 debug!(stream = %stream_name, "Extracting columns from SELECT");
243
244 // 2. Extract requested columns (source names) and aliases
245 let aliases = Self::extract_column_aliases(select)?;
246 let requested_columns: Vec<String> = aliases.iter().map(|(_, src)| src.clone()).collect();
247
248 info!(
249 stream = %stream_name,
250 columns = ?requested_columns,
251 "Enforcing RBAC policy"
252 );
253
254 // 3. Enforce policy (checks stream access + filters columns)
255 let (allowed_columns, where_clause_sql) = self
256 .enforcer
257 .enforce_query(&stream_name, &requested_columns)?;
258
259 if allowed_columns.is_empty() {
260 warn!(stream = %stream_name, "No authorized columns");
261 return Err(RbacError::NoAuthorizedColumns);
262 }
263
264 // 4. Rewrite SELECT projection (filter columns)
265 Self::rewrite_projection(select, &allowed_columns);
266
267 // 5. Inject WHERE clause for row-level security
268 if !where_clause_sql.is_empty() {
269 Self::inject_where_clause(select, &where_clause_sql)?;
270 }
271
272 info!(
273 stream = %stream_name,
274 allowed_columns = ?allowed_columns,
275 where_clause = %where_clause_sql,
276 "Query rewritten successfully"
277 );
278
279 // 6. Trim the alias map to the surviving projection.
280 let allowed: std::collections::HashSet<&str> =
281 allowed_columns.iter().map(String::as_str).collect();
282 let surviving_aliases = aliases
283 .into_iter()
284 .filter(|(_, src)| allowed.contains(src.as_str()))
285 .collect();
286
287 Ok(surviving_aliases)
288 }
289
290 /// AUDIT-2026-04 M-7 helper — recurse into nested queries
291 /// carried by a `TableFactor`.
292 ///
293 /// `TableFactor::Derived { subquery }` is the AST node for
294 /// `FROM (SELECT ...)`. `TableFactor::NestedJoin` wraps a
295 /// `TableWithJoins` that may itself contain derived tables.
296 /// Anything else is a terminal table reference handled by
297 /// `extract_stream_name` downstream.
298 fn rewrite_table_factor(&self, factor: &mut TableFactor) -> Result<()> {
299 match factor {
300 TableFactor::Derived { subquery, .. } => {
301 self.rewrite_query(subquery.as_mut())?;
302 Ok(())
303 }
304 TableFactor::NestedJoin {
305 table_with_joins, ..
306 } => {
307 self.rewrite_table_factor(&mut table_with_joins.relation)?;
308 for join in &mut table_with_joins.joins {
309 self.rewrite_table_factor(&mut join.relation)?;
310 }
311 Ok(())
312 }
313 _ => Ok(()),
314 }
315 }
316
317 /// AUDIT-2026-04 M-7 helper — recurse into subqueries embedded
318 /// in a WHERE-clause `Expr`.
319 ///
320 /// Walks `Expr::Subquery`, `Expr::InSubquery`, `Expr::Exists`,
321 /// and combinators (`BinaryOp`, `UnaryOp`, `Nested`) that can
322 /// transport a subquery in their children. Non-subquery leaves
323 /// (identifiers, literals) are terminal.
324 ///
325 /// A bounded-depth guard would belong here if the recursive
326 /// kernel principle forbade it; the query parser already
327 /// rejects SQL with unbounded expression depth before reaching
328 /// this point, so we rely on the sqlparser limit.
329 fn rewrite_expr_subqueries(&self, expr: &mut Expr) -> Result<()> {
330 match expr {
331 Expr::Subquery(q) | Expr::Exists { subquery: q, .. } => {
332 self.rewrite_query(q.as_mut())?;
333 Ok(())
334 }
335 Expr::InSubquery {
336 subquery,
337 expr: inner,
338 ..
339 } => {
340 self.rewrite_expr_subqueries(inner.as_mut())?;
341 self.rewrite_query(subquery.as_mut())?;
342 Ok(())
343 }
344 Expr::BinaryOp { left, right, .. } => {
345 self.rewrite_expr_subqueries(left.as_mut())?;
346 self.rewrite_expr_subqueries(right.as_mut())
347 }
348 Expr::UnaryOp { expr: inner, .. } | Expr::Nested(inner) => {
349 self.rewrite_expr_subqueries(inner.as_mut())
350 }
351 Expr::InList {
352 expr: inner, list, ..
353 } => {
354 self.rewrite_expr_subqueries(inner.as_mut())?;
355 for item in list.iter_mut() {
356 self.rewrite_expr_subqueries(item)?;
357 }
358 Ok(())
359 }
360 Expr::Between {
361 expr: inner,
362 low,
363 high,
364 ..
365 } => {
366 self.rewrite_expr_subqueries(inner.as_mut())?;
367 self.rewrite_expr_subqueries(low.as_mut())?;
368 self.rewrite_expr_subqueries(high.as_mut())
369 }
370 Expr::Case {
371 conditions,
372 else_result,
373 ..
374 } => {
375 for case_when in conditions.iter_mut() {
376 self.rewrite_expr_subqueries(&mut case_when.condition)?;
377 self.rewrite_expr_subqueries(&mut case_when.result)?;
378 }
379 if let Some(else_r) = else_result.as_mut() {
380 self.rewrite_expr_subqueries(else_r.as_mut())?;
381 }
382 Ok(())
383 }
384 // Identifiers, literals, function calls without subquery
385 // arguments, etc. — nothing to rewrite.
386 _ => Ok(()),
387 }
388 }
389
390 /// Extracts the stream name from the FROM clause.
391 fn extract_stream_name(select: &Select) -> Result<String> {
392 if select.from.is_empty() {
393 return Err(RbacError::UnsupportedQuery(
394 "SELECT without FROM clause".to_string(),
395 ));
396 }
397
398 let table = &select.from[0];
399 match &table.relation {
400 TableFactor::Table { name, .. } => {
401 let stream_name = name
402 .0
403 .iter()
404 .map(|part| match part.as_ident() {
405 Some(ident) => ident.value.clone(),
406 None => part.to_string(),
407 })
408 .collect::<Vec<_>>()
409 .join(".");
410 Ok(stream_name)
411 }
412 _ => Err(RbacError::UnsupportedQuery(
413 "Only simple table references are supported".to_string(),
414 )),
415 }
416 }
417
418 /// Extracts `(output_column_name, source_column_name)` pairs for
419 /// each item in the SELECT projection. See [`column_aliases`] for
420 /// the free-function entry point used by the SQL-level mask pass.
421 fn extract_column_aliases(select: &Select) -> Result<Vec<(String, String)>> {
422 column_aliases_from_select(select)
423 }
424}
425
426/// Extracts `(output_column_name, source_column_name)` pairs for each
427/// item in the SELECT projection of `stmt`.
428///
429/// Returns an empty vector for non-`SELECT` statements or for set-expr
430/// bodies that are not a plain `SELECT` (e.g. `UNION`) — the masking
431/// pass treats an empty map as "no aliases known" and falls back to
432/// output-name keying, matching pre-M-7 semantics for those shapes.
433///
434/// Semantics:
435/// - `SELECT col` → `("col", "col")`
436/// - `SELECT col AS alias` → `("alias", "col")`
437/// - `SELECT UPPER(col) AS alias` → `("alias", "alias")` (non-identifier
438/// expressions cannot be resolved to a source column — mask lookup
439/// keys on the alias, mirroring the pre-M-7 behaviour).
440///
441/// AUDIT-2026-04 M-7: the masking pass uses the source half of the
442/// pair to look up column masks. Without this, `SELECT ssn AS id FROM
443/// patients` passed RBAC (source `ssn` is permitted) but
444/// `mask_for_column("id")` returned `None`, leaking the masked
445/// attribute under a rename.
446pub fn column_aliases(stmt: &Statement) -> Vec<(String, String)> {
447 let Statement::Query(query) = stmt else {
448 return Vec::new();
449 };
450 let SetExpr::Select(select) = query.body.as_ref() else {
451 return Vec::new();
452 };
453 column_aliases_from_select(select).unwrap_or_default()
454}
455
456fn column_aliases_from_select(select: &Select) -> Result<Vec<(String, String)>> {
457 let mut pairs = Vec::new();
458
459 for item in &select.projection {
460 match item {
461 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
462 pairs.push((ident.value.clone(), ident.value.clone()));
463 }
464 SelectItem::ExprWithAlias { expr, alias } => {
465 if let Expr::Identifier(ident) = expr {
466 pairs.push((alias.value.clone(), ident.value.clone()));
467 } else {
468 pairs.push((alias.value.clone(), alias.value.clone()));
469 }
470 }
471 SelectItem::Wildcard(_) => {
472 return Err(RbacError::UnsupportedQuery(
473 "SELECT * is not yet supported with RBAC".to_string(),
474 ));
475 }
476 _ => {
477 return Err(RbacError::UnsupportedQuery(format!(
478 "Unsupported SELECT item: {item:?}"
479 )));
480 }
481 }
482 }
483
484 Ok(pairs)
485}
486
487impl RbacFilter {
488 /// Rewrites the SELECT projection to include only allowed columns.
489 fn rewrite_projection(select: &mut Select, allowed_columns: &[String]) {
490 let allowed_set: std::collections::HashSet<_> = allowed_columns.iter().collect();
491
492 select.projection.retain(|item| match item {
493 SelectItem::UnnamedExpr(Expr::Identifier(ident))
494 | SelectItem::ExprWithAlias {
495 expr: Expr::Identifier(ident),
496 ..
497 } => allowed_set.contains(&ident.value),
498 _ => false,
499 });
500 }
501
502 /// Injects a WHERE clause for row-level security.
503 fn inject_where_clause(select: &mut Select, where_clause_sql: &str) -> Result<()> {
504 // Parse the WHERE clause SQL into an Expr
505 let where_expr = Self::parse_where_clause(where_clause_sql)?;
506
507 // Combine with existing WHERE clause (if any)
508 select.selection = match select.selection.take() {
509 Some(existing) => Some(Expr::BinaryOp {
510 left: Box::new(existing),
511 op: sqlparser::ast::BinaryOperator::And,
512 right: Box::new(where_expr),
513 }),
514 None => Some(where_expr),
515 };
516
517 Ok(())
518 }
519
520 /// Parses a WHERE clause SQL string into an Expr.
521 ///
522 /// # Security boundary
523 ///
524 /// This function is **only ever called with trusted `RowFilter` values** generated
525 /// internally by the RBAC policy engine (see [`PolicyEnforcer::row_filter`]).
526 /// It is **not** called with user-supplied SQL strings and is therefore not a
527 /// SQL-injection vector. If you ever call this with data derived from user input,
528 /// you MUST validate/sanitize the input first.
529 ///
530 /// The parser handles `column = value` predicates joined by `AND`. It produces
531 /// AST nodes directly (not concatenated SQL), so the result is safe to pass to
532 /// the query planner without further escaping.
533 ///
534 /// More complex predicates may require the full SQL parser.
535 fn parse_where_clause(where_clause_sql: &str) -> Result<Expr> {
536 // Simple parser for "column = value" and "column1 = value1 AND column2 = value2".
537 // SAFETY: Only called with trusted, internally-generated RowFilter strings.
538 let parts: Vec<&str> = where_clause_sql.split(" AND ").collect();
539
540 let mut exprs = Vec::new();
541
542 for part in parts {
543 // Parse "column = value"
544 let tokens: Vec<&str> = part.trim().split('=').collect();
545 if tokens.len() != 2 {
546 return Err(RbacError::UnsupportedQuery(format!(
547 "Invalid WHERE clause: {part}"
548 )));
549 }
550
551 let column = tokens[0].trim();
552 let value = tokens[1].trim();
553
554 let expr = Expr::BinaryOp {
555 left: Box::new(Expr::Identifier(sqlparser::ast::Ident::new(column))),
556 op: sqlparser::ast::BinaryOperator::Eq,
557 right: Box::new(Expr::Value(
558 sqlparser::ast::Value::Number(value.to_string(), false).into(),
559 )),
560 };
561
562 exprs.push(expr);
563 }
564
565 // Combine with AND
566 let mut iter = exprs.into_iter();
567 let mut result = iter
568 .next()
569 .ok_or_else(|| RbacError::UnsupportedQuery("Empty WHERE clause".to_string()))?;
570
571 for expr in iter {
572 result = Expr::BinaryOp {
573 left: Box::new(result),
574 op: sqlparser::ast::BinaryOperator::And,
575 right: Box::new(expr),
576 };
577 }
578
579 Ok(result)
580 }
581
582 /// Returns the underlying policy enforcer.
583 pub fn enforcer(&self) -> &PolicyEnforcer {
584 &self.enforcer
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591 use kimberlite_rbac::policy::StandardPolicies;
592 use kimberlite_types::TenantId;
593 use sqlparser::dialect::GenericDialect;
594 use sqlparser::parser::Parser;
595
596 fn parse_sql(sql: &str) -> Statement {
597 let dialect = GenericDialect {};
598 let statements = Parser::parse_sql(&dialect, sql).expect("Failed to parse SQL");
599 statements.into_iter().next().expect("No statement parsed")
600 }
601
602 #[test]
603 fn test_rewrite_admin_policy() {
604 let policy = StandardPolicies::admin();
605 let filter = RbacFilter::new(policy);
606
607 let sql = "SELECT name, email FROM users";
608 let stmt = parse_sql(sql);
609
610 let result = filter.rewrite_statement(stmt);
611 assert!(result.is_ok());
612 }
613
614 #[test]
615 fn test_rewrite_user_policy_column_filter() {
616 let policy = kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
617 .allow_stream("users")
618 .allow_column("name")
619 .deny_column("ssn");
620
621 let filter = RbacFilter::new(policy);
622
623 let sql = "SELECT name, ssn FROM users";
624 let stmt = parse_sql(sql);
625
626 let result = filter.rewrite_statement(stmt);
627 assert!(result.is_ok());
628
629 // Check that ssn was filtered out
630 if let Statement::Query(query) = result.unwrap().statement {
631 if let SetExpr::Select(select) = query.body.as_ref() {
632 assert_eq!(select.projection.len(), 1);
633 // Should only have "name" column
634 }
635 }
636 }
637
638 #[test]
639 fn test_rewrite_with_row_filter() {
640 let tenant_id = TenantId::new(42);
641 let policy = StandardPolicies::user(tenant_id);
642 let filter = RbacFilter::new(policy);
643
644 let sql = "SELECT name, email FROM users";
645 let stmt = parse_sql(sql);
646
647 let result = filter.rewrite_statement(stmt);
648 assert!(result.is_ok());
649
650 // Check that WHERE clause was injected
651 if let Statement::Query(query) = result.unwrap().statement {
652 if let SetExpr::Select(select) = query.body.as_ref() {
653 assert!(select.selection.is_some());
654 // Should have WHERE tenant_id = 42
655 }
656 }
657 }
658
659 #[test]
660 fn test_rewrite_access_denied() {
661 let policy = StandardPolicies::auditor();
662 let filter = RbacFilter::new(policy);
663
664 let sql = "SELECT name FROM users"; // Auditor cannot access users table
665 let stmt = parse_sql(sql);
666
667 let result = filter.rewrite_statement(stmt);
668 assert!(result.is_err());
669 assert!(matches!(result.unwrap_err(), RbacError::AccessDenied(_)));
670 }
671
672 #[test]
673 fn test_rewrite_no_authorized_columns() {
674 let policy = kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
675 .allow_stream("users")
676 .deny_column("*"); // Deny all columns
677
678 let filter = RbacFilter::new(policy);
679
680 let sql = "SELECT name, email FROM users";
681 let stmt = parse_sql(sql);
682
683 let result = filter.rewrite_statement(stmt);
684 assert!(result.is_err());
685 let err = result.unwrap_err();
686 assert!(
687 matches!(err, RbacError::AccessDenied(ref msg) if msg.contains("No authorized columns"))
688 );
689 }
690
691 // -----------------------------------------------------------------
692 // AUDIT-2026-04 M-7 — subquery / nested-SELECT RBAC enforcement.
693 //
694 // Before this fix, `rewrite_statement` only processed the
695 // top-level SELECT. A predicate like
696 // SELECT id FROM orders WHERE customer IN (SELECT ssn FROM users)
697 // passed through untouched because the inner SELECT was never
698 // visited; `ssn` was exposed despite the user's `deny_column`.
699 //
700 // These tests pin that every nested Query (WHERE IN, EXISTS,
701 // derived table in FROM, UNION branch) is rewritten under the
702 // same policy.
703 // -----------------------------------------------------------------
704
705 fn user_denies_ssn_policy() -> kimberlite_rbac::policy::AccessPolicy {
706 // `users` stream is fully accessible on the allow-list, but
707 // `ssn` is explicitly denied. Any nested reference to
708 // `ssn` must be rejected by the recursive walk.
709 kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
710 .allow_stream("users")
711 .allow_stream("orders")
712 .allow_column("name")
713 .allow_column("email")
714 .allow_column("customer")
715 .allow_column("id")
716 .deny_column("ssn")
717 }
718
719 #[test]
720 fn subquery_rbac_in_where_clause_enforces_inner_grants() {
721 // AUDIT-2026-04 M-7 regression test. Prior to the fix, this
722 // returned `Ok(_)` — `ssn` was never seen by the enforcer.
723 // After the fix, the inner SELECT is rewritten, and since
724 // `ssn` is denied + the inner projection has no other
725 // allowed columns, the whole query is rejected.
726 let filter = RbacFilter::new(user_denies_ssn_policy());
727 let sql = "SELECT id FROM orders WHERE customer IN (SELECT ssn FROM users)";
728 let stmt = parse_sql(sql);
729 let result = filter.rewrite_statement(stmt);
730 assert!(
731 result.is_err(),
732 "nested subquery referencing denied column must be rejected"
733 );
734 }
735
736 #[test]
737 fn subquery_rbac_exists_clause_recurses() {
738 // EXISTS subqueries are rewritten too.
739 let filter = RbacFilter::new(user_denies_ssn_policy());
740 let sql = "SELECT id FROM orders WHERE EXISTS (SELECT ssn FROM users)";
741 let stmt = parse_sql(sql);
742 let result = filter.rewrite_statement(stmt);
743 assert!(
744 result.is_err(),
745 "EXISTS-subquery referencing denied column must be rejected"
746 );
747 }
748
749 #[test]
750 fn subquery_rbac_derived_table_in_from_recurses() {
751 // Derived-table subquery in FROM clause.
752 let filter = RbacFilter::new(user_denies_ssn_policy());
753 let sql = "SELECT t.email FROM (SELECT ssn FROM users) t";
754 let stmt = parse_sql(sql);
755 let result = filter.rewrite_statement(stmt);
756 assert!(
757 result.is_err(),
758 "derived-table SELECT referencing denied column must be rejected"
759 );
760 }
761
762 #[test]
763 fn subquery_rbac_union_both_branches_checked() {
764 // UNION — both branches must pass RBAC. The left branch
765 // asks for `ssn` (denied) → whole query rejected.
766 let filter = RbacFilter::new(user_denies_ssn_policy());
767 let sql = "SELECT ssn FROM users UNION SELECT name FROM users";
768 let stmt = parse_sql(sql);
769 let result = filter.rewrite_statement(stmt);
770 assert!(
771 result.is_err(),
772 "UNION branch referencing denied column must be rejected"
773 );
774 }
775
776 #[test]
777 fn subquery_rbac_allowed_subquery_still_succeeds() {
778 // Sanity check: a subquery that references only allowed
779 // columns is unaffected — the M-7 fix must not introduce
780 // false-positive rejections.
781 let filter = RbacFilter::new(user_denies_ssn_policy());
782 let sql = "SELECT id FROM orders WHERE customer IN (SELECT name FROM users)";
783 let stmt = parse_sql(sql);
784 let result = filter.rewrite_statement(stmt);
785 assert!(
786 result.is_ok(),
787 "all-allowed subquery must pass, got: {:?}",
788 result.err()
789 );
790 }
791
792 #[test]
793 fn subquery_rbac_cte_with_denied_column_rejected() {
794 // CTEs are rewritten before the outer select reads them.
795 let filter = RbacFilter::new(user_denies_ssn_policy());
796 let sql = "WITH u AS (SELECT ssn FROM users) SELECT id FROM orders";
797 let stmt = parse_sql(sql);
798 let result = filter.rewrite_statement(stmt);
799 assert!(
800 result.is_err(),
801 "CTE referencing denied column must be rejected"
802 );
803 }
804
805 #[test]
806 fn subquery_rbac_deeply_nested_three_levels() {
807 // Three levels of nesting — inner-most references denied
808 // column. Recursive walk must reach it.
809 let filter = RbacFilter::new(user_denies_ssn_policy());
810 let sql = "SELECT id FROM orders \
811 WHERE customer IN ( \
812 SELECT name FROM users \
813 WHERE email IN (SELECT ssn FROM users) \
814 )";
815 let stmt = parse_sql(sql);
816 let result = filter.rewrite_statement(stmt);
817 assert!(
818 result.is_err(),
819 "deeply nested subquery referencing denied column must be rejected"
820 );
821 }
822
823 #[test]
824 fn subquery_rbac_in_list_does_not_recurse_into_values() {
825 // `IN (literal_list)` is NOT a subquery — no recursion
826 // needed. The fix must not trip on regular in-list
827 // predicates.
828 let filter = RbacFilter::new(user_denies_ssn_policy());
829 let sql = "SELECT id FROM orders WHERE customer IN ('alice', 'bob')";
830 let stmt = parse_sql(sql);
831 let result = filter.rewrite_statement(stmt);
832 assert!(
833 result.is_ok(),
834 "in-list with literal values must pass: {:?}",
835 result.err()
836 );
837 }
838}