use polars_utils::arena::{Arena, Node};
use crate::plans::aexpr::AExpr;
use crate::prelude::Operator;
pub(crate) fn factor_or_in_aexpr(node: Node, expr_arena: &mut Arena<AExpr>) {
let mut work = vec![node];
let mut post_order = Vec::new();
while let Some(n) = work.pop() {
post_order.push(n);
expr_arena.get(n).children_rev(&mut work);
}
for &n in post_order.iter().rev() {
if matches!(
expr_arena.get(n),
AExpr::BinaryExpr {
op: Operator::Or | Operator::LogicalOr,
..
}
) {
if let Some(factored) = try_factor_or(n, expr_arena) {
expr_arena.replace(n, factored);
}
}
}
}
fn collect_or_branches(root: Node, expr_arena: &Arena<AExpr>, out: &mut Vec<Node>) {
let mut stack = vec![root];
while let Some(top) = stack.pop() {
match expr_arena.get(top) {
AExpr::BinaryExpr {
left,
op: Operator::Or | Operator::LogicalOr,
right,
} => {
stack.push(*right);
stack.push(*left);
},
_ => out.push(top),
}
}
}
fn try_factor_or(or_node: Node, expr_arena: &mut Arena<AExpr>) -> Option<AExpr> {
use std::hash::{BuildHasher, Hasher};
use polars_utils::aliases::{InitHashMaps, PlFixedStateQuality, PlHashMap};
use polars_utils::scratch_vec::ScratchVec;
use crate::plans::aexpr::{
MintermIter, is_inherently_nondeterministic, traverse_and_hash_aexpr,
};
let mut branches = Vec::new();
collect_or_branches(or_node, expr_arena, &mut branches);
if branches.len() < 2 {
return None;
}
let mut branch_terms: Vec<Vec<Node>> = branches
.iter()
.map(|&b| MintermIter::new(b, expr_arena).collect::<Vec<_>>())
.collect();
branch_terms.sort_by_key(|terms| terms.len());
let hb = PlFixedStateQuality::with_seed(0);
let hash_of = |n: Node, arena: &Arena<AExpr>| -> u64 {
let mut h = hb.build_hasher();
traverse_and_hash_aexpr(n, arena, &mut h);
h.finish()
};
let buckets: Vec<PlHashMap<u64, Vec<usize>>> = branch_terms
.iter()
.map(|terms| {
let mut m: PlHashMap<u64, Vec<usize>> = PlHashMap::with_capacity(terms.len());
for (i, &n) in terms.iter().enumerate() {
m.entry(hash_of(n, expr_arena)).or_default().push(i);
}
m
})
.collect();
let mut common = Vec::new();
let mut taken: Vec<_> = branch_terms.iter().map(|t| vec![false; t.len()]).collect();
let (mut l_stack, mut r_stack) = (Vec::new(), Vec::new());
let mut other_matches: ScratchVec<usize> = ScratchVec::default();
for (cand_idx, &cand) in branch_terms[0].iter().enumerate() {
if is_inherently_nondeterministic(cand, expr_arena) {
continue;
}
let cand_expr = expr_arena.get(cand);
let cand_hash = hash_of(cand, expr_arena);
let other_matches = other_matches.get();
let all_matched = (1..branch_terms.len()).all(|b_idx| {
let Some(m) = buckets[b_idx].get(&cand_hash).and_then(|ixs| {
ixs.iter().copied().find(|&i| {
!taken[b_idx][i]
&& cand_expr.is_expr_equal_to_amortized(
expr_arena.get(branch_terms[b_idx][i]),
expr_arena,
&mut l_stack,
&mut r_stack,
)
})
}) else {
return false;
};
other_matches.push(m);
true
});
if !all_matched {
continue;
}
common.push(cand);
taken[0][cand_idx] = true;
for (offset, &m) in other_matches.iter().enumerate() {
taken[offset + 1][m] = true;
}
}
if common.is_empty() {
return None;
}
let residuals: Option<Vec<Node>> = branch_terms
.into_iter()
.enumerate()
.map(|(b_idx, terms)| {
let kept: Vec<_> = terms
.into_iter()
.enumerate()
.filter_map(|(i, n)| (!taken[b_idx][i]).then_some(n))
.collect();
(!kept.is_empty()).then(|| combine_with(kept, Operator::And, expr_arena))
})
.collect();
let folded_node = match residuals {
Some(nodes) => {
let or_node = combine_with(nodes, Operator::Or, expr_arena);
let mut all_nodes = common;
all_nodes.push(or_node);
combine_with(all_nodes, Operator::And, expr_arena)
},
None => combine_with(common, Operator::And, expr_arena),
};
Some(expr_arena.get(folded_node).clone())
}
fn combine_with(
nodes: impl IntoIterator<Item = Node>,
op: Operator,
expr_arena: &mut Arena<AExpr>,
) -> Node {
nodes
.into_iter()
.reduce(|left, right| expr_arena.add(AExpr::BinaryExpr { left, op, right }))
.expect("combine_with: non-empty iterator")
}