use polars_utils::idx_vec::UnitVec;
use polars_utils::unitvec;
use super::super::*;
impl AExpr {
pub(crate) fn is_leaf(&self) -> bool {
matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)
}
pub(crate) fn is_col(&self) -> bool {
matches!(self, AExpr::Column(_))
}
pub(crate) fn is_elementwise_top_level(&self) -> bool {
use AExpr::*;
match self {
AnonymousFunction { options, .. } => options.is_elementwise(),
Function { options, .. } => options.is_elementwise(),
Literal(v) => v.is_scalar(),
Eval { variant, .. } => variant.is_elementwise(),
Element | BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,
#[cfg(feature = "dtype-struct")]
StructEval { .. } | StructField(_) => true,
#[cfg(feature = "dynamic_group_by")]
Rolling { .. } => false,
Agg { .. }
| AnonymousAgg { .. }
| Explode { .. }
| Filter { .. }
| Gather { .. }
| Len
| Slice { .. }
| Sort { .. }
| SortBy { .. }
| Over { .. } => false,
}
}
pub(crate) fn is_row_separable_top_level(&self) -> bool {
use AExpr::*;
match self {
AnonymousFunction { options, .. } => options.is_row_separable(),
Function { options, .. } => options.is_row_separable(),
Literal(v) => v.is_scalar(),
Explode { .. } | Filter { .. } => true,
_ => self.is_elementwise_top_level(),
}
}
pub(crate) fn does_not_modify_top_level(&self) -> bool {
match self {
AExpr::Column(_) => true,
AExpr::Function { function, .. } => {
matches!(function, IRFunctionExpr::SetSortedFlag(_))
},
_ => false,
}
}
}
fn property_and_traverse<F>(stack: &mut UnitVec<Node>, ae: &AExpr, property: F) -> bool
where
F: Fn(&AExpr) -> bool,
{
if !property(ae) {
return false;
}
ae.inputs_rev(stack);
true
}
fn property_rec<F>(node: Node, expr_arena: &Arena<AExpr>, property: F) -> bool
where
F: Fn(&mut UnitVec<Node>, &AExpr, &Arena<AExpr>) -> bool,
{
let mut stack = unitvec![];
let mut ae = expr_arena.get(node);
loop {
if !property(&mut stack, ae, expr_arena) {
return false;
}
let Some(node) = stack.pop() else {
break;
};
ae = expr_arena.get(node);
}
true
}
fn does_not_modify(stack: &mut UnitVec<Node>, ae: &AExpr, _expr_arena: &Arena<AExpr>) -> bool {
property_and_traverse(stack, ae, |ae| ae.does_not_modify_top_level())
}
pub fn does_not_modify_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {
property_rec(node, expr_arena, does_not_modify)
}
pub fn is_prop<P: Fn(&AExpr) -> bool>(
stack: &mut UnitVec<Node>,
ae: &AExpr,
expr_arena: &Arena<AExpr>,
prop_top_level: P,
) -> bool {
use AExpr::*;
if !prop_top_level(ae) {
return false;
}
match ae {
#[cfg(feature = "is_in")]
Function {
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { .. }),
input,
..
} => (|| {
if let Some(rhs) = input.get(1) {
assert_eq!(input.len(), 2); let rhs = rhs.node();
if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) {
stack.extend([input[0].node()]);
return;
}
};
ae.inputs_rev(stack);
})(),
_ => {
ae.inputs_rev(stack);
},
}
true
}
pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {
is_prop(stack, ae, expr_arena, |ae| ae.is_elementwise_top_level())
}
pub fn all_elementwise<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool
where
Node: From<&'a N>,
{
nodes
.iter()
.all(|n| is_elementwise_rec(n.into(), expr_arena))
}
pub fn is_elementwise_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {
property_rec(node, expr_arena, is_elementwise)
}
pub fn is_row_separable(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {
is_prop(stack, ae, expr_arena, |ae| ae.is_row_separable_top_level())
}
pub fn all_row_separable<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool
where
Node: From<&'a N>,
{
nodes
.iter()
.all(|n| is_row_separable_rec(n.into(), expr_arena))
}
pub fn is_row_separable_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {
property_rec(node, expr_arena, is_row_separable)
}
#[derive(Debug, Clone)]
pub enum ExprPushdownGroup {
Pushable,
Fallible,
Barrier,
}
impl ExprPushdownGroup {
pub fn update_with_expr(
&mut self,
stack: &mut UnitVec<Node>,
ae: &AExpr,
expr_arena: &Arena<AExpr>,
) -> &mut Self {
match self {
ExprPushdownGroup::Pushable | ExprPushdownGroup::Fallible => {
if ae.is_fallible_top_level(expr_arena) {
*self = ExprPushdownGroup::Fallible;
}
if !is_elementwise(stack, ae, expr_arena) {
*self = ExprPushdownGroup::Barrier
}
},
ExprPushdownGroup::Barrier => {},
}
self
}
pub fn update_with_expr_rec<'a>(
&mut self,
mut ae: &'a AExpr,
expr_arena: &'a Arena<AExpr>,
scratch: Option<&mut UnitVec<Node>>,
) -> &mut Self {
let mut local_scratch = unitvec![];
let stack = scratch.unwrap_or(&mut local_scratch);
loop {
self.update_with_expr(stack, ae, expr_arena);
if let ExprPushdownGroup::Barrier = self {
return self;
}
let Some(node) = stack.pop() else {
break;
};
ae = expr_arena.get(node);
}
self
}
pub fn blocks_pushdown(&self, maintain_errors: bool) -> bool {
match self {
ExprPushdownGroup::Barrier => true,
ExprPushdownGroup::Fallible => maintain_errors,
ExprPushdownGroup::Pushable => false,
}
}
}
pub fn can_pre_agg_exprs(
exprs: &[ExprIR],
expr_arena: &Arena<AExpr>,
_input_schema: &Schema,
) -> bool {
exprs
.iter()
.all(|e| can_pre_agg(e.node(), expr_arena, _input_schema))
}
pub fn can_pre_agg(agg: Node, expr_arena: &Arena<AExpr>, _input_schema: &Schema) -> bool {
let aexpr = expr_arena.get(agg);
match aexpr {
AExpr::Len => true,
AExpr::Column(_) | AExpr::Literal(_) => false,
AExpr::Agg(_) => {
let has_aggregation =
|node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_)));
let can_partition = (expr_arena).iter(agg).all(|(_, ae)| {
use AExpr::*;
match ae {
#[cfg(feature = "dtype-struct")]
Agg(IRAggExpr::Mean(_)) => {
matches!(
expr_arena
.get(agg)
.to_dtype(&ToFieldContext::new(expr_arena, _input_schema))
.map(|dt| { dt.is_primitive_numeric() }),
Ok(true)
)
},
Agg(agg_e) => {
matches!(
agg_e,
IRAggExpr::Min { .. }
| IRAggExpr::Max { .. }
| IRAggExpr::Sum(_)
| IRAggExpr::Last(_)
| IRAggExpr::First(_)
| IRAggExpr::Count {
input: _,
include_nulls: true
}
)
},
Function { input, options, .. } => {
options.is_elementwise()
&& input.len() == 1
&& !has_aggregation(input[0].node())
},
BinaryExpr { left, right, .. } => {
!has_aggregation(*left) && !has_aggregation(*right)
},
Ternary {
truthy,
falsy,
predicate,
..
} => {
!has_aggregation(*truthy)
&& !has_aggregation(*falsy)
&& !has_aggregation(*predicate)
},
Literal(lv) => lv.is_scalar(),
Column(_) | Len | Cast { .. } => true,
_ => false,
}
});
#[cfg(feature = "object")]
{
for name in aexpr_to_leaf_names(agg, expr_arena) {
let dtype = _input_schema.get(&name).unwrap();
if let DataType::Object(_) = dtype {
return false;
}
}
}
can_partition
},
_ => false,
}
}
pub(crate) fn predicate_non_null_column_outputs(
predicate_node: Node,
expr_arena: &Arena<AExpr>,
non_null_column_callback: &mut dyn FnMut(&PlSmallStr),
) {
let mut minterm_iter = MintermIter::new(predicate_node, expr_arena);
let stack: &mut UnitVec<Node> = &mut unitvec![];
macro_rules! traverse_first_input {
($inputs:expr) => {{
if let Some(expr_ir) = $inputs.first() {
stack.push(expr_ir.node())
}
false
}};
}
loop {
use AExpr::*;
let node = if let Some(node) = stack.pop() {
node
} else if let Some(minterm_node) = minterm_iter.next() {
match expr_arena.get(minterm_node) {
Function {
input,
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),
options: _,
} if !input.is_empty() => input.first().unwrap().node(),
Function {
input,
function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),
options: _,
} if !input.is_empty() => match expr_arena.get(input.first().unwrap().node()) {
Function {
input,
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),
options: _,
} if !input.is_empty() => input.first().unwrap().node(),
_ => minterm_node,
},
_ => minterm_node,
}
} else {
break;
};
let ae = expr_arena.get(node);
let traverse_all_inputs = match ae {
BinaryExpr {
left: _,
op,
right: _,
} => {
use Operator::*;
match op {
Eq | NotEq | Lt | LtEq | Gt | GtEq | Plus | Minus | Multiply | RustDivide
| TrueDivide | FloorDivide | Modulus | Xor => true,
EqValidity | NotEqValidity | Or | LogicalOr | And | LogicalAnd => false,
}
},
Cast { dtype, .. } => {
!dtype.is_nested()
},
Function {
input,
function: _,
options,
} => {
if options
.flags
.contains(FunctionFlags::PRESERVES_NULL_FIRST_INPUT)
{
traverse_first_input!(input)
} else {
options
.flags
.contains(FunctionFlags::PRESERVES_NULL_ALL_INPUTS)
}
},
Column(name) => {
non_null_column_callback(name);
false
},
_ => false,
};
if traverse_all_inputs {
ae.inputs_rev(stack);
}
}
}