use std::ops::ControlFlow;
use polars_core::prelude::{PlHashMap, PlHashSet};
use polars_error::PolarsResult;
use polars_utils::aliases::{InitHashMaps, PlIndexMap};
use polars_utils::arena::{Arena, Node};
use polars_utils::pl_str::PlSmallStr;
use polars_utils::unique_id::UniqueId;
use crate::dsl::Expr;
use crate::plans::deep_copy::deep_copy_ir_delete_caches;
use crate::plans::optimizer::ir_traversal::ir_graph_traversal;
use crate::plans::optimizer::ir_traversal::storage::IRTraversalStorage;
use crate::plans::{AExpr, IR, PredicatePushDown};
use crate::traversal::visitor::{FnVisitors, SubtreeVisit};
use crate::utils::aexpr_to_leaf_names;
fn get_upper_projections(
parent: Node,
lp_arena: &Arena<IR>,
expr_arena: &Arena<AExpr>,
names_scratch: &mut Vec<PlSmallStr>,
found_required_columns: &mut bool,
) -> bool {
let parent = lp_arena.get(parent);
match parent {
IR::SimpleProjection { columns, .. } => {
let iter = columns.iter_names_cloned();
names_scratch.extend(iter);
*found_required_columns = true;
false
},
IR::Filter { predicate, .. } => {
names_scratch.extend(aexpr_to_leaf_names(predicate.node(), expr_arena));
true
},
_ => false,
}
}
fn get_upper_predicates(
parent: Node,
lp_arena: &Arena<IR>,
expr_arena: &mut Arena<AExpr>,
predicate_scratch: &mut Vec<Expr>,
) -> bool {
let parent = lp_arena.get(parent);
match parent {
IR::Filter { predicate, .. } => {
let expr = predicate.to_expr(expr_arena);
predicate_scratch.push(expr);
false
},
IR::SimpleProjection { .. } => true,
_ => false,
}
}
type TwoParents = [Option<Node>; 2];
pub(crate) fn set_cache_states(
root: Node,
lp_arena: &mut Arena<IR>,
expr_arena: &mut Arena<AExpr>,
scratch: &mut Vec<Node>,
verbose: bool,
pushdown_maintain_errors: bool,
streaming: bool,
) -> PolarsResult<()> {
let mut stack = Vec::with_capacity(4);
let mut names_scratch = vec![];
let mut predicates_scratch = vec![];
scratch.clear();
stack.clear();
#[derive(Default)]
struct Value {
children: Vec<Node>,
parents: Vec<TwoParents>,
cache_nodes: Vec<Node>,
names_union: PlHashSet<PlSmallStr>,
predicate_union: PlHashMap<Expr, u32>,
}
let mut cache_schema_and_children = PlIndexMap::new();
#[derive(Default, Clone)]
struct Frame {
current: Node,
cache_id: Option<UniqueId>,
parent: TwoParents,
}
let init = Frame {
current: root,
..Default::default()
};
stack.push(init);
ir_graph_traversal(
root,
&mut FnVisitors::new(
|| (),
|key, storage: &mut IRTraversalStorage<'_>, _| {
if let IR::Cache { input: _, id } = storage.get(key) {
cache_schema_and_children.insert(*id, Value::default());
}
ControlFlow::Continue(SubtreeVisit::Visit)
},
|_, _, _| ControlFlow::<()>::Continue(()),
),
&mut vec![],
&mut vec![],
IRTraversalStorage {
arena: lp_arena,
skip_subtree: |_| false,
},
)
.continue_value()
.unwrap();
while let Some(mut frame) = stack.pop() {
let lp = lp_arena.get(frame.current);
lp.copy_inputs(scratch);
if let IR::Cache { input, id, .. } = lp {
if frame.parent[0].is_some() {
let v = cache_schema_and_children.get_mut(id).unwrap();
v.children.push(*input);
v.parents.push(frame.parent);
v.cache_nodes.push(frame.current);
let mut found_required_columns = false;
for parent_node in frame.parent.into_iter().flatten() {
let keep_going = get_upper_projections(
parent_node,
lp_arena,
expr_arena,
&mut names_scratch,
&mut found_required_columns,
);
if !names_scratch.is_empty() {
v.names_union.extend(names_scratch.drain(..));
}
if !keep_going {
break;
}
}
for parent_node in frame.parent.into_iter().flatten() {
let keep_going = get_upper_predicates(
parent_node,
lp_arena,
expr_arena,
&mut predicates_scratch,
);
if !predicates_scratch.is_empty() {
for pred in predicates_scratch.drain(..) {
let count = v.predicate_union.entry(pred).or_insert(0);
*count += 1;
}
}
if !keep_going {
break;
}
}
if !found_required_columns {
let schema = lp.schema(lp_arena);
v.names_union.extend(schema.iter_names_cloned());
}
}
frame.cache_id = Some(*id);
};
frame.parent[1] = frame.parent[0];
frame.parent[0] = Some(frame.current);
for n in scratch.iter() {
let mut new_frame = frame.clone();
new_frame.current = *n;
stack.push(new_frame);
}
scratch.clear();
}
if !cache_schema_and_children.is_empty() {
let mut pred_pd = PredicatePushDown::new(pushdown_maintain_errors, streaming);
for v in cache_schema_and_children.into_values().rev() {
if v.predicate_union.len() > 1 {
if verbose {
eprintln!("cache nodes will be removed because predicates don't match")
}
for ((_, cache), parents) in v.children.iter().zip(v.cache_nodes).zip(v.parents) {
let mut node = cache;
for p_node in parents.into_iter().flatten() {
match lp_arena.get(p_node) {
IR::Filter { .. } | IR::SimpleProjection { .. } => true,
_ => break,
};
node = p_node
}
let copied_node = deep_copy_ir_delete_caches(node, lp_arena, expr_arena);
let lp = lp_arena.take(copied_node);
let lp = pred_pd.optimize(lp, lp_arena, expr_arena)?;
lp_arena.replace(node, lp);
}
return Ok(());
}
let allow_parent_predicate_pushdown = v.predicate_union.len() == 1 && {
let (_pred, count) = v.predicate_union.iter().next().unwrap();
*count == v.children.len() as u32
};
if allow_parent_predicate_pushdown {
let parents = *v.parents.first().unwrap();
let node = get_filter_node(parents, lp_arena)
.expect("expected filter; this is an optimizer bug");
let start_lp = lp_arena.take(node);
let mut pred_pd =
PredicatePushDown::new(pushdown_maintain_errors, streaming).block_at_cache(1);
let lp = pred_pd.optimize(start_lp, lp_arena, expr_arena)?;
lp_arena.replace(node, lp.clone());
let mut updated_cache_node = node;
loop {
match lp_arena.get(updated_cache_node) {
IR::Cache { .. } => break,
IR::SimpleProjection { input, .. } => updated_cache_node = *input,
_ => unreachable!(),
}
}
for &parents in &v.parents[1..] {
let filter_node = get_filter_node(parents, lp_arena)
.expect("expected filter; this is an optimizer bug");
let IR::Filter { input, .. } = lp_arena.get(filter_node) else {
unreachable!()
};
let new_lp = match lp_arena.get(*input) {
IR::SimpleProjection { input, columns } => {
debug_assert!(matches!(lp_arena.get(*input), IR::Cache { .. }));
IR::SimpleProjection {
input: updated_cache_node,
columns: columns.clone(),
}
},
ir => {
debug_assert!(matches!(ir, IR::Cache { .. }));
lp_arena.get(updated_cache_node).clone()
},
};
lp_arena.replace(filter_node, new_lp);
}
} else {
let child = *v.children.first().unwrap();
let child_lp = lp_arena.take(child);
let lp = pred_pd.optimize(child_lp, lp_arena, expr_arena)?;
lp_arena.replace(child, lp.clone());
for &child in &v.children[1..] {
lp_arena.replace(child, lp.clone());
}
}
}
}
Ok(())
}
fn get_filter_node(parents: TwoParents, lp_arena: &Arena<IR>) -> Option<Node> {
parents
.into_iter()
.flatten()
.find(|&parent| matches!(lp_arena.get(parent), IR::Filter { .. }))
}