1use super::*;
2
3impl AExpr {
4 pub fn inputs_rev<E>(&self, container: &mut E)
7 where
8 E: Extend<Node>,
9 {
10 use AExpr::*;
11
12 match self {
13 Column(_) | Literal(_) | Len => {},
14 BinaryExpr { left, op: _, right } => {
15 container.extend([*right, *left]);
16 },
17 Cast { expr, .. } => container.extend([*expr]),
18 Sort { expr, .. } => container.extend([*expr]),
19 Gather { expr, idx, .. } => {
20 container.extend([*idx, *expr]);
21 },
22 SortBy { expr, by, .. } => {
23 container.extend(by.iter().cloned().rev());
24 container.extend([*expr]);
25 },
26 Filter { input, by } => {
27 container.extend([*by, *input]);
28 },
29 Agg(agg_e) => match agg_e.get_input() {
30 NodeInputs::Single(node) => container.extend([node]),
31 NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()),
32 NodeInputs::Leaf => {},
33 },
34 Ternary {
35 truthy,
36 falsy,
37 predicate,
38 } => {
39 container.extend([*predicate, *falsy, *truthy]);
40 },
41 AnonymousFunction { input, .. } | Function { input, .. } => {
42 container.extend(input.iter().rev().map(|e| e.node()))
43 },
44 Explode { expr: e, .. } => container.extend([*e]),
45 Window {
46 function,
47 partition_by,
48 order_by,
49 options: _,
50 } => {
51 if let Some((n, _)) = order_by {
52 container.extend([*n]);
53 }
54 container.extend(partition_by.iter().rev().cloned());
55 container.extend([*function]);
56 },
57 Eval {
58 expr,
59 evaluation,
60 variant: _,
61 } => {
62 _ = evaluation;
64 container.extend([*expr]);
65 },
66 Slice {
67 input,
68 offset,
69 length,
70 } => {
71 container.extend([*length, *offset, *input]);
72 },
73 }
74 }
75
76 pub fn replace_inputs(mut self, inputs: &[Node]) -> Self {
77 use AExpr::*;
78 let input = match &mut self {
79 Column(_) | Literal(_) | Len => return self,
80 Cast { expr, .. } => expr,
81 Explode { expr, .. } => expr,
82 BinaryExpr { left, right, .. } => {
83 *left = inputs[0];
84 *right = inputs[1];
85 return self;
86 },
87 Gather { expr, idx, .. } => {
88 *expr = inputs[0];
89 *idx = inputs[1];
90 return self;
91 },
92 Sort { expr, .. } => expr,
93 SortBy { expr, by, .. } => {
94 *expr = inputs[0];
95 by.clear();
96 by.extend_from_slice(&inputs[1..]);
97 return self;
98 },
99 Filter { input, by, .. } => {
100 *input = inputs[0];
101 *by = inputs[1];
102 return self;
103 },
104 Agg(a) => {
105 match a {
106 IRAggExpr::Quantile { expr, quantile, .. } => {
107 *expr = inputs[0];
108 *quantile = inputs[1];
109 },
110 _ => {
111 a.set_input(inputs[0]);
112 },
113 }
114 return self;
115 },
116 Ternary {
117 truthy,
118 falsy,
119 predicate,
120 } => {
121 *truthy = inputs[0];
122 *falsy = inputs[1];
123 *predicate = inputs[2];
124 return self;
125 },
126 AnonymousFunction { input, .. } | Function { input, .. } => {
127 assert_eq!(input.len(), inputs.len());
128 for (e, node) in input.iter_mut().zip(inputs.iter()) {
129 e.set_node(*node);
130 }
131 return self;
132 },
133 Eval {
134 expr,
135 evaluation,
136 variant: _,
137 } => {
138 *expr = inputs[0];
139 _ = evaluation; return self;
141 },
142 Slice {
143 input,
144 offset,
145 length,
146 } => {
147 *input = inputs[0];
148 *offset = inputs[1];
149 *length = inputs[2];
150 return self;
151 },
152 Window {
153 function,
154 partition_by,
155 order_by,
156 ..
157 } => {
158 let offset = order_by.is_some() as usize;
159 *function = inputs[0];
160 partition_by.clear();
161 partition_by.extend_from_slice(&inputs[1..inputs.len() - offset]);
162 if let Some((_, options)) = order_by {
163 *order_by = Some((*inputs.last().unwrap(), *options));
164 }
165 return self;
166 },
167 };
168 *input = inputs[0];
169 self
170 }
171}
172
173impl IRAggExpr {
174 pub fn get_input(&self) -> NodeInputs {
175 use IRAggExpr::*;
176 use NodeInputs::*;
177 match self {
178 Min { input, .. } => Single(*input),
179 Max { input, .. } => Single(*input),
180 Median(input) => Single(*input),
181 NUnique(input) => Single(*input),
182 First(input) => Single(*input),
183 Last(input) => Single(*input),
184 Mean(input) => Single(*input),
185 Implode(input) => Single(*input),
186 Quantile { expr, quantile, .. } => Many(vec![*expr, *quantile]),
187 Sum(input) => Single(*input),
188 Count(input, _) => Single(*input),
189 Std(input, _) => Single(*input),
190 Var(input, _) => Single(*input),
191 AggGroups(input) => Single(*input),
192 }
193 }
194 pub fn set_input(&mut self, input: Node) {
195 use IRAggExpr::*;
196 let node = match self {
197 Min { input, .. } => input,
198 Max { input, .. } => input,
199 Median(input) => input,
200 NUnique(input) => input,
201 First(input) => input,
202 Last(input) => input,
203 Mean(input) => input,
204 Implode(input) => input,
205 Quantile { expr, .. } => expr,
206 Sum(input) => input,
207 Count(input, _) => input,
208 Std(input, _) => input,
209 Var(input, _) => input,
210 AggGroups(input) => input,
211 };
212 *node = input;
213 }
214}
215
216pub enum NodeInputs {
217 Leaf,
218 Single(Node),
219 Many(Vec<Node>),
220}
221
222impl NodeInputs {
223 pub fn first(&self) -> Node {
224 match self {
225 NodeInputs::Single(node) => *node,
226 NodeInputs::Many(nodes) => nodes[0],
227 NodeInputs::Leaf => panic!(),
228 }
229 }
230}