Skip to main content

polars_expr/reduce/
convert.rs

1// use polars_core::error::feature_gated;
2use polars_plan::prelude::*;
3use polars_utils::arena::{Arena, Node};
4
5use super::*;
6use crate::reduce::any_all::{new_all_reduction, new_any_reduction};
7#[cfg(feature = "approx_unique")]
8use crate::reduce::approx_n_unique::new_approx_n_unique_reduction;
9#[cfg(feature = "bitwise")]
10use crate::reduce::bitwise::{
11    new_bitwise_and_reduction, new_bitwise_or_reduction, new_bitwise_xor_reduction,
12};
13use crate::reduce::count::{CountReduce, NullCountReduce};
14#[cfg(feature = "cov")]
15use crate::reduce::cov::{new_cov_reduction, new_pearson_corr_reduction};
16use crate::reduce::first_last::{new_first_reduction, new_item_reduction, new_last_reduction};
17use crate::reduce::first_last_nonnull::{new_first_nonnull_reduction, new_last_nonnull_reduction};
18use crate::reduce::has_nulls::HasNullsReduce;
19use crate::reduce::implode::new_unordered_implode_reduction;
20use crate::reduce::is_empty::IsEmptyReduce;
21use crate::reduce::mean::new_mean_reduction;
22use crate::reduce::min_max::{new_max_reduction, new_min_reduction};
23use crate::reduce::min_max_by::{new_max_by_reduction, new_min_by_reduction};
24#[cfg(feature = "moment")]
25use crate::reduce::skew_kurtosis::{new_kurtosis_reduction, new_skew_reduction};
26use crate::reduce::sum::new_sum_reduction;
27use crate::reduce::var_std::new_var_std_reduction;
28
29/// Converts a node into a reduction + its associated selector expression.
30pub fn into_reduction(
31    node: Node,
32    expr_arena: &mut Arena<AExpr>,
33    schema: &Schema,
34    is_aggregation_context: bool,
35) -> PolarsResult<(Box<dyn GroupedReduction>, Vec<Node>)> {
36    let get_dt = |node| {
37        expr_arena
38            .get(node)
39            .to_dtype(&ToFieldContext::new(expr_arena, schema))?
40            .materialize_unknown(false)
41    };
42    let (gr, in_node) = match expr_arena.get(node) {
43        AExpr::Agg(agg) => match agg {
44            IRAggExpr::Sum(input) => (new_sum_reduction(get_dt(*input)?)?, *input),
45            IRAggExpr::Mean(input) => (new_mean_reduction(get_dt(*input)?)?, *input),
46            IRAggExpr::Min {
47                propagate_nans,
48                input,
49            } => (new_min_reduction(get_dt(*input)?, *propagate_nans)?, *input),
50            IRAggExpr::Max {
51                propagate_nans,
52                input,
53            } => (new_max_reduction(get_dt(*input)?, *propagate_nans)?, *input),
54            IRAggExpr::Var(input, ddof) => (
55                new_var_std_reduction(get_dt(*input)?, false, *ddof)?,
56                *input,
57            ),
58            IRAggExpr::Std(input, ddof) => {
59                (new_var_std_reduction(get_dt(*input)?, true, *ddof)?, *input)
60            },
61            IRAggExpr::First(input) => (new_first_reduction(get_dt(*input)?), *input),
62            IRAggExpr::FirstNonNull(input) => {
63                (new_first_nonnull_reduction(get_dt(*input)?), *input)
64            },
65            IRAggExpr::Last(input) => (new_last_reduction(get_dt(*input)?), *input),
66            IRAggExpr::LastNonNull(input) => (new_last_nonnull_reduction(get_dt(*input)?), *input),
67            IRAggExpr::Item { input, allow_empty } => {
68                (new_item_reduction(get_dt(*input)?, *allow_empty), *input)
69            },
70            IRAggExpr::Count {
71                input,
72                include_nulls,
73            } => {
74                let count = Box::new(CountReduce::new(*include_nulls)) as Box<_>;
75                (count, *input)
76            },
77            IRAggExpr::Implode {
78                input,
79                maintain_order: false,
80            } => (new_unordered_implode_reduction(get_dt(*input)?), *input),
81            IRAggExpr::Median(_) => todo!(),
82            IRAggExpr::NUnique(_) => todo!(),
83            IRAggExpr::Implode { .. } => todo!(),
84            IRAggExpr::AggGroups(_) => todo!(),
85        },
86        AExpr::Len => {
87            if let Some(first_column) = schema.iter_names().next() {
88                let out: Box<dyn GroupedReduction> = Box::new(CountReduce::new(true));
89                let expr = expr_arena.add(AExpr::Column(first_column.as_str().into()));
90
91                (out, expr)
92            } else {
93                // Support len aggregation on 0-width morsels.
94                // Notes:
95                // * We do this instead of projecting a scalar, because scalar literals don't
96                //   project to the height of the DataFrame (in the PhysicalExpr impl).
97                // * This approach is not sound for `update_groups()`, but currently that case is
98                //   not hit (it would need group-by -> len on empty morsels).
99                polars_ensure!(
100                    !is_aggregation_context,
101                    ComputeError:
102                    "not implemented: len() of groups with no columns"
103                );
104
105                let out: Box<dyn GroupedReduction> = new_sum_reduction(DataType::IDX_DTYPE)?;
106                let expr = expr_arena.add(AExpr::Len);
107
108                (out, expr)
109            }
110        },
111
112        AExpr::Function {
113            input: inner_exprs,
114            function: IRFunctionExpr::NullCount,
115            options: _,
116        } => {
117            assert!(inner_exprs.len() == 1);
118            let input = inner_exprs[0].node();
119            let count = Box::new(NullCountReduce::new()) as Box<_>;
120            (count, input)
121        },
122
123        #[cfg(feature = "approx_unique")]
124        AExpr::Function {
125            input: inner_exprs,
126            function: IRFunctionExpr::ApproxNUnique,
127            options: _,
128        } => {
129            assert!(inner_exprs.len() == 1);
130            let input = inner_exprs[0].node();
131            let out = new_approx_n_unique_reduction(get_dt(input)?)?;
132            (out, input)
133        },
134
135        #[cfg(feature = "bitwise")]
136        AExpr::Function {
137            input: inner_exprs,
138            function: IRFunctionExpr::Bitwise(inner_fn),
139            options: _,
140        } => {
141            assert!(inner_exprs.len() == 1);
142            let input = inner_exprs[0].node();
143            match inner_fn {
144                IRBitwiseFunction::And => (new_bitwise_and_reduction(get_dt(input)?), input),
145                IRBitwiseFunction::Or => (new_bitwise_or_reduction(get_dt(input)?), input),
146                IRBitwiseFunction::Xor => (new_bitwise_xor_reduction(get_dt(input)?), input),
147                _ => unreachable!(),
148            }
149        },
150
151        AExpr::Function {
152            input: inner_exprs,
153            function: IRFunctionExpr::Boolean(inner_fn),
154            options: _,
155        } => {
156            assert!(inner_exprs.len() == 1);
157            let input = inner_exprs[0].node();
158            match inner_fn {
159                IRBooleanFunction::Any { ignore_nulls } => {
160                    (new_any_reduction(*ignore_nulls), input)
161                },
162                IRBooleanFunction::All { ignore_nulls } => {
163                    (new_all_reduction(*ignore_nulls), input)
164                },
165                IRBooleanFunction::IsEmpty { ignore_nulls } => {
166                    let is_empty = Box::new(IsEmptyReduce::new(*ignore_nulls)) as Box<_>;
167                    (is_empty, input)
168                },
169                IRBooleanFunction::HasNulls => (Box::new(HasNullsReduce::new()) as Box<_>, input),
170                _ => unreachable!(),
171            }
172        },
173
174        AExpr::Function {
175            input: inner_exprs,
176            function: IRFunctionExpr::MinBy,
177            options: _,
178        } => {
179            assert!(inner_exprs.len() == 2);
180            let input = inner_exprs[0].node();
181            let mut by = inner_exprs[1].node();
182            let input_dtype = get_dt(input)?;
183            let mut by_dtype = get_dt(by)?;
184            if by_dtype.is_nested() {
185                by = AExprBuilder::row_encode(
186                    vec![inner_exprs[1].clone()],
187                    vec![by_dtype.clone()],
188                    RowEncodingVariant::Ordered {
189                        descending: None,
190                        nulls_last: None,
191                        broadcast_nulls: None,
192                    },
193                    expr_arena,
194                )
195                .node();
196                by_dtype = DataType::BinaryOffset;
197            }
198            let gr = new_min_by_reduction(input_dtype, by_dtype)?;
199            return Ok((gr, vec![input, by]));
200        },
201
202        AExpr::Function {
203            input: inner_exprs,
204            function: IRFunctionExpr::MaxBy,
205            options: _,
206        } => {
207            assert!(inner_exprs.len() == 2);
208            let input = inner_exprs[0].node();
209            let mut by = inner_exprs[1].node();
210            let input_dtype = get_dt(input)?;
211            let mut by_dtype = get_dt(by)?;
212            if by_dtype.is_nested() {
213                by = AExprBuilder::row_encode(
214                    vec![inner_exprs[1].clone()],
215                    vec![by_dtype.clone()],
216                    RowEncodingVariant::Ordered {
217                        descending: None,
218                        nulls_last: None,
219                        broadcast_nulls: None,
220                    },
221                    expr_arena,
222                )
223                .node();
224                by_dtype = DataType::BinaryOffset;
225            }
226            let gr = new_max_by_reduction(input_dtype, by_dtype)?;
227            return Ok((gr, vec![input, by]));
228        },
229
230        AExpr::AnonymousAgg {
231            input: inner_exprs,
232            fmt_str: _,
233            function,
234        } => {
235            let ann_agg = function.materialize()?;
236            assert!(inner_exprs.len() == 1);
237            let input = inner_exprs[0].node();
238            let reduction = ann_agg.as_any();
239            let reduction = reduction
240                .downcast_ref::<Box<dyn GroupedReduction>>()
241                .unwrap();
242            (reduction.new_empty(), input)
243        },
244
245        #[cfg(feature = "cov")]
246        AExpr::Function {
247            input: inner_exprs,
248            function:
249                IRFunctionExpr::Correlation {
250                    method:
251                        method @ (polars_plan::plans::IRCorrelationMethod::Covariance(_)
252                        | polars_plan::plans::IRCorrelationMethod::Pearson),
253                },
254            options: _,
255        } => {
256            use polars_plan::plans::IRCorrelationMethod;
257            assert!(inner_exprs.len() == 2);
258            let input_x = inner_exprs[0].node();
259            let input_y = inner_exprs[1].node();
260            let dtype_x = get_dt(input_x)?;
261            let dtype_y = get_dt(input_y)?;
262            let gr: Box<dyn GroupedReduction> = match method {
263                IRCorrelationMethod::Covariance(ddof) => {
264                    new_cov_reduction(dtype_x, dtype_y, *ddof)?
265                },
266                IRCorrelationMethod::Pearson => new_pearson_corr_reduction(dtype_x, dtype_y)?,
267                _ => unreachable!(),
268            };
269            return Ok((gr, vec![input_x, input_y]));
270        },
271
272        #[cfg(feature = "moment")]
273        AExpr::Function {
274            input: inner_exprs,
275            function: IRFunctionExpr::Skew(bias),
276            options: _,
277        } => {
278            assert!(inner_exprs.len() == 1);
279            let input = inner_exprs[0].node();
280            let out = new_skew_reduction(get_dt(input)?, *bias)?;
281            (out, input)
282        },
283
284        #[cfg(feature = "moment")]
285        AExpr::Function {
286            input: inner_exprs,
287            function: IRFunctionExpr::Kurtosis(fisher, bias),
288            options: _,
289        } => {
290            assert!(inner_exprs.len() == 1);
291            let input = inner_exprs[0].node();
292            let out = new_kurtosis_reduction(get_dt(input)?, *fisher, *bias)?;
293            (out, input)
294        },
295
296        _ => unreachable!(),
297    };
298    Ok((gr, vec![in_node]))
299}