use std::sync::Arc;
use arrow_schema::DataType;
use datafusion_common::Result;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use datafusion_physical_plan::ExecutionPlan;
use datafusion_physical_plan::empty::EmptyExec;
use datafusion_physical_plan::filter::FilterExec;
use crate::extensions::HirnSessionExt;
#[derive(Debug, Default)]
pub struct PolicyPushdownRule;
impl PolicyPushdownRule {
pub fn new() -> Self {
Self
}
fn has_namespace_column(plan: &dyn ExecutionPlan) -> bool {
plan.schema()
.fields()
.iter()
.any(|f| f.name() == "namespace" && f.data_type() == &DataType::Utf8)
}
fn build_namespace_filter(
input: Arc<dyn ExecutionPlan>,
namespaces: &[String],
) -> Result<Arc<dyn ExecutionPlan>> {
use datafusion_physical_expr::expressions::{self, BinaryExpr, InListExpr};
let schema = input.schema();
let (_idx, _) = schema.column_with_name("namespace").ok_or_else(|| {
datafusion_common::DataFusionError::Internal(
"PolicyPushdownRule: expected 'namespace' column".into(),
)
})?;
let ns_col = expressions::col("namespace", &schema)?;
let predicate: Arc<dyn datafusion_physical_expr::PhysicalExpr> = if namespaces.len() == 1 {
let lit = expressions::lit(datafusion_common::ScalarValue::Utf8(Some(
namespaces[0].clone(),
)));
Arc::new(BinaryExpr::new(ns_col, datafusion_expr::Operator::Eq, lit))
} else {
let list: Vec<Arc<dyn datafusion_physical_expr::PhysicalExpr>> = namespaces
.iter()
.map(|ns| {
expressions::lit(datafusion_common::ScalarValue::Utf8(Some(ns.clone())))
as Arc<dyn datafusion_physical_expr::PhysicalExpr>
})
.collect();
Arc::new(InListExpr::try_new(ns_col, list, false, &schema)?)
};
let filter = FilterExec::try_new(predicate, input)?;
Ok(Arc::new(filter))
}
}
impl PhysicalOptimizerRule for PolicyPushdownRule {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let namespaces = config
.extensions
.get::<HirnSessionExt>()
.and_then(|ext| ext.allowed_namespaces().map(|ns| ns.to_vec()));
let Some(allowed) = namespaces else {
return Ok(plan);
};
if allowed.is_empty() {
return Ok(Arc::new(EmptyExec::new(plan.schema())));
}
let allowed = Arc::new(allowed);
plan.transform_up(|node| {
if !node.children().is_empty() {
return Ok(Transformed::no(node));
}
if !Self::has_namespace_column(node.as_ref()) {
return Ok(Transformed::no(node));
}
let filtered = Self::build_namespace_filter(node, &allowed)?;
Ok(Transformed::yes(filtered))
})
.map(|t| t.data)
}
fn name(&self) -> &str {
"PolicyPushdownRule"
}
fn schema_check(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{RecordBatch, StringArray};
use arrow_schema::{Field, Schema};
use datafusion_datasource::memory::MemorySourceConfig;
fn scan_with_namespace() -> Arc<dyn ExecutionPlan> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("namespace", DataType::Utf8, false),
Field::new("content", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["m1", "m2", "m3"])),
Arc::new(StringArray::from(vec!["ns_a", "ns_b", "ns_a"])),
Arc::new(StringArray::from(vec!["hello", "world", "foo"])),
],
)
.unwrap();
MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap()
}
fn scan_without_namespace() -> Arc<dyn ExecutionPlan> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("content", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["m1"])),
Arc::new(StringArray::from(vec!["hello"])),
],
)
.unwrap();
MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap()
}
fn config_with_namespaces(namespaces: Option<Vec<String>>) -> ConfigOptions {
let mut config = ConfigOptions::default();
let ext = HirnSessionExt::new(
Arc::new(()),
Arc::new(hirn_core::config::HirnConfig::default()),
None,
)
.with_allowed_namespaces(namespaces);
config.extensions.insert(ext);
config
}
#[test]
fn open_mode_no_filter() {
let plan = scan_with_namespace();
let rule = PolicyPushdownRule::new();
let config = config_with_namespaces(None);
let result = rule.optimize(plan.clone(), &config).unwrap();
assert!(result.as_any().downcast_ref::<FilterExec>().is_none());
}
#[test]
fn deny_all_returns_empty() {
let plan = scan_with_namespace();
let rule = PolicyPushdownRule::new();
let config = config_with_namespaces(Some(vec![]));
let result = rule.optimize(plan, &config).unwrap();
assert!(result.as_any().downcast_ref::<EmptyExec>().is_some());
}
#[test]
fn single_namespace_equality_filter() {
let plan = scan_with_namespace();
let rule = PolicyPushdownRule::new();
let config = config_with_namespaces(Some(vec!["ns_a".to_string()]));
let result = rule.optimize(plan, &config).unwrap();
let filter = result.as_any().downcast_ref::<FilterExec>();
assert!(filter.is_some(), "expected FilterExec");
let filter = filter.unwrap();
let pred_str = format!("{}", filter.predicate());
assert!(
pred_str.contains("namespace") && pred_str.contains("ns_a"),
"expected namespace = 'ns_a' predicate, got: {pred_str}"
);
}
#[test]
fn multiple_namespaces_in_list_filter() {
let plan = scan_with_namespace();
let rule = PolicyPushdownRule::new();
let config = config_with_namespaces(Some(vec!["ns_a".to_string(), "ns_b".to_string()]));
let result = rule.optimize(plan, &config).unwrap();
let filter = result.as_any().downcast_ref::<FilterExec>();
assert!(filter.is_some(), "expected FilterExec");
let pred_str = format!("{}", filter.unwrap().predicate());
assert!(
pred_str.contains("namespace") && pred_str.contains("IN"),
"expected IN predicate, got: {pred_str}"
);
}
#[test]
fn no_namespace_column_no_filter() {
let plan = scan_without_namespace();
let rule = PolicyPushdownRule::new();
let config = config_with_namespaces(Some(vec!["ns_a".to_string()]));
let result = rule.optimize(plan.clone(), &config).unwrap();
assert!(result.as_any().downcast_ref::<FilterExec>().is_none());
}
#[test]
fn no_ext_registered_no_filter() {
let plan = scan_with_namespace();
let rule = PolicyPushdownRule::new();
let config = ConfigOptions::default();
let result = rule.optimize(plan.clone(), &config).unwrap();
assert!(result.as_any().downcast_ref::<FilterExec>().is_none());
}
}