use std::collections::HashMap;
use std::future::Future;
use std::sync::{Arc, RwLock};
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::config::ConfigOptions;
use datafusion::error::Result as DfResult;
use datafusion::logical_expr::{BinaryExpr, Expr, Filter, LogicalPlan, Operator, TableScan};
use datafusion::optimizer::AnalyzerRule;
use datafusion::scalar::ScalarValue;
use crate::tenant::{TenantContext, TenantId};
tokio::task_local! {
static CURRENT_TENANT_OVERRIDE: TenantContext;
static ADMIN_SCOPE_ACTIVE: ();
}
#[derive(Debug, Default)]
pub struct SourceTenantColumns {
inner: RwLock<HashMap<String, Option<String>>>,
}
impl SourceTenantColumns {
pub fn new() -> Self {
Self::default()
}
pub fn set(&self, source: &str, column: Option<String>) {
self.inner
.write()
.expect("source tenant columns lock poisoned")
.insert(source.to_string(), column);
}
pub fn tenant_column(&self, source_name: &str) -> Option<String> {
self.inner
.read()
.expect("source tenant columns lock poisoned")
.get(source_name)
.cloned()
.flatten()
}
}
#[derive(Debug, Clone)]
pub struct TenantBinding {
shared: Arc<RwLock<TenantContext>>,
}
impl TenantBinding {
pub fn unscoped() -> Self {
Self {
shared: Arc::new(RwLock::new(TenantContext::Unscoped)),
}
}
pub fn current(&self) -> TenantContext {
CURRENT_TENANT_OVERRIDE
.try_with(|c| *c)
.unwrap_or_else(|_| self.read_shared())
}
pub fn current_tenant(&self) -> Option<TenantId> {
self.current().tenant()
}
pub fn read_shared(&self) -> TenantContext {
*self.shared.read().expect("tenant binding lock poisoned")
}
pub fn set_shared(&self, ctx: TenantContext) {
*self.shared.write().expect("tenant binding lock poisoned") = ctx;
}
pub fn shared_arc(&self) -> Arc<RwLock<TenantContext>> {
Arc::clone(&self.shared)
}
pub async fn scope<F, T>(&self, ctx: TenantContext, f: F) -> T
where
F: Future<Output = T>,
{
CURRENT_TENANT_OVERRIDE.scope(ctx, f).await
}
pub fn is_admin_scope() -> bool {
ADMIN_SCOPE_ACTIVE.try_with(|_| ()).is_ok()
}
pub async fn admin_scope<F, T>(f: F) -> T
where
F: Future<Output = T>,
{
ADMIN_SCOPE_ACTIVE.scope((), f).await
}
}
impl Default for TenantBinding {
fn default() -> Self {
Self::unscoped()
}
}
#[derive(Debug)]
pub struct TenantScopeAnalyzerRule {
binding: TenantBinding,
source_columns: Arc<SourceTenantColumns>,
}
impl TenantScopeAnalyzerRule {
pub fn new(binding: TenantBinding, source_columns: Arc<SourceTenantColumns>) -> Self {
Self {
binding,
source_columns,
}
}
fn current_context(&self) -> TenantContext {
self.binding.current()
}
fn discover_tenant_column(&self, scan: &TableScan) -> Option<String> {
if scan.source.schema().field_with_name("tenant_id").is_ok() {
return Some("tenant_id".to_string());
}
let r = &scan.table_name;
if let Some(col) = self.source_columns.tenant_column(r.table()) {
return Some(col);
}
if let Some(catalog) = r.catalog() {
if let Some(col) = self.source_columns.tenant_column(catalog) {
return Some(col);
}
}
None
}
fn build_predicate(&self, col_name: &str, ctx: TenantContext) -> Expr {
let col_expr = Expr::Cast(datafusion::logical_expr::Cast::new(
Box::new(Expr::Column(col_name.into())),
arrow_schema::DataType::Utf8,
));
match ctx {
TenantContext::Unscoped => Expr::IsNull(Box::new(col_expr)),
TenantContext::Scoped(t) => {
let eq = Expr::BinaryExpr(BinaryExpr {
left: Box::new(col_expr.clone()),
op: Operator::Eq,
right: Box::new(Expr::Literal(ScalarValue::Utf8(Some(t.to_string())), None)),
});
let is_null = Expr::IsNull(Box::new(col_expr));
Expr::BinaryExpr(BinaryExpr {
left: Box::new(eq),
op: Operator::Or,
right: Box::new(is_null),
})
}
}
}
fn rewrite_node(
&self,
node: LogicalPlan,
ctx: TenantContext,
) -> DfResult<Transformed<LogicalPlan>> {
if let LogicalPlan::TableScan(scan) = &node {
if let Some(col) = self.discover_tenant_column(scan) {
let predicate = self.build_predicate(&col, ctx);
let filter = Filter::try_new(predicate, Arc::new(node))?;
return Ok(Transformed::yes(LogicalPlan::Filter(filter)));
}
}
Ok(Transformed::no(node))
}
}
impl AnalyzerRule for TenantScopeAnalyzerRule {
fn name(&self) -> &str {
"tenant_scope_predicate_injection"
}
fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> DfResult<LogicalPlan> {
if TenantBinding::is_admin_scope() {
return Ok(plan);
}
let ctx = self.current_context();
plan.transform_up(|node| self.rewrite_node(node, ctx))
.map(|t| t.data)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::functions_aggregate::count::count;
use datafusion::logical_expr::{
col, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalTableSource, SubqueryAlias,
};
use datafusion::optimizer::AnalyzerRule;
use crate::tenant::TenantId;
fn tenant_a() -> TenantId {
TenantId::from_str("01906c83-d4c8-7e10-9c4f-3b6f7c5a8e9a").unwrap()
}
fn tenanted_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("region", DataType::Utf8, true),
Field::new("tenant_id", DataType::Utf8, true),
]))
}
fn tenanted_source() -> Arc<LogicalTableSource> {
Arc::new(LogicalTableSource::new(tenanted_schema()))
}
fn scoped_binding(t: TenantId) -> TenantBinding {
let b = TenantBinding::unscoped();
b.set_shared(TenantContext::Scoped(t));
b
}
fn unscoped_binding() -> TenantBinding {
TenantBinding::unscoped()
}
fn rule_with(binding: TenantBinding) -> TenantScopeAnalyzerRule {
TenantScopeAnalyzerRule::new(binding, Arc::new(SourceTenantColumns::new()))
}
fn count_tenant_filters_over(plan: &LogicalPlan, table: &str) -> usize {
use datafusion::logical_expr::Filter;
let mut count = 0;
plan.apply(|node| {
if let LogicalPlan::Filter(Filter {
input, predicate, ..
}) = node
{
if let LogicalPlan::TableScan(scan) = input.as_ref() {
if scan.table_name.table() == table && involves_tenant_id(predicate) {
count += 1;
}
}
}
Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
})
.unwrap();
count
}
fn involves_tenant_id(expr: &Expr) -> bool {
match expr {
Expr::Column(c) => c.name == "tenant_id",
Expr::Cast(c) => involves_tenant_id(c.expr.as_ref()),
Expr::IsNull(inner) => involves_tenant_id(inner),
Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
involves_tenant_id(left) || involves_tenant_id(right)
}
_ => false,
}
}
fn assert_or_eq_isnull(predicate: &Expr, col_name: &str, t: TenantId) {
let outer = match predicate {
Expr::BinaryExpr(b) => b,
other => panic!("expected OR BinaryExpr, got {other:?}"),
};
assert!(
matches!(outer.op, Operator::Or),
"outer predicate must be OR, got {:?}",
outer.op
);
let left = &outer.left;
let right = &outer.right;
let eq = match left.as_ref() {
Expr::BinaryExpr(b) => b,
other => panic!("expected Eq BinaryExpr, got {other:?}"),
};
assert!(matches!(eq.op, Operator::Eq), "left arm op should be Eq");
let lhs_cast = match eq.left.as_ref() {
Expr::Cast(c) => c,
other => panic!("expected Cast on left arm lhs, got {other:?}"),
};
let lhs_col = match lhs_cast.expr.as_ref() {
Expr::Column(c) => c,
other => panic!("expected column inside Cast, got {other:?}"),
};
assert_eq!(lhs_col.name, col_name);
assert_eq!(lhs_cast.data_type, arrow_schema::DataType::Utf8);
let rhs_lit = match eq.right.as_ref() {
Expr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(),
other => panic!("expected Utf8 literal on left arm rhs, got {other:?}"),
};
assert_eq!(rhs_lit, t.to_string());
let is_null_inner = match right.as_ref() {
Expr::IsNull(inner) => inner,
other => panic!("expected IsNull on right arm, got {other:?}"),
};
let cast = match is_null_inner.as_ref() {
Expr::Cast(c) => c,
other => panic!("expected Cast inside IsNull, got {other:?}"),
};
let isnull_col = match cast.expr.as_ref() {
Expr::Column(c) => c,
other => panic!("expected column inside Cast(IsNull), got {other:?}"),
};
assert_eq!(isnull_col.name, col_name);
}
fn assert_isnull_only(predicate: &Expr, col_name: &str) {
let inner = match predicate {
Expr::IsNull(b) => b,
other => panic!("expected IsNull predicate, got {other:?}"),
};
let cast = match inner.as_ref() {
Expr::Cast(c) => c,
other => panic!("expected Cast inside IsNull, got {other:?}"),
};
let col = match cast.expr.as_ref() {
Expr::Column(c) => c,
other => panic!("expected column inside Cast, got {other:?}"),
};
assert_eq!(col.name, col_name);
}
fn first_predicate_above_scan(plan: &LogicalPlan) -> Expr {
fn walk(plan: &LogicalPlan, out: &mut Option<Expr>) {
if out.is_some() {
return;
}
if let LogicalPlan::Filter(f) = plan {
if matches!(f.input.as_ref(), LogicalPlan::TableScan(_)) {
*out = Some(f.predicate.clone());
return;
}
}
for child in plan.inputs() {
walk(child, out);
if out.is_some() {
return;
}
}
}
let mut out = None;
walk(plan, &mut out);
out.expect("expected a Filter directly above a TableScan")
}
#[test]
fn single_scan_scoped_emits_eq_or_isnull() {
let plan = LogicalPlanBuilder::scan("sources", tenanted_source(), None)
.unwrap()
.project(vec![col("id")])
.unwrap()
.build()
.unwrap();
let out = rule_with(scoped_binding(tenant_a()))
.analyze(plan, &ConfigOptions::default())
.unwrap();
let predicate = first_predicate_above_scan(&out);
assert_or_eq_isnull(&predicate, "tenant_id", tenant_a());
}
#[test]
fn join_rewrites_both_sides() {
let right = LogicalPlanBuilder::scan("models", tenanted_source(), None)
.unwrap()
.build()
.unwrap();
let plan = LogicalPlanBuilder::scan("sources", tenanted_source(), None)
.unwrap()
.join_on(
right,
JoinType::Inner,
vec![col("sources.id").eq(col("models.id"))],
)
.unwrap()
.build()
.unwrap();
let out = rule_with(scoped_binding(tenant_a()))
.analyze(plan, &ConfigOptions::default())
.unwrap();
assert_eq!(
count_tenant_filters_over(&out, "sources"),
1,
"JOIN left should get exactly one tenant filter"
);
assert_eq!(
count_tenant_filters_over(&out, "models"),
1,
"JOIN right should get exactly one tenant filter"
);
}
#[test]
fn subquery_inner_scan_rewritten_via_transform_up() {
let inner = LogicalPlanBuilder::scan("sources", tenanted_source(), None)
.unwrap()
.project(vec![col("id")])
.unwrap()
.build()
.unwrap();
let inner_alias =
LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(Arc::new(inner), "inner_q").unwrap());
let outer = LogicalPlanBuilder::scan("models", tenanted_source(), None)
.unwrap()
.cross_join(inner_alias)
.unwrap()
.build()
.unwrap();
let out = rule_with(scoped_binding(tenant_a()))
.analyze(outer, &ConfigOptions::default())
.unwrap();
assert_eq!(count_tenant_filters_over(&out, "sources"), 1);
assert_eq!(count_tenant_filters_over(&out, "models"), 1);
}
#[test]
fn group_by_input_rewritten_aggregate_untouched() {
let plan = LogicalPlanBuilder::scan("sources", tenanted_source(), None)
.unwrap()
.aggregate(vec![col("region")], vec![count(col("id"))])
.unwrap()
.build()
.unwrap();
let out = rule_with(scoped_binding(tenant_a()))
.analyze(plan, &ConfigOptions::default())
.unwrap();
let agg = match &out {
LogicalPlan::Aggregate(a) => a,
other => panic!("expected Aggregate at top, got {other:?}"),
};
let filter = match agg.input.as_ref() {
LogicalPlan::Filter(f) => f,
other => panic!("expected Filter under Aggregate, got {other:?}"),
};
assert!(matches!(filter.input.as_ref(), LogicalPlan::TableScan(_)));
assert_eq!(agg.group_expr.len(), 1);
assert_eq!(agg.aggr_expr.len(), 1);
}
#[test]
fn unscoped_session_rewrites_to_isnull_only() {
let plan = LogicalPlanBuilder::scan("sources", tenanted_source(), None)
.unwrap()
.project(vec![col("id")])
.unwrap()
.build()
.unwrap();
let out = rule_with(unscoped_binding())
.analyze(plan, &ConfigOptions::default())
.unwrap();
let predicate = first_predicate_above_scan(&out);
assert_isnull_only(&predicate, "tenant_id");
}
}