mod builder;
mod determinism;
mod equality;
mod evaluate;
mod function_expr;
mod hash;
mod minterm_iter;
pub(crate) mod or_factoring;
pub mod predicates;
mod scalar;
mod schema;
mod traverse;
use std::hash::{Hash, Hasher};
pub use determinism::{is_inherently_nondeterministic, is_inherently_nondeterministic_top_level};
pub use function_expr::*;
pub(crate) use hash::traverse_and_hash_aexpr;
pub use minterm_iter::MintermIter;
use polars_core::chunked_array::cast::CastOptions;
use polars_core::prelude::*;
use polars_core::utils::{get_time_units, try_get_supertype};
use polars_utils::arena::{Arena, Node};
pub use scalar::{is_length_preserving_ae, is_scalar_ae};
use strum_macros::IntoStaticStr;
pub use traverse::*;
pub mod projection_height;
mod properties;
pub use aexpr::function_expr::schema::FieldsMapper;
pub use builder::AExprBuilder;
pub use evaluate::{constant_evaluate, into_column};
pub use properties::*;
pub use schema::ToFieldContext;
use crate::constants::LEN;
use crate::prelude::*;
#[derive(Clone, Debug, IntoStaticStr)]
#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
pub enum IRAggExpr {
Min {
input: Node,
propagate_nans: bool,
},
Max {
input: Node,
propagate_nans: bool,
},
Median(Node),
NUnique(Node),
Item {
input: Node,
allow_empty: bool,
},
First(Node),
FirstNonNull(Node),
Last(Node),
LastNonNull(Node),
Mean(Node),
Implode {
input: Node,
maintain_order: bool,
},
Sum(Node),
Count {
input: Node,
include_nulls: bool,
},
Std(Node, u8),
Var(Node, u8),
AggGroups(Node),
}
impl Hash for IRAggExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
std::mem::discriminant(self).hash(state);
match self {
Self::Min {
input: _,
propagate_nans,
}
| Self::Max {
input: _,
propagate_nans,
} => propagate_nans.hash(state),
Self::Std(_, v) | Self::Var(_, v) => v.hash(state),
Self::Count {
input: _,
include_nulls,
} => include_nulls.hash(state),
_ => {},
}
}
}
impl IRAggExpr {
pub(super) fn equal_nodes(&self, other: &IRAggExpr) -> bool {
use IRAggExpr::*;
match (self, other) {
(
Min {
propagate_nans: l, ..
},
Min {
propagate_nans: r, ..
},
) => l == r,
(
Max {
propagate_nans: l, ..
},
Max {
propagate_nans: r, ..
},
) => l == r,
(Std(_, l), Std(_, r)) => l == r,
(Var(_, l), Var(_, r)) => l == r,
_ => std::mem::discriminant(self) == std::mem::discriminant(other),
}
}
}
impl From<IRAggExpr> for GroupByMethod {
fn from(value: IRAggExpr) -> Self {
use IRAggExpr::*;
match value {
Min {
input: _,
propagate_nans,
} => {
if propagate_nans {
GroupByMethod::NanMin
} else {
GroupByMethod::Min
}
},
Max {
input: _,
propagate_nans,
} => {
if propagate_nans {
GroupByMethod::NanMax
} else {
GroupByMethod::Max
}
},
Median(_) => GroupByMethod::Median,
NUnique(_) => GroupByMethod::NUnique,
First(_) => GroupByMethod::First,
FirstNonNull(_) => GroupByMethod::FirstNonNull,
Last(_) => GroupByMethod::Last,
LastNonNull(_) => GroupByMethod::LastNonNull,
Item { allow_empty, .. } => GroupByMethod::Item { allow_empty },
Mean(_) => GroupByMethod::Mean,
Implode { maintain_order, .. } => GroupByMethod::Implode { maintain_order },
Sum(_) => GroupByMethod::Sum,
Count {
input: _,
include_nulls,
} => GroupByMethod::Count { include_nulls },
Std(_, ddof) => GroupByMethod::Std(ddof),
Var(_, ddof) => GroupByMethod::Var(ddof),
AggGroups(_) => GroupByMethod::Groups,
}
}
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
pub enum AExpr {
Element,
Explode {
expr: Node,
options: ExplodeOptions,
},
Column(PlSmallStr),
#[cfg(feature = "dtype-struct")]
StructField(PlSmallStr),
Literal(LiteralValue),
BinaryExpr {
left: Node,
op: Operator,
right: Node,
},
Cast {
expr: Node,
dtype: DataType,
options: CastOptions,
},
Sort {
expr: Node,
options: SortOptions,
},
Gather {
expr: Node,
idx: Node,
returns_scalar: bool,
null_on_oob: bool,
},
SortBy {
expr: Node,
by: Vec<Node>,
sort_options: SortMultipleOptions,
},
Filter {
input: Node,
by: Node,
},
Agg(IRAggExpr),
Ternary {
predicate: Node,
truthy: Node,
falsy: Node,
},
AnonymousAgg {
input: Vec<ExprIR>,
fmt_str: Box<PlSmallStr>,
function: OpaqueStreamingAgg,
},
AnonymousFunction {
input: Vec<ExprIR>,
function: OpaqueColumnUdf,
options: FunctionOptions,
fmt_str: Box<PlSmallStr>,
},
Eval {
expr: Node,
evaluation: Node,
variant: EvalVariant,
},
#[cfg(feature = "dtype-struct")]
StructEval {
expr: Node,
evaluation: Vec<ExprIR>,
},
Function {
input: Vec<ExprIR>,
function: IRFunctionExpr,
options: FunctionOptions,
},
Over {
function: Node,
partition_by: Vec<Node>,
order_by: Option<(Node, SortOptions)>,
mapping: WindowMapping,
},
#[cfg(feature = "dynamic_group_by")]
Rolling {
function: Node,
index_column: Node,
period: Duration,
offset: Duration,
closed_window: ClosedWindow,
},
Slice {
input: Node,
offset: Node,
length: Node,
},
#[default]
Len,
}
impl AExpr {
#[cfg(feature = "cse")]
pub(crate) fn col(name: PlSmallStr) -> Self {
AExpr::Column(name)
}
pub fn is_fallible_top_level(&self, arena: &Arena<AExpr>) -> bool {
#[allow(clippy::collapsible_match, clippy::match_like_matches_macro)]
match self {
AExpr::Function {
input, function, ..
} => match function {
IRFunctionExpr::ListExpr(f) => match f {
IRListFunction::Get(false) => true,
#[cfg(feature = "list_gather")]
IRListFunction::Gather(false) => true,
_ => false,
},
#[cfg(feature = "dtype-array")]
IRFunctionExpr::ArrayExpr(f) => match f {
IRArrayFunction::Get(false) => true,
_ => false,
},
#[cfg(feature = "replace")]
IRFunctionExpr::ReplaceStrict { .. } => true,
#[cfg(all(feature = "strings", feature = "temporal"))]
IRFunctionExpr::StringExpr(f) => match f {
IRStringFunction::Strptime(_, strptime_options) => {
debug_assert!(input.len() <= 2);
let ambiguous_arg_is_infallible_scalar = input
.get(1)
.map(|x| arena.get(x.node()))
.is_some_and(|ae| match ae {
AExpr::Literal(lv) => {
lv.extract_str().is_some_and(|ambiguous| match ambiguous {
"earliest" | "latest" | "null" => true,
"raise" => false,
v => {
if cfg!(debug_assertions) {
panic!("unhandled parameter to ambiguous: {v}")
}
false
},
})
},
_ => false,
});
let ambiguous_is_fallible = !ambiguous_arg_is_infallible_scalar;
!matches!(arena.get(input[0].node()), AExpr::Literal(_))
&& (strptime_options.strict || ambiguous_is_fallible)
},
_ => false,
},
_ => false,
},
AExpr::Cast {
expr,
dtype: _,
options: CastOptions::Strict,
} => !matches!(arena.get(*expr), AExpr::Literal(_)),
_ => false,
}
}
}
#[recursive::recursive]
pub fn deep_clone_ae(ae: Node, arena: &mut Arena<AExpr>) -> Node {
let slf = arena.get(ae).clone();
let mut children = vec![];
slf.children_rev(&mut children);
for child in &mut children {
*child = deep_clone_ae(*child, arena);
}
children.reverse();
arena.add(slf.replace_children(&children))
}