use std::sync::Arc;
use datafusion_common::Result;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_physical_expr::expressions::{BinaryExpr, InListExpr};
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use datafusion_physical_plan::ExecutionPlan;
use datafusion_physical_plan::filter::FilterExec;
use crate::extensions::HirnSessionExt;
#[derive(Debug, Default)]
pub struct NamespacePartitionPruneRule;
impl NamespacePartitionPruneRule {
pub fn new() -> Self {
Self
}
}
impl PhysicalOptimizerRule for NamespacePartitionPruneRule {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let should_apply = config
.extensions
.get::<HirnSessionExt>()
.and_then(|ext| ext.allowed_namespaces())
.is_some_and(|ns| ns.len() == 1);
if !should_apply {
return Ok(plan);
}
plan.transform_up(|node| {
let Some(filter_exec) = node.as_any().downcast_ref::<FilterExec>() else {
return Ok(Transformed::no(node));
};
let predicate = filter_exec.predicate();
let Some(in_list) = predicate.as_any().downcast_ref::<InListExpr>() else {
return Ok(Transformed::no(node));
};
if in_list.negated() || in_list.list().len() != 1 {
return Ok(Transformed::no(node));
}
let eq_predicate: Arc<dyn datafusion_physical_expr::PhysicalExpr> =
Arc::new(BinaryExpr::new(
Arc::clone(in_list.expr()),
datafusion_expr::Operator::Eq,
Arc::clone(&in_list.list()[0]),
));
let new_filter = FilterExec::try_new(eq_predicate, Arc::clone(filter_exec.input()))?;
Ok(Transformed::yes(
Arc::new(new_filter) as Arc<dyn ExecutionPlan>
))
})
.map(|t| t.data)
}
fn name(&self) -> &str {
"NamespacePartitionPruneRule"
}
fn schema_check(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rules::PolicyPushdownRule;
use arrow_array::{RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use datafusion_datasource::memory::MemorySourceConfig;
use datafusion_physical_expr::expressions;
fn scan_with_namespace() -> Arc<dyn ExecutionPlan> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("namespace", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["m1"])),
Arc::new(StringArray::from(vec!["ns_a"])),
],
)
.unwrap();
MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap()
}
fn config_with_namespaces(namespaces: Option<Vec<String>>) -> ConfigOptions {
let mut config = ConfigOptions::default();
let ext = HirnSessionExt::new(
Arc::new(()),
Arc::new(hirn_core::config::HirnConfig::default()),
None,
)
.with_allowed_namespaces(namespaces);
config.extensions.insert(ext);
config
}
#[test]
fn passthrough_when_multiple_namespaces() {
let plan = scan_with_namespace();
let config = config_with_namespaces(Some(vec!["a".into(), "b".into()]));
let rule = NamespacePartitionPruneRule::new();
let result = rule.optimize(plan.clone(), &config).unwrap();
assert_eq!(format!("{:?}", result), format!("{:?}", plan));
}
#[test]
fn passthrough_when_open_mode() {
let plan = scan_with_namespace();
let config = config_with_namespaces(None);
let rule = NamespacePartitionPruneRule::new();
let result = rule.optimize(plan.clone(), &config).unwrap();
assert_eq!(format!("{:?}", result), format!("{:?}", plan));
}
#[test]
fn single_namespace_after_pushdown_is_equality() {
let plan = scan_with_namespace();
let config = config_with_namespaces(Some(vec!["ns_a".into()]));
let pushdown = PolicyPushdownRule::new();
let after_pushdown = pushdown.optimize(plan, &config).unwrap();
let prune = NamespacePartitionPruneRule::new();
let after_prune = prune.optimize(after_pushdown.clone(), &config).unwrap();
let filter = after_prune.as_any().downcast_ref::<FilterExec>();
assert!(filter.is_some());
let pred_str = format!("{}", filter.unwrap().predicate());
assert!(
pred_str.contains('=') && pred_str.contains("ns_a"),
"expected equality predicate, got: {pred_str}"
);
}
#[test]
fn rewrites_single_element_in_list_to_equality() {
let scan = scan_with_namespace();
let schema = scan.schema();
let ns_col = expressions::col("namespace", &schema).unwrap();
let val = expressions::lit(datafusion_common::ScalarValue::Utf8(Some("ns_x".into())));
let in_predicate =
Arc::new(InListExpr::try_new(ns_col, vec![val], false, &schema).unwrap());
let filter = Arc::new(FilterExec::try_new(in_predicate, scan).unwrap());
let config = config_with_namespaces(Some(vec!["ns_x".into()]));
let rule = NamespacePartitionPruneRule::new();
let result = rule.optimize(filter, &config).unwrap();
let out_filter = result.as_any().downcast_ref::<FilterExec>().unwrap();
let pred_str = format!("{}", out_filter.predicate());
assert!(
pred_str.contains('=') && !pred_str.contains("IN"),
"expected equality, got: {pred_str}"
);
}
#[test]
fn rewrites_svo_events_schema() {
let svo_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("subject", DataType::Utf8, false),
Field::new("verb", DataType::Utf8, false),
Field::new("object", DataType::Utf8, true),
Field::new("confidence", DataType::Float64, false),
Field::new("namespace", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
svo_schema.clone(),
vec![
Arc::new(StringArray::from(vec!["e1"])),
Arc::new(StringArray::from(vec!["Alice"])),
Arc::new(StringArray::from(vec!["met"])),
Arc::new(StringArray::from(vec![Some("Bob")])),
Arc::new(arrow_array::Float64Array::from(vec![0.9])),
Arc::new(StringArray::from(vec!["project_x"])),
],
)
.unwrap();
let scan =
MemorySourceConfig::try_new_exec(&[vec![batch]], svo_schema.clone(), None).unwrap();
let ns_col = expressions::col("namespace", &scan.schema()).unwrap();
let val = expressions::lit(datafusion_common::ScalarValue::Utf8(Some(
"project_x".into(),
)));
let in_predicate =
Arc::new(InListExpr::try_new(ns_col, vec![val], false, &scan.schema()).unwrap());
let filter = Arc::new(FilterExec::try_new(in_predicate, scan).unwrap());
let config = config_with_namespaces(Some(vec!["project_x".into()]));
let rule = NamespacePartitionPruneRule::new();
let result = rule.optimize(filter, &config).unwrap();
let out_filter = result.as_any().downcast_ref::<FilterExec>().unwrap();
let pred_str = format!("{}", out_filter.predicate());
assert!(
pred_str.contains('=') && !pred_str.contains("IN"),
"expected equality on svo_events schema, got: {pred_str}"
);
}
}