use crate::error::QueryError;
use kimberlite_rbac::{AccessPolicy, enforcement::PolicyEnforcer};
use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
use thiserror::Error;
use tracing::{debug, info, warn};
#[derive(Debug, Error)]
pub enum RbacError {
#[error("Access denied: {0}")]
AccessDenied(String),
#[error("No authorized columns in query")]
NoAuthorizedColumns,
#[error("Unsupported query type: {0}")]
UnsupportedQuery(String),
#[error("Policy enforcement failed: {0}")]
EnforcementFailed(String),
}
impl From<kimberlite_rbac::enforcement::EnforcementError> for RbacError {
fn from(err: kimberlite_rbac::enforcement::EnforcementError) -> Self {
match err {
kimberlite_rbac::enforcement::EnforcementError::AccessDenied { reason } => {
RbacError::AccessDenied(reason)
}
_ => RbacError::EnforcementFailed(err.to_string()),
}
}
}
impl From<RbacError> for QueryError {
fn from(err: RbacError) -> Self {
QueryError::UnsupportedFeature(err.to_string())
}
}
pub type Result<T> = std::result::Result<T, RbacError>;
#[derive(Debug)]
pub struct RewriteOutput {
pub statement: Statement,
pub column_aliases: Vec<(String, String)>,
}
pub struct RbacFilter {
enforcer: PolicyEnforcer,
}
impl RbacFilter {
pub fn new(policy: AccessPolicy) -> Self {
Self {
enforcer: PolicyEnforcer::new(policy),
}
}
pub fn rewrite_statement(&self, mut stmt: Statement) -> Result<RewriteOutput> {
match &mut stmt {
Statement::Query(query) => {
let column_aliases = self.rewrite_query(query)?;
Ok(RewriteOutput {
statement: stmt,
column_aliases,
})
}
_ => Err(RbacError::UnsupportedQuery(
"Only SELECT queries are currently supported".to_string(),
)),
}
}
fn rewrite_query(&self, query: &mut Query) -> Result<Vec<(String, String)>> {
if let Some(with) = query.with.as_mut() {
for cte in &mut with.cte_tables {
let _ = self.rewrite_query(cte.query.as_mut())?;
}
}
self.rewrite_set_expr(query.body.as_mut())
}
fn rewrite_set_expr(&self, set_expr: &mut SetExpr) -> Result<Vec<(String, String)>> {
match set_expr {
SetExpr::Select(select) => self.rewrite_select(select),
SetExpr::Query(inner) => self.rewrite_query(inner.as_mut()),
SetExpr::SetOperation { left, right, .. } => {
let left_lineage = self.rewrite_set_expr(left.as_mut())?;
let _right_lineage = self.rewrite_set_expr(right.as_mut())?;
Ok(left_lineage)
}
_ => Err(RbacError::UnsupportedQuery(format!(
"unsupported set-expression: {set_expr:?}"
))),
}
}
fn rewrite_select(&self, select: &mut Select) -> Result<Vec<(String, String)>> {
for table_with_joins in &mut select.from {
self.rewrite_table_factor(&mut table_with_joins.relation)?;
for join in &mut table_with_joins.joins {
self.rewrite_table_factor(&mut join.relation)?;
}
}
if let Some(ref mut selection) = select.selection {
self.rewrite_expr_subqueries(selection)?;
}
let stream_name = Self::extract_stream_name(select)?;
debug!(stream = %stream_name, "Extracting columns from SELECT");
let aliases = Self::extract_column_aliases(select)?;
let requested_columns: Vec<String> = aliases.iter().map(|(_, src)| src.clone()).collect();
info!(
stream = %stream_name,
columns = ?requested_columns,
"Enforcing RBAC policy"
);
let (allowed_columns, where_clause_sql) = self
.enforcer
.enforce_query(&stream_name, &requested_columns)?;
if allowed_columns.is_empty() {
warn!(stream = %stream_name, "No authorized columns");
return Err(RbacError::NoAuthorizedColumns);
}
Self::rewrite_projection(select, &allowed_columns);
if !where_clause_sql.is_empty() {
Self::inject_where_clause(select, &where_clause_sql)?;
}
info!(
stream = %stream_name,
allowed_columns = ?allowed_columns,
where_clause = %where_clause_sql,
"Query rewritten successfully"
);
let allowed: std::collections::HashSet<&str> =
allowed_columns.iter().map(String::as_str).collect();
let surviving_aliases = aliases
.into_iter()
.filter(|(_, src)| allowed.contains(src.as_str()))
.collect();
Ok(surviving_aliases)
}
fn rewrite_table_factor(&self, factor: &mut TableFactor) -> Result<()> {
match factor {
TableFactor::Derived { subquery, .. } => {
self.rewrite_query(subquery.as_mut())?;
Ok(())
}
TableFactor::NestedJoin {
table_with_joins, ..
} => {
self.rewrite_table_factor(&mut table_with_joins.relation)?;
for join in &mut table_with_joins.joins {
self.rewrite_table_factor(&mut join.relation)?;
}
Ok(())
}
_ => Ok(()),
}
}
fn rewrite_expr_subqueries(&self, expr: &mut Expr) -> Result<()> {
match expr {
Expr::Subquery(q) | Expr::Exists { subquery: q, .. } => {
self.rewrite_query(q.as_mut())?;
Ok(())
}
Expr::InSubquery {
subquery,
expr: inner,
..
} => {
self.rewrite_expr_subqueries(inner.as_mut())?;
self.rewrite_query(subquery.as_mut())?;
Ok(())
}
Expr::BinaryOp { left, right, .. } => {
self.rewrite_expr_subqueries(left.as_mut())?;
self.rewrite_expr_subqueries(right.as_mut())
}
Expr::UnaryOp { expr: inner, .. } | Expr::Nested(inner) => {
self.rewrite_expr_subqueries(inner.as_mut())
}
Expr::InList {
expr: inner, list, ..
} => {
self.rewrite_expr_subqueries(inner.as_mut())?;
for item in list.iter_mut() {
self.rewrite_expr_subqueries(item)?;
}
Ok(())
}
Expr::Between {
expr: inner,
low,
high,
..
} => {
self.rewrite_expr_subqueries(inner.as_mut())?;
self.rewrite_expr_subqueries(low.as_mut())?;
self.rewrite_expr_subqueries(high.as_mut())
}
Expr::Case {
conditions,
else_result,
..
} => {
for case_when in conditions.iter_mut() {
self.rewrite_expr_subqueries(&mut case_when.condition)?;
self.rewrite_expr_subqueries(&mut case_when.result)?;
}
if let Some(else_r) = else_result.as_mut() {
self.rewrite_expr_subqueries(else_r.as_mut())?;
}
Ok(())
}
_ => Ok(()),
}
}
fn extract_stream_name(select: &Select) -> Result<String> {
if select.from.is_empty() {
return Err(RbacError::UnsupportedQuery(
"SELECT without FROM clause".to_string(),
));
}
let table = &select.from[0];
match &table.relation {
TableFactor::Table { name, .. } => {
let stream_name = name
.0
.iter()
.map(|part| match part.as_ident() {
Some(ident) => ident.value.clone(),
None => part.to_string(),
})
.collect::<Vec<_>>()
.join(".");
Ok(stream_name)
}
_ => Err(RbacError::UnsupportedQuery(
"Only simple table references are supported".to_string(),
)),
}
}
fn extract_column_aliases(select: &Select) -> Result<Vec<(String, String)>> {
column_aliases_from_select(select)
}
}
pub fn column_aliases(stmt: &Statement) -> Vec<(String, String)> {
let Statement::Query(query) = stmt else {
return Vec::new();
};
let SetExpr::Select(select) = query.body.as_ref() else {
return Vec::new();
};
column_aliases_from_select(select).unwrap_or_default()
}
fn column_aliases_from_select(select: &Select) -> Result<Vec<(String, String)>> {
let mut pairs = Vec::new();
for item in &select.projection {
match item {
SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
pairs.push((ident.value.clone(), ident.value.clone()));
}
SelectItem::ExprWithAlias { expr, alias } => {
if let Expr::Identifier(ident) = expr {
pairs.push((alias.value.clone(), ident.value.clone()));
} else {
pairs.push((alias.value.clone(), alias.value.clone()));
}
}
SelectItem::Wildcard(_) => {
return Err(RbacError::UnsupportedQuery(
"SELECT * is not yet supported with RBAC".to_string(),
));
}
_ => {
return Err(RbacError::UnsupportedQuery(format!(
"Unsupported SELECT item: {item:?}"
)));
}
}
}
Ok(pairs)
}
impl RbacFilter {
fn rewrite_projection(select: &mut Select, allowed_columns: &[String]) {
let allowed_set: std::collections::HashSet<_> = allowed_columns.iter().collect();
select.projection.retain(|item| match item {
SelectItem::UnnamedExpr(Expr::Identifier(ident))
| SelectItem::ExprWithAlias {
expr: Expr::Identifier(ident),
..
} => allowed_set.contains(&ident.value),
_ => false,
});
}
fn inject_where_clause(select: &mut Select, where_clause_sql: &str) -> Result<()> {
let where_expr = Self::parse_where_clause(where_clause_sql)?;
select.selection = match select.selection.take() {
Some(existing) => Some(Expr::BinaryOp {
left: Box::new(existing),
op: sqlparser::ast::BinaryOperator::And,
right: Box::new(where_expr),
}),
None => Some(where_expr),
};
Ok(())
}
fn parse_where_clause(where_clause_sql: &str) -> Result<Expr> {
let parts: Vec<&str> = where_clause_sql.split(" AND ").collect();
let mut exprs = Vec::new();
for part in parts {
let tokens: Vec<&str> = part.trim().split('=').collect();
if tokens.len() != 2 {
return Err(RbacError::UnsupportedQuery(format!(
"Invalid WHERE clause: {part}"
)));
}
let column = tokens[0].trim();
let value = tokens[1].trim();
let expr = Expr::BinaryOp {
left: Box::new(Expr::Identifier(sqlparser::ast::Ident::new(column))),
op: sqlparser::ast::BinaryOperator::Eq,
right: Box::new(Expr::Value(
sqlparser::ast::Value::Number(value.to_string(), false).into(),
)),
};
exprs.push(expr);
}
let mut iter = exprs.into_iter();
let mut result = iter
.next()
.ok_or_else(|| RbacError::UnsupportedQuery("Empty WHERE clause".to_string()))?;
for expr in iter {
result = Expr::BinaryOp {
left: Box::new(result),
op: sqlparser::ast::BinaryOperator::And,
right: Box::new(expr),
};
}
Ok(result)
}
pub fn enforcer(&self) -> &PolicyEnforcer {
&self.enforcer
}
}
#[cfg(test)]
mod tests {
use super::*;
use kimberlite_rbac::policy::StandardPolicies;
use kimberlite_types::TenantId;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
fn parse_sql(sql: &str) -> Statement {
let dialect = GenericDialect {};
let statements = Parser::parse_sql(&dialect, sql).expect("Failed to parse SQL");
statements.into_iter().next().expect("No statement parsed")
}
#[test]
fn test_rewrite_admin_policy() {
let policy = StandardPolicies::admin();
let filter = RbacFilter::new(policy);
let sql = "SELECT name, email FROM users";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(result.is_ok());
}
#[test]
fn test_rewrite_user_policy_column_filter() {
let policy = kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
.allow_stream("users")
.allow_column("name")
.deny_column("ssn");
let filter = RbacFilter::new(policy);
let sql = "SELECT name, ssn FROM users";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(result.is_ok());
if let Statement::Query(query) = result.unwrap().statement {
if let SetExpr::Select(select) = query.body.as_ref() {
assert_eq!(select.projection.len(), 1);
}
}
}
#[test]
fn test_rewrite_with_row_filter() {
let tenant_id = TenantId::new(42);
let policy = StandardPolicies::user(tenant_id);
let filter = RbacFilter::new(policy);
let sql = "SELECT name, email FROM users";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(result.is_ok());
if let Statement::Query(query) = result.unwrap().statement {
if let SetExpr::Select(select) = query.body.as_ref() {
assert!(select.selection.is_some());
}
}
}
#[test]
fn test_rewrite_access_denied() {
let policy = StandardPolicies::auditor();
let filter = RbacFilter::new(policy);
let sql = "SELECT name FROM users"; let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), RbacError::AccessDenied(_)));
}
#[test]
fn test_rewrite_no_authorized_columns() {
let policy = kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
.allow_stream("users")
.deny_column("*");
let filter = RbacFilter::new(policy);
let sql = "SELECT name, email FROM users";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, RbacError::AccessDenied(ref msg) if msg.contains("No authorized columns"))
);
}
fn user_denies_ssn_policy() -> kimberlite_rbac::policy::AccessPolicy {
kimberlite_rbac::policy::AccessPolicy::new(kimberlite_rbac::roles::Role::User)
.allow_stream("users")
.allow_stream("orders")
.allow_column("name")
.allow_column("email")
.allow_column("customer")
.allow_column("id")
.deny_column("ssn")
}
#[test]
fn subquery_rbac_in_where_clause_enforces_inner_grants() {
let filter = RbacFilter::new(user_denies_ssn_policy());
let sql = "SELECT id FROM orders WHERE customer IN (SELECT ssn FROM users)";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(
result.is_err(),
"nested subquery referencing denied column must be rejected"
);
}
#[test]
fn subquery_rbac_exists_clause_recurses() {
let filter = RbacFilter::new(user_denies_ssn_policy());
let sql = "SELECT id FROM orders WHERE EXISTS (SELECT ssn FROM users)";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(
result.is_err(),
"EXISTS-subquery referencing denied column must be rejected"
);
}
#[test]
fn subquery_rbac_derived_table_in_from_recurses() {
let filter = RbacFilter::new(user_denies_ssn_policy());
let sql = "SELECT t.email FROM (SELECT ssn FROM users) t";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(
result.is_err(),
"derived-table SELECT referencing denied column must be rejected"
);
}
#[test]
fn subquery_rbac_union_both_branches_checked() {
let filter = RbacFilter::new(user_denies_ssn_policy());
let sql = "SELECT ssn FROM users UNION SELECT name FROM users";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(
result.is_err(),
"UNION branch referencing denied column must be rejected"
);
}
#[test]
fn subquery_rbac_allowed_subquery_still_succeeds() {
let filter = RbacFilter::new(user_denies_ssn_policy());
let sql = "SELECT id FROM orders WHERE customer IN (SELECT name FROM users)";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(
result.is_ok(),
"all-allowed subquery must pass, got: {:?}",
result.err()
);
}
#[test]
fn subquery_rbac_cte_with_denied_column_rejected() {
let filter = RbacFilter::new(user_denies_ssn_policy());
let sql = "WITH u AS (SELECT ssn FROM users) SELECT id FROM orders";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(
result.is_err(),
"CTE referencing denied column must be rejected"
);
}
#[test]
fn subquery_rbac_deeply_nested_three_levels() {
let filter = RbacFilter::new(user_denies_ssn_policy());
let sql = "SELECT id FROM orders \
WHERE customer IN ( \
SELECT name FROM users \
WHERE email IN (SELECT ssn FROM users) \
)";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(
result.is_err(),
"deeply nested subquery referencing denied column must be rejected"
);
}
#[test]
fn subquery_rbac_in_list_does_not_recurse_into_values() {
let filter = RbacFilter::new(user_denies_ssn_policy());
let sql = "SELECT id FROM orders WHERE customer IN ('alice', 'bob')";
let stmt = parse_sql(sql);
let result = filter.rewrite_statement(stmt);
assert!(
result.is_ok(),
"in-list with literal values must pass: {:?}",
result.err()
);
}
}