Skip to main content

robin_sparkless/dataframe/
aggregations.rs

1//! GroupBy and aggregation operations.
2
3use super::DataFrame;
4use polars::prelude::{
5    col, len, lit, when, DataFrame as PlDataFrame, DataType, Expr, LazyGroupBy, NamedFrom,
6    PolarsError, Series,
7};
8
9/// GroupedData - represents a DataFrame grouped by certain columns.
10/// Similar to PySpark's GroupedData
11pub struct GroupedData {
12    // Underlying Polars DataFrame (before grouping). Used by some Python-only paths
13    // (e.g. grouped vectorized UDF execution). When the `pyo3` feature is not
14    // enabled this field is effectively unused, so we allow dead_code there.
15    #[cfg_attr(not(feature = "pyo3"), allow(dead_code))]
16    pub(crate) df: PlDataFrame,
17    pub(crate) lazy_grouped: LazyGroupBy,
18    pub(crate) grouping_cols: Vec<String>,
19    pub(crate) case_sensitive: bool,
20}
21
22impl GroupedData {
23    /// Count rows in each group
24    pub fn count(&self) -> Result<DataFrame, PolarsError> {
25        use polars::prelude::*;
26        let agg_expr = vec![len().alias("count")];
27        let lf = self.lazy_grouped.clone().agg(agg_expr);
28        let mut pl_df = lf.collect()?;
29        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
30        Ok(super::DataFrame::from_polars_with_options(
31            pl_df,
32            self.case_sensitive,
33        ))
34    }
35
36    /// Sum a column in each group
37    pub fn sum(&self, column: &str) -> Result<DataFrame, PolarsError> {
38        use polars::prelude::*;
39        let agg_expr = vec![col(column).sum().alias(format!("sum({column})"))];
40        let lf = self.lazy_grouped.clone().agg(agg_expr);
41        let mut pl_df = lf.collect()?;
42        let all_cols: Vec<String> = pl_df
43            .get_column_names()
44            .iter()
45            .map(|s| s.to_string())
46            .collect();
47        let grouping_cols: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
48        let mut reordered_cols: Vec<&str> = Vec::new();
49        for gc in &grouping_cols {
50            if all_cols.iter().any(|c| c == gc) {
51                reordered_cols.push(gc);
52            }
53        }
54        for col_name in &all_cols {
55            if !grouping_cols.iter().any(|gc| *gc == col_name) {
56                reordered_cols.push(col_name);
57            }
58        }
59        if !reordered_cols.is_empty() {
60            pl_df = pl_df.select(reordered_cols)?;
61        }
62        Ok(super::DataFrame::from_polars_with_options(
63            pl_df,
64            self.case_sensitive,
65        ))
66    }
67
68    /// Average (mean) of a column in each group
69    pub fn avg(&self, column: &str) -> Result<DataFrame, PolarsError> {
70        use polars::prelude::*;
71        let agg_expr = vec![col(column).mean().alias(format!("avg({column})"))];
72        let lf = self.lazy_grouped.clone().agg(agg_expr);
73        let mut pl_df = lf.collect()?;
74        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
75        Ok(super::DataFrame::from_polars_with_options(
76            pl_df,
77            self.case_sensitive,
78        ))
79    }
80
81    /// Minimum value of a column in each group
82    pub fn min(&self, column: &str) -> Result<DataFrame, PolarsError> {
83        use polars::prelude::*;
84        let agg_expr = vec![col(column).min().alias(format!("min({column})"))];
85        let lf = self.lazy_grouped.clone().agg(agg_expr);
86        let mut pl_df = lf.collect()?;
87        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
88        Ok(super::DataFrame::from_polars_with_options(
89            pl_df,
90            self.case_sensitive,
91        ))
92    }
93
94    /// Maximum value of a column in each group
95    pub fn max(&self, column: &str) -> Result<DataFrame, PolarsError> {
96        use polars::prelude::*;
97        let agg_expr = vec![col(column).max().alias(format!("max({column})"))];
98        let lf = self.lazy_grouped.clone().agg(agg_expr);
99        let mut pl_df = lf.collect()?;
100        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
101        Ok(super::DataFrame::from_polars_with_options(
102            pl_df,
103            self.case_sensitive,
104        ))
105    }
106
107    /// First value of a column in each group (order not guaranteed unless explicitly sorted).
108    pub fn first(&self, column: &str) -> Result<DataFrame, PolarsError> {
109        use polars::prelude::*;
110        let agg_expr = vec![col(column).first().alias(format!("first({column})"))];
111        let lf = self.lazy_grouped.clone().agg(agg_expr);
112        let mut pl_df = lf.collect()?;
113        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
114        Ok(super::DataFrame::from_polars_with_options(
115            pl_df,
116            self.case_sensitive,
117        ))
118    }
119
120    /// Last value of a column in each group (order not guaranteed unless explicitly sorted).
121    pub fn last(&self, column: &str) -> Result<DataFrame, PolarsError> {
122        use polars::prelude::*;
123        let agg_expr = vec![col(column).last().alias(format!("last({column})"))];
124        let lf = self.lazy_grouped.clone().agg(agg_expr);
125        let mut pl_df = lf.collect()?;
126        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
127        Ok(super::DataFrame::from_polars_with_options(
128            pl_df,
129            self.case_sensitive,
130        ))
131    }
132
133    /// Approximate count of distinct values in each group (uses n_unique; same as count_distinct for exact).
134    pub fn approx_count_distinct(&self, column: &str) -> Result<DataFrame, PolarsError> {
135        use polars::prelude::{col, DataType};
136        let agg_expr = vec![col(column)
137            .n_unique()
138            .cast(DataType::Int64)
139            .alias(format!("approx_count_distinct({column})"))];
140        let lf = self.lazy_grouped.clone().agg(agg_expr);
141        let mut pl_df = lf.collect()?;
142        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
143        Ok(super::DataFrame::from_polars_with_options(
144            pl_df,
145            self.case_sensitive,
146        ))
147    }
148
149    /// Any value from the group (PySpark any_value). Uses first value.
150    pub fn any_value(&self, column: &str) -> Result<DataFrame, PolarsError> {
151        use polars::prelude::*;
152        let agg_expr = vec![col(column).first().alias(format!("any_value({column})"))];
153        let lf = self.lazy_grouped.clone().agg(agg_expr);
154        let mut pl_df = lf.collect()?;
155        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
156        Ok(super::DataFrame::from_polars_with_options(
157            pl_df,
158            self.case_sensitive,
159        ))
160    }
161
162    /// Boolean AND across group (PySpark bool_and / every).
163    pub fn bool_and(&self, column: &str) -> Result<DataFrame, PolarsError> {
164        use polars::prelude::*;
165        let agg_expr = vec![col(column).all(true).alias(format!("bool_and({column})"))];
166        let lf = self.lazy_grouped.clone().agg(agg_expr);
167        let mut pl_df = lf.collect()?;
168        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
169        Ok(super::DataFrame::from_polars_with_options(
170            pl_df,
171            self.case_sensitive,
172        ))
173    }
174
175    /// Boolean OR across group (PySpark bool_or / some).
176    pub fn bool_or(&self, column: &str) -> Result<DataFrame, PolarsError> {
177        use polars::prelude::*;
178        let agg_expr = vec![col(column).any(true).alias(format!("bool_or({column})"))];
179        let lf = self.lazy_grouped.clone().agg(agg_expr);
180        let mut pl_df = lf.collect()?;
181        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
182        Ok(super::DataFrame::from_polars_with_options(
183            pl_df,
184            self.case_sensitive,
185        ))
186    }
187
188    /// Product of column values in each group (PySpark product).
189    pub fn product(&self, column: &str) -> Result<DataFrame, PolarsError> {
190        use polars::prelude::*;
191        let agg_expr = vec![col(column).product().alias(format!("product({column})"))];
192        let lf = self.lazy_grouped.clone().agg(agg_expr);
193        let mut pl_df = lf.collect()?;
194        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
195        Ok(super::DataFrame::from_polars_with_options(
196            pl_df,
197            self.case_sensitive,
198        ))
199    }
200
201    /// Collect column values into list per group (PySpark collect_list).
202    pub fn collect_list(&self, column: &str) -> Result<DataFrame, PolarsError> {
203        use polars::prelude::*;
204        let agg_expr = vec![col(column)
205            .implode()
206            .alias(format!("collect_list({column})"))];
207        let lf = self.lazy_grouped.clone().agg(agg_expr);
208        let mut pl_df = lf.collect()?;
209        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
210        Ok(super::DataFrame::from_polars_with_options(
211            pl_df,
212            self.case_sensitive,
213        ))
214    }
215
216    /// Collect distinct column values into list per group (PySpark collect_set).
217    pub fn collect_set(&self, column: &str) -> Result<DataFrame, PolarsError> {
218        use polars::prelude::*;
219        let agg_expr = vec![col(column)
220            .unique()
221            .implode()
222            .alias(format!("collect_set({column})"))];
223        let lf = self.lazy_grouped.clone().agg(agg_expr);
224        let mut pl_df = lf.collect()?;
225        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
226        Ok(super::DataFrame::from_polars_with_options(
227            pl_df,
228            self.case_sensitive,
229        ))
230    }
231
232    /// Count rows where condition column is true (PySpark count_if).
233    pub fn count_if(&self, column: &str) -> Result<DataFrame, PolarsError> {
234        use polars::prelude::*;
235        let agg_expr = vec![col(column)
236            .cast(DataType::Int64)
237            .sum()
238            .alias(format!("count_if({column})"))];
239        let lf = self.lazy_grouped.clone().agg(agg_expr);
240        let mut pl_df = lf.collect()?;
241        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
242        Ok(super::DataFrame::from_polars_with_options(
243            pl_df,
244            self.case_sensitive,
245        ))
246    }
247
248    /// Percentile of column (PySpark percentile). p in 0.0..=1.0.
249    pub fn percentile(&self, column: &str, p: f64) -> Result<DataFrame, PolarsError> {
250        use polars::prelude::*;
251        let agg_expr = vec![col(column)
252            .quantile(lit(p), QuantileMethod::Linear)
253            .alias(format!("percentile({column}, {p})"))];
254        let lf = self.lazy_grouped.clone().agg(agg_expr);
255        let mut pl_df = lf.collect()?;
256        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
257        Ok(super::DataFrame::from_polars_with_options(
258            pl_df,
259            self.case_sensitive,
260        ))
261    }
262
263    /// Value of value_col where ord_col is maximum (PySpark max_by).
264    pub fn max_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
265        use polars::prelude::*;
266        let st = as_struct(vec![
267            col(ord_col).alias("_ord"),
268            col(value_col).alias("_val"),
269        ]);
270        let agg_expr = vec![st
271            .sort(SortOptions::default().with_order_descending(true))
272            .first()
273            .struct_()
274            .field_by_name("_val")
275            .alias(format!("max_by({value_col}, {ord_col})"))];
276        let lf = self.lazy_grouped.clone().agg(agg_expr);
277        let mut pl_df = lf.collect()?;
278        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
279        Ok(super::DataFrame::from_polars_with_options(
280            pl_df,
281            self.case_sensitive,
282        ))
283    }
284
285    /// Value of value_col where ord_col is minimum (PySpark min_by).
286    pub fn min_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
287        use polars::prelude::*;
288        let st = as_struct(vec![
289            col(ord_col).alias("_ord"),
290            col(value_col).alias("_val"),
291        ]);
292        let agg_expr = vec![st
293            .sort(SortOptions::default())
294            .first()
295            .struct_()
296            .field_by_name("_val")
297            .alias(format!("min_by({value_col}, {ord_col})"))];
298        let lf = self.lazy_grouped.clone().agg(agg_expr);
299        let mut pl_df = lf.collect()?;
300        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
301        Ok(super::DataFrame::from_polars_with_options(
302            pl_df,
303            self.case_sensitive,
304        ))
305    }
306
307    /// Population covariance between two columns in each group (PySpark covar_pop).
308    pub fn covar_pop(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
309        use polars::prelude::DataType;
310        let c1 = col(col1).cast(DataType::Float64);
311        let c2 = col(col2).cast(DataType::Float64);
312        let n = len().cast(DataType::Float64);
313        let sum_ab = (c1.clone() * c2.clone()).sum();
314        let sum_a = col(col1).sum().cast(DataType::Float64);
315        let sum_b = col(col2).sum().cast(DataType::Float64);
316        let cov = (sum_ab - sum_a * sum_b / n.clone()) / n;
317        let agg_expr = vec![cov.alias(format!("covar_pop({col1}, {col2})"))];
318        let lf = self.lazy_grouped.clone().agg(agg_expr);
319        let mut pl_df = lf.collect()?;
320        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
321        Ok(super::DataFrame::from_polars_with_options(
322            pl_df,
323            self.case_sensitive,
324        ))
325    }
326
327    /// Sample covariance between two columns in each group (PySpark covar_samp). ddof=1.
328    pub fn covar_samp(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
329        use polars::prelude::DataType;
330        let c1 = col(col1).cast(DataType::Float64);
331        let c2 = col(col2).cast(DataType::Float64);
332        let n = len().cast(DataType::Float64);
333        let sum_ab = (c1.clone() * c2.clone()).sum();
334        let sum_a = col(col1).sum().cast(DataType::Float64);
335        let sum_b = col(col2).sum().cast(DataType::Float64);
336        let cov = when(len().gt(lit(1)))
337            .then((sum_ab - sum_a * sum_b / n.clone()) / (len() - lit(1)).cast(DataType::Float64))
338            .otherwise(lit(f64::NAN));
339        let agg_expr = vec![cov.alias(format!("covar_samp({col1}, {col2})"))];
340        let lf = self.lazy_grouped.clone().agg(agg_expr);
341        let mut pl_df = lf.collect()?;
342        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
343        Ok(super::DataFrame::from_polars_with_options(
344            pl_df,
345            self.case_sensitive,
346        ))
347    }
348
349    /// Pearson correlation between two columns in each group (PySpark corr).
350    pub fn corr(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
351        use polars::prelude::DataType;
352        let c1 = col(col1).cast(DataType::Float64);
353        let c2 = col(col2).cast(DataType::Float64);
354        let n = len().cast(DataType::Float64);
355        let n1 = (len() - lit(1)).cast(DataType::Float64);
356        let sum_ab = (c1.clone() * c2.clone()).sum();
357        let sum_a = col(col1).sum().cast(DataType::Float64);
358        let sum_b = col(col2).sum().cast(DataType::Float64);
359        let sum_a2 = (c1.clone() * c1).sum();
360        let sum_b2 = (c2.clone() * c2).sum();
361        let cov_samp = (sum_ab - sum_a.clone() * sum_b.clone() / n.clone()) / n1.clone();
362        let var_a = (sum_a2 - sum_a.clone() * sum_a / n.clone()) / n1.clone();
363        let var_b = (sum_b2 - sum_b.clone() * sum_b / n.clone()) / n1.clone();
364        let std_a = var_a.sqrt();
365        let std_b = var_b.sqrt();
366        let corr_expr = when(len().gt(lit(1)))
367            .then(cov_samp / (std_a * std_b))
368            .otherwise(lit(f64::NAN));
369        let agg_expr = vec![corr_expr.alias(format!("corr({col1}, {col2})"))];
370        let lf = self.lazy_grouped.clone().agg(agg_expr);
371        let mut pl_df = lf.collect()?;
372        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
373        Ok(super::DataFrame::from_polars_with_options(
374            pl_df,
375            self.case_sensitive,
376        ))
377    }
378
379    /// Regression count of (y, x) pairs where both non-null (PySpark regr_count).
380    pub fn regr_count(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
381        let agg_expr = vec![crate::functions::regr_count_expr(y_col, x_col)
382            .alias(format!("regr_count({y_col}, {x_col})"))];
383        let lf = self.lazy_grouped.clone().agg(agg_expr);
384        let mut pl_df = lf.collect()?;
385        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
386        Ok(super::DataFrame::from_polars_with_options(
387            pl_df,
388            self.case_sensitive,
389        ))
390    }
391
392    /// Regression average of x (PySpark regr_avgx).
393    pub fn regr_avgx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
394        let agg_expr = vec![crate::functions::regr_avgx_expr(y_col, x_col)
395            .alias(format!("regr_avgx({y_col}, {x_col})"))];
396        let lf = self.lazy_grouped.clone().agg(agg_expr);
397        let mut pl_df = lf.collect()?;
398        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
399        Ok(super::DataFrame::from_polars_with_options(
400            pl_df,
401            self.case_sensitive,
402        ))
403    }
404
405    /// Regression average of y (PySpark regr_avgy).
406    pub fn regr_avgy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
407        let agg_expr = vec![crate::functions::regr_avgy_expr(y_col, x_col)
408            .alias(format!("regr_avgy({y_col}, {x_col})"))];
409        let lf = self.lazy_grouped.clone().agg(agg_expr);
410        let mut pl_df = lf.collect()?;
411        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
412        Ok(super::DataFrame::from_polars_with_options(
413            pl_df,
414            self.case_sensitive,
415        ))
416    }
417
418    /// Regression slope (PySpark regr_slope).
419    pub fn regr_slope(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
420        let agg_expr = vec![crate::functions::regr_slope_expr(y_col, x_col)
421            .alias(format!("regr_slope({y_col}, {x_col})"))];
422        let lf = self.lazy_grouped.clone().agg(agg_expr);
423        let mut pl_df = lf.collect()?;
424        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
425        Ok(super::DataFrame::from_polars_with_options(
426            pl_df,
427            self.case_sensitive,
428        ))
429    }
430
431    /// Regression intercept (PySpark regr_intercept).
432    pub fn regr_intercept(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
433        let agg_expr = vec![crate::functions::regr_intercept_expr(y_col, x_col)
434            .alias(format!("regr_intercept({y_col}, {x_col})"))];
435        let lf = self.lazy_grouped.clone().agg(agg_expr);
436        let mut pl_df = lf.collect()?;
437        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
438        Ok(super::DataFrame::from_polars_with_options(
439            pl_df,
440            self.case_sensitive,
441        ))
442    }
443
444    /// Regression R-squared (PySpark regr_r2).
445    pub fn regr_r2(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
446        let agg_expr = vec![crate::functions::regr_r2_expr(y_col, x_col)
447            .alias(format!("regr_r2({y_col}, {x_col})"))];
448        let lf = self.lazy_grouped.clone().agg(agg_expr);
449        let mut pl_df = lf.collect()?;
450        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
451        Ok(super::DataFrame::from_polars_with_options(
452            pl_df,
453            self.case_sensitive,
454        ))
455    }
456
457    /// Regression sum (x - avg_x)^2 (PySpark regr_sxx).
458    pub fn regr_sxx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
459        let agg_expr = vec![crate::functions::regr_sxx_expr(y_col, x_col)
460            .alias(format!("regr_sxx({y_col}, {x_col})"))];
461        let lf = self.lazy_grouped.clone().agg(agg_expr);
462        let mut pl_df = lf.collect()?;
463        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
464        Ok(super::DataFrame::from_polars_with_options(
465            pl_df,
466            self.case_sensitive,
467        ))
468    }
469
470    /// Regression sum (y - avg_y)^2 (PySpark regr_syy).
471    pub fn regr_syy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
472        let agg_expr = vec![crate::functions::regr_syy_expr(y_col, x_col)
473            .alias(format!("regr_syy({y_col}, {x_col})"))];
474        let lf = self.lazy_grouped.clone().agg(agg_expr);
475        let mut pl_df = lf.collect()?;
476        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
477        Ok(super::DataFrame::from_polars_with_options(
478            pl_df,
479            self.case_sensitive,
480        ))
481    }
482
483    /// Regression sum (x - avg_x)(y - avg_y) (PySpark regr_sxy).
484    pub fn regr_sxy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
485        let agg_expr = vec![crate::functions::regr_sxy_expr(y_col, x_col)
486            .alias(format!("regr_sxy({y_col}, {x_col})"))];
487        let lf = self.lazy_grouped.clone().agg(agg_expr);
488        let mut pl_df = lf.collect()?;
489        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
490        Ok(super::DataFrame::from_polars_with_options(
491            pl_df,
492            self.case_sensitive,
493        ))
494    }
495
496    /// Kurtosis of a column in each group (PySpark kurtosis). Fisher definition, bias=true.
497    pub fn kurtosis(&self, column: &str) -> Result<DataFrame, PolarsError> {
498        use polars::prelude::*;
499        let agg_expr = vec![col(column)
500            .cast(DataType::Float64)
501            .kurtosis(true, true)
502            .alias(format!("kurtosis({column})"))];
503        let lf = self.lazy_grouped.clone().agg(agg_expr);
504        let mut pl_df = lf.collect()?;
505        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
506        Ok(super::DataFrame::from_polars_with_options(
507            pl_df,
508            self.case_sensitive,
509        ))
510    }
511
512    /// Skewness of a column in each group (PySpark skewness). bias=true.
513    pub fn skewness(&self, column: &str) -> Result<DataFrame, PolarsError> {
514        use polars::prelude::*;
515        let agg_expr = vec![col(column)
516            .cast(DataType::Float64)
517            .skew(true)
518            .alias(format!("skewness({column})"))];
519        let lf = self.lazy_grouped.clone().agg(agg_expr);
520        let mut pl_df = lf.collect()?;
521        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
522        Ok(super::DataFrame::from_polars_with_options(
523            pl_df,
524            self.case_sensitive,
525        ))
526    }
527
528    /// Apply multiple aggregations at once (generic agg method)
529    pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
530        let lf = self.lazy_grouped.clone().agg(aggregations);
531        let mut pl_df = lf.collect()?;
532        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
533        Ok(super::DataFrame::from_polars_with_options(
534            pl_df,
535            self.case_sensitive,
536        ))
537    }
538
539    /// Get grouping columns
540    pub fn grouping_columns(&self) -> &[String] {
541        &self.grouping_cols
542    }
543}
544
545/// Cube/rollup: multiple grouping sets then union (PySpark cube / rollup).
546pub struct CubeRollupData {
547    pub(super) df: PlDataFrame,
548    pub(super) grouping_cols: Vec<String>,
549    pub(super) case_sensitive: bool,
550    pub(super) is_cube: bool,
551}
552
553impl CubeRollupData {
554    /// Run aggregation on each grouping set and union results. Missing keys become null.
555    pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
556        use polars::prelude::*;
557        let subsets: Vec<Vec<String>> = if self.is_cube {
558            // All subsets of grouping_cols (2^n)
559            let n = self.grouping_cols.len();
560            (0..1 << n)
561                .map(|mask| {
562                    self.grouping_cols
563                        .iter()
564                        .enumerate()
565                        .filter(|(i, _)| (mask & (1 << i)) != 0)
566                        .map(|(_, c)| c.clone())
567                        .collect()
568                })
569                .collect()
570        } else {
571            // Prefixes: [all], [all-1], ..., []
572            (0..=self.grouping_cols.len())
573                .map(|len| self.grouping_cols[..len].to_vec())
574                .collect()
575        };
576
577        let schema = self.df.schema();
578        let mut parts: Vec<PlDataFrame> = Vec::with_capacity(subsets.len());
579        for subset in subsets {
580            if subset.is_empty() {
581                // Single row: no grouping keys, one row of aggregates over full table
582                let lf = self.df.clone().lazy().select(aggregations.clone());
583                let mut part = lf.collect()?;
584                let n = part.height();
585                for gc in &self.grouping_cols {
586                    let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
587                    let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
588                    part.with_column(null_series)?;
589                }
590                // Reorder to [grouping_cols..., agg_cols]
591                let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
592                for name in part.get_column_names() {
593                    if !self.grouping_cols.iter().any(|g| g == name) {
594                        order.push(name);
595                    }
596                }
597                part = part.select(order)?;
598                parts.push(part);
599            } else {
600                let grouped = self
601                    .df
602                    .clone()
603                    .lazy()
604                    .group_by(subset.iter().map(|s| col(s.as_str())).collect::<Vec<_>>());
605                let mut part = grouped.agg(aggregations.clone()).collect()?;
606                part = reorder_groupby_columns(&mut part, &subset)?;
607                let n = part.height();
608                for gc in &self.grouping_cols {
609                    if subset.iter().any(|s| s == gc) {
610                        continue;
611                    }
612                    let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
613                    let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
614                    part.with_column(null_series)?;
615                }
616                let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
617                for name in part.get_column_names() {
618                    if !self.grouping_cols.iter().any(|g| g == name) {
619                        order.push(name);
620                    }
621                }
622                part = part.select(order)?;
623                parts.push(part);
624            }
625        }
626
627        if parts.is_empty() {
628            return Ok(super::DataFrame::from_polars_with_options(
629                PlDataFrame::empty(),
630                self.case_sensitive,
631            ));
632        }
633        let first_schema = parts[0].schema();
634        let order: Vec<&str> = first_schema.iter_names().map(|s| s.as_str()).collect();
635        for p in parts.iter_mut().skip(1) {
636            *p = p.select(order.clone())?;
637        }
638        let lazy_frames: Vec<_> = parts.into_iter().map(|p| p.lazy()).collect();
639        let out = polars::prelude::concat(lazy_frames, UnionArgs::default())?.collect()?;
640        Ok(super::DataFrame::from_polars_with_options(
641            out,
642            self.case_sensitive,
643        ))
644    }
645}
646
647fn null_series_for_dtype(name: &str, n: usize, dtype: &DataType) -> Result<Series, PolarsError> {
648    let name = name.into();
649    let s = match dtype {
650        DataType::Int32 => Series::new(name, vec![None::<i32>; n]),
651        DataType::Int64 => Series::new(name, vec![None::<i64>; n]),
652        DataType::Float32 => Series::new(name, vec![None::<f32>; n]),
653        DataType::Float64 => Series::new(name, vec![None::<f64>; n]),
654        DataType::String => {
655            let v: Vec<Option<String>> = (0..n).map(|_| None).collect();
656            Series::new(name, v)
657        }
658        DataType::Boolean => Series::new(name, vec![None::<bool>; n]),
659        DataType::Date => Series::new(name, vec![None::<i32>; n]).cast(dtype)?,
660        DataType::Datetime(_, _) => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
661        _ => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
662    };
663    Ok(s)
664}
665
666/// Reorder columns after groupBy to match PySpark order: grouping columns first, then aggregations
667pub(super) fn reorder_groupby_columns(
668    pl_df: &mut PlDataFrame,
669    grouping_cols: &[String],
670) -> Result<PlDataFrame, PolarsError> {
671    let all_cols: Vec<String> = pl_df
672        .get_column_names()
673        .iter()
674        .map(|s| s.to_string())
675        .collect();
676    let mut reordered_cols: Vec<&str> = Vec::new();
677    for gc in grouping_cols {
678        if all_cols.iter().any(|c| c == gc) {
679            reordered_cols.push(gc);
680        }
681    }
682    for col_name in &all_cols {
683        if !grouping_cols.iter().any(|gc| gc == col_name) {
684            reordered_cols.push(col_name);
685        }
686    }
687    if !reordered_cols.is_empty() && reordered_cols.len() == all_cols.len() {
688        pl_df.select(reordered_cols)
689    } else {
690        Ok(pl_df.clone())
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use crate::{DataFrame, SparkSession};
697
698    fn test_df() -> DataFrame {
699        let spark = SparkSession::builder()
700            .app_name("agg_tests")
701            .get_or_create();
702        let tuples = vec![
703            (1i64, 10i64, "a".to_string()),
704            (1i64, 20i64, "a".to_string()),
705            (2i64, 30i64, "b".to_string()),
706        ];
707        spark
708            .create_dataframe(tuples, vec!["k", "v", "label"])
709            .unwrap()
710    }
711
712    #[test]
713    fn group_by_count_single_group() {
714        let df = test_df();
715        let grouped = df.group_by(vec!["k"]).unwrap();
716        let out = grouped.count().unwrap();
717        assert_eq!(out.count().unwrap(), 2);
718        let cols = out.columns().unwrap();
719        assert!(cols.contains(&"k".to_string()));
720        assert!(cols.contains(&"count".to_string()));
721    }
722
723    #[test]
724    fn group_by_sum() {
725        let df = test_df();
726        let grouped = df.group_by(vec!["k"]).unwrap();
727        let out = grouped.sum("v").unwrap();
728        assert_eq!(out.count().unwrap(), 2);
729        let cols = out.columns().unwrap();
730        assert!(cols.iter().any(|c| c.starts_with("sum(")));
731    }
732
733    #[test]
734    fn group_by_empty_groups() {
735        let spark = SparkSession::builder()
736            .app_name("agg_tests")
737            .get_or_create();
738        let tuples: Vec<(i64, i64, String)> = vec![];
739        let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
740        let grouped = df.group_by(vec!["a"]).unwrap();
741        let out = grouped.count().unwrap();
742        assert_eq!(out.count().unwrap(), 0);
743    }
744
745    #[test]
746    fn group_by_agg_multi() {
747        use polars::prelude::*;
748        let df = test_df();
749        let grouped = df.group_by(vec!["k"]).unwrap();
750        let out = grouped
751            .agg(vec![len().alias("cnt"), col("v").sum().alias("total")])
752            .unwrap();
753        assert_eq!(out.count().unwrap(), 2);
754        let cols = out.columns().unwrap();
755        assert!(cols.contains(&"k".to_string()));
756        assert!(cols.contains(&"cnt".to_string()));
757        assert!(cols.contains(&"total".to_string()));
758    }
759}