Skip to main content

kimberlite_query/
rbac_filter.rs

1//! RBAC query filtering and rewriting.
2//!
3//! This module provides query rewriting to enforce RBAC policies:
4//! - **Column filtering**: Remove unauthorized columns from SELECT
5//! - **Row-level security**: Inject WHERE clauses
6//!
7//! ## Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────────┐
11//! │  Original Query                      │
12//! │  SELECT name, ssn FROM patients      │
13//! └───────────────┬─────────────────────┘
14//!                 │
15//!                 ▼
16//! ┌─────────────────────────────────────┐
17//! │  RBAC Filter                         │
18//! │  - Check stream access               │
19//! │  - Filter columns (remove "ssn")     │
20//! │  - Inject WHERE clause               │
21//! └───────────────┬─────────────────────┘
22//!                 │
23//!                 ▼
24//! ┌─────────────────────────────────────┐
25//! │  Rewritten Query                     │
26//! │  SELECT name FROM patients           │
27//! │  WHERE tenant_id = 42                │
28//! └─────────────────────────────────────┘
29//! ```
30
31use crate::error::QueryError;
32use kimberlite_rbac::{AccessPolicy, enforcement::PolicyEnforcer};
33use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
34use thiserror::Error;
35use tracing::{debug, info, warn};
36
37/// Error type for RBAC filtering.
38#[derive(Debug, Error)]
39pub enum RbacError {
40    /// Access denied by policy.
41    #[error("Access denied: {0}")]
42    AccessDenied(String),
43
44    /// No authorized columns in query.
45    #[error("No authorized columns in query")]
46    NoAuthorizedColumns,
47
48    /// Unsupported query type for RBAC.
49    #[error("Unsupported query type: {0}")]
50    UnsupportedQuery(String),
51
52    /// Policy enforcement failed.
53    #[error("Policy enforcement failed: {0}")]
54    EnforcementFailed(String),
55}
56
57impl From<kimberlite_rbac::enforcement::EnforcementError> for RbacError {
58    fn from(err: kimberlite_rbac::enforcement::EnforcementError) -> Self {
59        match err {
60            kimberlite_rbac::enforcement::EnforcementError::AccessDenied { reason } => {
61                RbacError::AccessDenied(reason)
62            }
63            _ => RbacError::EnforcementFailed(err.to_string()),
64        }
65    }
66}
67
68impl From<RbacError> for QueryError {
69    fn from(err: RbacError) -> Self {
70        QueryError::UnsupportedFeature(err.to_string())
71    }
72}
73
74/// Result type for RBAC operations.
75pub type Result<T> = std::result::Result<T, RbacError>;
76
77/// Output of [`RbacFilter::rewrite_statement`].
78///
79/// Carries the rewritten statement alongside the alias mapping derived
80/// from the original projection. Downstream code (e.g. the masking pass
81/// in `kimberlite`) uses the mapping to resolve output column names
82/// back to their source columns so masks are applied to the underlying
83/// sensitive attribute, not the user-visible alias.
84#[derive(Debug)]
85pub struct RewriteOutput {
86    /// The rewritten SQL statement.
87    pub statement: Statement,
88    /// Pairs of `(output_column_name, source_column_name)` for each
89    /// projection item that survived RBAC filtering.
90    ///
91    /// Bare identifiers produce pairs where both entries are equal.
92    /// Aliased identifiers (`SELECT ssn AS id`) produce distinct
93    /// output/source entries — the masking pass must key its lookup
94    /// on the source entry (AUDIT-2026-04 M-7).
95    pub column_aliases: Vec<(String, String)>,
96}
97
98/// RBAC query filter.
99///
100/// Rewrites SQL queries to enforce access control policies.
101pub struct RbacFilter {
102    enforcer: PolicyEnforcer,
103}
104
105impl RbacFilter {
106    /// Creates a new RBAC filter with the given policy.
107    pub fn new(policy: AccessPolicy) -> Self {
108        Self {
109            enforcer: PolicyEnforcer::new(policy),
110        }
111    }
112
113    /// Rewrites a SQL statement to enforce RBAC policy.
114    ///
115    /// **Transformations:**
116    /// 1. Check stream access (deny if unauthorized)
117    /// 2. Filter SELECT columns (remove unauthorized columns)
118    /// 3. Inject WHERE clause for row-level security
119    ///
120    /// # Arguments
121    ///
122    /// * `stmt` - SQL statement to rewrite
123    ///
124    /// # Returns
125    ///
126    /// Rewritten statement plus a map of `(output_column_name,
127    /// source_column_name)` pairs — one entry per projection item that
128    /// survived RBAC filtering. The masking pass uses this map to look
129    /// up column masks by source column rather than by the
130    /// potentially-aliased output name (AUDIT-2026-04 M-7).
131    ///
132    /// # Errors
133    ///
134    /// - `AccessDenied` if stream access is denied
135    /// - `NoAuthorizedColumns` if all columns are unauthorized
136    /// - `UnsupportedQuery` if query type is not supported
137    pub fn rewrite_statement(&self, mut stmt: Statement) -> Result<RewriteOutput> {
138        match &mut stmt {
139            Statement::Query(query) => {
140                let column_aliases = self.rewrite_query(query)?;
141                Ok(RewriteOutput {
142                    statement: stmt,
143                    column_aliases,
144                })
145            }
146            _ => Err(RbacError::UnsupportedQuery(
147                "Only SELECT queries are currently supported".to_string(),
148            )),
149        }
150    }
151
152    /// Rewrites a query to enforce RBAC.
153    ///
154    /// **AUDIT-2026-04 M-7 — recursive traversal.** Prior to this
155    /// change, only the top-level `SetExpr::Select` was rewritten,
156    /// so a predicate like
157    /// `SELECT id FROM t WHERE x IN (SELECT ssn FROM patients)`
158    /// would bypass column filtering on `ssn`. The recursive walk
159    /// below ensures every nested `Query` (CTE / UNION / subquery
160    /// in FROM / subquery in WHERE) is rewritten under the same
161    /// policy before the outer select is processed.
162    fn rewrite_query(&self, query: &mut Query) -> Result<Vec<(String, String)>> {
163        // 1. Rewrite CTEs first — later referenced by name in the
164        //    main set-expression, so their filtering must land
165        //    before the outer select reads them.
166        if let Some(with) = query.with.as_mut() {
167            for cte in &mut with.cte_tables {
168                // CTEs themselves cannot leak if the outer select
169                // never references the denied columns — but we
170                // rewrite defensively so that any CTE reference
171                // through `SELECT * FROM cte_name` (once wildcard
172                // support lands) does not expose masked sources.
173                let _ = self.rewrite_query(cte.query.as_mut())?;
174            }
175        }
176
177        // 2. Dispatch on set-expression shape.
178        self.rewrite_set_expr(query.body.as_mut())
179    }
180
181    /// Recursively rewrites a `SetExpr`, returning the column
182    /// lineage for the *representative* select (the left-most
183    /// branch of a UNION, or the inner select of a parenthesised
184    /// query).
185    ///
186    /// UNION branches must all satisfy the policy independently —
187    /// if any branch references a denied column, the whole query
188    /// is rejected.
189    fn rewrite_set_expr(&self, set_expr: &mut SetExpr) -> Result<Vec<(String, String)>> {
190        match set_expr {
191            SetExpr::Select(select) => self.rewrite_select(select),
192            // Parenthesised query — recurse.
193            SetExpr::Query(inner) => self.rewrite_query(inner.as_mut()),
194            // UNION / INTERSECT / EXCEPT — every branch must pass
195            // RBAC independently. The outer lineage comes from the
196            // left branch (all branches must have compatible
197            // column counts, so either branch's lineage is a valid
198            // descriptor; we use left for determinism).
199            SetExpr::SetOperation { left, right, .. } => {
200                let left_lineage = self.rewrite_set_expr(left.as_mut())?;
201                let _right_lineage = self.rewrite_set_expr(right.as_mut())?;
202                Ok(left_lineage)
203            }
204            _ => Err(RbacError::UnsupportedQuery(format!(
205                "unsupported set-expression: {set_expr:?}"
206            ))),
207        }
208    }
209
210    /// Rewrites a SELECT statement. Returns the `(output, source)`
211    /// column pairs for the surviving projection items.
212    fn rewrite_select(&self, select: &mut Select) -> Result<Vec<(String, String)>> {
213        // AUDIT-2026-04 M-7 — subquery / nested-SELECT recursion.
214        //
215        // Step 0a: rewrite any `TableFactor::Derived { subquery }`
216        // in the FROM clause. A predicate that reads
217        // `SELECT outer.x FROM (SELECT ssn AS x FROM patients) outer`
218        // was previously accepted because `extract_stream_name` only
219        // saw the outer derived-table reference — the inner SELECT
220        // was never filtered against the `patients.ssn` deny policy.
221        // Now the inner SELECT is rewritten first; if it references
222        // a denied column it errors out here, before any outer
223        // lineage is reported.
224        for table_with_joins in &mut select.from {
225            self.rewrite_table_factor(&mut table_with_joins.relation)?;
226            for join in &mut table_with_joins.joins {
227                self.rewrite_table_factor(&mut join.relation)?;
228            }
229        }
230
231        // Step 0b: rewrite subqueries inside the WHERE clause.
232        // Handles `IN (SELECT ...)`, `EXISTS (SELECT ...)`, and
233        // scalar-subquery forms. The traversal is read-mutable
234        // because the inner rewrite replaces column projections.
235        if let Some(ref mut selection) = select.selection {
236            self.rewrite_expr_subqueries(selection)?;
237        }
238
239        // 1. Extract stream name from FROM clause
240        let stream_name = Self::extract_stream_name(select)?;
241
242        debug!(stream = %stream_name, "Extracting columns from SELECT");
243
244        // 2. Extract requested columns (source names) and aliases
245        let aliases = Self::extract_column_aliases(select)?;
246        let requested_columns: Vec<String> = aliases.iter().map(|(_, src)| src.clone()).collect();
247
248        info!(
249            stream = %stream_name,
250            columns = ?requested_columns,
251            "Enforcing RBAC policy"
252        );
253
254        // 3. Enforce policy (checks stream access + filters columns)
255        let (allowed_columns, where_clause_sql) = self
256            .enforcer
257            .enforce_query(&stream_name, &requested_columns)?;
258
259        if allowed_columns.is_empty() {
260            warn!(stream = %stream_name, "No authorized columns");
261            return Err(RbacError::NoAuthorizedColumns);
262        }
263
264        // 4. Rewrite SELECT projection (filter columns)
265        Self::rewrite_projection(select, &allowed_columns);
266
267        // 5. Inject WHERE clause for row-level security
268        if !where_clause_sql.is_empty() {
269            Self::inject_where_clause(select, &where_clause_sql)?;
270        }
271
272        info!(
273            stream = %stream_name,
274            allowed_columns = ?allowed_columns,
275            where_clause = %where_clause_sql,
276            "Query rewritten successfully"
277        );
278
279        // 6. Trim the alias map to the surviving projection.
280        let allowed: std::collections::HashSet<&str> =
281            allowed_columns.iter().map(String::as_str).collect();
282        let surviving_aliases = aliases
283            .into_iter()
284            .filter(|(_, src)| allowed.contains(src.as_str()))
285            .collect();
286
287        Ok(surviving_aliases)
288    }
289
290    /// AUDIT-2026-04 M-7 helper — recurse into nested queries
291    /// carried by a `TableFactor`.
292    ///
293    /// `TableFactor::Derived { subquery }` is the AST node for
294    /// `FROM (SELECT ...)`. `TableFactor::NestedJoin` wraps a
295    /// `TableWithJoins` that may itself contain derived tables.
296    /// Anything else is a terminal table reference handled by
297    /// `extract_stream_name` downstream.
298    fn rewrite_table_factor(&self, factor: &mut TableFactor) -> Result<()> {
299        match factor {
300            TableFactor::Derived { subquery, .. } => {
301                self.rewrite_query(subquery.as_mut())?;
302                Ok(())
303            }
304            TableFactor::NestedJoin {
305                table_with_joins, ..
306            } => {
307                self.rewrite_table_factor(&mut table_with_joins.relation)?;
308                for join in &mut table_with_joins.joins {
309                    self.rewrite_table_factor(&mut join.relation)?;
310                }
311                Ok(())
312            }
313            _ => Ok(()),
314        }
315    }
316
317    /// AUDIT-2026-04 M-7 helper — recurse into subqueries embedded
318    /// in a WHERE-clause `Expr`.
319    ///
320    /// Walks `Expr::Subquery`, `Expr::InSubquery`, `Expr::Exists`,
321    /// and combinators (`BinaryOp`, `UnaryOp`, `Nested`) that can
322    /// transport a subquery in their children. Non-subquery leaves
323    /// (identifiers, literals) are terminal.
324    ///
325    /// A bounded-depth guard would belong here if the recursive
326    /// kernel principle forbade it; the query parser already
327    /// rejects SQL with unbounded expression depth before reaching
328    /// this point, so we rely on the sqlparser limit.
329    fn rewrite_expr_subqueries(&self, expr: &mut Expr) -> Result<()> {
330        match expr {
331            Expr::Subquery(q) | Expr::Exists { subquery: q, .. } => {
332                self.rewrite_query(q.as_mut())?;
333                Ok(())
334            }
335            Expr::InSubquery {
336                subquery,
337                expr: inner,
338                ..
339            } => {
340                self.rewrite_expr_subqueries(inner.as_mut())?;
341                self.rewrite_query(subquery.as_mut())?;
342                Ok(())
343            }
344            Expr::BinaryOp { left, right, .. } => {
345                self.rewrite_expr_subqueries(left.as_mut())?;
346                self.rewrite_expr_subqueries(right.as_mut())
347            }
348            Expr::UnaryOp { expr: inner, .. } | Expr::Nested(inner) => {
349                self.rewrite_expr_subqueries(inner.as_mut())
350            }
351            Expr::InList {
352                expr: inner, list, ..
353            } => {
354                self.rewrite_expr_subqueries(inner.as_mut())?;
355                for item in list.iter_mut() {
356                    self.rewrite_expr_subqueries(item)?;
357                }
358                Ok(())
359            }
360            Expr::Between {
361                expr: inner,
362                low,
363                high,
364                ..
365            } => {
366                self.rewrite_expr_subqueries(inner.as_mut())?;
367                self.rewrite_expr_subqueries(low.as_mut())?;
368                self.rewrite_expr_subqueries(high.as_mut())
369            }
370            Expr::Case {
371                conditions,
372                else_result,
373                ..
374            } => {
375                for case_when in conditions.iter_mut() {
376                    self.rewrite_expr_subqueries(&mut case_when.condition)?;
377                    self.rewrite_expr_subqueries(&mut case_when.result)?;
378                }
379                if let Some(else_r) = else_result.as_mut() {
380                    self.rewrite_expr_subqueries(else_r.as_mut())?;
381                }
382                Ok(())
383            }
384            // Identifiers, literals, function calls without subquery
385            // arguments, etc. — nothing to rewrite.
386            _ => Ok(()),
387        }
388    }
389
390    /// Extracts the stream name from the FROM clause.
391    fn extract_stream_name(select: &Select) -> Result<String> {
392        if select.from.is_empty() {
393            return Err(RbacError::UnsupportedQuery(
394                "SELECT without FROM clause".to_string(),
395            ));
396        }
397
398        let table = &select.from[0];
399        match &table.relation {
400            TableFactor::Table { name, .. } => {
401                let stream_name = name
402                    .0
403                    .iter()
404                    .map(|part| match part.as_ident() {
405                        Some(ident) => ident.value.clone(),
406                        None => part.to_string(),
407                    })
408                    .collect::<Vec<_>>()
409                    .join(".");
410                Ok(stream_name)
411            }
412            _ => Err(RbacError::UnsupportedQuery(
413                "Only simple table references are supported".to_string(),
414            )),
415        }
416    }
417
418    /// Extracts `(output_column_name, source_column_name)` pairs for
419    /// each item in the SELECT projection. See [`column_aliases`] for
420    /// the free-function entry point used by the SQL-level mask pass.
421    fn extract_column_aliases(select: &Select) -> Result<Vec<(String, String)>> {
422        column_aliases_from_select(select)
423    }
424}
425
426/// Extracts `(output_column_name, source_column_name)` pairs for each
427/// item in the SELECT projection of `stmt`.
428///
429/// Returns an empty vector for non-`SELECT` statements or for set-expr
430/// bodies that are not a plain `SELECT` (e.g. `UNION`) — the masking
431/// pass treats an empty map as "no aliases known" and falls back to
432/// output-name keying, matching pre-M-7 semantics for those shapes.
433///
434/// Semantics:
435/// - `SELECT col` → `("col", "col")`
436/// - `SELECT col AS alias` → `("alias", "col")`
437/// - `SELECT UPPER(col) AS alias` → `("alias", "alias")` (non-identifier
438///   expressions cannot be resolved to a source column — mask lookup
439///   keys on the alias, mirroring the pre-M-7 behaviour).
440///
441/// AUDIT-2026-04 M-7: the masking pass uses the source half of the
442/// pair to look up column masks. Without this, `SELECT ssn AS id FROM
443/// patients` passed RBAC (source `ssn` is permitted) but
444/// `mask_for_column("id")` returned `None`, leaking the masked
445/// attribute under a rename.
446pub fn column_aliases(stmt: &Statement) -> Vec<(String, String)> {
447    let Statement::Query(query) = stmt else {
448        return Vec::new();
449    };
450    let SetExpr::Select(select) = query.body.as_ref() else {
451        return Vec::new();
452    };
453    column_aliases_from_select(select).unwrap_or_default()
454}
455
456fn column_aliases_from_select(select: &Select) -> Result<Vec<(String, String)>> {
457    let mut pairs = Vec::new();
458
459    for item in &select.projection {
460        match item {
461            SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
462                pairs.push((ident.value.clone(), ident.value.clone()));
463            }
464            SelectItem::ExprWithAlias { expr, alias } => {
465                if let Expr::Identifier(ident) = expr {
466                    pairs.push((alias.value.clone(), ident.value.clone()));
467                } else {
468                    pairs.push((alias.value.clone(), alias.value.clone()));
469                }
470            }
471            SelectItem::Wildcard(_) => {
472                return Err(RbacError::UnsupportedQuery(
473                    "SELECT * is not yet supported with RBAC".to_string(),
474                ));
475            }
476            _ => {
477                return Err(RbacError::UnsupportedQuery(format!(
478                    "Unsupported SELECT item: {item:?}"
479                )));
480            }
481        }
482    }
483
484    Ok(pairs)
485}
486
487impl RbacFilter {
488    /// Rewrites the SELECT projection to include only allowed columns.
489    fn rewrite_projection(select: &mut Select, allowed_columns: &[String]) {
490        let allowed_set: std::collections::HashSet<_> = allowed_columns.iter().collect();
491
492        select.projection.retain(|item| match item {
493            SelectItem::UnnamedExpr(Expr::Identifier(ident))
494            | SelectItem::ExprWithAlias {
495                expr: Expr::Identifier(ident),
496                ..
497            } => allowed_set.contains(&ident.value),
498            _ => false,
499        });
500    }
501
502    /// Injects a WHERE clause for row-level security.
503    fn inject_where_clause(select: &mut Select, where_clause_sql: &str) -> Result<()> {
504        // Parse the WHERE clause SQL into an Expr
505        let where_expr = Self::parse_where_clause(where_clause_sql)?;
506
507        // Combine with existing WHERE clause (if any)
508        select.selection = match select.selection.take() {
509            Some(existing) => Some(Expr::BinaryOp {
510                left: Box::new(existing),
511                op: sqlparser::ast::BinaryOperator::And,
512                right: Box::new(where_expr),
513            }),
514            None => Some(where_expr),
515        };
516
517        Ok(())
518    }
519
520    /// Parses a WHERE clause SQL string into an Expr.
521    ///
522    /// # Security boundary
523    ///
524    /// This function is **only ever called with trusted `RowFilter` values** generated
525    /// internally by the RBAC policy engine (see [`PolicyEnforcer::row_filter`]).
526    /// It is **not** called with user-supplied SQL strings and is therefore not a
527    /// SQL-injection vector.  If you ever call this with data derived from user input,
528    /// you MUST validate/sanitize the input first.
529    ///
530    /// The parser handles `column = value` predicates joined by `AND`.  It produces
531    /// AST nodes directly (not concatenated SQL), so the result is safe to pass to
532    /// the query planner without further escaping.
533    ///
534    /// More complex predicates may require the full SQL parser.
535    fn parse_where_clause(where_clause_sql: &str) -> Result<Expr> {
536        // Simple parser for "column = value" and "column1 = value1 AND column2 = value2".
537        // SAFETY: Only called with trusted, internally-generated RowFilter strings.
538        let parts: Vec<&str> = where_clause_sql.split(" AND ").collect();
539
540        let mut exprs = Vec::new();
541
542        for part in parts {
543            // Parse "column = value"
544            let tokens: Vec<&str> = part.trim().split('=').collect();
545            if tokens.len() != 2 {
546                return Err(RbacError::UnsupportedQuery(format!(
547                    "Invalid WHERE clause: {part}"
548                )));
549            }
550
551            let column = tokens[0].trim();
552            let value = tokens[1].trim();
553
554            let expr = Expr::BinaryOp {
555                left: Box::new(Expr::Identifier(sqlparser::ast::Ident::new(column))),
556                op: sqlparser::ast::BinaryOperator::Eq,
557                right: Box::new(Expr::Value(
558                    sqlparser::ast::Value::Number(value.to_string(), false).into(),
559                )),
560            };
561
562            exprs.push(expr);
563        }
564
565        // Combine with AND
566        let mut iter = exprs.into_iter();
567        let mut result = iter
568            .next()
569            .ok_or_else(|| RbacError::UnsupportedQuery("Empty WHERE clause".to_string()))?;
570
571        for expr in iter {
572            result = Expr::BinaryOp {
573                left: Box::new(result),
574                op: sqlparser::ast::BinaryOperator::And,
575                right: Box::new(expr),
576            };
577        }
578
579        Ok(result)
580    }
581
582    /// Returns the underlying policy enforcer.
583    pub fn enforcer(&self) -> &PolicyEnforcer {
584        &self.enforcer
585    }
586}
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591    use kimberlite_rbac::policy::StandardPolicies;
592    use kimberlite_types::TenantId;
593    use sqlparser::dialect::GenericDialect;
594    use sqlparser::parser::Parser;
595
596    fn parse_sql(sql: &str) -> Statement {
597        let dialect = GenericDialect {};
598        let statements = Parser::parse_sql(&dialect, sql).expect("Failed to parse SQL");
599        statements.into_iter().next().expect("No statement parsed")
600    }
601
602    #[test]
603    fn test_rewrite_admin_policy() {
604        let policy = StandardPolicies::admin();
605        let filter = RbacFilter::new(policy);
606
607        let sql = "SELECT name, email FROM users";
608        let stmt = parse_sql(sql);
609
610        let result = filter.rewrite_statement(stmt);
611        assert!(result.is_ok());
612    }
613
614    #[test]
615    fn test_rewrite_user_policy_column_filter() {
616        let policy = kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
617            .allow_stream("users")
618            .allow_column("name")
619            .deny_column("ssn");
620
621        let filter = RbacFilter::new(policy);
622
623        let sql = "SELECT name, ssn FROM users";
624        let stmt = parse_sql(sql);
625
626        let result = filter.rewrite_statement(stmt);
627        assert!(result.is_ok());
628
629        // Check that ssn was filtered out
630        if let Statement::Query(query) = result.unwrap().statement {
631            if let SetExpr::Select(select) = query.body.as_ref() {
632                assert_eq!(select.projection.len(), 1);
633                // Should only have "name" column
634            }
635        }
636    }
637
638    #[test]
639    fn test_rewrite_with_row_filter() {
640        let tenant_id = TenantId::new(42);
641        let policy = StandardPolicies::user(tenant_id);
642        let filter = RbacFilter::new(policy);
643
644        let sql = "SELECT name, email FROM users";
645        let stmt = parse_sql(sql);
646
647        let result = filter.rewrite_statement(stmt);
648        assert!(result.is_ok());
649
650        // Check that WHERE clause was injected
651        if let Statement::Query(query) = result.unwrap().statement {
652            if let SetExpr::Select(select) = query.body.as_ref() {
653                assert!(select.selection.is_some());
654                // Should have WHERE tenant_id = 42
655            }
656        }
657    }
658
659    #[test]
660    fn test_rewrite_access_denied() {
661        let policy = StandardPolicies::auditor();
662        let filter = RbacFilter::new(policy);
663
664        let sql = "SELECT name FROM users"; // Auditor cannot access users table
665        let stmt = parse_sql(sql);
666
667        let result = filter.rewrite_statement(stmt);
668        assert!(result.is_err());
669        assert!(matches!(result.unwrap_err(), RbacError::AccessDenied(_)));
670    }
671
672    #[test]
673    fn test_rewrite_no_authorized_columns() {
674        let policy = kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
675            .allow_stream("users")
676            .deny_column("*"); // Deny all columns
677
678        let filter = RbacFilter::new(policy);
679
680        let sql = "SELECT name, email FROM users";
681        let stmt = parse_sql(sql);
682
683        let result = filter.rewrite_statement(stmt);
684        assert!(result.is_err());
685        let err = result.unwrap_err();
686        assert!(
687            matches!(err, RbacError::AccessDenied(ref msg) if msg.contains("No authorized columns"))
688        );
689    }
690
691    // -----------------------------------------------------------------
692    // AUDIT-2026-04 M-7 — subquery / nested-SELECT RBAC enforcement.
693    //
694    // Before this fix, `rewrite_statement` only processed the
695    // top-level SELECT. A predicate like
696    //   SELECT id FROM orders WHERE customer IN (SELECT ssn FROM users)
697    // passed through untouched because the inner SELECT was never
698    // visited; `ssn` was exposed despite the user's `deny_column`.
699    //
700    // These tests pin that every nested Query (WHERE IN, EXISTS,
701    // derived table in FROM, UNION branch) is rewritten under the
702    // same policy.
703    // -----------------------------------------------------------------
704
705    fn user_denies_ssn_policy() -> kimberlite_rbac::policy::AccessPolicy {
706        // `users` stream is fully accessible on the allow-list, but
707        // `ssn` is explicitly denied. Any nested reference to
708        // `ssn` must be rejected by the recursive walk.
709        kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
710            .allow_stream("users")
711            .allow_stream("orders")
712            .allow_column("name")
713            .allow_column("email")
714            .allow_column("customer")
715            .allow_column("id")
716            .deny_column("ssn")
717    }
718
719    #[test]
720    fn subquery_rbac_in_where_clause_enforces_inner_grants() {
721        // AUDIT-2026-04 M-7 regression test. Prior to the fix, this
722        // returned `Ok(_)` — `ssn` was never seen by the enforcer.
723        // After the fix, the inner SELECT is rewritten, and since
724        // `ssn` is denied + the inner projection has no other
725        // allowed columns, the whole query is rejected.
726        let filter = RbacFilter::new(user_denies_ssn_policy());
727        let sql = "SELECT id FROM orders WHERE customer IN (SELECT ssn FROM users)";
728        let stmt = parse_sql(sql);
729        let result = filter.rewrite_statement(stmt);
730        assert!(
731            result.is_err(),
732            "nested subquery referencing denied column must be rejected"
733        );
734    }
735
736    #[test]
737    fn subquery_rbac_exists_clause_recurses() {
738        // EXISTS subqueries are rewritten too.
739        let filter = RbacFilter::new(user_denies_ssn_policy());
740        let sql = "SELECT id FROM orders WHERE EXISTS (SELECT ssn FROM users)";
741        let stmt = parse_sql(sql);
742        let result = filter.rewrite_statement(stmt);
743        assert!(
744            result.is_err(),
745            "EXISTS-subquery referencing denied column must be rejected"
746        );
747    }
748
749    #[test]
750    fn subquery_rbac_derived_table_in_from_recurses() {
751        // Derived-table subquery in FROM clause.
752        let filter = RbacFilter::new(user_denies_ssn_policy());
753        let sql = "SELECT t.email FROM (SELECT ssn FROM users) t";
754        let stmt = parse_sql(sql);
755        let result = filter.rewrite_statement(stmt);
756        assert!(
757            result.is_err(),
758            "derived-table SELECT referencing denied column must be rejected"
759        );
760    }
761
762    #[test]
763    fn subquery_rbac_union_both_branches_checked() {
764        // UNION — both branches must pass RBAC. The left branch
765        // asks for `ssn` (denied) → whole query rejected.
766        let filter = RbacFilter::new(user_denies_ssn_policy());
767        let sql = "SELECT ssn FROM users UNION SELECT name FROM users";
768        let stmt = parse_sql(sql);
769        let result = filter.rewrite_statement(stmt);
770        assert!(
771            result.is_err(),
772            "UNION branch referencing denied column must be rejected"
773        );
774    }
775
776    #[test]
777    fn subquery_rbac_allowed_subquery_still_succeeds() {
778        // Sanity check: a subquery that references only allowed
779        // columns is unaffected — the M-7 fix must not introduce
780        // false-positive rejections.
781        let filter = RbacFilter::new(user_denies_ssn_policy());
782        let sql = "SELECT id FROM orders WHERE customer IN (SELECT name FROM users)";
783        let stmt = parse_sql(sql);
784        let result = filter.rewrite_statement(stmt);
785        assert!(
786            result.is_ok(),
787            "all-allowed subquery must pass, got: {:?}",
788            result.err()
789        );
790    }
791
792    #[test]
793    fn subquery_rbac_cte_with_denied_column_rejected() {
794        // CTEs are rewritten before the outer select reads them.
795        let filter = RbacFilter::new(user_denies_ssn_policy());
796        let sql = "WITH u AS (SELECT ssn FROM users) SELECT id FROM orders";
797        let stmt = parse_sql(sql);
798        let result = filter.rewrite_statement(stmt);
799        assert!(
800            result.is_err(),
801            "CTE referencing denied column must be rejected"
802        );
803    }
804
805    #[test]
806    fn subquery_rbac_deeply_nested_three_levels() {
807        // Three levels of nesting — inner-most references denied
808        // column. Recursive walk must reach it.
809        let filter = RbacFilter::new(user_denies_ssn_policy());
810        let sql = "SELECT id FROM orders \
811                   WHERE customer IN ( \
812                     SELECT name FROM users \
813                     WHERE email IN (SELECT ssn FROM users) \
814                   )";
815        let stmt = parse_sql(sql);
816        let result = filter.rewrite_statement(stmt);
817        assert!(
818            result.is_err(),
819            "deeply nested subquery referencing denied column must be rejected"
820        );
821    }
822
823    #[test]
824    fn subquery_rbac_in_list_does_not_recurse_into_values() {
825        // `IN (literal_list)` is NOT a subquery — no recursion
826        // needed. The fix must not trip on regular in-list
827        // predicates.
828        let filter = RbacFilter::new(user_denies_ssn_policy());
829        let sql = "SELECT id FROM orders WHERE customer IN ('alice', 'bob')";
830        let stmt = parse_sql(sql);
831        let result = filter.rewrite_statement(stmt);
832        assert!(
833            result.is_ok(),
834            "in-list with literal values must pass: {:?}",
835            result.err()
836        );
837    }
838}