1use super::DataFrame;
4use polars::prelude::{
5 col, len, lit, when, DataFrame as PlDataFrame, DataType, Expr, LazyGroupBy, NamedFrom,
6 PolarsError, Series,
7};
8
9pub struct GroupedData {
12 pub(super) lazy_grouped: LazyGroupBy,
13 pub(super) grouping_cols: Vec<String>,
14 pub(super) case_sensitive: bool,
15}
16
17impl GroupedData {
18 pub fn count(&self) -> Result<DataFrame, PolarsError> {
20 use polars::prelude::*;
21 let agg_expr = vec![len().alias("count")];
22 let lf = self.lazy_grouped.clone().agg(agg_expr);
23 let mut pl_df = lf.collect()?;
24 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
25 Ok(super::DataFrame::from_polars_with_options(
26 pl_df,
27 self.case_sensitive,
28 ))
29 }
30
31 pub fn sum(&self, column: &str) -> Result<DataFrame, PolarsError> {
33 use polars::prelude::*;
34 let agg_expr = vec![col(column).sum().alias(format!("sum({column})"))];
35 let lf = self.lazy_grouped.clone().agg(agg_expr);
36 let mut pl_df = lf.collect()?;
37 let all_cols: Vec<String> = pl_df
38 .get_column_names()
39 .iter()
40 .map(|s| s.to_string())
41 .collect();
42 let grouping_cols: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
43 let mut reordered_cols: Vec<&str> = Vec::new();
44 for gc in &grouping_cols {
45 if all_cols.iter().any(|c| c == gc) {
46 reordered_cols.push(gc);
47 }
48 }
49 for col_name in &all_cols {
50 if !grouping_cols.iter().any(|gc| *gc == col_name) {
51 reordered_cols.push(col_name);
52 }
53 }
54 if !reordered_cols.is_empty() {
55 pl_df = pl_df.select(reordered_cols)?;
56 }
57 Ok(super::DataFrame::from_polars_with_options(
58 pl_df,
59 self.case_sensitive,
60 ))
61 }
62
63 pub fn avg(&self, column: &str) -> Result<DataFrame, PolarsError> {
65 use polars::prelude::*;
66 let agg_expr = vec![col(column).mean().alias(format!("avg({column})"))];
67 let lf = self.lazy_grouped.clone().agg(agg_expr);
68 let mut pl_df = lf.collect()?;
69 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
70 Ok(super::DataFrame::from_polars_with_options(
71 pl_df,
72 self.case_sensitive,
73 ))
74 }
75
76 pub fn min(&self, column: &str) -> Result<DataFrame, PolarsError> {
78 use polars::prelude::*;
79 let agg_expr = vec![col(column).min().alias(format!("min({column})"))];
80 let lf = self.lazy_grouped.clone().agg(agg_expr);
81 let mut pl_df = lf.collect()?;
82 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
83 Ok(super::DataFrame::from_polars_with_options(
84 pl_df,
85 self.case_sensitive,
86 ))
87 }
88
89 pub fn max(&self, column: &str) -> Result<DataFrame, PolarsError> {
91 use polars::prelude::*;
92 let agg_expr = vec![col(column).max().alias(format!("max({column})"))];
93 let lf = self.lazy_grouped.clone().agg(agg_expr);
94 let mut pl_df = lf.collect()?;
95 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
96 Ok(super::DataFrame::from_polars_with_options(
97 pl_df,
98 self.case_sensitive,
99 ))
100 }
101
102 pub fn first(&self, column: &str) -> Result<DataFrame, PolarsError> {
104 use polars::prelude::*;
105 let agg_expr = vec![col(column).first().alias(format!("first({column})"))];
106 let lf = self.lazy_grouped.clone().agg(agg_expr);
107 let mut pl_df = lf.collect()?;
108 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
109 Ok(super::DataFrame::from_polars_with_options(
110 pl_df,
111 self.case_sensitive,
112 ))
113 }
114
115 pub fn last(&self, column: &str) -> Result<DataFrame, PolarsError> {
117 use polars::prelude::*;
118 let agg_expr = vec![col(column).last().alias(format!("last({column})"))];
119 let lf = self.lazy_grouped.clone().agg(agg_expr);
120 let mut pl_df = lf.collect()?;
121 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
122 Ok(super::DataFrame::from_polars_with_options(
123 pl_df,
124 self.case_sensitive,
125 ))
126 }
127
128 pub fn approx_count_distinct(&self, column: &str) -> Result<DataFrame, PolarsError> {
130 use polars::prelude::{col, DataType};
131 let agg_expr = vec![col(column)
132 .n_unique()
133 .cast(DataType::Int64)
134 .alias(format!("approx_count_distinct({column})"))];
135 let lf = self.lazy_grouped.clone().agg(agg_expr);
136 let mut pl_df = lf.collect()?;
137 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
138 Ok(super::DataFrame::from_polars_with_options(
139 pl_df,
140 self.case_sensitive,
141 ))
142 }
143
144 pub fn any_value(&self, column: &str) -> Result<DataFrame, PolarsError> {
146 use polars::prelude::*;
147 let agg_expr = vec![col(column).first().alias(format!("any_value({column})"))];
148 let lf = self.lazy_grouped.clone().agg(agg_expr);
149 let mut pl_df = lf.collect()?;
150 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
151 Ok(super::DataFrame::from_polars_with_options(
152 pl_df,
153 self.case_sensitive,
154 ))
155 }
156
157 pub fn bool_and(&self, column: &str) -> Result<DataFrame, PolarsError> {
159 use polars::prelude::*;
160 let agg_expr = vec![col(column).all(true).alias(format!("bool_and({column})"))];
161 let lf = self.lazy_grouped.clone().agg(agg_expr);
162 let mut pl_df = lf.collect()?;
163 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
164 Ok(super::DataFrame::from_polars_with_options(
165 pl_df,
166 self.case_sensitive,
167 ))
168 }
169
170 pub fn bool_or(&self, column: &str) -> Result<DataFrame, PolarsError> {
172 use polars::prelude::*;
173 let agg_expr = vec![col(column).any(true).alias(format!("bool_or({column})"))];
174 let lf = self.lazy_grouped.clone().agg(agg_expr);
175 let mut pl_df = lf.collect()?;
176 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
177 Ok(super::DataFrame::from_polars_with_options(
178 pl_df,
179 self.case_sensitive,
180 ))
181 }
182
183 pub fn product(&self, column: &str) -> Result<DataFrame, PolarsError> {
185 use polars::prelude::*;
186 let agg_expr = vec![col(column).product().alias(format!("product({column})"))];
187 let lf = self.lazy_grouped.clone().agg(agg_expr);
188 let mut pl_df = lf.collect()?;
189 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
190 Ok(super::DataFrame::from_polars_with_options(
191 pl_df,
192 self.case_sensitive,
193 ))
194 }
195
196 pub fn collect_list(&self, column: &str) -> Result<DataFrame, PolarsError> {
198 use polars::prelude::*;
199 let agg_expr = vec![col(column)
200 .implode()
201 .alias(format!("collect_list({column})"))];
202 let lf = self.lazy_grouped.clone().agg(agg_expr);
203 let mut pl_df = lf.collect()?;
204 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
205 Ok(super::DataFrame::from_polars_with_options(
206 pl_df,
207 self.case_sensitive,
208 ))
209 }
210
211 pub fn collect_set(&self, column: &str) -> Result<DataFrame, PolarsError> {
213 use polars::prelude::*;
214 let agg_expr = vec![col(column)
215 .unique()
216 .implode()
217 .alias(format!("collect_set({column})"))];
218 let lf = self.lazy_grouped.clone().agg(agg_expr);
219 let mut pl_df = lf.collect()?;
220 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
221 Ok(super::DataFrame::from_polars_with_options(
222 pl_df,
223 self.case_sensitive,
224 ))
225 }
226
227 pub fn count_if(&self, column: &str) -> Result<DataFrame, PolarsError> {
229 use polars::prelude::*;
230 let agg_expr = vec![col(column)
231 .cast(DataType::Int64)
232 .sum()
233 .alias(format!("count_if({column})"))];
234 let lf = self.lazy_grouped.clone().agg(agg_expr);
235 let mut pl_df = lf.collect()?;
236 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
237 Ok(super::DataFrame::from_polars_with_options(
238 pl_df,
239 self.case_sensitive,
240 ))
241 }
242
243 pub fn percentile(&self, column: &str, p: f64) -> Result<DataFrame, PolarsError> {
245 use polars::prelude::*;
246 let agg_expr = vec![col(column)
247 .quantile(lit(p), QuantileMethod::Linear)
248 .alias(format!("percentile({column}, {p})"))];
249 let lf = self.lazy_grouped.clone().agg(agg_expr);
250 let mut pl_df = lf.collect()?;
251 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
252 Ok(super::DataFrame::from_polars_with_options(
253 pl_df,
254 self.case_sensitive,
255 ))
256 }
257
258 pub fn max_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
260 use polars::prelude::*;
261 let st = as_struct(vec![
262 col(ord_col).alias("_ord"),
263 col(value_col).alias("_val"),
264 ]);
265 let agg_expr = vec![st
266 .sort(SortOptions::default().with_order_descending(true))
267 .first()
268 .struct_()
269 .field_by_name("_val")
270 .alias(format!("max_by({value_col}, {ord_col})"))];
271 let lf = self.lazy_grouped.clone().agg(agg_expr);
272 let mut pl_df = lf.collect()?;
273 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
274 Ok(super::DataFrame::from_polars_with_options(
275 pl_df,
276 self.case_sensitive,
277 ))
278 }
279
280 pub fn min_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
282 use polars::prelude::*;
283 let st = as_struct(vec![
284 col(ord_col).alias("_ord"),
285 col(value_col).alias("_val"),
286 ]);
287 let agg_expr = vec![st
288 .sort(SortOptions::default())
289 .first()
290 .struct_()
291 .field_by_name("_val")
292 .alias(format!("min_by({value_col}, {ord_col})"))];
293 let lf = self.lazy_grouped.clone().agg(agg_expr);
294 let mut pl_df = lf.collect()?;
295 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
296 Ok(super::DataFrame::from_polars_with_options(
297 pl_df,
298 self.case_sensitive,
299 ))
300 }
301
302 pub fn covar_pop(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
304 use polars::prelude::DataType;
305 let c1 = col(col1).cast(DataType::Float64);
306 let c2 = col(col2).cast(DataType::Float64);
307 let n = len().cast(DataType::Float64);
308 let sum_ab = (c1.clone() * c2.clone()).sum();
309 let sum_a = col(col1).sum().cast(DataType::Float64);
310 let sum_b = col(col2).sum().cast(DataType::Float64);
311 let cov = (sum_ab - sum_a * sum_b / n.clone()) / n;
312 let agg_expr = vec![cov.alias(format!("covar_pop({col1}, {col2})"))];
313 let lf = self.lazy_grouped.clone().agg(agg_expr);
314 let mut pl_df = lf.collect()?;
315 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
316 Ok(super::DataFrame::from_polars_with_options(
317 pl_df,
318 self.case_sensitive,
319 ))
320 }
321
322 pub fn covar_samp(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
324 use polars::prelude::DataType;
325 let c1 = col(col1).cast(DataType::Float64);
326 let c2 = col(col2).cast(DataType::Float64);
327 let n = len().cast(DataType::Float64);
328 let sum_ab = (c1.clone() * c2.clone()).sum();
329 let sum_a = col(col1).sum().cast(DataType::Float64);
330 let sum_b = col(col2).sum().cast(DataType::Float64);
331 let cov = when(len().gt(lit(1)))
332 .then((sum_ab - sum_a * sum_b / n.clone()) / (len() - lit(1)).cast(DataType::Float64))
333 .otherwise(lit(f64::NAN));
334 let agg_expr = vec![cov.alias(format!("covar_samp({col1}, {col2})"))];
335 let lf = self.lazy_grouped.clone().agg(agg_expr);
336 let mut pl_df = lf.collect()?;
337 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
338 Ok(super::DataFrame::from_polars_with_options(
339 pl_df,
340 self.case_sensitive,
341 ))
342 }
343
344 pub fn corr(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
346 use polars::prelude::DataType;
347 let c1 = col(col1).cast(DataType::Float64);
348 let c2 = col(col2).cast(DataType::Float64);
349 let n = len().cast(DataType::Float64);
350 let n1 = (len() - lit(1)).cast(DataType::Float64);
351 let sum_ab = (c1.clone() * c2.clone()).sum();
352 let sum_a = col(col1).sum().cast(DataType::Float64);
353 let sum_b = col(col2).sum().cast(DataType::Float64);
354 let sum_a2 = (c1.clone() * c1).sum();
355 let sum_b2 = (c2.clone() * c2).sum();
356 let cov_samp = (sum_ab - sum_a.clone() * sum_b.clone() / n.clone()) / n1.clone();
357 let var_a = (sum_a2 - sum_a.clone() * sum_a / n.clone()) / n1.clone();
358 let var_b = (sum_b2 - sum_b.clone() * sum_b / n.clone()) / n1.clone();
359 let std_a = var_a.sqrt();
360 let std_b = var_b.sqrt();
361 let corr_expr = when(len().gt(lit(1)))
362 .then(cov_samp / (std_a * std_b))
363 .otherwise(lit(f64::NAN));
364 let agg_expr = vec![corr_expr.alias(format!("corr({col1}, {col2})"))];
365 let lf = self.lazy_grouped.clone().agg(agg_expr);
366 let mut pl_df = lf.collect()?;
367 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
368 Ok(super::DataFrame::from_polars_with_options(
369 pl_df,
370 self.case_sensitive,
371 ))
372 }
373
374 pub fn regr_count(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
376 let agg_expr = vec![crate::functions::regr_count_expr(y_col, x_col)
377 .alias(format!("regr_count({y_col}, {x_col})"))];
378 let lf = self.lazy_grouped.clone().agg(agg_expr);
379 let mut pl_df = lf.collect()?;
380 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
381 Ok(super::DataFrame::from_polars_with_options(
382 pl_df,
383 self.case_sensitive,
384 ))
385 }
386
387 pub fn regr_avgx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
389 let agg_expr = vec![crate::functions::regr_avgx_expr(y_col, x_col)
390 .alias(format!("regr_avgx({y_col}, {x_col})"))];
391 let lf = self.lazy_grouped.clone().agg(agg_expr);
392 let mut pl_df = lf.collect()?;
393 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
394 Ok(super::DataFrame::from_polars_with_options(
395 pl_df,
396 self.case_sensitive,
397 ))
398 }
399
400 pub fn regr_avgy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
402 let agg_expr = vec![crate::functions::regr_avgy_expr(y_col, x_col)
403 .alias(format!("regr_avgy({y_col}, {x_col})"))];
404 let lf = self.lazy_grouped.clone().agg(agg_expr);
405 let mut pl_df = lf.collect()?;
406 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
407 Ok(super::DataFrame::from_polars_with_options(
408 pl_df,
409 self.case_sensitive,
410 ))
411 }
412
413 pub fn regr_slope(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
415 let agg_expr = vec![crate::functions::regr_slope_expr(y_col, x_col)
416 .alias(format!("regr_slope({y_col}, {x_col})"))];
417 let lf = self.lazy_grouped.clone().agg(agg_expr);
418 let mut pl_df = lf.collect()?;
419 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
420 Ok(super::DataFrame::from_polars_with_options(
421 pl_df,
422 self.case_sensitive,
423 ))
424 }
425
426 pub fn regr_intercept(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
428 let agg_expr = vec![crate::functions::regr_intercept_expr(y_col, x_col)
429 .alias(format!("regr_intercept({y_col}, {x_col})"))];
430 let lf = self.lazy_grouped.clone().agg(agg_expr);
431 let mut pl_df = lf.collect()?;
432 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
433 Ok(super::DataFrame::from_polars_with_options(
434 pl_df,
435 self.case_sensitive,
436 ))
437 }
438
439 pub fn regr_r2(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
441 let agg_expr = vec![crate::functions::regr_r2_expr(y_col, x_col)
442 .alias(format!("regr_r2({y_col}, {x_col})"))];
443 let lf = self.lazy_grouped.clone().agg(agg_expr);
444 let mut pl_df = lf.collect()?;
445 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
446 Ok(super::DataFrame::from_polars_with_options(
447 pl_df,
448 self.case_sensitive,
449 ))
450 }
451
452 pub fn regr_sxx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
454 let agg_expr = vec![crate::functions::regr_sxx_expr(y_col, x_col)
455 .alias(format!("regr_sxx({y_col}, {x_col})"))];
456 let lf = self.lazy_grouped.clone().agg(agg_expr);
457 let mut pl_df = lf.collect()?;
458 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
459 Ok(super::DataFrame::from_polars_with_options(
460 pl_df,
461 self.case_sensitive,
462 ))
463 }
464
465 pub fn regr_syy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
467 let agg_expr = vec![crate::functions::regr_syy_expr(y_col, x_col)
468 .alias(format!("regr_syy({y_col}, {x_col})"))];
469 let lf = self.lazy_grouped.clone().agg(agg_expr);
470 let mut pl_df = lf.collect()?;
471 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
472 Ok(super::DataFrame::from_polars_with_options(
473 pl_df,
474 self.case_sensitive,
475 ))
476 }
477
478 pub fn regr_sxy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
480 let agg_expr = vec![crate::functions::regr_sxy_expr(y_col, x_col)
481 .alias(format!("regr_sxy({y_col}, {x_col})"))];
482 let lf = self.lazy_grouped.clone().agg(agg_expr);
483 let mut pl_df = lf.collect()?;
484 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
485 Ok(super::DataFrame::from_polars_with_options(
486 pl_df,
487 self.case_sensitive,
488 ))
489 }
490
491 pub fn kurtosis(&self, column: &str) -> Result<DataFrame, PolarsError> {
493 use polars::prelude::*;
494 let agg_expr = vec![col(column)
495 .cast(DataType::Float64)
496 .kurtosis(true, true)
497 .alias(format!("kurtosis({column})"))];
498 let lf = self.lazy_grouped.clone().agg(agg_expr);
499 let mut pl_df = lf.collect()?;
500 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
501 Ok(super::DataFrame::from_polars_with_options(
502 pl_df,
503 self.case_sensitive,
504 ))
505 }
506
507 pub fn skewness(&self, column: &str) -> Result<DataFrame, PolarsError> {
509 use polars::prelude::*;
510 let agg_expr = vec![col(column)
511 .cast(DataType::Float64)
512 .skew(true)
513 .alias(format!("skewness({column})"))];
514 let lf = self.lazy_grouped.clone().agg(agg_expr);
515 let mut pl_df = lf.collect()?;
516 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
517 Ok(super::DataFrame::from_polars_with_options(
518 pl_df,
519 self.case_sensitive,
520 ))
521 }
522
523 pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
525 let lf = self.lazy_grouped.clone().agg(aggregations);
526 let mut pl_df = lf.collect()?;
527 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
528 Ok(super::DataFrame::from_polars_with_options(
529 pl_df,
530 self.case_sensitive,
531 ))
532 }
533
534 pub fn grouping_columns(&self) -> &[String] {
536 &self.grouping_cols
537 }
538}
539
540pub struct CubeRollupData {
542 pub(super) df: PlDataFrame,
543 pub(super) grouping_cols: Vec<String>,
544 pub(super) case_sensitive: bool,
545 pub(super) is_cube: bool,
546}
547
548impl CubeRollupData {
549 pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
551 use polars::prelude::*;
552 let subsets: Vec<Vec<String>> = if self.is_cube {
553 let n = self.grouping_cols.len();
555 (0..1 << n)
556 .map(|mask| {
557 self.grouping_cols
558 .iter()
559 .enumerate()
560 .filter(|(i, _)| (mask & (1 << i)) != 0)
561 .map(|(_, c)| c.clone())
562 .collect()
563 })
564 .collect()
565 } else {
566 (0..=self.grouping_cols.len())
568 .map(|len| self.grouping_cols[..len].to_vec())
569 .collect()
570 };
571
572 let schema = self.df.schema();
573 let mut parts: Vec<PlDataFrame> = Vec::with_capacity(subsets.len());
574 for subset in subsets {
575 if subset.is_empty() {
576 let lf = self.df.clone().lazy().select(aggregations.clone());
578 let mut part = lf.collect()?;
579 let n = part.height();
580 for gc in &self.grouping_cols {
581 let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
582 let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
583 part.with_column(null_series)?;
584 }
585 let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
587 for name in part.get_column_names() {
588 if !self.grouping_cols.iter().any(|g| g == name) {
589 order.push(name);
590 }
591 }
592 part = part.select(order)?;
593 parts.push(part);
594 } else {
595 let grouped = self
596 .df
597 .clone()
598 .lazy()
599 .group_by(subset.iter().map(|s| col(s.as_str())).collect::<Vec<_>>());
600 let mut part = grouped.agg(aggregations.clone()).collect()?;
601 part = reorder_groupby_columns(&mut part, &subset)?;
602 let n = part.height();
603 for gc in &self.grouping_cols {
604 if subset.iter().any(|s| s == gc) {
605 continue;
606 }
607 let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
608 let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
609 part.with_column(null_series)?;
610 }
611 let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
612 for name in part.get_column_names() {
613 if !self.grouping_cols.iter().any(|g| g == name) {
614 order.push(name);
615 }
616 }
617 part = part.select(order)?;
618 parts.push(part);
619 }
620 }
621
622 if parts.is_empty() {
623 return Ok(super::DataFrame::from_polars_with_options(
624 PlDataFrame::empty(),
625 self.case_sensitive,
626 ));
627 }
628 let first_schema = parts[0].schema();
629 let order: Vec<&str> = first_schema.iter_names().map(|s| s.as_str()).collect();
630 for p in parts.iter_mut().skip(1) {
631 *p = p.select(order.clone())?;
632 }
633 let lazy_frames: Vec<_> = parts.into_iter().map(|p| p.lazy()).collect();
634 let out = polars::prelude::concat(lazy_frames, UnionArgs::default())?.collect()?;
635 Ok(super::DataFrame::from_polars_with_options(
636 out,
637 self.case_sensitive,
638 ))
639 }
640}
641
642fn null_series_for_dtype(name: &str, n: usize, dtype: &DataType) -> Result<Series, PolarsError> {
643 let name = name.into();
644 let s = match dtype {
645 DataType::Int32 => Series::new(name, vec![None::<i32>; n]),
646 DataType::Int64 => Series::new(name, vec![None::<i64>; n]),
647 DataType::Float32 => Series::new(name, vec![None::<f32>; n]),
648 DataType::Float64 => Series::new(name, vec![None::<f64>; n]),
649 DataType::String => {
650 let v: Vec<Option<String>> = (0..n).map(|_| None).collect();
651 Series::new(name, v)
652 }
653 DataType::Boolean => Series::new(name, vec![None::<bool>; n]),
654 DataType::Date => Series::new(name, vec![None::<i32>; n]).cast(dtype)?,
655 DataType::Datetime(_, _) => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
656 _ => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
657 };
658 Ok(s)
659}
660
661pub(super) fn reorder_groupby_columns(
663 pl_df: &mut PlDataFrame,
664 grouping_cols: &[String],
665) -> Result<PlDataFrame, PolarsError> {
666 let all_cols: Vec<String> = pl_df
667 .get_column_names()
668 .iter()
669 .map(|s| s.to_string())
670 .collect();
671 let mut reordered_cols: Vec<&str> = Vec::new();
672 for gc in grouping_cols {
673 if all_cols.iter().any(|c| c == gc) {
674 reordered_cols.push(gc);
675 }
676 }
677 for col_name in &all_cols {
678 if !grouping_cols.iter().any(|gc| gc == col_name) {
679 reordered_cols.push(col_name);
680 }
681 }
682 if !reordered_cols.is_empty() && reordered_cols.len() == all_cols.len() {
683 pl_df.select(reordered_cols)
684 } else {
685 Ok(pl_df.clone())
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use crate::{DataFrame, SparkSession};
692
693 fn test_df() -> DataFrame {
694 let spark = SparkSession::builder()
695 .app_name("agg_tests")
696 .get_or_create();
697 let tuples = vec![
698 (1i64, 10i64, "a".to_string()),
699 (1i64, 20i64, "a".to_string()),
700 (2i64, 30i64, "b".to_string()),
701 ];
702 spark
703 .create_dataframe(tuples, vec!["k", "v", "label"])
704 .unwrap()
705 }
706
707 #[test]
708 fn group_by_count_single_group() {
709 let df = test_df();
710 let grouped = df.group_by(vec!["k"]).unwrap();
711 let out = grouped.count().unwrap();
712 assert_eq!(out.count().unwrap(), 2);
713 let cols = out.columns().unwrap();
714 assert!(cols.contains(&"k".to_string()));
715 assert!(cols.contains(&"count".to_string()));
716 }
717
718 #[test]
719 fn group_by_sum() {
720 let df = test_df();
721 let grouped = df.group_by(vec!["k"]).unwrap();
722 let out = grouped.sum("v").unwrap();
723 assert_eq!(out.count().unwrap(), 2);
724 let cols = out.columns().unwrap();
725 assert!(cols.iter().any(|c| c.starts_with("sum(")));
726 }
727
728 #[test]
729 fn group_by_empty_groups() {
730 let spark = SparkSession::builder()
731 .app_name("agg_tests")
732 .get_or_create();
733 let tuples: Vec<(i64, i64, String)> = vec![];
734 let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
735 let grouped = df.group_by(vec!["a"]).unwrap();
736 let out = grouped.count().unwrap();
737 assert_eq!(out.count().unwrap(), 0);
738 }
739
740 #[test]
741 fn group_by_agg_multi() {
742 use polars::prelude::*;
743 let df = test_df();
744 let grouped = df.group_by(vec!["k"]).unwrap();
745 let out = grouped
746 .agg(vec![len().alias("cnt"), col("v").sum().alias("total")])
747 .unwrap();
748 assert_eq!(out.count().unwrap(), 2);
749 let cols = out.columns().unwrap();
750 assert!(cols.contains(&"k".to_string()));
751 assert!(cols.contains(&"cnt".to_string()));
752 assert!(cols.contains(&"total".to_string()));
753 }
754}