use std::sync::Arc;
use crate::PhysicalOptimizerRule;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{Result, assert_eq_or_internal_err, config::ConfigOptions};
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::physical_expr::is_volatile;
use datafusion_physical_plan::filter_pushdown::{
ChildFilterPushdownResult, ChildPushdownResult, FilterPushdownPhase,
FilterPushdownPropagation, PushedDown,
};
use datafusion_physical_plan::{ExecutionPlan, with_new_children_if_necessary};
use itertools::{Itertools, izip};
#[derive(Debug)]
pub struct FilterPushdown {
phase: FilterPushdownPhase,
name: String,
}
impl FilterPushdown {
fn new_with_phase(phase: FilterPushdownPhase) -> Self {
let name = match phase {
FilterPushdownPhase::Pre => "FilterPushdown",
FilterPushdownPhase::Post => "FilterPushdown(Post)",
}
.to_string();
Self { phase, name }
}
pub fn new() -> Self {
Self::new_with_phase(FilterPushdownPhase::Pre)
}
pub fn new_post_optimization() -> Self {
Self::new_with_phase(FilterPushdownPhase::Post)
}
}
impl Default for FilterPushdown {
fn default() -> Self {
Self::new()
}
}
impl PhysicalOptimizerRule for FilterPushdown {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(
push_down_filters(&Arc::clone(&plan), vec![], config, self.phase)?
.updated_node
.unwrap_or(plan),
)
}
fn name(&self) -> &str {
&self.name
}
fn schema_check(&self) -> bool {
true }
}
fn push_down_filters(
node: &Arc<dyn ExecutionPlan>,
parent_predicates: Vec<Arc<dyn PhysicalExpr>>,
config: &ConfigOptions,
phase: FilterPushdownPhase,
) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
let mut parent_filter_pushdown_supports: Vec<Vec<PushedDown>> =
vec![vec![]; parent_predicates.len()];
let mut self_filters_pushdown_supports = vec![];
let mut new_children = Vec::with_capacity(node.children().len());
let children = node.children();
let parent_filtered = FilteredVec::new(&parent_predicates, allow_pushdown_for_expr);
let filter_description = node.gather_filters_for_pushdown(
phase,
parent_filtered.items().to_vec(),
config,
)?;
let filter_description_parent_filters = filter_description.parent_filters();
let filter_description_self_filters = filter_description.self_filters();
assert_eq_or_internal_err!(
filter_description_parent_filters.len(),
children.len(),
"Filter pushdown expected parent filters count to match number of children for node {}",
node.name()
);
assert_eq_or_internal_err!(
filter_description_self_filters.len(),
children.len(),
"Filter pushdown expected self filters count to match number of children for node {}",
node.name()
);
for (child_idx, (child, parent_filters, self_filters)) in izip!(
children,
filter_description.parent_filters(),
filter_description.self_filters()
)
.enumerate()
{
let self_filtered = FilteredVec::new(&self_filters, allow_pushdown_for_expr);
let num_self_filters = self_filtered.len();
let mut all_predicates = self_filtered.items().to_vec();
let parent_filters_for_child = parent_filtered
.chain_filter_slice(&parent_filters, |filter| {
matches!(filter.discriminant, PushedDown::Yes)
});
for filter in parent_filters_for_child.items() {
all_predicates.push(Arc::clone(&filter.predicate));
}
let num_parent_filters = all_predicates.len() - num_self_filters;
let result =
push_down_filters(&Arc::clone(child), all_predicates, config, phase)?;
if let Some(new_child) = result.updated_node {
new_children.push(new_child);
} else {
new_children.push(Arc::clone(child));
}
let mut all_filters = result.filters.into_iter().collect_vec();
assert_eq_or_internal_err!(
all_filters.len(),
num_self_filters + num_parent_filters,
"Filter pushdown did not return the expected number of filters from {}",
child.name()
);
let parent_filters = all_filters
.split_off(num_self_filters)
.into_iter()
.collect_vec();
let mapped_self_results =
self_filtered.map_results_to_original(all_filters, PushedDown::No);
let self_filter_results: Vec<_> = mapped_self_results
.into_iter()
.zip(self_filters)
.map(|(support, filter)| support.wrap_expression(filter))
.collect();
self_filters_pushdown_supports.push(self_filter_results);
for parent_filter_pushdown_support in parent_filter_pushdown_supports.iter_mut() {
parent_filter_pushdown_support.push(PushedDown::No);
assert_eq!(
parent_filter_pushdown_support.len(),
child_idx + 1,
"Parent filter pushdown supports should have the same length as the number of children"
);
}
let mapped_parent_results = parent_filters_for_child
.map_results_to_original(parent_filters, PushedDown::No);
for (idx, support) in parent_filter_pushdown_supports.iter_mut().enumerate() {
support[child_idx] = mapped_parent_results[idx];
}
}
let updated_node = with_new_children_if_necessary(Arc::clone(node), new_children)?;
let mut res = updated_node.handle_child_pushdown_result(
phase,
ChildPushdownResult {
parent_filters: parent_predicates
.into_iter()
.enumerate()
.map(
|(parent_filter_idx, parent_filter)| ChildFilterPushdownResult {
filter: parent_filter,
child_results: parent_filter_pushdown_supports[parent_filter_idx]
.clone(),
},
)
.collect(),
self_filters: self_filters_pushdown_supports,
},
config,
)?;
if res.updated_node.is_none() && !Arc::ptr_eq(&updated_node, node) {
res.updated_node = Some(updated_node)
}
Ok(res)
}
struct FilteredVec<T> {
items: Vec<T>,
index_mappings: Vec<Vec<usize>>,
original_len: usize,
}
impl<T: Clone> FilteredVec<T> {
fn new<F>(items: &[T], predicate: F) -> Self
where
F: Fn(&T) -> bool,
{
let mut filtered_items = Vec::new();
let mut original_indices = Vec::new();
for (idx, item) in items.iter().enumerate() {
if predicate(item) {
filtered_items.push(item.clone());
original_indices.push(idx);
}
}
Self {
items: filtered_items,
index_mappings: vec![original_indices],
original_len: items.len(),
}
}
fn items(&self) -> &[T] {
&self.items
}
fn len(&self) -> usize {
self.items.len()
}
fn map_results_to_original<R: Clone>(
&self,
results: Vec<R>,
default_value: R,
) -> Vec<R> {
let mut mapped_results = vec![default_value; self.original_len];
for (result_idx, result) in results.into_iter().enumerate() {
let original_idx = self.trace_to_original_index(result_idx);
mapped_results[original_idx] = result;
}
mapped_results
}
fn trace_to_original_index(&self, mut current_idx: usize) -> usize {
for mapping in self.index_mappings.iter().rev() {
current_idx = mapping[current_idx];
}
current_idx
}
fn chain_filter_slice<U: Clone, F>(&self, items: &[U], predicate: F) -> FilteredVec<U>
where
F: Fn(&U) -> bool,
{
let mut filtered_items = Vec::new();
let mut filtered_indices = Vec::new();
for (idx, item) in items.iter().enumerate() {
if predicate(item) {
filtered_items.push(item.clone());
filtered_indices.push(idx);
}
}
let mut index_mappings = self.index_mappings.clone();
index_mappings.push(filtered_indices);
FilteredVec {
items: filtered_items,
index_mappings,
original_len: self.original_len,
}
}
}
fn allow_pushdown_for_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
let mut allow_pushdown = true;
expr.apply(|e| {
allow_pushdown = allow_pushdown && !is_volatile(e);
if allow_pushdown {
Ok(TreeNodeRecursion::Continue)
} else {
Ok(TreeNodeRecursion::Stop)
}
})
.expect("Infallible traversal of PhysicalExpr tree failed");
allow_pushdown
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filtered_vec_single_pass() {
let items = vec![1, 2, 3, 4, 5, 6];
let filtered = FilteredVec::new(&items, |&x| x % 2 == 0);
assert_eq!(filtered.items(), &[2, 4, 6]);
assert_eq!(filtered.len(), 3);
let results = vec!["a", "b", "c"];
let mapped = filtered.map_results_to_original(results, "default");
assert_eq!(mapped, vec!["default", "a", "default", "b", "default", "c"]);
}
#[test]
fn test_filtered_vec_empty_filter() {
let items = vec![1, 3, 5];
let filtered = FilteredVec::new(&items, |&x| x % 2 == 0);
assert_eq!(filtered.items(), &[] as &[i32]);
assert_eq!(filtered.len(), 0);
let results: Vec<&str> = vec![];
let mapped = filtered.map_results_to_original(results, "default");
assert_eq!(mapped, vec!["default", "default", "default"]);
}
#[test]
fn test_filtered_vec_all_pass() {
let items = vec![2, 4, 6];
let filtered = FilteredVec::new(&items, |&x| x % 2 == 0);
assert_eq!(filtered.items(), &[2, 4, 6]);
assert_eq!(filtered.len(), 3);
let results = vec!["a", "b", "c"];
let mapped = filtered.map_results_to_original(results, "default");
assert_eq!(mapped, vec!["a", "b", "c"]);
}
#[test]
fn test_chain_filter_slice_different_types() {
let numbers = vec![1, 2, 3, 4, 5, 6];
let first_pass = FilteredVec::new(&numbers, |&x| x > 3);
assert_eq!(first_pass.items(), &[4, 5, 6]);
let strings = vec!["four", "five", "six"];
let second_pass = first_pass.chain_filter_slice(&strings, |s| s.contains('i'));
assert_eq!(second_pass.items(), &["five", "six"]);
let results = vec![100, 200];
let mapped = second_pass.map_results_to_original(results, 0);
assert_eq!(mapped, vec![0, 0, 0, 0, 100, 200]);
}
#[test]
fn test_chain_filter_slice_complex_scenario() {
let parent_predicates = vec!["A", "B", "C", "D", "E"];
let first_pass = FilteredVec::new(&parent_predicates, |s| *s != "B" && *s != "D");
assert_eq!(first_pass.items(), &["A", "C", "E"]);
#[derive(Clone, Debug, PartialEq)]
struct TransformedPredicate {
name: String,
can_push: bool,
}
let child_predicates = vec![
TransformedPredicate {
name: "A_transformed".to_string(),
can_push: false,
},
TransformedPredicate {
name: "C_transformed".to_string(),
can_push: true,
},
TransformedPredicate {
name: "E_transformed".to_string(),
can_push: true,
},
];
let second_pass =
first_pass.chain_filter_slice(&child_predicates, |p| p.can_push);
assert_eq!(second_pass.len(), 2);
assert_eq!(second_pass.items()[0].name, "C_transformed");
assert_eq!(second_pass.items()[1].name, "E_transformed");
let child_results = vec!["C_result", "E_result"];
let mapped = second_pass.map_results_to_original(child_results, "no_result");
assert_eq!(
mapped,
vec![
"no_result",
"no_result",
"C_result",
"no_result",
"E_result"
]
);
}
#[test]
fn test_trace_to_original_index() {
let items = vec![10, 20, 30, 40, 50];
let filtered = FilteredVec::new(&items, |&x| x != 20 && x != 40);
assert_eq!(filtered.trace_to_original_index(0), 0); assert_eq!(filtered.trace_to_original_index(1), 2); assert_eq!(filtered.trace_to_original_index(2), 4); }
#[test]
fn test_chain_filter_preserves_original_len() {
let items = vec![1, 2, 3, 4, 5];
let first = FilteredVec::new(&items, |&x| x > 2);
let strings = vec!["three", "four", "five"];
let second = first.chain_filter_slice(&strings, |s| s.len() == 4);
let results = vec!["x", "y"];
let mapped = second.map_results_to_original(results, "-");
assert_eq!(mapped.len(), 5);
}
}