polars_plan/plans/visitor/
expr.rs

1use std::fmt::Debug;
2#[cfg(feature = "cse")]
3use std::fmt::Formatter;
4
5use polars_core::prelude::{Field, Schema};
6use polars_utils::unitvec;
7
8use super::*;
9use crate::prelude::*;
10
11impl TreeWalker for Expr {
12    type Arena = ();
13
14    fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(
15        &self,
16        op: &mut F,
17        arena: &Self::Arena,
18    ) -> PolarsResult<VisitRecursion> {
19        let mut scratch = unitvec![];
20
21        self.nodes(&mut scratch);
22
23        for &child in scratch.as_slice() {
24            match op(child, arena)? {
25                // let the recursion continue
26                VisitRecursion::Continue | VisitRecursion::Skip => {},
27                // early stop
28                VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
29            }
30        }
31        Ok(VisitRecursion::Continue)
32    }
33
34    fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(
35        self,
36        f: &mut F,
37        _arena: &mut Self::Arena,
38    ) -> PolarsResult<Self> {
39        use polars_utils::functions::try_arc_map as am;
40        let mut f = |expr| f(expr, &mut ());
41        use AggExpr::*;
42        use Expr::*;
43        #[rustfmt::skip]
44        let ret = match self {
45            Alias(l, r) => Alias(am(l, f)?, r),
46            Column(_) => self,
47            Columns(_) => self,
48            DtypeColumn(_) => self,
49            IndexColumn(_) => self,
50            Literal(_) => self,
51            #[cfg(feature = "dtype-struct")]
52            Field(_) => self,
53            BinaryExpr { left, op, right } => {
54                BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?}
55            },
56            Cast { expr, dtype, options: strict } => Cast { expr: am(expr, f)?, dtype, options: strict },
57            Sort { expr, options } => Sort { expr: am(expr, f)?, options },
58            Gather { expr, idx, returns_scalar } => Gather { expr: am(expr, &mut f)?, idx: am(idx, f)?, returns_scalar },
59            SortBy { expr, by, sort_options } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::<Result<_, _>>()?, sort_options },
60            Agg(agg_expr) => Agg(match agg_expr {
61                Min { input, propagate_nans } => Min { input: am(input, f)?, propagate_nans },
62                Max { input, propagate_nans } => Max { input: am(input, f)?, propagate_nans },
63                Median(x) => Median(am(x, f)?),
64                NUnique(x) => NUnique(am(x, f)?),
65                First(x) => First(am(x, f)?),
66                Last(x) => Last(am(x, f)?),
67                Mean(x) => Mean(am(x, f)?),
68                Implode(x) => Implode(am(x, f)?),
69                Count(x, nulls) => Count(am(x, f)?, nulls),
70                Quantile { expr, quantile, method: interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, method: interpol },
71                Sum(x) => Sum(am(x, f)?),
72                AggGroups(x) => AggGroups(am(x, f)?),
73                Std(x, ddf) => Std(am(x, f)?, ddf),
74                Var(x, ddf) => Var(am(x, f)?, ddf),
75            }),
76            Ternary { predicate, truthy, falsy } => Ternary { predicate: am(predicate, &mut f)?, truthy: am(truthy, &mut f)?, falsy: am(falsy, f)? },
77            Function { input, function, options } => Function { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function, options },
78            Explode { input, skip_empty } => Explode { input: am(input, f)?, skip_empty },
79            Filter { input, by } => Filter { input: am(input, &mut f)?, by: am(by, f)? },
80            Window { function, partition_by, order_by, options } => {
81                let partition_by = partition_by.into_iter().map(&mut f).collect::<Result<_, _>>()?;
82                Window { function: am(function, f)?, partition_by, order_by, options }
83            },
84            Wildcard => Wildcard,
85            Slice { input, offset, length } => Slice { input: am(input, &mut f)?, offset: am(offset, &mut f)?, length: am(length, f)? },
86            Exclude(expr, excluded) => Exclude(am(expr, f)?, excluded),
87            KeepName(expr) => KeepName(am(expr, f)?),
88            Len => Len,
89            Nth(_) => self,
90            RenameAlias { function, expr } => RenameAlias { function, expr: am(expr, f)? },
91            AnonymousFunction { input, function, output_type, options } => {
92                AnonymousFunction { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function, output_type, options }
93            },
94            SubPlan(_, _) => self,
95            Selector(_) => self,
96        };
97        Ok(ret)
98    }
99}
100
101#[derive(Copy, Clone, Debug)]
102pub struct AexprNode {
103    node: Node,
104}
105
106impl AexprNode {
107    pub fn new(node: Node) -> Self {
108        Self { node }
109    }
110
111    /// Get the `Node`.
112    pub fn node(&self) -> Node {
113        self.node
114    }
115
116    pub fn to_aexpr<'a>(&self, arena: &'a Arena<AExpr>) -> &'a AExpr {
117        arena.get(self.node)
118    }
119
120    pub fn to_expr(&self, arena: &Arena<AExpr>) -> Expr {
121        node_to_expr(self.node, arena)
122    }
123
124    pub fn to_field(&self, schema: &Schema, arena: &Arena<AExpr>) -> PolarsResult<Field> {
125        let aexpr = arena.get(self.node);
126        aexpr.to_field(schema, Context::Default, arena)
127    }
128
129    pub fn assign(&mut self, ae: AExpr, arena: &mut Arena<AExpr>) {
130        let node = arena.add(ae);
131        self.node = node;
132    }
133
134    #[cfg(feature = "cse")]
135    pub(crate) fn is_leaf(&self, arena: &Arena<AExpr>) -> bool {
136        matches!(self.to_aexpr(arena), AExpr::Column(_) | AExpr::Literal(_))
137    }
138
139    #[cfg(feature = "cse")]
140    pub(crate) fn hashable_and_cmp<'a>(&self, arena: &'a Arena<AExpr>) -> AExprArena<'a> {
141        AExprArena {
142            node: self.node,
143            arena,
144        }
145    }
146}
147
148#[cfg(feature = "cse")]
149pub struct AExprArena<'a> {
150    node: Node,
151    arena: &'a Arena<AExpr>,
152}
153
154#[cfg(feature = "cse")]
155impl Debug for AExprArena<'_> {
156    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
157        write!(f, "AexprArena: {}", self.node.0)
158    }
159}
160
161impl AExpr {
162    #[cfg(feature = "cse")]
163    fn is_equal_node(&self, other: &Self) -> bool {
164        use AExpr::*;
165        match (self, other) {
166            (Alias(_, l), Alias(_, r)) => l == r,
167            (Column(l), Column(r)) => l == r,
168            (Literal(l), Literal(r)) => l == r,
169            (Window { options: l, .. }, Window { options: r, .. }) => l == r,
170            (
171                Cast {
172                    options: strict_l,
173                    dtype: dtl,
174                    ..
175                },
176                Cast {
177                    options: strict_r,
178                    dtype: dtr,
179                    ..
180                },
181            ) => strict_l == strict_r && dtl == dtr,
182            (Sort { options: l, .. }, Sort { options: r, .. }) => l == r,
183            (Gather { .. }, Gather { .. })
184            | (Filter { .. }, Filter { .. })
185            | (Ternary { .. }, Ternary { .. })
186            | (Len, Len)
187            | (Slice { .. }, Slice { .. }) => true,
188            (
189                Explode {
190                    expr: _,
191                    skip_empty: l_skip_empty,
192                },
193                Explode {
194                    expr: _,
195                    skip_empty: r_skip_empty,
196                },
197            ) => l_skip_empty == r_skip_empty,
198            (
199                SortBy {
200                    sort_options: l_sort_options,
201                    ..
202                },
203                SortBy {
204                    sort_options: r_sort_options,
205                    ..
206                },
207            ) => l_sort_options == r_sort_options,
208            (Agg(l), Agg(r)) => l.equal_nodes(r),
209            (
210                Function {
211                    input: il,
212                    function: fl,
213                    options: ol,
214                },
215                Function {
216                    input: ir,
217                    function: fr,
218                    options: or,
219                },
220            ) => {
221                fl == fr && ol == or && {
222                    let mut all_same_name = true;
223                    for (l, r) in il.iter().zip(ir) {
224                        all_same_name &= l.output_name() == r.output_name()
225                    }
226
227                    all_same_name
228                }
229            },
230            (AnonymousFunction { .. }, AnonymousFunction { .. }) => false,
231            (BinaryExpr { op: l, .. }, BinaryExpr { op: r, .. }) => l == r,
232            _ => false,
233        }
234    }
235}
236
237#[cfg(feature = "cse")]
238impl<'a> AExprArena<'a> {
239    pub fn new(node: Node, arena: &'a Arena<AExpr>) -> Self {
240        Self { node, arena }
241    }
242    pub fn to_aexpr(&self) -> &'a AExpr {
243        self.arena.get(self.node)
244    }
245
246    // Check single node on equality
247    pub fn is_equal_single(&self, other: &Self) -> bool {
248        let self_ae = self.to_aexpr();
249        let other_ae = other.to_aexpr();
250        self_ae.is_equal_node(other_ae)
251    }
252}
253
254#[cfg(feature = "cse")]
255impl PartialEq for AExprArena<'_> {
256    fn eq(&self, other: &Self) -> bool {
257        let mut scratch1 = unitvec![];
258        let mut scratch2 = unitvec![];
259
260        scratch1.push(self.node);
261        scratch2.push(other.node);
262
263        loop {
264            match (scratch1.pop(), scratch2.pop()) {
265                (Some(l), Some(r)) => {
266                    let l = Self::new(l, self.arena);
267                    let r = Self::new(r, self.arena);
268
269                    if !l.is_equal_single(&r) {
270                        return false;
271                    }
272
273                    l.to_aexpr().inputs_rev(&mut scratch1);
274                    r.to_aexpr().inputs_rev(&mut scratch2);
275                },
276                (None, None) => return true,
277                _ => return false,
278            }
279        }
280    }
281}
282
283impl TreeWalker for AexprNode {
284    type Arena = Arena<AExpr>;
285    fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(
286        &self,
287        op: &mut F,
288        arena: &Self::Arena,
289    ) -> PolarsResult<VisitRecursion> {
290        let mut scratch = unitvec![];
291
292        self.to_aexpr(arena).inputs_rev(&mut scratch);
293        for node in scratch.as_slice() {
294            let aenode = AexprNode::new(*node);
295            match op(&aenode, arena)? {
296                // let the recursion continue
297                VisitRecursion::Continue | VisitRecursion::Skip => {},
298                // early stop
299                VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
300            }
301        }
302        Ok(VisitRecursion::Continue)
303    }
304
305    fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(
306        mut self,
307        op: &mut F,
308        arena: &mut Self::Arena,
309    ) -> PolarsResult<Self> {
310        let mut scratch = unitvec![];
311
312        let ae = arena.get(self.node).clone();
313        ae.inputs_rev(&mut scratch);
314
315        // rewrite the nodes
316        for node in scratch.as_mut_slice() {
317            let aenode = AexprNode::new(*node);
318            *node = op(aenode, arena)?.node;
319        }
320
321        scratch.as_mut_slice().reverse();
322        let ae = ae.replace_inputs(&scratch);
323        self.node = arena.add(ae);
324        Ok(self)
325    }
326}