Skip to main content

kimberlite_query/
rbac_filter.rs

1//! RBAC query filtering and rewriting.
2//!
3//! This module provides query rewriting to enforce RBAC policies:
4//! - **Column filtering**: Remove unauthorized columns from SELECT
5//! - **Row-level security**: Inject WHERE clauses
6//!
7//! ## Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────────┐
11//! │  Original Query                      │
12//! │  SELECT name, ssn FROM patients      │
13//! └───────────────┬─────────────────────┘
14//!                 │
15//!                 ▼
16//! ┌─────────────────────────────────────┐
17//! │  RBAC Filter                         │
18//! │  - Check stream access               │
19//! │  - Filter columns (remove "ssn")     │
20//! │  - Inject WHERE clause               │
21//! └───────────────┬─────────────────────┘
22//!                 │
23//!                 ▼
24//! ┌─────────────────────────────────────┐
25//! │  Rewritten Query                     │
26//! │  SELECT name FROM patients           │
27//! │  WHERE tenant_id = 42                │
28//! └─────────────────────────────────────┘
29//! ```
30
31use crate::error::QueryError;
32use kimberlite_rbac::{AccessPolicy, enforcement::PolicyEnforcer};
33use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
34use thiserror::Error;
35use tracing::{debug, info, warn};
36
37/// Error type for RBAC filtering.
38#[derive(Debug, Error)]
39pub enum RbacError {
40    /// Access denied by policy.
41    #[error("Access denied: {0}")]
42    AccessDenied(String),
43
44    /// No authorized columns in query.
45    #[error("No authorized columns in query")]
46    NoAuthorizedColumns,
47
48    /// Unsupported query type for RBAC.
49    #[error("Unsupported query type: {0}")]
50    UnsupportedQuery(String),
51
52    /// Policy enforcement failed.
53    #[error("Policy enforcement failed: {0}")]
54    EnforcementFailed(String),
55}
56
57impl From<kimberlite_rbac::enforcement::EnforcementError> for RbacError {
58    fn from(err: kimberlite_rbac::enforcement::EnforcementError) -> Self {
59        match err {
60            kimberlite_rbac::enforcement::EnforcementError::AccessDenied { reason } => {
61                RbacError::AccessDenied(reason)
62            }
63            _ => RbacError::EnforcementFailed(err.to_string()),
64        }
65    }
66}
67
68impl From<RbacError> for QueryError {
69    fn from(err: RbacError) -> Self {
70        QueryError::UnsupportedFeature(err.to_string())
71    }
72}
73
74/// Result type for RBAC operations.
75pub type Result<T> = std::result::Result<T, RbacError>;
76
77/// Output of [`RbacFilter::rewrite_statement`].
78///
79/// Carries the rewritten statement alongside the alias mapping derived
80/// from the original projection. Downstream code (e.g. the masking pass
81/// in `kimberlite`) uses the mapping to resolve output column names
82/// back to their source columns so masks are applied to the underlying
83/// sensitive attribute, not the user-visible alias.
84#[derive(Debug)]
85pub struct RewriteOutput {
86    /// The rewritten SQL statement.
87    pub statement: Statement,
88    /// Pairs of `(output_column_name, source_column_name)` for each
89    /// projection item that survived RBAC filtering.
90    ///
91    /// Bare identifiers produce pairs where both entries are equal.
92    /// Aliased identifiers (`SELECT ssn AS id`) produce distinct
93    /// output/source entries — the masking pass must key its lookup
94    /// on the source entry (AUDIT-2026-04 M-7).
95    pub column_aliases: Vec<(String, String)>,
96}
97
98/// RBAC query filter.
99///
100/// Rewrites SQL queries to enforce access control policies.
101pub struct RbacFilter {
102    enforcer: PolicyEnforcer,
103}
104
105impl RbacFilter {
106    /// Creates a new RBAC filter with the given policy.
107    pub fn new(policy: AccessPolicy) -> Self {
108        Self {
109            enforcer: PolicyEnforcer::new(policy),
110        }
111    }
112
113    /// Rewrites a SQL statement to enforce RBAC policy.
114    ///
115    /// **Transformations:**
116    /// 1. Check stream access (deny if unauthorized)
117    /// 2. Filter SELECT columns (remove unauthorized columns)
118    /// 3. Inject WHERE clause for row-level security
119    ///
120    /// # Arguments
121    ///
122    /// * `stmt` - SQL statement to rewrite
123    ///
124    /// # Returns
125    ///
126    /// Rewritten statement plus a map of `(output_column_name,
127    /// source_column_name)` pairs — one entry per projection item that
128    /// survived RBAC filtering. The masking pass uses this map to look
129    /// up column masks by source column rather than by the
130    /// potentially-aliased output name (AUDIT-2026-04 M-7).
131    ///
132    /// # Errors
133    ///
134    /// - `AccessDenied` if stream access is denied
135    /// - `NoAuthorizedColumns` if all columns are unauthorized
136    /// - `UnsupportedQuery` if query type is not supported
137    pub fn rewrite_statement(&self, mut stmt: Statement) -> Result<RewriteOutput> {
138        match &mut stmt {
139            Statement::Query(query) => {
140                let column_aliases = self.rewrite_query(query)?;
141                Ok(RewriteOutput {
142                    statement: stmt,
143                    column_aliases,
144                })
145            }
146            _ => Err(RbacError::UnsupportedQuery(
147                "Only SELECT queries are currently supported".to_string(),
148            )),
149        }
150    }
151
152    /// Rewrites a query to enforce RBAC.
153    ///
154    /// **AUDIT-2026-04 M-7 — recursive traversal.** Prior to this
155    /// change, only the top-level `SetExpr::Select` was rewritten,
156    /// so a predicate like
157    /// `SELECT id FROM t WHERE x IN (SELECT ssn FROM patients)`
158    /// would bypass column filtering on `ssn`. The recursive walk
159    /// below ensures every nested `Query` (CTE / UNION / subquery
160    /// in FROM / subquery in WHERE) is rewritten under the same
161    /// policy before the outer select is processed.
162    fn rewrite_query(&self, query: &mut Query) -> Result<Vec<(String, String)>> {
163        // 1. Rewrite CTEs first — later referenced by name in the
164        //    main set-expression, so their filtering must land
165        //    before the outer select reads them.
166        if let Some(with) = query.with.as_mut() {
167            for cte in with.cte_tables.iter_mut() {
168                // CTEs themselves cannot leak if the outer select
169                // never references the denied columns — but we
170                // rewrite defensively so that any CTE reference
171                // through `SELECT * FROM cte_name` (once wildcard
172                // support lands) does not expose masked sources.
173                let _ = self.rewrite_query(cte.query.as_mut())?;
174            }
175        }
176
177        // 2. Dispatch on set-expression shape.
178        self.rewrite_set_expr(query.body.as_mut())
179    }
180
181    /// Recursively rewrites a `SetExpr`, returning the column
182    /// lineage for the *representative* select (the left-most
183    /// branch of a UNION, or the inner select of a parenthesised
184    /// query).
185    ///
186    /// UNION branches must all satisfy the policy independently —
187    /// if any branch references a denied column, the whole query
188    /// is rejected.
189    fn rewrite_set_expr(&self, set_expr: &mut SetExpr) -> Result<Vec<(String, String)>> {
190        match set_expr {
191            SetExpr::Select(select) => self.rewrite_select(select),
192            // Parenthesised query — recurse.
193            SetExpr::Query(inner) => self.rewrite_query(inner.as_mut()),
194            // UNION / INTERSECT / EXCEPT — every branch must pass
195            // RBAC independently. The outer lineage comes from the
196            // left branch (all branches must have compatible
197            // column counts, so either branch's lineage is a valid
198            // descriptor; we use left for determinism).
199            SetExpr::SetOperation { left, right, .. } => {
200                let left_lineage = self.rewrite_set_expr(left.as_mut())?;
201                let _right_lineage = self.rewrite_set_expr(right.as_mut())?;
202                Ok(left_lineage)
203            }
204            _ => Err(RbacError::UnsupportedQuery(format!(
205                "unsupported set-expression: {set_expr:?}"
206            ))),
207        }
208    }
209
210    /// Rewrites a SELECT statement. Returns the `(output, source)`
211    /// column pairs for the surviving projection items.
212    fn rewrite_select(&self, select: &mut Select) -> Result<Vec<(String, String)>> {
213        // AUDIT-2026-04 M-7 — subquery / nested-SELECT recursion.
214        //
215        // Step 0a: rewrite any `TableFactor::Derived { subquery }`
216        // in the FROM clause. A predicate that reads
217        // `SELECT outer.x FROM (SELECT ssn AS x FROM patients) outer`
218        // was previously accepted because `extract_stream_name` only
219        // saw the outer derived-table reference — the inner SELECT
220        // was never filtered against the `patients.ssn` deny policy.
221        // Now the inner SELECT is rewritten first; if it references
222        // a denied column it errors out here, before any outer
223        // lineage is reported.
224        for table_with_joins in select.from.iter_mut() {
225            self.rewrite_table_factor(&mut table_with_joins.relation)?;
226            for join in table_with_joins.joins.iter_mut() {
227                self.rewrite_table_factor(&mut join.relation)?;
228            }
229        }
230
231        // Step 0b: rewrite subqueries inside the WHERE clause.
232        // Handles `IN (SELECT ...)`, `EXISTS (SELECT ...)`, and
233        // scalar-subquery forms. The traversal is read-mutable
234        // because the inner rewrite replaces column projections.
235        if let Some(ref mut selection) = select.selection {
236            self.rewrite_expr_subqueries(selection)?;
237        }
238
239        // 1. Extract stream name from FROM clause
240        let stream_name = Self::extract_stream_name(select)?;
241
242        debug!(stream = %stream_name, "Extracting columns from SELECT");
243
244        // 2. Extract requested columns (source names) and aliases
245        let aliases = Self::extract_column_aliases(select)?;
246        let requested_columns: Vec<String> =
247            aliases.iter().map(|(_, src)| src.clone()).collect();
248
249        info!(
250            stream = %stream_name,
251            columns = ?requested_columns,
252            "Enforcing RBAC policy"
253        );
254
255        // 3. Enforce policy (checks stream access + filters columns)
256        let (allowed_columns, where_clause_sql) = self
257            .enforcer
258            .enforce_query(&stream_name, &requested_columns)?;
259
260        if allowed_columns.is_empty() {
261            warn!(stream = %stream_name, "No authorized columns");
262            return Err(RbacError::NoAuthorizedColumns);
263        }
264
265        // 4. Rewrite SELECT projection (filter columns)
266        Self::rewrite_projection(select, &allowed_columns);
267
268        // 5. Inject WHERE clause for row-level security
269        if !where_clause_sql.is_empty() {
270            Self::inject_where_clause(select, &where_clause_sql)?;
271        }
272
273        info!(
274            stream = %stream_name,
275            allowed_columns = ?allowed_columns,
276            where_clause = %where_clause_sql,
277            "Query rewritten successfully"
278        );
279
280        // 6. Trim the alias map to the surviving projection.
281        let allowed: std::collections::HashSet<&str> =
282            allowed_columns.iter().map(String::as_str).collect();
283        let surviving_aliases = aliases
284            .into_iter()
285            .filter(|(_, src)| allowed.contains(src.as_str()))
286            .collect();
287
288        Ok(surviving_aliases)
289    }
290
291    /// AUDIT-2026-04 M-7 helper — recurse into nested queries
292    /// carried by a `TableFactor`.
293    ///
294    /// `TableFactor::Derived { subquery }` is the AST node for
295    /// `FROM (SELECT ...)`. `TableFactor::NestedJoin` wraps a
296    /// `TableWithJoins` that may itself contain derived tables.
297    /// Anything else is a terminal table reference handled by
298    /// `extract_stream_name` downstream.
299    fn rewrite_table_factor(&self, factor: &mut TableFactor) -> Result<()> {
300        match factor {
301            TableFactor::Derived { subquery, .. } => {
302                self.rewrite_query(subquery.as_mut())?;
303                Ok(())
304            }
305            TableFactor::NestedJoin {
306                table_with_joins, ..
307            } => {
308                self.rewrite_table_factor(&mut table_with_joins.relation)?;
309                for join in table_with_joins.joins.iter_mut() {
310                    self.rewrite_table_factor(&mut join.relation)?;
311                }
312                Ok(())
313            }
314            _ => Ok(()),
315        }
316    }
317
318    /// AUDIT-2026-04 M-7 helper — recurse into subqueries embedded
319    /// in a WHERE-clause `Expr`.
320    ///
321    /// Walks `Expr::Subquery`, `Expr::InSubquery`, `Expr::Exists`,
322    /// and combinators (`BinaryOp`, `UnaryOp`, `Nested`) that can
323    /// transport a subquery in their children. Non-subquery leaves
324    /// (identifiers, literals) are terminal.
325    ///
326    /// A bounded-depth guard would belong here if the recursive
327    /// kernel principle forbade it; the query parser already
328    /// rejects SQL with unbounded expression depth before reaching
329    /// this point, so we rely on the sqlparser limit.
330    fn rewrite_expr_subqueries(&self, expr: &mut Expr) -> Result<()> {
331        match expr {
332            Expr::Subquery(q) | Expr::Exists { subquery: q, .. } => {
333                self.rewrite_query(q.as_mut())?;
334                Ok(())
335            }
336            Expr::InSubquery { subquery, expr: inner, .. } => {
337                self.rewrite_expr_subqueries(inner.as_mut())?;
338                self.rewrite_query(subquery.as_mut())?;
339                Ok(())
340            }
341            Expr::BinaryOp { left, right, .. } => {
342                self.rewrite_expr_subqueries(left.as_mut())?;
343                self.rewrite_expr_subqueries(right.as_mut())
344            }
345            Expr::UnaryOp { expr: inner, .. } => self.rewrite_expr_subqueries(inner.as_mut()),
346            Expr::Nested(inner) => self.rewrite_expr_subqueries(inner.as_mut()),
347            Expr::InList { expr: inner, list, .. } => {
348                self.rewrite_expr_subqueries(inner.as_mut())?;
349                for item in list.iter_mut() {
350                    self.rewrite_expr_subqueries(item)?;
351                }
352                Ok(())
353            }
354            Expr::Between {
355                expr: inner,
356                low,
357                high,
358                ..
359            } => {
360                self.rewrite_expr_subqueries(inner.as_mut())?;
361                self.rewrite_expr_subqueries(low.as_mut())?;
362                self.rewrite_expr_subqueries(high.as_mut())
363            }
364            Expr::Case {
365                conditions,
366                results,
367                else_result,
368                ..
369            } => {
370                for c in conditions.iter_mut() {
371                    self.rewrite_expr_subqueries(c)?;
372                }
373                for r in results.iter_mut() {
374                    self.rewrite_expr_subqueries(r)?;
375                }
376                if let Some(else_r) = else_result.as_mut() {
377                    self.rewrite_expr_subqueries(else_r.as_mut())?;
378                }
379                Ok(())
380            }
381            // Identifiers, literals, function calls without subquery
382            // arguments, etc. — nothing to rewrite.
383            _ => Ok(()),
384        }
385    }
386
387    /// Extracts the stream name from the FROM clause.
388    fn extract_stream_name(select: &Select) -> Result<String> {
389        if select.from.is_empty() {
390            return Err(RbacError::UnsupportedQuery(
391                "SELECT without FROM clause".to_string(),
392            ));
393        }
394
395        let table = &select.from[0];
396        match &table.relation {
397            TableFactor::Table { name, .. } => {
398                let stream_name = name
399                    .0
400                    .iter()
401                    .map(|i| i.value.as_str())
402                    .collect::<Vec<_>>()
403                    .join(".");
404                Ok(stream_name)
405            }
406            _ => Err(RbacError::UnsupportedQuery(
407                "Only simple table references are supported".to_string(),
408            )),
409        }
410    }
411
412    /// Extracts `(output_column_name, source_column_name)` pairs for
413    /// each item in the SELECT projection. See [`column_aliases`] for
414    /// the free-function entry point used by the SQL-level mask pass.
415    fn extract_column_aliases(select: &Select) -> Result<Vec<(String, String)>> {
416        column_aliases_from_select(select)
417    }
418}
419
420/// Extracts `(output_column_name, source_column_name)` pairs for each
421/// item in the SELECT projection of `stmt`.
422///
423/// Returns an empty vector for non-`SELECT` statements or for set-expr
424/// bodies that are not a plain `SELECT` (e.g. `UNION`) — the masking
425/// pass treats an empty map as "no aliases known" and falls back to
426/// output-name keying, matching pre-M-7 semantics for those shapes.
427///
428/// Semantics:
429/// - `SELECT col` → `("col", "col")`
430/// - `SELECT col AS alias` → `("alias", "col")`
431/// - `SELECT UPPER(col) AS alias` → `("alias", "alias")` (non-identifier
432///   expressions cannot be resolved to a source column — mask lookup
433///   keys on the alias, mirroring the pre-M-7 behaviour).
434///
435/// AUDIT-2026-04 M-7: the masking pass uses the source half of the
436/// pair to look up column masks. Without this, `SELECT ssn AS id FROM
437/// patients` passed RBAC (source `ssn` is permitted) but
438/// `mask_for_column("id")` returned `None`, leaking the masked
439/// attribute under a rename.
440pub fn column_aliases(stmt: &Statement) -> Vec<(String, String)> {
441    let Statement::Query(query) = stmt else {
442        return Vec::new();
443    };
444    let SetExpr::Select(select) = query.body.as_ref() else {
445        return Vec::new();
446    };
447    column_aliases_from_select(select).unwrap_or_default()
448}
449
450fn column_aliases_from_select(select: &Select) -> Result<Vec<(String, String)>> {
451    let mut pairs = Vec::new();
452
453    for item in &select.projection {
454        match item {
455            SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
456                pairs.push((ident.value.clone(), ident.value.clone()));
457            }
458            SelectItem::ExprWithAlias { expr, alias } => {
459                if let Expr::Identifier(ident) = expr {
460                    pairs.push((alias.value.clone(), ident.value.clone()));
461                } else {
462                    pairs.push((alias.value.clone(), alias.value.clone()));
463                }
464            }
465            SelectItem::Wildcard(_) => {
466                return Err(RbacError::UnsupportedQuery(
467                    "SELECT * is not yet supported with RBAC".to_string(),
468                ));
469            }
470            _ => {
471                return Err(RbacError::UnsupportedQuery(format!(
472                    "Unsupported SELECT item: {item:?}"
473                )));
474            }
475        }
476    }
477
478    Ok(pairs)
479}
480
481impl RbacFilter {
482
483    /// Rewrites the SELECT projection to include only allowed columns.
484    fn rewrite_projection(select: &mut Select, allowed_columns: &[String]) {
485        let allowed_set: std::collections::HashSet<_> = allowed_columns.iter().collect();
486
487        select.projection.retain(|item| match item {
488            SelectItem::UnnamedExpr(Expr::Identifier(ident))
489            | SelectItem::ExprWithAlias {
490                expr: Expr::Identifier(ident),
491                ..
492            } => allowed_set.contains(&ident.value),
493            _ => false,
494        });
495    }
496
497    /// Injects a WHERE clause for row-level security.
498    fn inject_where_clause(select: &mut Select, where_clause_sql: &str) -> Result<()> {
499        // Parse the WHERE clause SQL into an Expr
500        let where_expr = Self::parse_where_clause(where_clause_sql)?;
501
502        // Combine with existing WHERE clause (if any)
503        select.selection = match select.selection.take() {
504            Some(existing) => Some(Expr::BinaryOp {
505                left: Box::new(existing),
506                op: sqlparser::ast::BinaryOperator::And,
507                right: Box::new(where_expr),
508            }),
509            None => Some(where_expr),
510        };
511
512        Ok(())
513    }
514
515    /// Parses a WHERE clause SQL string into an Expr.
516    ///
517    /// # Security boundary
518    ///
519    /// This function is **only ever called with trusted `RowFilter` values** generated
520    /// internally by the RBAC policy engine (see [`PolicyEnforcer::row_filter`]).
521    /// It is **not** called with user-supplied SQL strings and is therefore not a
522    /// SQL-injection vector.  If you ever call this with data derived from user input,
523    /// you MUST validate/sanitize the input first.
524    ///
525    /// The parser handles `column = value` predicates joined by `AND`.  It produces
526    /// AST nodes directly (not concatenated SQL), so the result is safe to pass to
527    /// the query planner without further escaping.
528    ///
529    /// More complex predicates may require the full SQL parser.
530    fn parse_where_clause(where_clause_sql: &str) -> Result<Expr> {
531        // Simple parser for "column = value" and "column1 = value1 AND column2 = value2".
532        // SAFETY: Only called with trusted, internally-generated RowFilter strings.
533        let parts: Vec<&str> = where_clause_sql.split(" AND ").collect();
534
535        let mut exprs = Vec::new();
536
537        for part in parts {
538            // Parse "column = value"
539            let tokens: Vec<&str> = part.trim().split('=').collect();
540            if tokens.len() != 2 {
541                return Err(RbacError::UnsupportedQuery(format!(
542                    "Invalid WHERE clause: {part}"
543                )));
544            }
545
546            let column = tokens[0].trim();
547            let value = tokens[1].trim();
548
549            let expr = Expr::BinaryOp {
550                left: Box::new(Expr::Identifier(sqlparser::ast::Ident::new(column))),
551                op: sqlparser::ast::BinaryOperator::Eq,
552                right: Box::new(Expr::Value(sqlparser::ast::Value::Number(
553                    value.to_string(),
554                    false,
555                ))),
556            };
557
558            exprs.push(expr);
559        }
560
561        // Combine with AND
562        let mut iter = exprs.into_iter();
563        let mut result = iter
564            .next()
565            .ok_or_else(|| RbacError::UnsupportedQuery("Empty WHERE clause".to_string()))?;
566
567        for expr in iter {
568            result = Expr::BinaryOp {
569                left: Box::new(result),
570                op: sqlparser::ast::BinaryOperator::And,
571                right: Box::new(expr),
572            };
573        }
574
575        Ok(result)
576    }
577
578    /// Returns the underlying policy enforcer.
579    pub fn enforcer(&self) -> &PolicyEnforcer {
580        &self.enforcer
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587    use kimberlite_rbac::policy::StandardPolicies;
588    use kimberlite_types::TenantId;
589    use sqlparser::dialect::GenericDialect;
590    use sqlparser::parser::Parser;
591
592    fn parse_sql(sql: &str) -> Statement {
593        let dialect = GenericDialect {};
594        let statements = Parser::parse_sql(&dialect, sql).expect("Failed to parse SQL");
595        statements.into_iter().next().expect("No statement parsed")
596    }
597
598    #[test]
599    fn test_rewrite_admin_policy() {
600        let policy = StandardPolicies::admin();
601        let filter = RbacFilter::new(policy);
602
603        let sql = "SELECT name, email FROM users";
604        let stmt = parse_sql(sql);
605
606        let result = filter.rewrite_statement(stmt);
607        assert!(result.is_ok());
608    }
609
610    #[test]
611    fn test_rewrite_user_policy_column_filter() {
612        let policy = kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
613            .allow_stream("users")
614            .allow_column("name")
615            .deny_column("ssn");
616
617        let filter = RbacFilter::new(policy);
618
619        let sql = "SELECT name, ssn FROM users";
620        let stmt = parse_sql(sql);
621
622        let result = filter.rewrite_statement(stmt);
623        assert!(result.is_ok());
624
625        // Check that ssn was filtered out
626        if let Statement::Query(query) = result.unwrap().statement {
627            if let SetExpr::Select(select) = query.body.as_ref() {
628                assert_eq!(select.projection.len(), 1);
629                // Should only have "name" column
630            }
631        }
632    }
633
634    #[test]
635    fn test_rewrite_with_row_filter() {
636        let tenant_id = TenantId::new(42);
637        let policy = StandardPolicies::user(tenant_id);
638        let filter = RbacFilter::new(policy);
639
640        let sql = "SELECT name, email FROM users";
641        let stmt = parse_sql(sql);
642
643        let result = filter.rewrite_statement(stmt);
644        assert!(result.is_ok());
645
646        // Check that WHERE clause was injected
647        if let Statement::Query(query) = result.unwrap().statement {
648            if let SetExpr::Select(select) = query.body.as_ref() {
649                assert!(select.selection.is_some());
650                // Should have WHERE tenant_id = 42
651            }
652        }
653    }
654
655    #[test]
656    fn test_rewrite_access_denied() {
657        let policy = StandardPolicies::auditor();
658        let filter = RbacFilter::new(policy);
659
660        let sql = "SELECT name FROM users"; // Auditor cannot access users table
661        let stmt = parse_sql(sql);
662
663        let result = filter.rewrite_statement(stmt);
664        assert!(result.is_err());
665        assert!(matches!(result.unwrap_err(), RbacError::AccessDenied(_)));
666    }
667
668    #[test]
669    fn test_rewrite_no_authorized_columns() {
670        let policy = kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
671            .allow_stream("users")
672            .deny_column("*"); // Deny all columns
673
674        let filter = RbacFilter::new(policy);
675
676        let sql = "SELECT name, email FROM users";
677        let stmt = parse_sql(sql);
678
679        let result = filter.rewrite_statement(stmt);
680        assert!(result.is_err());
681        let err = result.unwrap_err();
682        assert!(
683            matches!(err, RbacError::AccessDenied(ref msg) if msg.contains("No authorized columns"))
684        );
685    }
686
687    // -----------------------------------------------------------------
688    // AUDIT-2026-04 M-7 — subquery / nested-SELECT RBAC enforcement.
689    //
690    // Before this fix, `rewrite_statement` only processed the
691    // top-level SELECT. A predicate like
692    //   SELECT id FROM orders WHERE customer IN (SELECT ssn FROM users)
693    // passed through untouched because the inner SELECT was never
694    // visited; `ssn` was exposed despite the user's `deny_column`.
695    //
696    // These tests pin that every nested Query (WHERE IN, EXISTS,
697    // derived table in FROM, UNION branch) is rewritten under the
698    // same policy.
699    // -----------------------------------------------------------------
700
701    fn user_denies_ssn_policy() -> kimberlite_rbac::policy::AccessPolicy {
702        // `users` stream is fully accessible on the allow-list, but
703        // `ssn` is explicitly denied. Any nested reference to
704        // `ssn` must be rejected by the recursive walk.
705        kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
706            .allow_stream("users")
707            .allow_stream("orders")
708            .allow_column("name")
709            .allow_column("email")
710            .allow_column("customer")
711            .allow_column("id")
712            .deny_column("ssn")
713    }
714
715    #[test]
716    fn subquery_rbac_in_where_clause_enforces_inner_grants() {
717        // AUDIT-2026-04 M-7 regression test. Prior to the fix, this
718        // returned `Ok(_)` — `ssn` was never seen by the enforcer.
719        // After the fix, the inner SELECT is rewritten, and since
720        // `ssn` is denied + the inner projection has no other
721        // allowed columns, the whole query is rejected.
722        let filter = RbacFilter::new(user_denies_ssn_policy());
723        let sql =
724            "SELECT id FROM orders WHERE customer IN (SELECT ssn FROM users)";
725        let stmt = parse_sql(sql);
726        let result = filter.rewrite_statement(stmt);
727        assert!(
728            result.is_err(),
729            "nested subquery referencing denied column must be rejected"
730        );
731    }
732
733    #[test]
734    fn subquery_rbac_exists_clause_recurses() {
735        // EXISTS subqueries are rewritten too.
736        let filter = RbacFilter::new(user_denies_ssn_policy());
737        let sql =
738            "SELECT id FROM orders WHERE EXISTS (SELECT ssn FROM users)";
739        let stmt = parse_sql(sql);
740        let result = filter.rewrite_statement(stmt);
741        assert!(
742            result.is_err(),
743            "EXISTS-subquery referencing denied column must be rejected"
744        );
745    }
746
747    #[test]
748    fn subquery_rbac_derived_table_in_from_recurses() {
749        // Derived-table subquery in FROM clause.
750        let filter = RbacFilter::new(user_denies_ssn_policy());
751        let sql =
752            "SELECT t.email FROM (SELECT ssn FROM users) t";
753        let stmt = parse_sql(sql);
754        let result = filter.rewrite_statement(stmt);
755        assert!(
756            result.is_err(),
757            "derived-table SELECT referencing denied column must be rejected"
758        );
759    }
760
761    #[test]
762    fn subquery_rbac_union_both_branches_checked() {
763        // UNION — both branches must pass RBAC. The left branch
764        // asks for `ssn` (denied) → whole query rejected.
765        let filter = RbacFilter::new(user_denies_ssn_policy());
766        let sql =
767            "SELECT ssn FROM users UNION SELECT name FROM users";
768        let stmt = parse_sql(sql);
769        let result = filter.rewrite_statement(stmt);
770        assert!(
771            result.is_err(),
772            "UNION branch referencing denied column must be rejected"
773        );
774    }
775
776    #[test]
777    fn subquery_rbac_allowed_subquery_still_succeeds() {
778        // Sanity check: a subquery that references only allowed
779        // columns is unaffected — the M-7 fix must not introduce
780        // false-positive rejections.
781        let filter = RbacFilter::new(user_denies_ssn_policy());
782        let sql =
783            "SELECT id FROM orders WHERE customer IN (SELECT name FROM users)";
784        let stmt = parse_sql(sql);
785        let result = filter.rewrite_statement(stmt);
786        assert!(
787            result.is_ok(),
788            "all-allowed subquery must pass, got: {:?}",
789            result.err()
790        );
791    }
792
793    #[test]
794    fn subquery_rbac_cte_with_denied_column_rejected() {
795        // CTEs are rewritten before the outer select reads them.
796        let filter = RbacFilter::new(user_denies_ssn_policy());
797        let sql = "WITH u AS (SELECT ssn FROM users) SELECT id FROM orders";
798        let stmt = parse_sql(sql);
799        let result = filter.rewrite_statement(stmt);
800        assert!(
801            result.is_err(),
802            "CTE referencing denied column must be rejected"
803        );
804    }
805
806    #[test]
807    fn subquery_rbac_deeply_nested_three_levels() {
808        // Three levels of nesting — inner-most references denied
809        // column. Recursive walk must reach it.
810        let filter = RbacFilter::new(user_denies_ssn_policy());
811        let sql = "SELECT id FROM orders \
812                   WHERE customer IN ( \
813                     SELECT name FROM users \
814                     WHERE email IN (SELECT ssn FROM users) \
815                   )";
816        let stmt = parse_sql(sql);
817        let result = filter.rewrite_statement(stmt);
818        assert!(
819            result.is_err(),
820            "deeply nested subquery referencing denied column must be rejected"
821        );
822    }
823
824    #[test]
825    fn subquery_rbac_in_list_does_not_recurse_into_values() {
826        // `IN (literal_list)` is NOT a subquery — no recursion
827        // needed. The fix must not trip on regular in-list
828        // predicates.
829        let filter = RbacFilter::new(user_denies_ssn_policy());
830        let sql =
831            "SELECT id FROM orders WHERE customer IN ('alice', 'bob')";
832        let stmt = parse_sql(sql);
833        let result = filter.rewrite_statement(stmt);
834        assert!(
835            result.is_ok(),
836            "in-list with literal values must pass: {:?}",
837            result.err()
838        );
839    }
840}