use std::collections::{BTreeMap, BTreeSet};
use std::hash::{BuildHasher, Hash, Hasher};
use polars_core::prelude::*;
use crate::prelude::*;
type Trail = Vec<Node>;
pub(super) fn collect_trails(
root: Node,
lp_arena: &Arena<ALogicalPlan>,
trails: &mut BTreeMap<u32, Trail>,
id: &mut u32,
collect: bool,
) -> Option<()> {
if collect {
trails.get_mut(id).unwrap().push(root);
}
use ALogicalPlan::*;
match lp_arena.get(root) {
Cache { .. } => return None,
Join {
input_left,
input_right,
..
} => {
let new_trail = trails.get(id).unwrap().clone();
collect_trails(*input_left, lp_arena, trails, id, true)?;
*id += 1;
trails.insert(*id, new_trail);
collect_trails(*input_right, lp_arena, trails, id, true)?;
}
Union { inputs, .. } => {
if inputs.len() > 200 {
return None;
}
let new_trail = trails.get(id).unwrap().clone();
let last_i = inputs.len() - 1;
for (i, input) in inputs.iter().enumerate() {
collect_trails(*input, lp_arena, trails, id, true)?;
if i != last_i {
*id += 1;
trails.insert(*id, new_trail.clone());
}
}
}
ExtContext { .. } => {
}
lp => {
let nodes = &mut [None];
lp.copy_inputs(nodes);
if let Some(input) = nodes[0] {
collect_trails(input, lp_arena, trails, id, collect)?
}
}
}
Some(())
}
fn expr_nodes_equal(a: &[Node], b: &[Node], expr_arena: &Arena<AExpr>) -> bool {
a.len() == b.len()
&& a.iter()
.zip(b)
.all(|(a, b)| node_to_expr(*a, expr_arena) == node_to_expr(*b, expr_arena))
}
fn predicate_equal(a: Option<Node>, b: Option<Node>, expr_arena: &Arena<AExpr>) -> bool {
match (a, b) {
(Some(l), Some(r)) => node_to_expr(l, expr_arena) == node_to_expr(r, expr_arena),
(None, None) => true,
_ => false,
}
}
fn lp_node_equal(a: &ALogicalPlan, b: &ALogicalPlan, expr_arena: &Arena<AExpr>) -> bool {
use ALogicalPlan::*;
match (a, b) {
(
DataFrameScan {
df: left_df,
projection: None,
selection: None,
..
},
DataFrameScan {
df: right_df,
projection: None,
selection: None,
..
},
) => Arc::ptr_eq(left_df, right_df),
#[cfg(feature = "parquet")]
(
ParquetScan {
path: path_left,
predicate: predicate_l,
options: options_l,
..
},
ParquetScan {
path: path_right,
predicate: predicate_r,
options: options_r,
..
},
) => {
path_left == path_right
&& options_l == options_r
&& predicate_equal(*predicate_l, *predicate_r, expr_arena)
}
#[cfg(feature = "ipc")]
(
IpcScan {
path: path_left,
predicate: predicate_l,
options: options_l,
..
},
IpcScan {
path: path_right,
predicate: predicate_r,
options: options_r,
..
},
) => {
path_left == path_right
&& options_l == options_r
&& predicate_equal(*predicate_l, *predicate_r, expr_arena)
}
#[cfg(feature = "csv-file")]
(
CsvScan {
path: path_left,
predicate: predicate_l,
options: options_l,
..
},
CsvScan {
path: path_right,
predicate: predicate_r,
options: options_r,
..
},
) => {
path_left == path_right
&& options_l == options_r
&& predicate_equal(*predicate_l, *predicate_r, expr_arena)
}
(Selection { predicate: l, .. }, Selection { predicate: r, .. }) => {
node_to_expr(*l, expr_arena) == node_to_expr(*r, expr_arena)
}
(Projection { expr: l, .. }, Projection { expr: r, .. })
| (HStack { exprs: l, .. }, HStack { exprs: r, .. }) => expr_nodes_equal(l, r, expr_arena),
(Melt { args: l, .. }, Melt { args: r, .. }) => Arc::ptr_eq(l, r),
(
Slice {
offset: offset_l,
len: len_l,
..
},
Slice {
offset: offset_r,
len: len_r,
..
},
) => offset_l == offset_r && len_l == len_r,
(
Sort {
by_column: by_l,
args: args_l,
..
},
Sort {
by_column: by_r,
args: args_r,
..
},
) => expr_nodes_equal(by_l, by_r, expr_arena) && args_l == args_r,
(Explode { columns: l, .. }, Explode { columns: r, .. }) => l == r,
(Distinct { options: l, .. }, Distinct { options: r, .. }) => l == r,
(MapFunction { function: l, .. }, MapFunction { function: r, .. }) => l == r,
(
Aggregate {
keys: keys_l,
aggs: agg_l,
apply: None,
maintain_order: maintain_order_l,
options: options_l,
..
},
Aggregate {
keys: keys_r,
aggs: agg_r,
apply: None,
maintain_order: maintain_order_r,
options: options_r,
..
},
) => {
maintain_order_l == maintain_order_r
&& options_l == options_r
&& expr_nodes_equal(keys_l, keys_r, expr_arena)
&& expr_nodes_equal(agg_l, agg_r, expr_arena)
}
#[cfg(feature = "python")]
(PythonScan { options: l, .. }, PythonScan { options: r, .. }) => l == r,
_ => {
false
}
}
}
fn longest_subgraph(
trail_a: &Trail,
trail_b: &Trail,
lp_arena: &Arena<ALogicalPlan>,
expr_arena: &Arena<AExpr>,
) -> Option<(Node, Node, bool)> {
if trail_a.is_empty() || trail_b.is_empty() {
return None;
}
let mut prev_node_a = Node(0);
let mut prev_node_b = Node(0);
let mut is_equal;
let mut i = 0;
let mut entirely_equal = trail_a.len() == trail_b.len();
for (node_a, node_b) in trail_a.iter().rev().zip(trail_b.iter().rev()) {
if *node_a == *node_b {
break;
}
let a = lp_arena.get(*node_a);
let b = lp_arena.get(*node_b);
is_equal = lp_node_equal(a, b, expr_arena);
if !is_equal {
entirely_equal = false;
break;
}
prev_node_a = *node_a;
prev_node_b = *node_b;
i += 1;
}
if i > 0 {
Some((prev_node_a, prev_node_b, entirely_equal))
} else {
None
}
}
pub(crate) fn elim_cmn_subplans(
root: Node,
lp_arena: &mut Arena<ALogicalPlan>,
expr_arena: &Arena<AExpr>,
) -> (Node, bool) {
let mut trails = BTreeMap::new();
let mut id = 0;
trails.insert(id, Vec::new());
if collect_trails(root, lp_arena, &mut trails, &mut id, false).is_none() {
return (root, false);
}
let trails = trails.into_values().collect::<Vec<_>>();
let mut trail_ends = vec![];
let mut to_skip = BTreeSet::new();
for i in 0..trails.len() {
if to_skip.contains(&i) {
continue;
}
let trail_i = &trails[i];
for (j, trail_j) in trails.iter().enumerate().skip(i + 1) {
if let Some((a, b, all_equal)) =
longest_subgraph(trail_i, trail_j, lp_arena, expr_arena)
{
if all_equal {
to_skip.insert(j);
}
trail_ends.push((a, b))
}
}
}
let lp_cache = lp_arena as *const Arena<ALogicalPlan> as usize;
let hb = ahash::RandomState::new();
let mut changed = false;
let mut cache_mapping = BTreeMap::new();
let mut cache_counts = PlHashMap::with_capacity(trail_ends.len());
for combination in trail_ends.iter() {
let node1 = combination.0 .0;
let node2 = combination.1 .0;
let cache_id = match (cache_mapping.get(&node1), cache_mapping.get(&node2)) {
(Some(h), _) => *h,
(_, Some(h)) => *h,
_ => {
let mut h = hb.build_hasher();
node1.hash(&mut h);
let hash = h.finish();
let mut cache_id = lp_cache.wrapping_add(hash as usize);
if (usize::MAX - cache_id) < 2048 {
cache_id -= 2048
}
cache_mapping.insert(node1, cache_id);
cache_mapping.insert(node2, cache_id);
cache_id
}
};
*cache_counts.entry(cache_id).or_insert(0usize) += 1;
}
for combination in trail_ends.iter() {
let node1 = combination.0 .0;
let node2 = combination.1 .0;
let cache_id = match (cache_mapping.get(&node1), cache_mapping.get(&node2)) {
(Some(h), _) => *h,
(_, Some(h)) => *h,
_ => {
unreachable!()
}
};
let cache_count = *cache_counts.get(&cache_id).unwrap();
for inp_node in [combination.0, combination.1] {
if let ALogicalPlan::Cache { count, .. } = lp_arena.get_mut(inp_node) {
*count = cache_count;
} else {
let lp = lp_arena.get(inp_node).clone();
let node = lp_arena.add(lp);
let cache_lp = ALogicalPlan::Cache {
input: node,
id: cache_id,
count: cache_count,
};
lp_arena.replace(inp_node, cache_lp.clone());
};
}
changed = true;
}
(root, changed)
}
pub(crate) fn decrement_file_counters_by_cache_hits(
root: Node,
lp_arena: &mut Arena<ALogicalPlan>,
_expr_arena: &Arena<AExpr>,
acc_count: FileCount,
scratch: &mut Vec<Node>,
) {
use ALogicalPlan::*;
match lp_arena.get_mut(root) {
#[cfg(feature = "parquet")]
ParquetScan { options, .. } => {
if acc_count >= options.file_counter {
options.file_counter = 1;
} else {
options.file_counter -= acc_count as FileCount
}
}
#[cfg(feature = "ipc")]
IpcScan { options, .. } => {
if acc_count >= options.file_counter {
options.file_counter = 1;
} else {
options.file_counter -= acc_count as FileCount
}
}
#[cfg(feature = "csv-file")]
CsvScan { options, .. } => {
if acc_count >= options.file_counter {
options.file_counter = 1;
} else {
options.file_counter -= acc_count as FileCount
}
}
Cache { count, input, .. } => {
let new_count = if *count != usize::MAX {
acc_count + *count as FileCount
} else {
acc_count
};
decrement_file_counters_by_cache_hits(*input, lp_arena, _expr_arena, new_count, scratch)
}
lp => {
lp.copy_inputs(scratch);
while let Some(input) = scratch.pop() {
decrement_file_counters_by_cache_hits(
input,
lp_arena,
_expr_arena,
acc_count,
scratch,
)
}
}
}
}