1mod evaluate;
2#[cfg(feature = "cse")]
3mod hash;
4mod minterm_iter;
5pub mod predicates;
6mod scalar;
7mod schema;
8mod traverse;
9
10use std::hash::{Hash, Hasher};
11
12#[cfg(feature = "cse")]
13pub(super) use hash::traverse_and_hash_aexpr;
14pub use minterm_iter::MintermIter;
15use polars_compute::rolling::QuantileMethod;
16use polars_core::chunked_array::cast::CastOptions;
17use polars_core::prelude::*;
18use polars_core::utils::{get_time_units, try_get_supertype};
19use polars_utils::arena::{Arena, Node};
20pub use scalar::is_scalar_ae;
21#[cfg(feature = "ir_serde")]
22use serde::{Deserialize, Serialize};
23use strum_macros::IntoStaticStr;
24pub use traverse::*;
25mod properties;
26pub use properties::*;
27
28use crate::constants::LEN;
29use crate::plans::Context;
30use crate::prelude::*;
31
32#[derive(Clone, Debug, IntoStaticStr)]
33#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
34pub enum IRAggExpr {
35 Min {
36 input: Node,
37 propagate_nans: bool,
38 },
39 Max {
40 input: Node,
41 propagate_nans: bool,
42 },
43 Median(Node),
44 NUnique(Node),
45 First(Node),
46 Last(Node),
47 Mean(Node),
48 Implode(Node),
49 Quantile {
50 expr: Node,
51 quantile: Node,
52 method: QuantileMethod,
53 },
54 Sum(Node),
55 Count(Node, bool),
57 Std(Node, u8),
58 Var(Node, u8),
59 AggGroups(Node),
60}
61
62impl Hash for IRAggExpr {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 std::mem::discriminant(self).hash(state);
65 match self {
66 Self::Min { propagate_nans, .. } | Self::Max { propagate_nans, .. } => {
67 propagate_nans.hash(state)
68 },
69 Self::Quantile {
70 method: interpol, ..
71 } => interpol.hash(state),
72 Self::Std(_, v) | Self::Var(_, v) => v.hash(state),
73 _ => {},
74 }
75 }
76}
77
78#[cfg(feature = "cse")]
79impl IRAggExpr {
80 pub(super) fn equal_nodes(&self, other: &IRAggExpr) -> bool {
81 use IRAggExpr::*;
82 match (self, other) {
83 (
84 Min {
85 propagate_nans: l, ..
86 },
87 Min {
88 propagate_nans: r, ..
89 },
90 ) => l == r,
91 (
92 Max {
93 propagate_nans: l, ..
94 },
95 Max {
96 propagate_nans: r, ..
97 },
98 ) => l == r,
99 (Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r,
100 (Std(_, l), Std(_, r)) => l == r,
101 (Var(_, l), Var(_, r)) => l == r,
102 _ => std::mem::discriminant(self) == std::mem::discriminant(other),
103 }
104 }
105}
106
107impl From<IRAggExpr> for GroupByMethod {
108 fn from(value: IRAggExpr) -> Self {
109 use IRAggExpr::*;
110 match value {
111 Min { propagate_nans, .. } => {
112 if propagate_nans {
113 GroupByMethod::NanMin
114 } else {
115 GroupByMethod::Min
116 }
117 },
118 Max { propagate_nans, .. } => {
119 if propagate_nans {
120 GroupByMethod::NanMax
121 } else {
122 GroupByMethod::Max
123 }
124 },
125 Median(_) => GroupByMethod::Median,
126 NUnique(_) => GroupByMethod::NUnique,
127 First(_) => GroupByMethod::First,
128 Last(_) => GroupByMethod::Last,
129 Mean(_) => GroupByMethod::Mean,
130 Implode(_) => GroupByMethod::Implode,
131 Sum(_) => GroupByMethod::Sum,
132 Count(_, include_nulls) => GroupByMethod::Count { include_nulls },
133 Std(_, ddof) => GroupByMethod::Std(ddof),
134 Var(_, ddof) => GroupByMethod::Var(ddof),
135 AggGroups(_) => GroupByMethod::Groups,
136 Quantile { .. } => unreachable!(),
137 }
138 }
139}
140
141#[derive(Clone, Debug, Default)]
143#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
144pub enum AExpr {
145 Explode {
146 expr: Node,
147 skip_empty: bool,
148 },
149 Alias(Node, PlSmallStr),
150 Column(PlSmallStr),
151 Literal(LiteralValue),
152 BinaryExpr {
153 left: Node,
154 op: Operator,
155 right: Node,
156 },
157 Cast {
158 expr: Node,
159 dtype: DataType,
160 options: CastOptions,
161 },
162 Sort {
163 expr: Node,
164 options: SortOptions,
165 },
166 Gather {
167 expr: Node,
168 idx: Node,
169 returns_scalar: bool,
170 },
171 SortBy {
172 expr: Node,
173 by: Vec<Node>,
174 sort_options: SortMultipleOptions,
175 },
176 Filter {
177 input: Node,
178 by: Node,
179 },
180 Agg(IRAggExpr),
181 Ternary {
182 predicate: Node,
183 truthy: Node,
184 falsy: Node,
185 },
186 AnonymousFunction {
187 input: Vec<ExprIR>,
188 function: OpaqueColumnUdf,
189 output_type: GetOutput,
190 options: FunctionOptions,
191 },
192 Function {
193 input: Vec<ExprIR>,
198 function: FunctionExpr,
200 options: FunctionOptions,
201 },
202 Window {
203 function: Node,
204 partition_by: Vec<Node>,
205 order_by: Option<(Node, SortOptions)>,
206 options: WindowType,
207 },
208 Slice {
209 input: Node,
210 offset: Node,
211 length: Node,
212 },
213 #[default]
214 Len,
215}
216
217impl AExpr {
218 #[cfg(feature = "cse")]
219 pub(crate) fn col(name: PlSmallStr) -> Self {
220 AExpr::Column(name)
221 }
222
223 pub fn get_type(
225 &self,
226 schema: &Schema,
227 ctxt: Context,
228 arena: &Arena<AExpr>,
229 ) -> PolarsResult<DataType> {
230 self.to_field(schema, ctxt, arena)
231 .map(|f| f.dtype().clone())
232 }
233}