polars_plan/plans/aexpr/
traverse.rs

1use super::*;
2
3impl AExpr {
4    /// Push the inputs of this node to the given container, in reverse order.
5    /// This ensures the primary node responsible for the name is pushed last.
6    pub fn inputs_rev<E>(&self, container: &mut E)
7    where
8        E: Extend<Node>,
9    {
10        use AExpr::*;
11
12        match self {
13            Column(_) | Literal(_) | Len => {},
14            BinaryExpr { left, op: _, right } => {
15                container.extend([*right, *left]);
16            },
17            Cast { expr, .. } => container.extend([*expr]),
18            Sort { expr, .. } => container.extend([*expr]),
19            Gather { expr, idx, .. } => {
20                container.extend([*idx, *expr]);
21            },
22            SortBy { expr, by, .. } => {
23                container.extend(by.iter().cloned().rev());
24                container.extend([*expr]);
25            },
26            Filter { input, by } => {
27                container.extend([*by, *input]);
28            },
29            Agg(agg_e) => match agg_e.get_input() {
30                NodeInputs::Single(node) => container.extend([node]),
31                NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()),
32                NodeInputs::Leaf => {},
33            },
34            Ternary {
35                truthy,
36                falsy,
37                predicate,
38            } => {
39                container.extend([*predicate, *falsy, *truthy]);
40            },
41            AnonymousFunction { input, .. } | Function { input, .. } => {
42                container.extend(input.iter().rev().map(|e| e.node()))
43            },
44            Explode { expr: e, .. } => container.extend([*e]),
45            Window {
46                function,
47                partition_by,
48                order_by,
49                options: _,
50            } => {
51                if let Some((n, _)) = order_by {
52                    container.extend([*n]);
53                }
54                container.extend(partition_by.iter().rev().cloned());
55                container.extend([*function]);
56            },
57            Eval {
58                expr,
59                evaluation,
60                variant: _,
61            } => {
62                // We don't use the evaluation here because it does not contain inputs.
63                _ = evaluation;
64                container.extend([*expr]);
65            },
66            Slice {
67                input,
68                offset,
69                length,
70            } => {
71                container.extend([*length, *offset, *input]);
72            },
73        }
74    }
75
76    pub fn replace_inputs(mut self, inputs: &[Node]) -> Self {
77        use AExpr::*;
78        let input = match &mut self {
79            Column(_) | Literal(_) | Len => return self,
80            Cast { expr, .. } => expr,
81            Explode { expr, .. } => expr,
82            BinaryExpr { left, right, .. } => {
83                *left = inputs[0];
84                *right = inputs[1];
85                return self;
86            },
87            Gather { expr, idx, .. } => {
88                *expr = inputs[0];
89                *idx = inputs[1];
90                return self;
91            },
92            Sort { expr, .. } => expr,
93            SortBy { expr, by, .. } => {
94                *expr = inputs[0];
95                by.clear();
96                by.extend_from_slice(&inputs[1..]);
97                return self;
98            },
99            Filter { input, by, .. } => {
100                *input = inputs[0];
101                *by = inputs[1];
102                return self;
103            },
104            Agg(a) => {
105                match a {
106                    IRAggExpr::Quantile { expr, quantile, .. } => {
107                        *expr = inputs[0];
108                        *quantile = inputs[1];
109                    },
110                    _ => {
111                        a.set_input(inputs[0]);
112                    },
113                }
114                return self;
115            },
116            Ternary {
117                truthy,
118                falsy,
119                predicate,
120            } => {
121                *truthy = inputs[0];
122                *falsy = inputs[1];
123                *predicate = inputs[2];
124                return self;
125            },
126            AnonymousFunction { input, .. } | Function { input, .. } => {
127                assert_eq!(input.len(), inputs.len());
128                for (e, node) in input.iter_mut().zip(inputs.iter()) {
129                    e.set_node(*node);
130                }
131                return self;
132            },
133            Eval {
134                expr,
135                evaluation,
136                variant: _,
137            } => {
138                *expr = inputs[0];
139                _ = evaluation; // Intentional.
140                return self;
141            },
142            Slice {
143                input,
144                offset,
145                length,
146            } => {
147                *input = inputs[0];
148                *offset = inputs[1];
149                *length = inputs[2];
150                return self;
151            },
152            Window {
153                function,
154                partition_by,
155                order_by,
156                ..
157            } => {
158                let offset = order_by.is_some() as usize;
159                *function = inputs[0];
160                partition_by.clear();
161                partition_by.extend_from_slice(&inputs[1..inputs.len() - offset]);
162                if let Some((_, options)) = order_by {
163                    *order_by = Some((*inputs.last().unwrap(), *options));
164                }
165                return self;
166            },
167        };
168        *input = inputs[0];
169        self
170    }
171}
172
173impl IRAggExpr {
174    pub fn get_input(&self) -> NodeInputs {
175        use IRAggExpr::*;
176        use NodeInputs::*;
177        match self {
178            Min { input, .. } => Single(*input),
179            Max { input, .. } => Single(*input),
180            Median(input) => Single(*input),
181            NUnique(input) => Single(*input),
182            First(input) => Single(*input),
183            Last(input) => Single(*input),
184            Mean(input) => Single(*input),
185            Implode(input) => Single(*input),
186            Quantile { expr, quantile, .. } => Many(vec![*expr, *quantile]),
187            Sum(input) => Single(*input),
188            Count(input, _) => Single(*input),
189            Std(input, _) => Single(*input),
190            Var(input, _) => Single(*input),
191            AggGroups(input) => Single(*input),
192        }
193    }
194    pub fn set_input(&mut self, input: Node) {
195        use IRAggExpr::*;
196        let node = match self {
197            Min { input, .. } => input,
198            Max { input, .. } => input,
199            Median(input) => input,
200            NUnique(input) => input,
201            First(input) => input,
202            Last(input) => input,
203            Mean(input) => input,
204            Implode(input) => input,
205            Quantile { expr, .. } => expr,
206            Sum(input) => input,
207            Count(input, _) => input,
208            Std(input, _) => input,
209            Var(input, _) => input,
210            AggGroups(input) => input,
211        };
212        *node = input;
213    }
214}
215
216pub enum NodeInputs {
217    Leaf,
218    Single(Node),
219    Many(Vec<Node>),
220}
221
222impl NodeInputs {
223    pub fn first(&self) -> Node {
224        match self {
225            NodeInputs::Single(node) => *node,
226            NodeInputs::Many(nodes) => nodes[0],
227            NodeInputs::Leaf => panic!(),
228        }
229    }
230}