polars_plan/plans/aexpr/
mod.rs

1mod evaluate;
2#[cfg(feature = "cse")]
3mod hash;
4mod minterm_iter;
5pub mod predicates;
6mod scalar;
7mod schema;
8mod traverse;
9
10use std::hash::{Hash, Hasher};
11
12#[cfg(feature = "cse")]
13pub(super) use hash::traverse_and_hash_aexpr;
14pub use minterm_iter::MintermIter;
15use polars_compute::rolling::QuantileMethod;
16use polars_core::chunked_array::cast::CastOptions;
17use polars_core::prelude::*;
18use polars_core::utils::{get_time_units, try_get_supertype};
19use polars_utils::arena::{Arena, Node};
20pub use scalar::is_scalar_ae;
21#[cfg(feature = "ir_serde")]
22use serde::{Deserialize, Serialize};
23use strum_macros::IntoStaticStr;
24pub use traverse::*;
25mod properties;
26pub use properties::*;
27
28use crate::constants::LEN;
29use crate::plans::Context;
30use crate::prelude::*;
31
32#[derive(Clone, Debug, IntoStaticStr)]
33#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
34pub enum IRAggExpr {
35    Min {
36        input: Node,
37        propagate_nans: bool,
38    },
39    Max {
40        input: Node,
41        propagate_nans: bool,
42    },
43    Median(Node),
44    NUnique(Node),
45    First(Node),
46    Last(Node),
47    Mean(Node),
48    Implode(Node),
49    Quantile {
50        expr: Node,
51        quantile: Node,
52        method: QuantileMethod,
53    },
54    Sum(Node),
55    // include_nulls
56    Count(Node, bool),
57    Std(Node, u8),
58    Var(Node, u8),
59    AggGroups(Node),
60}
61
62impl Hash for IRAggExpr {
63    fn hash<H: Hasher>(&self, state: &mut H) {
64        std::mem::discriminant(self).hash(state);
65        match self {
66            Self::Min { propagate_nans, .. } | Self::Max { propagate_nans, .. } => {
67                propagate_nans.hash(state)
68            },
69            Self::Quantile {
70                method: interpol, ..
71            } => interpol.hash(state),
72            Self::Std(_, v) | Self::Var(_, v) => v.hash(state),
73            _ => {},
74        }
75    }
76}
77
78#[cfg(feature = "cse")]
79impl IRAggExpr {
80    pub(super) fn equal_nodes(&self, other: &IRAggExpr) -> bool {
81        use IRAggExpr::*;
82        match (self, other) {
83            (
84                Min {
85                    propagate_nans: l, ..
86                },
87                Min {
88                    propagate_nans: r, ..
89                },
90            ) => l == r,
91            (
92                Max {
93                    propagate_nans: l, ..
94                },
95                Max {
96                    propagate_nans: r, ..
97                },
98            ) => l == r,
99            (Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r,
100            (Std(_, l), Std(_, r)) => l == r,
101            (Var(_, l), Var(_, r)) => l == r,
102            _ => std::mem::discriminant(self) == std::mem::discriminant(other),
103        }
104    }
105}
106
107impl From<IRAggExpr> for GroupByMethod {
108    fn from(value: IRAggExpr) -> Self {
109        use IRAggExpr::*;
110        match value {
111            Min { propagate_nans, .. } => {
112                if propagate_nans {
113                    GroupByMethod::NanMin
114                } else {
115                    GroupByMethod::Min
116                }
117            },
118            Max { propagate_nans, .. } => {
119                if propagate_nans {
120                    GroupByMethod::NanMax
121                } else {
122                    GroupByMethod::Max
123                }
124            },
125            Median(_) => GroupByMethod::Median,
126            NUnique(_) => GroupByMethod::NUnique,
127            First(_) => GroupByMethod::First,
128            Last(_) => GroupByMethod::Last,
129            Mean(_) => GroupByMethod::Mean,
130            Implode(_) => GroupByMethod::Implode,
131            Sum(_) => GroupByMethod::Sum,
132            Count(_, include_nulls) => GroupByMethod::Count { include_nulls },
133            Std(_, ddof) => GroupByMethod::Std(ddof),
134            Var(_, ddof) => GroupByMethod::Var(ddof),
135            AggGroups(_) => GroupByMethod::Groups,
136            Quantile { .. } => unreachable!(),
137        }
138    }
139}
140
141/// IR expression node that is allocated in an [`Arena`][polars_utils::arena::Arena].
142#[derive(Clone, Debug, Default)]
143#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
144pub enum AExpr {
145    Explode {
146        expr: Node,
147        skip_empty: bool,
148    },
149    Alias(Node, PlSmallStr),
150    Column(PlSmallStr),
151    Literal(LiteralValue),
152    BinaryExpr {
153        left: Node,
154        op: Operator,
155        right: Node,
156    },
157    Cast {
158        expr: Node,
159        dtype: DataType,
160        options: CastOptions,
161    },
162    Sort {
163        expr: Node,
164        options: SortOptions,
165    },
166    Gather {
167        expr: Node,
168        idx: Node,
169        returns_scalar: bool,
170    },
171    SortBy {
172        expr: Node,
173        by: Vec<Node>,
174        sort_options: SortMultipleOptions,
175    },
176    Filter {
177        input: Node,
178        by: Node,
179    },
180    Agg(IRAggExpr),
181    Ternary {
182        predicate: Node,
183        truthy: Node,
184        falsy: Node,
185    },
186    AnonymousFunction {
187        input: Vec<ExprIR>,
188        function: OpaqueColumnUdf,
189        output_type: GetOutput,
190        options: FunctionOptions,
191    },
192    Function {
193        /// Function arguments
194        /// Some functions rely on aliases,
195        /// for instance assignment of struct fields.
196        /// Therefor we need [`ExprIr`].
197        input: Vec<ExprIR>,
198        /// function to apply
199        function: FunctionExpr,
200        options: FunctionOptions,
201    },
202    Window {
203        function: Node,
204        partition_by: Vec<Node>,
205        order_by: Option<(Node, SortOptions)>,
206        options: WindowType,
207    },
208    Slice {
209        input: Node,
210        offset: Node,
211        length: Node,
212    },
213    #[default]
214    Len,
215}
216
217impl AExpr {
218    #[cfg(feature = "cse")]
219    pub(crate) fn col(name: PlSmallStr) -> Self {
220        AExpr::Column(name)
221    }
222
223    /// This should be a 1 on 1 copy of the get_type method of Expr until Expr is completely phased out.
224    pub fn get_type(
225        &self,
226        schema: &Schema,
227        ctxt: Context,
228        arena: &Arena<AExpr>,
229    ) -> PolarsResult<DataType> {
230        self.to_field(schema, ctxt, arena)
231            .map(|f| f.dtype().clone())
232    }
233}