1use polars_plan::prelude::*;
3use polars_utils::arena::{Arena, Node};
4
5use super::*;
6use crate::reduce::any_all::{new_all_reduction, new_any_reduction};
7#[cfg(feature = "approx_unique")]
8use crate::reduce::approx_n_unique::new_approx_n_unique_reduction;
9#[cfg(feature = "bitwise")]
10use crate::reduce::bitwise::{
11 new_bitwise_and_reduction, new_bitwise_or_reduction, new_bitwise_xor_reduction,
12};
13use crate::reduce::count::{CountReduce, NullCountReduce};
14#[cfg(feature = "cov")]
15use crate::reduce::cov::{new_cov_reduction, new_pearson_corr_reduction};
16use crate::reduce::first_last::{new_first_reduction, new_item_reduction, new_last_reduction};
17use crate::reduce::first_last_nonnull::{new_first_nonnull_reduction, new_last_nonnull_reduction};
18use crate::reduce::has_nulls::HasNullsReduce;
19use crate::reduce::implode::new_unordered_implode_reduction;
20use crate::reduce::is_empty::IsEmptyReduce;
21use crate::reduce::mean::new_mean_reduction;
22use crate::reduce::min_max::{new_max_reduction, new_min_reduction};
23use crate::reduce::min_max_by::{new_max_by_reduction, new_min_by_reduction};
24#[cfg(feature = "moment")]
25use crate::reduce::skew_kurtosis::{new_kurtosis_reduction, new_skew_reduction};
26use crate::reduce::sum::new_sum_reduction;
27use crate::reduce::var_std::new_var_std_reduction;
28
29pub fn into_reduction(
31 node: Node,
32 expr_arena: &mut Arena<AExpr>,
33 schema: &Schema,
34 is_aggregation_context: bool,
35) -> PolarsResult<(Box<dyn GroupedReduction>, Vec<Node>)> {
36 let get_dt = |node| {
37 expr_arena
38 .get(node)
39 .to_dtype(&ToFieldContext::new(expr_arena, schema))?
40 .materialize_unknown(false)
41 };
42 let (gr, in_node) = match expr_arena.get(node) {
43 AExpr::Agg(agg) => match agg {
44 IRAggExpr::Sum(input) => (new_sum_reduction(get_dt(*input)?)?, *input),
45 IRAggExpr::Mean(input) => (new_mean_reduction(get_dt(*input)?)?, *input),
46 IRAggExpr::Min {
47 propagate_nans,
48 input,
49 } => (new_min_reduction(get_dt(*input)?, *propagate_nans)?, *input),
50 IRAggExpr::Max {
51 propagate_nans,
52 input,
53 } => (new_max_reduction(get_dt(*input)?, *propagate_nans)?, *input),
54 IRAggExpr::Var(input, ddof) => (
55 new_var_std_reduction(get_dt(*input)?, false, *ddof)?,
56 *input,
57 ),
58 IRAggExpr::Std(input, ddof) => {
59 (new_var_std_reduction(get_dt(*input)?, true, *ddof)?, *input)
60 },
61 IRAggExpr::First(input) => (new_first_reduction(get_dt(*input)?), *input),
62 IRAggExpr::FirstNonNull(input) => {
63 (new_first_nonnull_reduction(get_dt(*input)?), *input)
64 },
65 IRAggExpr::Last(input) => (new_last_reduction(get_dt(*input)?), *input),
66 IRAggExpr::LastNonNull(input) => (new_last_nonnull_reduction(get_dt(*input)?), *input),
67 IRAggExpr::Item { input, allow_empty } => {
68 (new_item_reduction(get_dt(*input)?, *allow_empty), *input)
69 },
70 IRAggExpr::Count {
71 input,
72 include_nulls,
73 } => {
74 let count = Box::new(CountReduce::new(*include_nulls)) as Box<_>;
75 (count, *input)
76 },
77 IRAggExpr::Implode {
78 input,
79 maintain_order: false,
80 } => (new_unordered_implode_reduction(get_dt(*input)?), *input),
81 IRAggExpr::Median(_) => todo!(),
82 IRAggExpr::NUnique(_) => todo!(),
83 IRAggExpr::Implode { .. } => todo!(),
84 IRAggExpr::AggGroups(_) => todo!(),
85 },
86 AExpr::Len => {
87 if let Some(first_column) = schema.iter_names().next() {
88 let out: Box<dyn GroupedReduction> = Box::new(CountReduce::new(true));
89 let expr = expr_arena.add(AExpr::Column(first_column.as_str().into()));
90
91 (out, expr)
92 } else {
93 polars_ensure!(
100 !is_aggregation_context,
101 ComputeError:
102 "not implemented: len() of groups with no columns"
103 );
104
105 let out: Box<dyn GroupedReduction> = new_sum_reduction(DataType::IDX_DTYPE)?;
106 let expr = expr_arena.add(AExpr::Len);
107
108 (out, expr)
109 }
110 },
111
112 AExpr::Function {
113 input: inner_exprs,
114 function: IRFunctionExpr::NullCount,
115 options: _,
116 } => {
117 assert!(inner_exprs.len() == 1);
118 let input = inner_exprs[0].node();
119 let count = Box::new(NullCountReduce::new()) as Box<_>;
120 (count, input)
121 },
122
123 #[cfg(feature = "approx_unique")]
124 AExpr::Function {
125 input: inner_exprs,
126 function: IRFunctionExpr::ApproxNUnique,
127 options: _,
128 } => {
129 assert!(inner_exprs.len() == 1);
130 let input = inner_exprs[0].node();
131 let out = new_approx_n_unique_reduction(get_dt(input)?)?;
132 (out, input)
133 },
134
135 #[cfg(feature = "bitwise")]
136 AExpr::Function {
137 input: inner_exprs,
138 function: IRFunctionExpr::Bitwise(inner_fn),
139 options: _,
140 } => {
141 assert!(inner_exprs.len() == 1);
142 let input = inner_exprs[0].node();
143 match inner_fn {
144 IRBitwiseFunction::And => (new_bitwise_and_reduction(get_dt(input)?), input),
145 IRBitwiseFunction::Or => (new_bitwise_or_reduction(get_dt(input)?), input),
146 IRBitwiseFunction::Xor => (new_bitwise_xor_reduction(get_dt(input)?), input),
147 _ => unreachable!(),
148 }
149 },
150
151 AExpr::Function {
152 input: inner_exprs,
153 function: IRFunctionExpr::Boolean(inner_fn),
154 options: _,
155 } => {
156 assert!(inner_exprs.len() == 1);
157 let input = inner_exprs[0].node();
158 match inner_fn {
159 IRBooleanFunction::Any { ignore_nulls } => {
160 (new_any_reduction(*ignore_nulls), input)
161 },
162 IRBooleanFunction::All { ignore_nulls } => {
163 (new_all_reduction(*ignore_nulls), input)
164 },
165 IRBooleanFunction::IsEmpty { ignore_nulls } => {
166 let is_empty = Box::new(IsEmptyReduce::new(*ignore_nulls)) as Box<_>;
167 (is_empty, input)
168 },
169 IRBooleanFunction::HasNulls => (Box::new(HasNullsReduce::new()) as Box<_>, input),
170 _ => unreachable!(),
171 }
172 },
173
174 AExpr::Function {
175 input: inner_exprs,
176 function: IRFunctionExpr::MinBy,
177 options: _,
178 } => {
179 assert!(inner_exprs.len() == 2);
180 let input = inner_exprs[0].node();
181 let mut by = inner_exprs[1].node();
182 let input_dtype = get_dt(input)?;
183 let mut by_dtype = get_dt(by)?;
184 if by_dtype.is_nested() {
185 by = AExprBuilder::row_encode(
186 vec![inner_exprs[1].clone()],
187 vec![by_dtype.clone()],
188 RowEncodingVariant::Ordered {
189 descending: None,
190 nulls_last: None,
191 broadcast_nulls: None,
192 },
193 expr_arena,
194 )
195 .node();
196 by_dtype = DataType::BinaryOffset;
197 }
198 let gr = new_min_by_reduction(input_dtype, by_dtype)?;
199 return Ok((gr, vec![input, by]));
200 },
201
202 AExpr::Function {
203 input: inner_exprs,
204 function: IRFunctionExpr::MaxBy,
205 options: _,
206 } => {
207 assert!(inner_exprs.len() == 2);
208 let input = inner_exprs[0].node();
209 let mut by = inner_exprs[1].node();
210 let input_dtype = get_dt(input)?;
211 let mut by_dtype = get_dt(by)?;
212 if by_dtype.is_nested() {
213 by = AExprBuilder::row_encode(
214 vec![inner_exprs[1].clone()],
215 vec![by_dtype.clone()],
216 RowEncodingVariant::Ordered {
217 descending: None,
218 nulls_last: None,
219 broadcast_nulls: None,
220 },
221 expr_arena,
222 )
223 .node();
224 by_dtype = DataType::BinaryOffset;
225 }
226 let gr = new_max_by_reduction(input_dtype, by_dtype)?;
227 return Ok((gr, vec![input, by]));
228 },
229
230 AExpr::AnonymousAgg {
231 input: inner_exprs,
232 fmt_str: _,
233 function,
234 } => {
235 let ann_agg = function.materialize()?;
236 assert!(inner_exprs.len() == 1);
237 let input = inner_exprs[0].node();
238 let reduction = ann_agg.as_any();
239 let reduction = reduction
240 .downcast_ref::<Box<dyn GroupedReduction>>()
241 .unwrap();
242 (reduction.new_empty(), input)
243 },
244
245 #[cfg(feature = "cov")]
246 AExpr::Function {
247 input: inner_exprs,
248 function:
249 IRFunctionExpr::Correlation {
250 method:
251 method @ (polars_plan::plans::IRCorrelationMethod::Covariance(_)
252 | polars_plan::plans::IRCorrelationMethod::Pearson),
253 },
254 options: _,
255 } => {
256 use polars_plan::plans::IRCorrelationMethod;
257 assert!(inner_exprs.len() == 2);
258 let input_x = inner_exprs[0].node();
259 let input_y = inner_exprs[1].node();
260 let dtype_x = get_dt(input_x)?;
261 let dtype_y = get_dt(input_y)?;
262 let gr: Box<dyn GroupedReduction> = match method {
263 IRCorrelationMethod::Covariance(ddof) => {
264 new_cov_reduction(dtype_x, dtype_y, *ddof)?
265 },
266 IRCorrelationMethod::Pearson => new_pearson_corr_reduction(dtype_x, dtype_y)?,
267 _ => unreachable!(),
268 };
269 return Ok((gr, vec![input_x, input_y]));
270 },
271
272 #[cfg(feature = "moment")]
273 AExpr::Function {
274 input: inner_exprs,
275 function: IRFunctionExpr::Skew(bias),
276 options: _,
277 } => {
278 assert!(inner_exprs.len() == 1);
279 let input = inner_exprs[0].node();
280 let out = new_skew_reduction(get_dt(input)?, *bias)?;
281 (out, input)
282 },
283
284 #[cfg(feature = "moment")]
285 AExpr::Function {
286 input: inner_exprs,
287 function: IRFunctionExpr::Kurtosis(fisher, bias),
288 options: _,
289 } => {
290 assert!(inner_exprs.len() == 1);
291 let input = inner_exprs[0].node();
292 let out = new_kurtosis_reduction(get_dt(input)?, *fisher, *bias)?;
293 (out, input)
294 },
295
296 _ => unreachable!(),
297 };
298 Ok((gr, vec![in_node]))
299}