1use std::fmt::Debug;
2#[cfg(feature = "cse")]
3use std::fmt::Formatter;
4
5use polars_core::prelude::{Field, Schema};
6use polars_utils::unitvec;
7
8use super::*;
9use crate::prelude::*;
10
11impl TreeWalker for Expr {
12 type Arena = ();
13
14 fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(
15 &self,
16 op: &mut F,
17 arena: &Self::Arena,
18 ) -> PolarsResult<VisitRecursion> {
19 let mut scratch = unitvec![];
20
21 self.nodes(&mut scratch);
22
23 for &child in scratch.as_slice() {
24 match op(child, arena)? {
25 VisitRecursion::Continue | VisitRecursion::Skip => {},
27 VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
29 }
30 }
31 Ok(VisitRecursion::Continue)
32 }
33
34 fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(
35 self,
36 f: &mut F,
37 _arena: &mut Self::Arena,
38 ) -> PolarsResult<Self> {
39 use polars_utils::functions::try_arc_map as am;
40 let mut f = |expr| f(expr, &mut ());
41 use AggExpr::*;
42 use Expr::*;
43 #[rustfmt::skip]
44 let ret = match self {
45 Alias(l, r) => Alias(am(l, f)?, r),
46 Column(_) => self,
47 Columns(_) => self,
48 DtypeColumn(_) => self,
49 IndexColumn(_) => self,
50 Literal(_) => self,
51 #[cfg(feature = "dtype-struct")]
52 Field(_) => self,
53 BinaryExpr { left, op, right } => {
54 BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?}
55 },
56 Cast { expr, dtype, options: strict } => Cast { expr: am(expr, f)?, dtype, options: strict },
57 Sort { expr, options } => Sort { expr: am(expr, f)?, options },
58 Gather { expr, idx, returns_scalar } => Gather { expr: am(expr, &mut f)?, idx: am(idx, f)?, returns_scalar },
59 SortBy { expr, by, sort_options } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::<Result<_, _>>()?, sort_options },
60 Agg(agg_expr) => Agg(match agg_expr {
61 Min { input, propagate_nans } => Min { input: am(input, f)?, propagate_nans },
62 Max { input, propagate_nans } => Max { input: am(input, f)?, propagate_nans },
63 Median(x) => Median(am(x, f)?),
64 NUnique(x) => NUnique(am(x, f)?),
65 First(x) => First(am(x, f)?),
66 Last(x) => Last(am(x, f)?),
67 Mean(x) => Mean(am(x, f)?),
68 Implode(x) => Implode(am(x, f)?),
69 Count(x, nulls) => Count(am(x, f)?, nulls),
70 Quantile { expr, quantile, method: interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, method: interpol },
71 Sum(x) => Sum(am(x, f)?),
72 AggGroups(x) => AggGroups(am(x, f)?),
73 Std(x, ddf) => Std(am(x, f)?, ddf),
74 Var(x, ddf) => Var(am(x, f)?, ddf),
75 }),
76 Ternary { predicate, truthy, falsy } => Ternary { predicate: am(predicate, &mut f)?, truthy: am(truthy, &mut f)?, falsy: am(falsy, f)? },
77 Function { input, function, options } => Function { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function, options },
78 Explode { input, skip_empty } => Explode { input: am(input, f)?, skip_empty },
79 Filter { input, by } => Filter { input: am(input, &mut f)?, by: am(by, f)? },
80 Window { function, partition_by, order_by, options } => {
81 let partition_by = partition_by.into_iter().map(&mut f).collect::<Result<_, _>>()?;
82 Window { function: am(function, f)?, partition_by, order_by, options }
83 },
84 Wildcard => Wildcard,
85 Slice { input, offset, length } => Slice { input: am(input, &mut f)?, offset: am(offset, &mut f)?, length: am(length, f)? },
86 Exclude(expr, excluded) => Exclude(am(expr, f)?, excluded),
87 KeepName(expr) => KeepName(am(expr, f)?),
88 Len => Len,
89 Nth(_) => self,
90 RenameAlias { function, expr } => RenameAlias { function, expr: am(expr, f)? },
91 AnonymousFunction { input, function, output_type, options } => {
92 AnonymousFunction { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function, output_type, options }
93 },
94 SubPlan(_, _) => self,
95 Selector(_) => self,
96 };
97 Ok(ret)
98 }
99}
100
101#[derive(Copy, Clone, Debug)]
102pub struct AexprNode {
103 node: Node,
104}
105
106impl AexprNode {
107 pub fn new(node: Node) -> Self {
108 Self { node }
109 }
110
111 pub fn node(&self) -> Node {
113 self.node
114 }
115
116 pub fn to_aexpr<'a>(&self, arena: &'a Arena<AExpr>) -> &'a AExpr {
117 arena.get(self.node)
118 }
119
120 pub fn to_expr(&self, arena: &Arena<AExpr>) -> Expr {
121 node_to_expr(self.node, arena)
122 }
123
124 pub fn to_field(&self, schema: &Schema, arena: &Arena<AExpr>) -> PolarsResult<Field> {
125 let aexpr = arena.get(self.node);
126 aexpr.to_field(schema, Context::Default, arena)
127 }
128
129 pub fn assign(&mut self, ae: AExpr, arena: &mut Arena<AExpr>) {
130 let node = arena.add(ae);
131 self.node = node;
132 }
133
134 #[cfg(feature = "cse")]
135 pub(crate) fn is_leaf(&self, arena: &Arena<AExpr>) -> bool {
136 matches!(self.to_aexpr(arena), AExpr::Column(_) | AExpr::Literal(_))
137 }
138
139 #[cfg(feature = "cse")]
140 pub(crate) fn hashable_and_cmp<'a>(&self, arena: &'a Arena<AExpr>) -> AExprArena<'a> {
141 AExprArena {
142 node: self.node,
143 arena,
144 }
145 }
146}
147
148#[cfg(feature = "cse")]
149pub struct AExprArena<'a> {
150 node: Node,
151 arena: &'a Arena<AExpr>,
152}
153
154#[cfg(feature = "cse")]
155impl Debug for AExprArena<'_> {
156 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
157 write!(f, "AexprArena: {}", self.node.0)
158 }
159}
160
161impl AExpr {
162 #[cfg(feature = "cse")]
163 fn is_equal_node(&self, other: &Self) -> bool {
164 use AExpr::*;
165 match (self, other) {
166 (Alias(_, l), Alias(_, r)) => l == r,
167 (Column(l), Column(r)) => l == r,
168 (Literal(l), Literal(r)) => l == r,
169 (Window { options: l, .. }, Window { options: r, .. }) => l == r,
170 (
171 Cast {
172 options: strict_l,
173 dtype: dtl,
174 ..
175 },
176 Cast {
177 options: strict_r,
178 dtype: dtr,
179 ..
180 },
181 ) => strict_l == strict_r && dtl == dtr,
182 (Sort { options: l, .. }, Sort { options: r, .. }) => l == r,
183 (Gather { .. }, Gather { .. })
184 | (Filter { .. }, Filter { .. })
185 | (Ternary { .. }, Ternary { .. })
186 | (Len, Len)
187 | (Slice { .. }, Slice { .. }) => true,
188 (
189 Explode {
190 expr: _,
191 skip_empty: l_skip_empty,
192 },
193 Explode {
194 expr: _,
195 skip_empty: r_skip_empty,
196 },
197 ) => l_skip_empty == r_skip_empty,
198 (
199 SortBy {
200 sort_options: l_sort_options,
201 ..
202 },
203 SortBy {
204 sort_options: r_sort_options,
205 ..
206 },
207 ) => l_sort_options == r_sort_options,
208 (Agg(l), Agg(r)) => l.equal_nodes(r),
209 (
210 Function {
211 input: il,
212 function: fl,
213 options: ol,
214 },
215 Function {
216 input: ir,
217 function: fr,
218 options: or,
219 },
220 ) => {
221 fl == fr && ol == or && {
222 let mut all_same_name = true;
223 for (l, r) in il.iter().zip(ir) {
224 all_same_name &= l.output_name() == r.output_name()
225 }
226
227 all_same_name
228 }
229 },
230 (AnonymousFunction { .. }, AnonymousFunction { .. }) => false,
231 (BinaryExpr { op: l, .. }, BinaryExpr { op: r, .. }) => l == r,
232 _ => false,
233 }
234 }
235}
236
237#[cfg(feature = "cse")]
238impl<'a> AExprArena<'a> {
239 pub fn new(node: Node, arena: &'a Arena<AExpr>) -> Self {
240 Self { node, arena }
241 }
242 pub fn to_aexpr(&self) -> &'a AExpr {
243 self.arena.get(self.node)
244 }
245
246 pub fn is_equal_single(&self, other: &Self) -> bool {
248 let self_ae = self.to_aexpr();
249 let other_ae = other.to_aexpr();
250 self_ae.is_equal_node(other_ae)
251 }
252}
253
254#[cfg(feature = "cse")]
255impl PartialEq for AExprArena<'_> {
256 fn eq(&self, other: &Self) -> bool {
257 let mut scratch1 = unitvec![];
258 let mut scratch2 = unitvec![];
259
260 scratch1.push(self.node);
261 scratch2.push(other.node);
262
263 loop {
264 match (scratch1.pop(), scratch2.pop()) {
265 (Some(l), Some(r)) => {
266 let l = Self::new(l, self.arena);
267 let r = Self::new(r, self.arena);
268
269 if !l.is_equal_single(&r) {
270 return false;
271 }
272
273 l.to_aexpr().inputs_rev(&mut scratch1);
274 r.to_aexpr().inputs_rev(&mut scratch2);
275 },
276 (None, None) => return true,
277 _ => return false,
278 }
279 }
280 }
281}
282
283impl TreeWalker for AexprNode {
284 type Arena = Arena<AExpr>;
285 fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(
286 &self,
287 op: &mut F,
288 arena: &Self::Arena,
289 ) -> PolarsResult<VisitRecursion> {
290 let mut scratch = unitvec![];
291
292 self.to_aexpr(arena).inputs_rev(&mut scratch);
293 for node in scratch.as_slice() {
294 let aenode = AexprNode::new(*node);
295 match op(&aenode, arena)? {
296 VisitRecursion::Continue | VisitRecursion::Skip => {},
298 VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
300 }
301 }
302 Ok(VisitRecursion::Continue)
303 }
304
305 fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(
306 mut self,
307 op: &mut F,
308 arena: &mut Self::Arena,
309 ) -> PolarsResult<Self> {
310 let mut scratch = unitvec![];
311
312 let ae = arena.get(self.node).clone();
313 ae.inputs_rev(&mut scratch);
314
315 for node in scratch.as_mut_slice() {
317 let aenode = AexprNode::new(*node);
318 *node = op(aenode, arena)?.node;
319 }
320
321 scratch.as_mut_slice().reverse();
322 let ae = ae.replace_inputs(&scratch);
323 self.node = arena.add(ae);
324 Ok(self)
325 }
326}