Skip to main content

robin_sparkless/dataframe/
aggregations.rs

1//! GroupBy and aggregation operations.
2
3use super::DataFrame;
4use polars::prelude::{
5    col, len, lit, when, DataFrame as PlDataFrame, DataType, Expr, LazyGroupBy, NamedFrom,
6    PolarsError, Series,
7};
8
9/// GroupedData - represents a DataFrame grouped by certain columns.
10/// Similar to PySpark's GroupedData
11pub struct GroupedData {
12    pub(super) lazy_grouped: LazyGroupBy,
13    pub(super) grouping_cols: Vec<String>,
14    pub(super) case_sensitive: bool,
15}
16
17impl GroupedData {
18    /// Count rows in each group
19    pub fn count(&self) -> Result<DataFrame, PolarsError> {
20        use polars::prelude::*;
21        let agg_expr = vec![len().alias("count")];
22        let lf = self.lazy_grouped.clone().agg(agg_expr);
23        let mut pl_df = lf.collect()?;
24        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
25        Ok(super::DataFrame::from_polars_with_options(
26            pl_df,
27            self.case_sensitive,
28        ))
29    }
30
31    /// Sum a column in each group
32    pub fn sum(&self, column: &str) -> Result<DataFrame, PolarsError> {
33        use polars::prelude::*;
34        let agg_expr = vec![col(column).sum().alias(format!("sum({column})"))];
35        let lf = self.lazy_grouped.clone().agg(agg_expr);
36        let mut pl_df = lf.collect()?;
37        let all_cols: Vec<String> = pl_df
38            .get_column_names()
39            .iter()
40            .map(|s| s.to_string())
41            .collect();
42        let grouping_cols: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
43        let mut reordered_cols: Vec<&str> = Vec::new();
44        for gc in &grouping_cols {
45            if all_cols.iter().any(|c| c == gc) {
46                reordered_cols.push(gc);
47            }
48        }
49        for col_name in &all_cols {
50            if !grouping_cols.iter().any(|gc| *gc == col_name) {
51                reordered_cols.push(col_name);
52            }
53        }
54        if !reordered_cols.is_empty() {
55            pl_df = pl_df.select(reordered_cols)?;
56        }
57        Ok(super::DataFrame::from_polars_with_options(
58            pl_df,
59            self.case_sensitive,
60        ))
61    }
62
63    /// Average (mean) of a column in each group
64    pub fn avg(&self, column: &str) -> Result<DataFrame, PolarsError> {
65        use polars::prelude::*;
66        let agg_expr = vec![col(column).mean().alias(format!("avg({column})"))];
67        let lf = self.lazy_grouped.clone().agg(agg_expr);
68        let mut pl_df = lf.collect()?;
69        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
70        Ok(super::DataFrame::from_polars_with_options(
71            pl_df,
72            self.case_sensitive,
73        ))
74    }
75
76    /// Minimum value of a column in each group
77    pub fn min(&self, column: &str) -> Result<DataFrame, PolarsError> {
78        use polars::prelude::*;
79        let agg_expr = vec![col(column).min().alias(format!("min({column})"))];
80        let lf = self.lazy_grouped.clone().agg(agg_expr);
81        let mut pl_df = lf.collect()?;
82        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
83        Ok(super::DataFrame::from_polars_with_options(
84            pl_df,
85            self.case_sensitive,
86        ))
87    }
88
89    /// Maximum value of a column in each group
90    pub fn max(&self, column: &str) -> Result<DataFrame, PolarsError> {
91        use polars::prelude::*;
92        let agg_expr = vec![col(column).max().alias(format!("max({column})"))];
93        let lf = self.lazy_grouped.clone().agg(agg_expr);
94        let mut pl_df = lf.collect()?;
95        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
96        Ok(super::DataFrame::from_polars_with_options(
97            pl_df,
98            self.case_sensitive,
99        ))
100    }
101
102    /// First value of a column in each group (order not guaranteed unless explicitly sorted).
103    pub fn first(&self, column: &str) -> Result<DataFrame, PolarsError> {
104        use polars::prelude::*;
105        let agg_expr = vec![col(column).first().alias(format!("first({column})"))];
106        let lf = self.lazy_grouped.clone().agg(agg_expr);
107        let mut pl_df = lf.collect()?;
108        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
109        Ok(super::DataFrame::from_polars_with_options(
110            pl_df,
111            self.case_sensitive,
112        ))
113    }
114
115    /// Last value of a column in each group (order not guaranteed unless explicitly sorted).
116    pub fn last(&self, column: &str) -> Result<DataFrame, PolarsError> {
117        use polars::prelude::*;
118        let agg_expr = vec![col(column).last().alias(format!("last({column})"))];
119        let lf = self.lazy_grouped.clone().agg(agg_expr);
120        let mut pl_df = lf.collect()?;
121        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
122        Ok(super::DataFrame::from_polars_with_options(
123            pl_df,
124            self.case_sensitive,
125        ))
126    }
127
128    /// Approximate count of distinct values in each group (uses n_unique; same as count_distinct for exact).
129    pub fn approx_count_distinct(&self, column: &str) -> Result<DataFrame, PolarsError> {
130        use polars::prelude::{col, DataType};
131        let agg_expr = vec![col(column)
132            .n_unique()
133            .cast(DataType::Int64)
134            .alias(format!("approx_count_distinct({column})"))];
135        let lf = self.lazy_grouped.clone().agg(agg_expr);
136        let mut pl_df = lf.collect()?;
137        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
138        Ok(super::DataFrame::from_polars_with_options(
139            pl_df,
140            self.case_sensitive,
141        ))
142    }
143
144    /// Any value from the group (PySpark any_value). Uses first value.
145    pub fn any_value(&self, column: &str) -> Result<DataFrame, PolarsError> {
146        use polars::prelude::*;
147        let agg_expr = vec![col(column).first().alias(format!("any_value({column})"))];
148        let lf = self.lazy_grouped.clone().agg(agg_expr);
149        let mut pl_df = lf.collect()?;
150        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
151        Ok(super::DataFrame::from_polars_with_options(
152            pl_df,
153            self.case_sensitive,
154        ))
155    }
156
157    /// Boolean AND across group (PySpark bool_and / every).
158    pub fn bool_and(&self, column: &str) -> Result<DataFrame, PolarsError> {
159        use polars::prelude::*;
160        let agg_expr = vec![col(column).all(true).alias(format!("bool_and({column})"))];
161        let lf = self.lazy_grouped.clone().agg(agg_expr);
162        let mut pl_df = lf.collect()?;
163        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
164        Ok(super::DataFrame::from_polars_with_options(
165            pl_df,
166            self.case_sensitive,
167        ))
168    }
169
170    /// Boolean OR across group (PySpark bool_or / some).
171    pub fn bool_or(&self, column: &str) -> Result<DataFrame, PolarsError> {
172        use polars::prelude::*;
173        let agg_expr = vec![col(column).any(true).alias(format!("bool_or({column})"))];
174        let lf = self.lazy_grouped.clone().agg(agg_expr);
175        let mut pl_df = lf.collect()?;
176        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
177        Ok(super::DataFrame::from_polars_with_options(
178            pl_df,
179            self.case_sensitive,
180        ))
181    }
182
183    /// Product of column values in each group (PySpark product).
184    pub fn product(&self, column: &str) -> Result<DataFrame, PolarsError> {
185        use polars::prelude::*;
186        let agg_expr = vec![col(column).product().alias(format!("product({column})"))];
187        let lf = self.lazy_grouped.clone().agg(agg_expr);
188        let mut pl_df = lf.collect()?;
189        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
190        Ok(super::DataFrame::from_polars_with_options(
191            pl_df,
192            self.case_sensitive,
193        ))
194    }
195
196    /// Collect column values into list per group (PySpark collect_list).
197    pub fn collect_list(&self, column: &str) -> Result<DataFrame, PolarsError> {
198        use polars::prelude::*;
199        let agg_expr = vec![col(column)
200            .implode()
201            .alias(format!("collect_list({column})"))];
202        let lf = self.lazy_grouped.clone().agg(agg_expr);
203        let mut pl_df = lf.collect()?;
204        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
205        Ok(super::DataFrame::from_polars_with_options(
206            pl_df,
207            self.case_sensitive,
208        ))
209    }
210
211    /// Collect distinct column values into list per group (PySpark collect_set).
212    pub fn collect_set(&self, column: &str) -> Result<DataFrame, PolarsError> {
213        use polars::prelude::*;
214        let agg_expr = vec![col(column)
215            .unique()
216            .implode()
217            .alias(format!("collect_set({column})"))];
218        let lf = self.lazy_grouped.clone().agg(agg_expr);
219        let mut pl_df = lf.collect()?;
220        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
221        Ok(super::DataFrame::from_polars_with_options(
222            pl_df,
223            self.case_sensitive,
224        ))
225    }
226
227    /// Count rows where condition column is true (PySpark count_if).
228    pub fn count_if(&self, column: &str) -> Result<DataFrame, PolarsError> {
229        use polars::prelude::*;
230        let agg_expr = vec![col(column)
231            .cast(DataType::Int64)
232            .sum()
233            .alias(format!("count_if({column})"))];
234        let lf = self.lazy_grouped.clone().agg(agg_expr);
235        let mut pl_df = lf.collect()?;
236        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
237        Ok(super::DataFrame::from_polars_with_options(
238            pl_df,
239            self.case_sensitive,
240        ))
241    }
242
243    /// Percentile of column (PySpark percentile). p in 0.0..=1.0.
244    pub fn percentile(&self, column: &str, p: f64) -> Result<DataFrame, PolarsError> {
245        use polars::prelude::*;
246        let agg_expr = vec![col(column)
247            .quantile(lit(p), QuantileMethod::Linear)
248            .alias(format!("percentile({column}, {p})"))];
249        let lf = self.lazy_grouped.clone().agg(agg_expr);
250        let mut pl_df = lf.collect()?;
251        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
252        Ok(super::DataFrame::from_polars_with_options(
253            pl_df,
254            self.case_sensitive,
255        ))
256    }
257
258    /// Value of value_col where ord_col is maximum (PySpark max_by).
259    pub fn max_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
260        use polars::prelude::*;
261        let st = as_struct(vec![
262            col(ord_col).alias("_ord"),
263            col(value_col).alias("_val"),
264        ]);
265        let agg_expr = vec![st
266            .sort(SortOptions::default().with_order_descending(true))
267            .first()
268            .struct_()
269            .field_by_name("_val")
270            .alias(format!("max_by({value_col}, {ord_col})"))];
271        let lf = self.lazy_grouped.clone().agg(agg_expr);
272        let mut pl_df = lf.collect()?;
273        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
274        Ok(super::DataFrame::from_polars_with_options(
275            pl_df,
276            self.case_sensitive,
277        ))
278    }
279
280    /// Value of value_col where ord_col is minimum (PySpark min_by).
281    pub fn min_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
282        use polars::prelude::*;
283        let st = as_struct(vec![
284            col(ord_col).alias("_ord"),
285            col(value_col).alias("_val"),
286        ]);
287        let agg_expr = vec![st
288            .sort(SortOptions::default())
289            .first()
290            .struct_()
291            .field_by_name("_val")
292            .alias(format!("min_by({value_col}, {ord_col})"))];
293        let lf = self.lazy_grouped.clone().agg(agg_expr);
294        let mut pl_df = lf.collect()?;
295        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
296        Ok(super::DataFrame::from_polars_with_options(
297            pl_df,
298            self.case_sensitive,
299        ))
300    }
301
302    /// Population covariance between two columns in each group (PySpark covar_pop).
303    pub fn covar_pop(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
304        use polars::prelude::DataType;
305        let c1 = col(col1).cast(DataType::Float64);
306        let c2 = col(col2).cast(DataType::Float64);
307        let n = len().cast(DataType::Float64);
308        let sum_ab = (c1.clone() * c2.clone()).sum();
309        let sum_a = col(col1).sum().cast(DataType::Float64);
310        let sum_b = col(col2).sum().cast(DataType::Float64);
311        let cov = (sum_ab - sum_a * sum_b / n.clone()) / n;
312        let agg_expr = vec![cov.alias(format!("covar_pop({col1}, {col2})"))];
313        let lf = self.lazy_grouped.clone().agg(agg_expr);
314        let mut pl_df = lf.collect()?;
315        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
316        Ok(super::DataFrame::from_polars_with_options(
317            pl_df,
318            self.case_sensitive,
319        ))
320    }
321
322    /// Sample covariance between two columns in each group (PySpark covar_samp). ddof=1.
323    pub fn covar_samp(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
324        use polars::prelude::DataType;
325        let c1 = col(col1).cast(DataType::Float64);
326        let c2 = col(col2).cast(DataType::Float64);
327        let n = len().cast(DataType::Float64);
328        let sum_ab = (c1.clone() * c2.clone()).sum();
329        let sum_a = col(col1).sum().cast(DataType::Float64);
330        let sum_b = col(col2).sum().cast(DataType::Float64);
331        let cov = when(len().gt(lit(1)))
332            .then((sum_ab - sum_a * sum_b / n.clone()) / (len() - lit(1)).cast(DataType::Float64))
333            .otherwise(lit(f64::NAN));
334        let agg_expr = vec![cov.alias(format!("covar_samp({col1}, {col2})"))];
335        let lf = self.lazy_grouped.clone().agg(agg_expr);
336        let mut pl_df = lf.collect()?;
337        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
338        Ok(super::DataFrame::from_polars_with_options(
339            pl_df,
340            self.case_sensitive,
341        ))
342    }
343
344    /// Pearson correlation between two columns in each group (PySpark corr).
345    pub fn corr(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
346        use polars::prelude::DataType;
347        let c1 = col(col1).cast(DataType::Float64);
348        let c2 = col(col2).cast(DataType::Float64);
349        let n = len().cast(DataType::Float64);
350        let n1 = (len() - lit(1)).cast(DataType::Float64);
351        let sum_ab = (c1.clone() * c2.clone()).sum();
352        let sum_a = col(col1).sum().cast(DataType::Float64);
353        let sum_b = col(col2).sum().cast(DataType::Float64);
354        let sum_a2 = (c1.clone() * c1).sum();
355        let sum_b2 = (c2.clone() * c2).sum();
356        let cov_samp = (sum_ab - sum_a.clone() * sum_b.clone() / n.clone()) / n1.clone();
357        let var_a = (sum_a2 - sum_a.clone() * sum_a / n.clone()) / n1.clone();
358        let var_b = (sum_b2 - sum_b.clone() * sum_b / n.clone()) / n1.clone();
359        let std_a = var_a.sqrt();
360        let std_b = var_b.sqrt();
361        let corr_expr = when(len().gt(lit(1)))
362            .then(cov_samp / (std_a * std_b))
363            .otherwise(lit(f64::NAN));
364        let agg_expr = vec![corr_expr.alias(format!("corr({col1}, {col2})"))];
365        let lf = self.lazy_grouped.clone().agg(agg_expr);
366        let mut pl_df = lf.collect()?;
367        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
368        Ok(super::DataFrame::from_polars_with_options(
369            pl_df,
370            self.case_sensitive,
371        ))
372    }
373
374    /// Regression count of (y, x) pairs where both non-null (PySpark regr_count).
375    pub fn regr_count(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
376        let agg_expr = vec![crate::functions::regr_count_expr(y_col, x_col)
377            .alias(format!("regr_count({y_col}, {x_col})"))];
378        let lf = self.lazy_grouped.clone().agg(agg_expr);
379        let mut pl_df = lf.collect()?;
380        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
381        Ok(super::DataFrame::from_polars_with_options(
382            pl_df,
383            self.case_sensitive,
384        ))
385    }
386
387    /// Regression average of x (PySpark regr_avgx).
388    pub fn regr_avgx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
389        let agg_expr = vec![crate::functions::regr_avgx_expr(y_col, x_col)
390            .alias(format!("regr_avgx({y_col}, {x_col})"))];
391        let lf = self.lazy_grouped.clone().agg(agg_expr);
392        let mut pl_df = lf.collect()?;
393        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
394        Ok(super::DataFrame::from_polars_with_options(
395            pl_df,
396            self.case_sensitive,
397        ))
398    }
399
400    /// Regression average of y (PySpark regr_avgy).
401    pub fn regr_avgy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
402        let agg_expr = vec![crate::functions::regr_avgy_expr(y_col, x_col)
403            .alias(format!("regr_avgy({y_col}, {x_col})"))];
404        let lf = self.lazy_grouped.clone().agg(agg_expr);
405        let mut pl_df = lf.collect()?;
406        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
407        Ok(super::DataFrame::from_polars_with_options(
408            pl_df,
409            self.case_sensitive,
410        ))
411    }
412
413    /// Regression slope (PySpark regr_slope).
414    pub fn regr_slope(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
415        let agg_expr = vec![crate::functions::regr_slope_expr(y_col, x_col)
416            .alias(format!("regr_slope({y_col}, {x_col})"))];
417        let lf = self.lazy_grouped.clone().agg(agg_expr);
418        let mut pl_df = lf.collect()?;
419        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
420        Ok(super::DataFrame::from_polars_with_options(
421            pl_df,
422            self.case_sensitive,
423        ))
424    }
425
426    /// Regression intercept (PySpark regr_intercept).
427    pub fn regr_intercept(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
428        let agg_expr = vec![crate::functions::regr_intercept_expr(y_col, x_col)
429            .alias(format!("regr_intercept({y_col}, {x_col})"))];
430        let lf = self.lazy_grouped.clone().agg(agg_expr);
431        let mut pl_df = lf.collect()?;
432        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
433        Ok(super::DataFrame::from_polars_with_options(
434            pl_df,
435            self.case_sensitive,
436        ))
437    }
438
439    /// Regression R-squared (PySpark regr_r2).
440    pub fn regr_r2(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
441        let agg_expr = vec![crate::functions::regr_r2_expr(y_col, x_col)
442            .alias(format!("regr_r2({y_col}, {x_col})"))];
443        let lf = self.lazy_grouped.clone().agg(agg_expr);
444        let mut pl_df = lf.collect()?;
445        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
446        Ok(super::DataFrame::from_polars_with_options(
447            pl_df,
448            self.case_sensitive,
449        ))
450    }
451
452    /// Regression sum (x - avg_x)^2 (PySpark regr_sxx).
453    pub fn regr_sxx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
454        let agg_expr = vec![crate::functions::regr_sxx_expr(y_col, x_col)
455            .alias(format!("regr_sxx({y_col}, {x_col})"))];
456        let lf = self.lazy_grouped.clone().agg(agg_expr);
457        let mut pl_df = lf.collect()?;
458        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
459        Ok(super::DataFrame::from_polars_with_options(
460            pl_df,
461            self.case_sensitive,
462        ))
463    }
464
465    /// Regression sum (y - avg_y)^2 (PySpark regr_syy).
466    pub fn regr_syy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
467        let agg_expr = vec![crate::functions::regr_syy_expr(y_col, x_col)
468            .alias(format!("regr_syy({y_col}, {x_col})"))];
469        let lf = self.lazy_grouped.clone().agg(agg_expr);
470        let mut pl_df = lf.collect()?;
471        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
472        Ok(super::DataFrame::from_polars_with_options(
473            pl_df,
474            self.case_sensitive,
475        ))
476    }
477
478    /// Regression sum (x - avg_x)(y - avg_y) (PySpark regr_sxy).
479    pub fn regr_sxy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
480        let agg_expr = vec![crate::functions::regr_sxy_expr(y_col, x_col)
481            .alias(format!("regr_sxy({y_col}, {x_col})"))];
482        let lf = self.lazy_grouped.clone().agg(agg_expr);
483        let mut pl_df = lf.collect()?;
484        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
485        Ok(super::DataFrame::from_polars_with_options(
486            pl_df,
487            self.case_sensitive,
488        ))
489    }
490
491    /// Kurtosis of a column in each group (PySpark kurtosis). Fisher definition, bias=true.
492    pub fn kurtosis(&self, column: &str) -> Result<DataFrame, PolarsError> {
493        use polars::prelude::*;
494        let agg_expr = vec![col(column)
495            .cast(DataType::Float64)
496            .kurtosis(true, true)
497            .alias(format!("kurtosis({column})"))];
498        let lf = self.lazy_grouped.clone().agg(agg_expr);
499        let mut pl_df = lf.collect()?;
500        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
501        Ok(super::DataFrame::from_polars_with_options(
502            pl_df,
503            self.case_sensitive,
504        ))
505    }
506
507    /// Skewness of a column in each group (PySpark skewness). bias=true.
508    pub fn skewness(&self, column: &str) -> Result<DataFrame, PolarsError> {
509        use polars::prelude::*;
510        let agg_expr = vec![col(column)
511            .cast(DataType::Float64)
512            .skew(true)
513            .alias(format!("skewness({column})"))];
514        let lf = self.lazy_grouped.clone().agg(agg_expr);
515        let mut pl_df = lf.collect()?;
516        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
517        Ok(super::DataFrame::from_polars_with_options(
518            pl_df,
519            self.case_sensitive,
520        ))
521    }
522
523    /// Apply multiple aggregations at once (generic agg method)
524    pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
525        let lf = self.lazy_grouped.clone().agg(aggregations);
526        let mut pl_df = lf.collect()?;
527        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
528        Ok(super::DataFrame::from_polars_with_options(
529            pl_df,
530            self.case_sensitive,
531        ))
532    }
533
534    /// Get grouping columns
535    pub fn grouping_columns(&self) -> &[String] {
536        &self.grouping_cols
537    }
538}
539
540/// Cube/rollup: multiple grouping sets then union (PySpark cube / rollup).
541pub struct CubeRollupData {
542    pub(super) df: PlDataFrame,
543    pub(super) grouping_cols: Vec<String>,
544    pub(super) case_sensitive: bool,
545    pub(super) is_cube: bool,
546}
547
548impl CubeRollupData {
549    /// Run aggregation on each grouping set and union results. Missing keys become null.
550    pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
551        use polars::prelude::*;
552        let subsets: Vec<Vec<String>> = if self.is_cube {
553            // All subsets of grouping_cols (2^n)
554            let n = self.grouping_cols.len();
555            (0..1 << n)
556                .map(|mask| {
557                    self.grouping_cols
558                        .iter()
559                        .enumerate()
560                        .filter(|(i, _)| (mask & (1 << i)) != 0)
561                        .map(|(_, c)| c.clone())
562                        .collect()
563                })
564                .collect()
565        } else {
566            // Prefixes: [all], [all-1], ..., []
567            (0..=self.grouping_cols.len())
568                .map(|len| self.grouping_cols[..len].to_vec())
569                .collect()
570        };
571
572        let schema = self.df.schema();
573        let mut parts: Vec<PlDataFrame> = Vec::with_capacity(subsets.len());
574        for subset in subsets {
575            if subset.is_empty() {
576                // Single row: no grouping keys, one row of aggregates over full table
577                let lf = self.df.clone().lazy().select(aggregations.clone());
578                let mut part = lf.collect()?;
579                let n = part.height();
580                for gc in &self.grouping_cols {
581                    let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
582                    let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
583                    part.with_column(null_series)?;
584                }
585                // Reorder to [grouping_cols..., agg_cols]
586                let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
587                for name in part.get_column_names() {
588                    if !self.grouping_cols.iter().any(|g| g == name) {
589                        order.push(name);
590                    }
591                }
592                part = part.select(order)?;
593                parts.push(part);
594            } else {
595                let grouped = self
596                    .df
597                    .clone()
598                    .lazy()
599                    .group_by(subset.iter().map(|s| col(s.as_str())).collect::<Vec<_>>());
600                let mut part = grouped.agg(aggregations.clone()).collect()?;
601                part = reorder_groupby_columns(&mut part, &subset)?;
602                let n = part.height();
603                for gc in &self.grouping_cols {
604                    if subset.iter().any(|s| s == gc) {
605                        continue;
606                    }
607                    let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
608                    let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
609                    part.with_column(null_series)?;
610                }
611                let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
612                for name in part.get_column_names() {
613                    if !self.grouping_cols.iter().any(|g| g == name) {
614                        order.push(name);
615                    }
616                }
617                part = part.select(order)?;
618                parts.push(part);
619            }
620        }
621
622        if parts.is_empty() {
623            return Ok(super::DataFrame::from_polars_with_options(
624                PlDataFrame::empty(),
625                self.case_sensitive,
626            ));
627        }
628        let first_schema = parts[0].schema();
629        let order: Vec<&str> = first_schema.iter_names().map(|s| s.as_str()).collect();
630        for p in parts.iter_mut().skip(1) {
631            *p = p.select(order.clone())?;
632        }
633        let lazy_frames: Vec<_> = parts.into_iter().map(|p| p.lazy()).collect();
634        let out = polars::prelude::concat(lazy_frames, UnionArgs::default())?.collect()?;
635        Ok(super::DataFrame::from_polars_with_options(
636            out,
637            self.case_sensitive,
638        ))
639    }
640}
641
642fn null_series_for_dtype(name: &str, n: usize, dtype: &DataType) -> Result<Series, PolarsError> {
643    let name = name.into();
644    let s = match dtype {
645        DataType::Int32 => Series::new(name, vec![None::<i32>; n]),
646        DataType::Int64 => Series::new(name, vec![None::<i64>; n]),
647        DataType::Float32 => Series::new(name, vec![None::<f32>; n]),
648        DataType::Float64 => Series::new(name, vec![None::<f64>; n]),
649        DataType::String => {
650            let v: Vec<Option<String>> = (0..n).map(|_| None).collect();
651            Series::new(name, v)
652        }
653        DataType::Boolean => Series::new(name, vec![None::<bool>; n]),
654        DataType::Date => Series::new(name, vec![None::<i32>; n]).cast(dtype)?,
655        DataType::Datetime(_, _) => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
656        _ => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
657    };
658    Ok(s)
659}
660
661/// Reorder columns after groupBy to match PySpark order: grouping columns first, then aggregations
662pub(super) fn reorder_groupby_columns(
663    pl_df: &mut PlDataFrame,
664    grouping_cols: &[String],
665) -> Result<PlDataFrame, PolarsError> {
666    let all_cols: Vec<String> = pl_df
667        .get_column_names()
668        .iter()
669        .map(|s| s.to_string())
670        .collect();
671    let mut reordered_cols: Vec<&str> = Vec::new();
672    for gc in grouping_cols {
673        if all_cols.iter().any(|c| c == gc) {
674            reordered_cols.push(gc);
675        }
676    }
677    for col_name in &all_cols {
678        if !grouping_cols.iter().any(|gc| gc == col_name) {
679            reordered_cols.push(col_name);
680        }
681    }
682    if !reordered_cols.is_empty() && reordered_cols.len() == all_cols.len() {
683        pl_df.select(reordered_cols)
684    } else {
685        Ok(pl_df.clone())
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use crate::{DataFrame, SparkSession};
692
693    fn test_df() -> DataFrame {
694        let spark = SparkSession::builder()
695            .app_name("agg_tests")
696            .get_or_create();
697        let tuples = vec![
698            (1i64, 10i64, "a".to_string()),
699            (1i64, 20i64, "a".to_string()),
700            (2i64, 30i64, "b".to_string()),
701        ];
702        spark
703            .create_dataframe(tuples, vec!["k", "v", "label"])
704            .unwrap()
705    }
706
707    #[test]
708    fn group_by_count_single_group() {
709        let df = test_df();
710        let grouped = df.group_by(vec!["k"]).unwrap();
711        let out = grouped.count().unwrap();
712        assert_eq!(out.count().unwrap(), 2);
713        let cols = out.columns().unwrap();
714        assert!(cols.contains(&"k".to_string()));
715        assert!(cols.contains(&"count".to_string()));
716    }
717
718    #[test]
719    fn group_by_sum() {
720        let df = test_df();
721        let grouped = df.group_by(vec!["k"]).unwrap();
722        let out = grouped.sum("v").unwrap();
723        assert_eq!(out.count().unwrap(), 2);
724        let cols = out.columns().unwrap();
725        assert!(cols.iter().any(|c| c.starts_with("sum(")));
726    }
727
728    #[test]
729    fn group_by_empty_groups() {
730        let spark = SparkSession::builder()
731            .app_name("agg_tests")
732            .get_or_create();
733        let tuples: Vec<(i64, i64, String)> = vec![];
734        let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
735        let grouped = df.group_by(vec!["a"]).unwrap();
736        let out = grouped.count().unwrap();
737        assert_eq!(out.count().unwrap(), 0);
738    }
739
740    #[test]
741    fn group_by_agg_multi() {
742        use polars::prelude::*;
743        let df = test_df();
744        let grouped = df.group_by(vec!["k"]).unwrap();
745        let out = grouped
746            .agg(vec![len().alias("cnt"), col("v").sum().alias("total")])
747            .unwrap();
748        assert_eq!(out.count().unwrap(), 2);
749        let cols = out.columns().unwrap();
750        assert!(cols.contains(&"k".to_string()));
751        assert!(cols.contains(&"cnt".to_string()));
752        assert!(cols.contains(&"total".to_string()));
753    }
754}