Skip to main content

robin_sparkless/dataframe/
aggregations.rs

1//! GroupBy and aggregation operations.
2
3use super::DataFrame;
4use crate::column::Column;
5use polars::prelude::{
6    DataFrame as PlDataFrame, DataType, Expr, LazyFrame, LazyGroupBy, NamedFrom, PolarsError,
7    SchemaNamesAndDtypes, Series, col, len, lit, when,
8};
9use std::collections::HashMap;
10
11/// Disambiguate duplicate output names in aggregation expressions (PySpark parity: issue #368).
12/// When multiple aggs produce the same name (e.g. sum("value"), avg("value") both "value"),
13/// suffix with _1, _2, ... so Polars does not error.
14pub(crate) fn disambiguate_agg_output_names(aggregations: Vec<Expr>) -> Vec<Expr> {
15    let mut name_count: HashMap<String, u32> = HashMap::new();
16    aggregations
17        .into_iter()
18        .map(|e| {
19            let base_name = polars_plan::utils::expr_output_name(&e)
20                .map(|s| s.to_string())
21                .unwrap_or_else(|_| "_".to_string());
22            let count = name_count.entry(base_name.clone()).or_insert(0);
23            *count += 1;
24            let final_name = if *count == 1 {
25                base_name
26            } else {
27                format!("{}_{}", base_name, *count - 1)
28            };
29            if *count == 1 {
30                e
31            } else {
32                e.alias(final_name.as_str())
33            }
34        })
35        .collect()
36}
37
38/// GroupedData - represents a DataFrame grouped by certain columns.
39/// Similar to PySpark's GroupedData. Holds LazyGroupBy for lazy agg.
40pub struct GroupedData {
41    pub(crate) lf: LazyFrame,
42    pub(crate) lazy_grouped: LazyGroupBy,
43    pub(crate) grouping_cols: Vec<String>,
44    pub(crate) case_sensitive: bool,
45}
46
47impl GroupedData {
48    /// Resolve aggregation column name against LazyFrame schema (case-sensitive or -insensitive).
49    fn resolve_column(&self, name: &str) -> Result<String, PolarsError> {
50        let schema = self.lf.clone().collect_schema()?;
51        let names: Vec<String> = schema
52            .iter_names_and_dtypes()
53            .map(|(n, _)| n.to_string())
54            .collect();
55        if self.case_sensitive {
56            if names.iter().any(|n| n == name) {
57                return Ok(name.to_string());
58            }
59        } else {
60            let name_lower = name.to_lowercase();
61            for n in &names {
62                if n.to_lowercase() == name_lower {
63                    return Ok(n.clone());
64                }
65            }
66        }
67        let available = names.join(", ");
68        Err(PolarsError::ColumnNotFound(
69            format!(
70                "Column '{}' not found in grouped DataFrame. Available: [{}].",
71                name, available
72            )
73            .into(),
74        ))
75    }
76
77    /// Count rows in each group
78    pub fn count(&self) -> Result<DataFrame, PolarsError> {
79        use polars::prelude::*;
80        let agg_expr = vec![len().alias("count")];
81        let lf = self.lazy_grouped.clone().agg(agg_expr);
82        let mut pl_df = lf.collect()?;
83        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
84        Ok(super::DataFrame::from_polars_with_options(
85            pl_df,
86            self.case_sensitive,
87        ))
88    }
89
90    /// Sum a column in each group
91    pub fn sum(&self, column: &str) -> Result<DataFrame, PolarsError> {
92        use polars::prelude::*;
93        let c = self.resolve_column(column)?;
94        let agg_expr = vec![col(c.as_str()).sum().alias(format!("sum({column})"))];
95        let lf = self.lazy_grouped.clone().agg(agg_expr);
96        let mut pl_df = lf.collect()?;
97        let all_cols: Vec<String> = pl_df
98            .get_column_names()
99            .iter()
100            .map(|s| s.to_string())
101            .collect();
102        let grouping_cols: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
103        let mut reordered_cols: Vec<&str> = Vec::new();
104        for gc in &grouping_cols {
105            if all_cols.iter().any(|c| c == gc) {
106                reordered_cols.push(gc);
107            }
108        }
109        for col_name in &all_cols {
110            if !grouping_cols.iter().any(|gc| *gc == col_name) {
111                reordered_cols.push(col_name);
112            }
113        }
114        if !reordered_cols.is_empty() {
115            pl_df = pl_df.select(reordered_cols)?;
116        }
117        Ok(super::DataFrame::from_polars_with_options(
118            pl_df,
119            self.case_sensitive,
120        ))
121    }
122
123    /// Average (mean) of one or more columns in each group (PySpark: df.groupBy("x").avg("a", "b")).
124    pub fn avg(&self, columns: &[&str]) -> Result<DataFrame, PolarsError> {
125        if columns.is_empty() {
126            return Err(PolarsError::ComputeError(
127                "avg requires at least one column".into(),
128            ));
129        }
130        use polars::prelude::*;
131        let agg_expr: Vec<Expr> = columns
132            .iter()
133            .map(|c| {
134                let resolved = self.resolve_column(c)?;
135                Ok(col(resolved.as_str()).mean().alias(format!("avg({c})")))
136            })
137            .collect::<Result<Vec<_>, PolarsError>>()?;
138        let lf = self.lazy_grouped.clone().agg(agg_expr);
139        let mut pl_df = lf.collect()?;
140        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
141        Ok(super::DataFrame::from_polars_with_options(
142            pl_df,
143            self.case_sensitive,
144        ))
145    }
146
147    /// Minimum value of a column in each group
148    pub fn min(&self, column: &str) -> Result<DataFrame, PolarsError> {
149        use polars::prelude::*;
150        let c = self.resolve_column(column)?;
151        let agg_expr = vec![col(c.as_str()).min().alias(format!("min({column})"))];
152        let lf = self.lazy_grouped.clone().agg(agg_expr);
153        let mut pl_df = lf.collect()?;
154        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
155        Ok(super::DataFrame::from_polars_with_options(
156            pl_df,
157            self.case_sensitive,
158        ))
159    }
160
161    /// Maximum value of a column in each group
162    pub fn max(&self, column: &str) -> Result<DataFrame, PolarsError> {
163        use polars::prelude::*;
164        let c = self.resolve_column(column)?;
165        let agg_expr = vec![col(c.as_str()).max().alias(format!("max({column})"))];
166        let lf = self.lazy_grouped.clone().agg(agg_expr);
167        let mut pl_df = lf.collect()?;
168        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
169        Ok(super::DataFrame::from_polars_with_options(
170            pl_df,
171            self.case_sensitive,
172        ))
173    }
174
175    /// First value of a column in each group (order not guaranteed unless explicitly sorted).
176    pub fn first(&self, column: &str) -> Result<DataFrame, PolarsError> {
177        use polars::prelude::*;
178        let c = self.resolve_column(column)?;
179        let agg_expr = vec![col(c.as_str()).first().alias(format!("first({column})"))];
180        let lf = self.lazy_grouped.clone().agg(agg_expr);
181        let mut pl_df = lf.collect()?;
182        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
183        Ok(super::DataFrame::from_polars_with_options(
184            pl_df,
185            self.case_sensitive,
186        ))
187    }
188
189    /// Last value of a column in each group (order not guaranteed unless explicitly sorted).
190    pub fn last(&self, column: &str) -> Result<DataFrame, PolarsError> {
191        use polars::prelude::*;
192        let c = self.resolve_column(column)?;
193        let agg_expr = vec![col(c.as_str()).last().alias(format!("last({column})"))];
194        let lf = self.lazy_grouped.clone().agg(agg_expr);
195        let mut pl_df = lf.collect()?;
196        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
197        Ok(super::DataFrame::from_polars_with_options(
198            pl_df,
199            self.case_sensitive,
200        ))
201    }
202
203    /// Approximate count of distinct values in each group (uses n_unique; same as count_distinct for exact).
204    pub fn approx_count_distinct(&self, column: &str) -> Result<DataFrame, PolarsError> {
205        use polars::prelude::{DataType, col};
206        let c = self.resolve_column(column)?;
207        let agg_expr = vec![
208            col(c.as_str())
209                .n_unique()
210                .cast(DataType::Int64)
211                .alias(format!("approx_count_distinct({column})")),
212        ];
213        let lf = self.lazy_grouped.clone().agg(agg_expr);
214        let mut pl_df = lf.collect()?;
215        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
216        Ok(super::DataFrame::from_polars_with_options(
217            pl_df,
218            self.case_sensitive,
219        ))
220    }
221
222    /// Any value from the group (PySpark any_value). Uses first value.
223    pub fn any_value(&self, column: &str) -> Result<DataFrame, PolarsError> {
224        use polars::prelude::*;
225        let c = self.resolve_column(column)?;
226        let agg_expr = vec![
227            col(c.as_str())
228                .first()
229                .alias(format!("any_value({column})")),
230        ];
231        let lf = self.lazy_grouped.clone().agg(agg_expr);
232        let mut pl_df = lf.collect()?;
233        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
234        Ok(super::DataFrame::from_polars_with_options(
235            pl_df,
236            self.case_sensitive,
237        ))
238    }
239
240    /// Boolean AND across group (PySpark bool_and / every).
241    pub fn bool_and(&self, column: &str) -> Result<DataFrame, PolarsError> {
242        use polars::prelude::*;
243        let c = self.resolve_column(column)?;
244        let agg_expr = vec![
245            col(c.as_str())
246                .all(true)
247                .alias(format!("bool_and({column})")),
248        ];
249        let lf = self.lazy_grouped.clone().agg(agg_expr);
250        let mut pl_df = lf.collect()?;
251        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
252        Ok(super::DataFrame::from_polars_with_options(
253            pl_df,
254            self.case_sensitive,
255        ))
256    }
257
258    /// Boolean OR across group (PySpark bool_or / some).
259    pub fn bool_or(&self, column: &str) -> Result<DataFrame, PolarsError> {
260        use polars::prelude::*;
261        let c = self.resolve_column(column)?;
262        let agg_expr = vec![
263            col(c.as_str())
264                .any(true)
265                .alias(format!("bool_or({column})")),
266        ];
267        let lf = self.lazy_grouped.clone().agg(agg_expr);
268        let mut pl_df = lf.collect()?;
269        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
270        Ok(super::DataFrame::from_polars_with_options(
271            pl_df,
272            self.case_sensitive,
273        ))
274    }
275
276    /// Product of column values in each group (PySpark product).
277    pub fn product(&self, column: &str) -> Result<DataFrame, PolarsError> {
278        use polars::prelude::*;
279        let c = self.resolve_column(column)?;
280        let agg_expr = vec![
281            col(c.as_str())
282                .product()
283                .alias(format!("product({column})")),
284        ];
285        let lf = self.lazy_grouped.clone().agg(agg_expr);
286        let mut pl_df = lf.collect()?;
287        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
288        Ok(super::DataFrame::from_polars_with_options(
289            pl_df,
290            self.case_sensitive,
291        ))
292    }
293
294    /// Collect column values into list per group (PySpark collect_list).
295    pub fn collect_list(&self, column: &str) -> Result<DataFrame, PolarsError> {
296        use polars::prelude::*;
297        let c = self.resolve_column(column)?;
298        let agg_expr = vec![
299            col(c.as_str())
300                .implode()
301                .alias(format!("collect_list({column})")),
302        ];
303        let lf = self.lazy_grouped.clone().agg(agg_expr);
304        let mut pl_df = lf.collect()?;
305        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
306        Ok(super::DataFrame::from_polars_with_options(
307            pl_df,
308            self.case_sensitive,
309        ))
310    }
311
312    /// Collect distinct column values into list per group (PySpark collect_set).
313    pub fn collect_set(&self, column: &str) -> Result<DataFrame, PolarsError> {
314        use polars::prelude::*;
315        let c = self.resolve_column(column)?;
316        let agg_expr = vec![
317            col(c.as_str())
318                .unique()
319                .implode()
320                .alias(format!("collect_set({column})")),
321        ];
322        let lf = self.lazy_grouped.clone().agg(agg_expr);
323        let mut pl_df = lf.collect()?;
324        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
325        Ok(super::DataFrame::from_polars_with_options(
326            pl_df,
327            self.case_sensitive,
328        ))
329    }
330
331    /// Count rows where condition column is true (PySpark count_if).
332    pub fn count_if(&self, column: &str) -> Result<DataFrame, PolarsError> {
333        use polars::prelude::*;
334        let c = self.resolve_column(column)?;
335        let agg_expr = vec![
336            col(c.as_str())
337                .cast(DataType::Int64)
338                .sum()
339                .alias(format!("count_if({column})")),
340        ];
341        let lf = self.lazy_grouped.clone().agg(agg_expr);
342        let mut pl_df = lf.collect()?;
343        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
344        Ok(super::DataFrame::from_polars_with_options(
345            pl_df,
346            self.case_sensitive,
347        ))
348    }
349
350    /// Percentile of column (PySpark percentile). p in 0.0..=1.0.
351    pub fn percentile(&self, column: &str, p: f64) -> Result<DataFrame, PolarsError> {
352        use polars::prelude::*;
353        let c = self.resolve_column(column)?;
354        let agg_expr = vec![
355            col(c.as_str())
356                .quantile(lit(p), QuantileMethod::Linear)
357                .alias(format!("percentile({column}, {p})")),
358        ];
359        let lf = self.lazy_grouped.clone().agg(agg_expr);
360        let mut pl_df = lf.collect()?;
361        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
362        Ok(super::DataFrame::from_polars_with_options(
363            pl_df,
364            self.case_sensitive,
365        ))
366    }
367
368    /// Value of value_col where ord_col is maximum (PySpark max_by).
369    pub fn max_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
370        use polars::prelude::*;
371        let vc = self.resolve_column(value_col)?;
372        let oc = self.resolve_column(ord_col)?;
373        let st = as_struct(vec![
374            col(oc.as_str()).alias("_ord"),
375            col(vc.as_str()).alias("_val"),
376        ]);
377        let agg_expr = vec![
378            st.sort(SortOptions::default().with_order_descending(true))
379                .first()
380                .struct_()
381                .field_by_name("_val")
382                .alias(format!("max_by({value_col}, {ord_col})")),
383        ];
384        let lf = self.lazy_grouped.clone().agg(agg_expr);
385        let mut pl_df = lf.collect()?;
386        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
387        Ok(super::DataFrame::from_polars_with_options(
388            pl_df,
389            self.case_sensitive,
390        ))
391    }
392
393    /// Value of value_col where ord_col is minimum (PySpark min_by).
394    pub fn min_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
395        use polars::prelude::*;
396        let vc = self.resolve_column(value_col)?;
397        let oc = self.resolve_column(ord_col)?;
398        let st = as_struct(vec![
399            col(oc.as_str()).alias("_ord"),
400            col(vc.as_str()).alias("_val"),
401        ]);
402        let agg_expr = vec![
403            st.sort(SortOptions::default())
404                .first()
405                .struct_()
406                .field_by_name("_val")
407                .alias(format!("min_by({value_col}, {ord_col})")),
408        ];
409        let lf = self.lazy_grouped.clone().agg(agg_expr);
410        let mut pl_df = lf.collect()?;
411        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
412        Ok(super::DataFrame::from_polars_with_options(
413            pl_df,
414            self.case_sensitive,
415        ))
416    }
417
418    /// Population covariance between two columns in each group (PySpark covar_pop).
419    pub fn covar_pop(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
420        use polars::prelude::DataType;
421        let c1_res = self.resolve_column(col1)?;
422        let c2_res = self.resolve_column(col2)?;
423        let c1 = col(c1_res.as_str()).cast(DataType::Float64);
424        let c2 = col(c2_res.as_str()).cast(DataType::Float64);
425        let n = len().cast(DataType::Float64);
426        let sum_ab = (c1.clone() * c2.clone()).sum();
427        let sum_a = col(c1_res.as_str()).sum().cast(DataType::Float64);
428        let sum_b = col(c2_res.as_str()).sum().cast(DataType::Float64);
429        let cov = (sum_ab - sum_a * sum_b / n.clone()) / n;
430        let agg_expr = vec![cov.alias(format!("covar_pop({col1}, {col2})"))];
431        let lf = self.lazy_grouped.clone().agg(agg_expr);
432        let mut pl_df = lf.collect()?;
433        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
434        Ok(super::DataFrame::from_polars_with_options(
435            pl_df,
436            self.case_sensitive,
437        ))
438    }
439
440    /// Sample covariance between two columns in each group (PySpark covar_samp). ddof=1.
441    pub fn covar_samp(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
442        use polars::prelude::DataType;
443        let c1_res = self.resolve_column(col1)?;
444        let c2_res = self.resolve_column(col2)?;
445        let c1 = col(c1_res.as_str()).cast(DataType::Float64);
446        let c2 = col(c2_res.as_str()).cast(DataType::Float64);
447        let n = len().cast(DataType::Float64);
448        let sum_ab = (c1.clone() * c2.clone()).sum();
449        let sum_a = col(c1_res.as_str()).sum().cast(DataType::Float64);
450        let sum_b = col(c2_res.as_str()).sum().cast(DataType::Float64);
451        let cov = when(len().gt(lit(1)))
452            .then((sum_ab - sum_a * sum_b / n.clone()) / (len() - lit(1)).cast(DataType::Float64))
453            .otherwise(lit(f64::NAN));
454        let agg_expr = vec![cov.alias(format!("covar_samp({col1}, {col2})"))];
455        let lf = self.lazy_grouped.clone().agg(agg_expr);
456        let mut pl_df = lf.collect()?;
457        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
458        Ok(super::DataFrame::from_polars_with_options(
459            pl_df,
460            self.case_sensitive,
461        ))
462    }
463
464    /// Pearson correlation between two columns in each group (PySpark corr).
465    pub fn corr(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
466        use polars::prelude::DataType;
467        let c1_res = self.resolve_column(col1)?;
468        let c2_res = self.resolve_column(col2)?;
469        let c1 = col(c1_res.as_str()).cast(DataType::Float64);
470        let c2 = col(c2_res.as_str()).cast(DataType::Float64);
471        let n = len().cast(DataType::Float64);
472        let n1 = (len() - lit(1)).cast(DataType::Float64);
473        let sum_ab = (c1.clone() * c2.clone()).sum();
474        let sum_a = col(c1_res.as_str()).sum().cast(DataType::Float64);
475        let sum_b = col(c2_res.as_str()).sum().cast(DataType::Float64);
476        let sum_a2 = (c1.clone() * c1).sum();
477        let sum_b2 = (c2.clone() * c2).sum();
478        let cov_samp = (sum_ab - sum_a.clone() * sum_b.clone() / n.clone()) / n1.clone();
479        let var_a = (sum_a2 - sum_a.clone() * sum_a / n.clone()) / n1.clone();
480        let var_b = (sum_b2 - sum_b.clone() * sum_b / n.clone()) / n1.clone();
481        let std_a = var_a.sqrt();
482        let std_b = var_b.sqrt();
483        let corr_expr = when(len().gt(lit(1)))
484            .then(cov_samp / (std_a * std_b))
485            .otherwise(lit(f64::NAN));
486        let agg_expr = vec![corr_expr.alias(format!("corr({col1}, {col2})"))];
487        let lf = self.lazy_grouped.clone().agg(agg_expr);
488        let mut pl_df = lf.collect()?;
489        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
490        Ok(super::DataFrame::from_polars_with_options(
491            pl_df,
492            self.case_sensitive,
493        ))
494    }
495
496    /// Regression count of (y, x) pairs where both non-null (PySpark regr_count).
497    pub fn regr_count(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
498        let yc = self.resolve_column(y_col)?;
499        let xc = self.resolve_column(x_col)?;
500        let agg_expr = vec![
501            crate::functions::regr_count_expr(yc.as_str(), xc.as_str())
502                .alias(format!("regr_count({y_col}, {x_col})")),
503        ];
504        let lf = self.lazy_grouped.clone().agg(agg_expr);
505        let mut pl_df = lf.collect()?;
506        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
507        Ok(super::DataFrame::from_polars_with_options(
508            pl_df,
509            self.case_sensitive,
510        ))
511    }
512
513    /// Regression average of x (PySpark regr_avgx).
514    pub fn regr_avgx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
515        let yc = self.resolve_column(y_col)?;
516        let xc = self.resolve_column(x_col)?;
517        let agg_expr = vec![
518            crate::functions::regr_avgx_expr(yc.as_str(), xc.as_str())
519                .alias(format!("regr_avgx({y_col}, {x_col})")),
520        ];
521        let lf = self.lazy_grouped.clone().agg(agg_expr);
522        let mut pl_df = lf.collect()?;
523        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
524        Ok(super::DataFrame::from_polars_with_options(
525            pl_df,
526            self.case_sensitive,
527        ))
528    }
529
530    /// Regression average of y (PySpark regr_avgy).
531    pub fn regr_avgy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
532        let yc = self.resolve_column(y_col)?;
533        let xc = self.resolve_column(x_col)?;
534        let agg_expr = vec![
535            crate::functions::regr_avgy_expr(yc.as_str(), xc.as_str())
536                .alias(format!("regr_avgy({y_col}, {x_col})")),
537        ];
538        let lf = self.lazy_grouped.clone().agg(agg_expr);
539        let mut pl_df = lf.collect()?;
540        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
541        Ok(super::DataFrame::from_polars_with_options(
542            pl_df,
543            self.case_sensitive,
544        ))
545    }
546
547    /// Regression slope (PySpark regr_slope).
548    pub fn regr_slope(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
549        let yc = self.resolve_column(y_col)?;
550        let xc = self.resolve_column(x_col)?;
551        let agg_expr = vec![
552            crate::functions::regr_slope_expr(yc.as_str(), xc.as_str())
553                .alias(format!("regr_slope({y_col}, {x_col})")),
554        ];
555        let lf = self.lazy_grouped.clone().agg(agg_expr);
556        let mut pl_df = lf.collect()?;
557        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
558        Ok(super::DataFrame::from_polars_with_options(
559            pl_df,
560            self.case_sensitive,
561        ))
562    }
563
564    /// Regression intercept (PySpark regr_intercept).
565    pub fn regr_intercept(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
566        let yc = self.resolve_column(y_col)?;
567        let xc = self.resolve_column(x_col)?;
568        let agg_expr = vec![
569            crate::functions::regr_intercept_expr(yc.as_str(), xc.as_str())
570                .alias(format!("regr_intercept({y_col}, {x_col})")),
571        ];
572        let lf = self.lazy_grouped.clone().agg(agg_expr);
573        let mut pl_df = lf.collect()?;
574        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
575        Ok(super::DataFrame::from_polars_with_options(
576            pl_df,
577            self.case_sensitive,
578        ))
579    }
580
581    /// Regression R-squared (PySpark regr_r2).
582    pub fn regr_r2(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
583        let yc = self.resolve_column(y_col)?;
584        let xc = self.resolve_column(x_col)?;
585        let agg_expr = vec![
586            crate::functions::regr_r2_expr(yc.as_str(), xc.as_str())
587                .alias(format!("regr_r2({y_col}, {x_col})")),
588        ];
589        let lf = self.lazy_grouped.clone().agg(agg_expr);
590        let mut pl_df = lf.collect()?;
591        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
592        Ok(super::DataFrame::from_polars_with_options(
593            pl_df,
594            self.case_sensitive,
595        ))
596    }
597
598    /// Regression sum (x - avg_x)^2 (PySpark regr_sxx).
599    pub fn regr_sxx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
600        let yc = self.resolve_column(y_col)?;
601        let xc = self.resolve_column(x_col)?;
602        let agg_expr = vec![
603            crate::functions::regr_sxx_expr(yc.as_str(), xc.as_str())
604                .alias(format!("regr_sxx({y_col}, {x_col})")),
605        ];
606        let lf = self.lazy_grouped.clone().agg(agg_expr);
607        let mut pl_df = lf.collect()?;
608        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
609        Ok(super::DataFrame::from_polars_with_options(
610            pl_df,
611            self.case_sensitive,
612        ))
613    }
614
615    /// Regression sum (y - avg_y)^2 (PySpark regr_syy).
616    pub fn regr_syy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
617        let yc = self.resolve_column(y_col)?;
618        let xc = self.resolve_column(x_col)?;
619        let agg_expr = vec![
620            crate::functions::regr_syy_expr(yc.as_str(), xc.as_str())
621                .alias(format!("regr_syy({y_col}, {x_col})")),
622        ];
623        let lf = self.lazy_grouped.clone().agg(agg_expr);
624        let mut pl_df = lf.collect()?;
625        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
626        Ok(super::DataFrame::from_polars_with_options(
627            pl_df,
628            self.case_sensitive,
629        ))
630    }
631
632    /// Regression sum (x - avg_x)(y - avg_y) (PySpark regr_sxy).
633    pub fn regr_sxy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
634        let yc = self.resolve_column(y_col)?;
635        let xc = self.resolve_column(x_col)?;
636        let agg_expr = vec![
637            crate::functions::regr_sxy_expr(yc.as_str(), xc.as_str())
638                .alias(format!("regr_sxy({y_col}, {x_col})")),
639        ];
640        let lf = self.lazy_grouped.clone().agg(agg_expr);
641        let mut pl_df = lf.collect()?;
642        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
643        Ok(super::DataFrame::from_polars_with_options(
644            pl_df,
645            self.case_sensitive,
646        ))
647    }
648
649    /// Kurtosis of a column in each group (PySpark kurtosis). Fisher definition, bias=true.
650    pub fn kurtosis(&self, column: &str) -> Result<DataFrame, PolarsError> {
651        use polars::prelude::*;
652        let c = self.resolve_column(column)?;
653        let agg_expr = vec![
654            col(c.as_str())
655                .cast(DataType::Float64)
656                .kurtosis(true, true)
657                .alias(format!("kurtosis({column})")),
658        ];
659        let lf = self.lazy_grouped.clone().agg(agg_expr);
660        let mut pl_df = lf.collect()?;
661        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
662        Ok(super::DataFrame::from_polars_with_options(
663            pl_df,
664            self.case_sensitive,
665        ))
666    }
667
668    /// Skewness of a column in each group (PySpark skewness). bias=true.
669    pub fn skewness(&self, column: &str) -> Result<DataFrame, PolarsError> {
670        use polars::prelude::*;
671        let c = self.resolve_column(column)?;
672        let agg_expr = vec![
673            col(c.as_str())
674                .cast(DataType::Float64)
675                .skew(true)
676                .alias(format!("skewness({column})")),
677        ];
678        let lf = self.lazy_grouped.clone().agg(agg_expr);
679        let mut pl_df = lf.collect()?;
680        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
681        Ok(super::DataFrame::from_polars_with_options(
682            pl_df,
683            self.case_sensitive,
684        ))
685    }
686
687    /// Apply multiple aggregations at once (generic agg method).
688    /// Duplicate output names are disambiguated with _1, _2, ... (PySpark parity, issue #368).
689    pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
690        let disambiguated = disambiguate_agg_output_names(aggregations);
691        let lf = self.lazy_grouped.clone().agg(disambiguated);
692        let mut pl_df = lf.collect()?;
693        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
694        Ok(super::DataFrame::from_polars_with_options(
695            pl_df,
696            self.case_sensitive,
697        ))
698    }
699
700    /// Apply multiple aggregations expressed as robin-sparkless Columns.
701    /// This is a convenience for downstream bindings that work purely with
702    /// `Column` instead of `polars::Expr`, and wraps the generic `agg` API.
703    pub fn agg_columns(&self, aggregations: Vec<Column>) -> Result<DataFrame, PolarsError> {
704        let exprs: Vec<Expr> = aggregations.into_iter().map(|c| c.into_expr()).collect();
705        self.agg(exprs)
706    }
707
708    /// Get grouping columns
709    pub fn grouping_columns(&self) -> &[String] {
710        &self.grouping_cols
711    }
712
713    /// Pivot a column for pivot-table aggregation (PySpark: groupBy(...).pivot(pivot_col).sum(value_col)).
714    /// Returns PivotedGroupedData; call .sum(column), .avg(column), etc. to run the aggregation.
715    pub fn pivot(&self, pivot_col: &str, values: Option<Vec<String>>) -> PivotedGroupedData {
716        PivotedGroupedData {
717            lf: self.lf.clone(),
718            grouping_cols: self.grouping_cols.clone(),
719            pivot_col: pivot_col.to_string(),
720            values,
721            case_sensitive: self.case_sensitive,
722        }
723    }
724}
725
726/// Result of GroupedData.pivot(pivot_col); has .sum(), .avg(), etc. (PySpark pivot table).
727pub struct PivotedGroupedData {
728    pub(crate) lf: LazyFrame,
729    pub(crate) grouping_cols: Vec<String>,
730    pub(crate) pivot_col: String,
731    pub(crate) values: Option<Vec<String>>,
732    pub(crate) case_sensitive: bool,
733}
734
735/// PySpark: pivot column names use string representation; null → "null".
736fn pivot_value_to_column_name(av: polars::prelude::AnyValue<'_>) -> String {
737    use polars::prelude::AnyValue;
738    match av {
739        AnyValue::Null => "null".to_string(),
740        AnyValue::String(s) => s.to_string(),
741        _ => av.to_string(),
742    }
743}
744
745fn pivot_values_from_lf(lf: &LazyFrame, pivot_col: &str) -> Result<Vec<String>, PolarsError> {
746    use polars::prelude::*;
747    let pl_df = lf
748        .clone()
749        .select([col(pivot_col)])
750        .unique(None, Default::default())
751        .collect()?;
752    let s = pl_df.column(pivot_col)?;
753    let mut out = Vec::with_capacity(s.len());
754    for i in 0..s.len() {
755        let av = s.get(i)?;
756        out.push(pivot_value_to_column_name(av));
757    }
758    // PySpark parity: deterministic column order when values not provided (lexicographic)
759    out.sort();
760    Ok(out)
761}
762
763impl PivotedGroupedData {
764    fn resolve_column(&self, name: &str) -> Result<String, PolarsError> {
765        let schema = self.lf.clone().collect_schema()?;
766        let names: Vec<String> = schema
767            .iter_names_and_dtypes()
768            .map(|(n, _)| n.to_string())
769            .collect();
770        if self.case_sensitive {
771            if names.iter().any(|n| n == name) {
772                return Ok(name.to_string());
773            }
774        } else {
775            let name_lower = name.to_lowercase();
776            for n in &names {
777                if n.to_lowercase() == name_lower {
778                    return Ok(n.clone());
779                }
780            }
781        }
782        let available = names.join(", ");
783        Err(PolarsError::ColumnNotFound(
784            format!(
785                "Column '{}' not found in pivot DataFrame. Available: [{}].",
786                name, available
787            )
788            .into(),
789        ))
790    }
791
792    fn pivot_values(&self) -> Result<Vec<String>, PolarsError> {
793        if let Some(ref v) = self.values {
794            return Ok(v.clone());
795        }
796        let resolved = self.resolve_column(&self.pivot_col)?;
797        pivot_values_from_lf(&self.lf, &resolved)
798    }
799
800    fn pivot_agg(
801        &self,
802        value_col: &str,
803        agg_fn: fn(Expr) -> Expr,
804    ) -> Result<DataFrame, PolarsError> {
805        use polars::prelude::*;
806        let pivot_resolved = self.resolve_column(&self.pivot_col)?;
807        let value_resolved = self.resolve_column(value_col)?;
808        let pivot_vals = self.pivot_values()?;
809        if pivot_vals.is_empty() {
810            let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
811            let lf = self.lf.clone().group_by(by).agg(vec![]);
812            let pl_df = lf.collect()?;
813            return Ok(super::DataFrame::from_polars_with_options(
814                pl_df,
815                self.case_sensitive,
816            ));
817        }
818        let mut agg_exprs: Vec<Expr> = Vec::with_capacity(pivot_vals.len());
819        use polars::prelude::DataType;
820        for v in &pivot_vals {
821            // PySpark: pivot_col can be any type; compare as string for column names. Null → is_null().
822            let pred = if v == "null" {
823                col(pivot_resolved.as_str()).is_null()
824            } else {
825                col(pivot_resolved.as_str())
826                    .cast(DataType::String)
827                    .eq(lit(v.as_str()))
828            };
829            let then_expr = col(value_resolved.as_str());
830            let expr = when(pred).then(then_expr).otherwise(lit(NULL));
831            // PySpark parity: pivot value with no matching rows → null (not 0)
832            let has_any = expr
833                .clone()
834                .is_not_null()
835                .cast(DataType::UInt32)
836                .sum()
837                .gt(lit(0));
838            let agg_expr = when(has_any)
839                .then(agg_fn(expr))
840                .otherwise(lit(NULL))
841                .alias(v.as_str());
842            agg_exprs.push(agg_expr);
843        }
844        let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
845        let lf = self.lf.clone().group_by(by).agg(agg_exprs);
846        let mut pl_df = lf.collect()?;
847        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
848        Ok(super::DataFrame::from_polars_with_options(
849            pl_df,
850            self.case_sensitive,
851        ))
852    }
853
854    /// Pivot then sum (PySpark: groupBy(...).pivot(...).sum(column)).
855    pub fn sum(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
856        self.pivot_agg(value_col, polars::prelude::Expr::sum)
857    }
858
859    /// Pivot then mean (PySpark: groupBy(...).pivot(...).avg(column)).
860    pub fn avg(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
861        self.pivot_agg(value_col, polars::prelude::Expr::mean)
862    }
863
864    /// Pivot then min (PySpark: groupBy(...).pivot(...).min(column)).
865    pub fn min(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
866        self.pivot_agg(value_col, polars::prelude::Expr::min)
867    }
868
869    /// Pivot then max (PySpark: groupBy(...).pivot(...).max(column)).
870    pub fn max(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
871        self.pivot_agg(value_col, polars::prelude::Expr::max)
872    }
873
874    /// Pivot then count (PySpark: groupBy(...).pivot(...).count()). Counts rows per group per pivot value.
875    pub fn count(&self) -> Result<DataFrame, PolarsError> {
876        use polars::prelude::*;
877        let pivot_vals = self.pivot_values()?;
878        if pivot_vals.is_empty() {
879            let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
880            let lf = self.lf.clone().group_by(by).agg(vec![]);
881            let pl_df = lf.collect()?;
882            return Ok(super::DataFrame::from_polars_with_options(
883                pl_df,
884                self.case_sensitive,
885            ));
886        }
887        let mut agg_exprs: Vec<Expr> = Vec::with_capacity(pivot_vals.len());
888        use polars::prelude::DataType;
889        let pivot_resolved = self.resolve_column(&self.pivot_col)?;
890        for v in &pivot_vals {
891            let pred = if v == "null" {
892                col(pivot_resolved.as_str()).is_null()
893            } else {
894                col(pivot_resolved.as_str())
895                    .cast(DataType::String)
896                    .eq(lit(v.as_str()))
897            };
898            let expr = when(pred).then(lit(1)).otherwise(lit(NULL));
899            let has_any = expr
900                .clone()
901                .is_not_null()
902                .cast(DataType::UInt32)
903                .sum()
904                .gt(lit(0));
905            let agg_expr = when(has_any)
906                .then(expr.sum())
907                .otherwise(lit(NULL))
908                .alias(v.as_str());
909            agg_exprs.push(agg_expr);
910        }
911        let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
912        let lf = self.lf.clone().group_by(by).agg(agg_exprs);
913        let mut pl_df = lf.collect()?;
914        pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
915        Ok(super::DataFrame::from_polars_with_options(
916            pl_df,
917            self.case_sensitive,
918        ))
919    }
920}
921
922/// Cube/rollup: multiple grouping sets then union (PySpark cube / rollup).
923pub struct CubeRollupData {
924    pub(super) lf: LazyFrame,
925    pub(super) grouping_cols: Vec<String>,
926    pub(super) case_sensitive: bool,
927    pub(super) is_cube: bool,
928}
929
930impl CubeRollupData {
931    /// Count rows per grouping set (PySpark cube/rollup .count()).
932    pub fn count(&self) -> Result<DataFrame, PolarsError> {
933        use polars::prelude::*;
934        self.agg(vec![len().alias("count")])
935    }
936
937    /// Run aggregation on each grouping set and union results. Missing keys become null.
938    /// Duplicate agg output names are disambiguated (issue #368).
939    pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
940        use polars::prelude::*;
941        let aggregations = disambiguate_agg_output_names(aggregations);
942        let subsets: Vec<Vec<String>> = if self.is_cube {
943            // All subsets of grouping_cols (2^n)
944            let n = self.grouping_cols.len();
945            (0..1 << n)
946                .map(|mask| {
947                    self.grouping_cols
948                        .iter()
949                        .enumerate()
950                        .filter(|(i, _)| (mask & (1 << i)) != 0)
951                        .map(|(_, c)| c.clone())
952                        .collect()
953                })
954                .collect()
955        } else {
956            // Prefixes: [all], [all-1], ..., []
957            (0..=self.grouping_cols.len())
958                .map(|len| self.grouping_cols[..len].to_vec())
959                .collect()
960        };
961
962        let schema = self.lf.clone().collect_schema()?;
963        let mut parts: Vec<PlDataFrame> = Vec::with_capacity(subsets.len());
964        for subset in subsets {
965            if subset.is_empty() {
966                // Single row: no grouping keys, one row of aggregates over full table
967                let lf = self.lf.clone().select(&aggregations);
968                let mut part = lf.collect()?;
969                let n = part.height();
970                for gc in &self.grouping_cols {
971                    let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
972                    let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
973                    part.with_column(null_series.into())?;
974                }
975                // Reorder to [grouping_cols..., agg_cols]
976                let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
977                for name in part.get_column_names() {
978                    if !self.grouping_cols.iter().any(|g| g == name) {
979                        order.push(name);
980                    }
981                }
982                part = part.select(order)?;
983                parts.push(part);
984            } else {
985                let grouped = self
986                    .lf
987                    .clone()
988                    .group_by(subset.iter().map(|s| col(s.as_str())).collect::<Vec<_>>());
989                let mut part = grouped.agg(aggregations.clone()).collect()?;
990                part = reorder_groupby_columns(&mut part, &subset)?;
991                let n = part.height();
992                for gc in &self.grouping_cols {
993                    if subset.iter().any(|s| s == gc) {
994                        continue;
995                    }
996                    let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
997                    let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
998                    part.with_column(null_series.into())?;
999                }
1000                let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
1001                for name in part.get_column_names() {
1002                    if !self.grouping_cols.iter().any(|g| g == name) {
1003                        order.push(name);
1004                    }
1005                }
1006                part = part.select(order)?;
1007                parts.push(part);
1008            }
1009        }
1010
1011        if parts.is_empty() {
1012            return Ok(super::DataFrame::from_polars_with_options(
1013                PlDataFrame::empty(),
1014                self.case_sensitive,
1015            ));
1016        }
1017        let order: Vec<String> = parts[0]
1018            .schema()
1019            .iter_names()
1020            .map(|s| s.to_string())
1021            .collect();
1022        for p in parts.iter_mut().skip(1) {
1023            *p = p.select(order.as_slice())?;
1024        }
1025        let lazy_frames: Vec<_> = parts.into_iter().map(|p| p.lazy()).collect();
1026        let out = polars::prelude::concat(lazy_frames, UnionArgs::default())?.collect()?;
1027        Ok(super::DataFrame::from_polars_with_options(
1028            out,
1029            self.case_sensitive,
1030        ))
1031    }
1032}
1033
1034fn null_series_for_dtype(name: &str, n: usize, dtype: &DataType) -> Result<Series, PolarsError> {
1035    let name = name.into();
1036    let s = match dtype {
1037        DataType::Int32 => Series::new(name, vec![None::<i32>; n]),
1038        DataType::Int64 => Series::new(name, vec![None::<i64>; n]),
1039        DataType::Float32 => Series::new(name, vec![None::<f32>; n]),
1040        DataType::Float64 => Series::new(name, vec![None::<f64>; n]),
1041        DataType::String => {
1042            let v: Vec<Option<String>> = (0..n).map(|_| None).collect();
1043            Series::new(name, v)
1044        }
1045        DataType::Boolean => Series::new(name, vec![None::<bool>; n]),
1046        DataType::Date => Series::new(name, vec![None::<i32>; n]).cast(dtype)?,
1047        DataType::Datetime(_, _) => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
1048        _ => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
1049    };
1050    Ok(s)
1051}
1052
1053/// Reorder columns after groupBy to match PySpark order: grouping columns first, then aggregations
1054pub(super) fn reorder_groupby_columns(
1055    pl_df: &mut PlDataFrame,
1056    grouping_cols: &[String],
1057) -> Result<PlDataFrame, PolarsError> {
1058    let all_cols: Vec<String> = pl_df
1059        .get_column_names()
1060        .iter()
1061        .map(|s| s.to_string())
1062        .collect();
1063    let mut reordered_cols: Vec<&str> = Vec::new();
1064    for gc in grouping_cols {
1065        if all_cols.iter().any(|c| c == gc) {
1066            reordered_cols.push(gc);
1067        }
1068    }
1069    for col_name in &all_cols {
1070        if !grouping_cols.iter().any(|gc| gc == col_name) {
1071            reordered_cols.push(col_name);
1072        }
1073    }
1074    if !reordered_cols.is_empty() && reordered_cols.len() == all_cols.len() {
1075        pl_df.select(reordered_cols)
1076    } else {
1077        Ok(pl_df.clone())
1078    }
1079}
1080
1081#[cfg(test)]
1082mod tests {
1083    use crate::{DataFrame, SparkSession, functions};
1084
1085    fn test_df() -> DataFrame {
1086        let spark = SparkSession::builder()
1087            .app_name("agg_tests")
1088            .get_or_create();
1089        let tuples = vec![
1090            (1i64, 10i64, "a".to_string()),
1091            (1i64, 20i64, "a".to_string()),
1092            (2i64, 30i64, "b".to_string()),
1093        ];
1094        spark
1095            .create_dataframe(tuples, vec!["k", "v", "label"])
1096            .unwrap()
1097    }
1098
1099    #[test]
1100    fn group_by_count_single_group() {
1101        let df = test_df();
1102        let grouped = df.group_by(vec!["k"]).unwrap();
1103        let out = grouped.count().unwrap();
1104        assert_eq!(out.count().unwrap(), 2);
1105        let cols = out.columns().unwrap();
1106        assert!(cols.contains(&"k".to_string()));
1107        assert!(cols.contains(&"count".to_string()));
1108    }
1109
1110    #[test]
1111    fn group_by_sum() {
1112        let df = test_df();
1113        let grouped = df.group_by(vec!["k"]).unwrap();
1114        let out = grouped.sum("v").unwrap();
1115        assert_eq!(out.count().unwrap(), 2);
1116        let cols = out.columns().unwrap();
1117        assert!(cols.iter().any(|c| c.starts_with("sum(")));
1118    }
1119
1120    #[test]
1121    fn group_by_empty_groups() {
1122        let spark = SparkSession::builder()
1123            .app_name("agg_tests")
1124            .get_or_create();
1125        let tuples: Vec<(i64, i64, String)> = vec![];
1126        let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
1127        let grouped = df.group_by(vec!["a"]).unwrap();
1128        let out = grouped.count().unwrap();
1129        assert_eq!(out.count().unwrap(), 0);
1130    }
1131
1132    #[test]
1133    fn group_by_agg_multi() {
1134        let df = test_df();
1135        let grouped = df.group_by(vec!["k"]).unwrap();
1136        let out = grouped
1137            .agg(vec![
1138                polars::prelude::len().alias("cnt"),
1139                polars::prelude::col("v").sum().alias("total"),
1140            ])
1141            .unwrap();
1142        assert_eq!(out.count().unwrap(), 2);
1143        let cols = out.columns().unwrap();
1144        assert!(cols.contains(&"k".to_string()));
1145        assert!(cols.contains(&"cnt".to_string()));
1146        assert!(cols.contains(&"total".to_string()));
1147    }
1148
1149    #[test]
1150    fn group_by_agg_columns_multi() {
1151        let df = test_df();
1152        let grouped = df.group_by(vec!["k"]).unwrap();
1153        let v_col = functions::col("v");
1154        let aggs = vec![functions::count(&v_col), functions::sum(&v_col)];
1155        let out = grouped.agg_columns(aggs).unwrap();
1156        assert_eq!(out.count().unwrap(), 2);
1157        let cols = out.columns().unwrap();
1158        assert!(cols.contains(&"k".to_string()));
1159        assert_eq!(cols.len(), 3);
1160    }
1161}