1use super::DataFrame;
4use polars::prelude::{
5 col, len, lit, when, DataFrame as PlDataFrame, DataType, Expr, LazyGroupBy, NamedFrom,
6 PolarsError, Series,
7};
8
9pub struct GroupedData {
12 #[cfg_attr(not(feature = "pyo3"), allow(dead_code))]
16 pub(crate) df: PlDataFrame,
17 pub(crate) lazy_grouped: LazyGroupBy,
18 pub(crate) grouping_cols: Vec<String>,
19 pub(crate) case_sensitive: bool,
20}
21
22impl GroupedData {
23 pub fn count(&self) -> Result<DataFrame, PolarsError> {
25 use polars::prelude::*;
26 let agg_expr = vec![len().alias("count")];
27 let lf = self.lazy_grouped.clone().agg(agg_expr);
28 let mut pl_df = lf.collect()?;
29 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
30 Ok(super::DataFrame::from_polars_with_options(
31 pl_df,
32 self.case_sensitive,
33 ))
34 }
35
36 pub fn sum(&self, column: &str) -> Result<DataFrame, PolarsError> {
38 use polars::prelude::*;
39 let agg_expr = vec![col(column).sum().alias(format!("sum({column})"))];
40 let lf = self.lazy_grouped.clone().agg(agg_expr);
41 let mut pl_df = lf.collect()?;
42 let all_cols: Vec<String> = pl_df
43 .get_column_names()
44 .iter()
45 .map(|s| s.to_string())
46 .collect();
47 let grouping_cols: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
48 let mut reordered_cols: Vec<&str> = Vec::new();
49 for gc in &grouping_cols {
50 if all_cols.iter().any(|c| c == gc) {
51 reordered_cols.push(gc);
52 }
53 }
54 for col_name in &all_cols {
55 if !grouping_cols.iter().any(|gc| *gc == col_name) {
56 reordered_cols.push(col_name);
57 }
58 }
59 if !reordered_cols.is_empty() {
60 pl_df = pl_df.select(reordered_cols)?;
61 }
62 Ok(super::DataFrame::from_polars_with_options(
63 pl_df,
64 self.case_sensitive,
65 ))
66 }
67
68 pub fn avg(&self, column: &str) -> Result<DataFrame, PolarsError> {
70 use polars::prelude::*;
71 let agg_expr = vec![col(column).mean().alias(format!("avg({column})"))];
72 let lf = self.lazy_grouped.clone().agg(agg_expr);
73 let mut pl_df = lf.collect()?;
74 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
75 Ok(super::DataFrame::from_polars_with_options(
76 pl_df,
77 self.case_sensitive,
78 ))
79 }
80
81 pub fn min(&self, column: &str) -> Result<DataFrame, PolarsError> {
83 use polars::prelude::*;
84 let agg_expr = vec![col(column).min().alias(format!("min({column})"))];
85 let lf = self.lazy_grouped.clone().agg(agg_expr);
86 let mut pl_df = lf.collect()?;
87 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
88 Ok(super::DataFrame::from_polars_with_options(
89 pl_df,
90 self.case_sensitive,
91 ))
92 }
93
94 pub fn max(&self, column: &str) -> Result<DataFrame, PolarsError> {
96 use polars::prelude::*;
97 let agg_expr = vec![col(column).max().alias(format!("max({column})"))];
98 let lf = self.lazy_grouped.clone().agg(agg_expr);
99 let mut pl_df = lf.collect()?;
100 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
101 Ok(super::DataFrame::from_polars_with_options(
102 pl_df,
103 self.case_sensitive,
104 ))
105 }
106
107 pub fn first(&self, column: &str) -> Result<DataFrame, PolarsError> {
109 use polars::prelude::*;
110 let agg_expr = vec![col(column).first().alias(format!("first({column})"))];
111 let lf = self.lazy_grouped.clone().agg(agg_expr);
112 let mut pl_df = lf.collect()?;
113 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
114 Ok(super::DataFrame::from_polars_with_options(
115 pl_df,
116 self.case_sensitive,
117 ))
118 }
119
120 pub fn last(&self, column: &str) -> Result<DataFrame, PolarsError> {
122 use polars::prelude::*;
123 let agg_expr = vec![col(column).last().alias(format!("last({column})"))];
124 let lf = self.lazy_grouped.clone().agg(agg_expr);
125 let mut pl_df = lf.collect()?;
126 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
127 Ok(super::DataFrame::from_polars_with_options(
128 pl_df,
129 self.case_sensitive,
130 ))
131 }
132
133 pub fn approx_count_distinct(&self, column: &str) -> Result<DataFrame, PolarsError> {
135 use polars::prelude::{col, DataType};
136 let agg_expr = vec![col(column)
137 .n_unique()
138 .cast(DataType::Int64)
139 .alias(format!("approx_count_distinct({column})"))];
140 let lf = self.lazy_grouped.clone().agg(agg_expr);
141 let mut pl_df = lf.collect()?;
142 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
143 Ok(super::DataFrame::from_polars_with_options(
144 pl_df,
145 self.case_sensitive,
146 ))
147 }
148
149 pub fn any_value(&self, column: &str) -> Result<DataFrame, PolarsError> {
151 use polars::prelude::*;
152 let agg_expr = vec![col(column).first().alias(format!("any_value({column})"))];
153 let lf = self.lazy_grouped.clone().agg(agg_expr);
154 let mut pl_df = lf.collect()?;
155 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
156 Ok(super::DataFrame::from_polars_with_options(
157 pl_df,
158 self.case_sensitive,
159 ))
160 }
161
162 pub fn bool_and(&self, column: &str) -> Result<DataFrame, PolarsError> {
164 use polars::prelude::*;
165 let agg_expr = vec![col(column).all(true).alias(format!("bool_and({column})"))];
166 let lf = self.lazy_grouped.clone().agg(agg_expr);
167 let mut pl_df = lf.collect()?;
168 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
169 Ok(super::DataFrame::from_polars_with_options(
170 pl_df,
171 self.case_sensitive,
172 ))
173 }
174
175 pub fn bool_or(&self, column: &str) -> Result<DataFrame, PolarsError> {
177 use polars::prelude::*;
178 let agg_expr = vec![col(column).any(true).alias(format!("bool_or({column})"))];
179 let lf = self.lazy_grouped.clone().agg(agg_expr);
180 let mut pl_df = lf.collect()?;
181 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
182 Ok(super::DataFrame::from_polars_with_options(
183 pl_df,
184 self.case_sensitive,
185 ))
186 }
187
188 pub fn product(&self, column: &str) -> Result<DataFrame, PolarsError> {
190 use polars::prelude::*;
191 let agg_expr = vec![col(column).product().alias(format!("product({column})"))];
192 let lf = self.lazy_grouped.clone().agg(agg_expr);
193 let mut pl_df = lf.collect()?;
194 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
195 Ok(super::DataFrame::from_polars_with_options(
196 pl_df,
197 self.case_sensitive,
198 ))
199 }
200
201 pub fn collect_list(&self, column: &str) -> Result<DataFrame, PolarsError> {
203 use polars::prelude::*;
204 let agg_expr = vec![col(column)
205 .implode()
206 .alias(format!("collect_list({column})"))];
207 let lf = self.lazy_grouped.clone().agg(agg_expr);
208 let mut pl_df = lf.collect()?;
209 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
210 Ok(super::DataFrame::from_polars_with_options(
211 pl_df,
212 self.case_sensitive,
213 ))
214 }
215
216 pub fn collect_set(&self, column: &str) -> Result<DataFrame, PolarsError> {
218 use polars::prelude::*;
219 let agg_expr = vec![col(column)
220 .unique()
221 .implode()
222 .alias(format!("collect_set({column})"))];
223 let lf = self.lazy_grouped.clone().agg(agg_expr);
224 let mut pl_df = lf.collect()?;
225 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
226 Ok(super::DataFrame::from_polars_with_options(
227 pl_df,
228 self.case_sensitive,
229 ))
230 }
231
232 pub fn count_if(&self, column: &str) -> Result<DataFrame, PolarsError> {
234 use polars::prelude::*;
235 let agg_expr = vec![col(column)
236 .cast(DataType::Int64)
237 .sum()
238 .alias(format!("count_if({column})"))];
239 let lf = self.lazy_grouped.clone().agg(agg_expr);
240 let mut pl_df = lf.collect()?;
241 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
242 Ok(super::DataFrame::from_polars_with_options(
243 pl_df,
244 self.case_sensitive,
245 ))
246 }
247
248 pub fn percentile(&self, column: &str, p: f64) -> Result<DataFrame, PolarsError> {
250 use polars::prelude::*;
251 let agg_expr = vec![col(column)
252 .quantile(lit(p), QuantileMethod::Linear)
253 .alias(format!("percentile({column}, {p})"))];
254 let lf = self.lazy_grouped.clone().agg(agg_expr);
255 let mut pl_df = lf.collect()?;
256 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
257 Ok(super::DataFrame::from_polars_with_options(
258 pl_df,
259 self.case_sensitive,
260 ))
261 }
262
263 pub fn max_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
265 use polars::prelude::*;
266 let st = as_struct(vec![
267 col(ord_col).alias("_ord"),
268 col(value_col).alias("_val"),
269 ]);
270 let agg_expr = vec![st
271 .sort(SortOptions::default().with_order_descending(true))
272 .first()
273 .struct_()
274 .field_by_name("_val")
275 .alias(format!("max_by({value_col}, {ord_col})"))];
276 let lf = self.lazy_grouped.clone().agg(agg_expr);
277 let mut pl_df = lf.collect()?;
278 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
279 Ok(super::DataFrame::from_polars_with_options(
280 pl_df,
281 self.case_sensitive,
282 ))
283 }
284
285 pub fn min_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
287 use polars::prelude::*;
288 let st = as_struct(vec![
289 col(ord_col).alias("_ord"),
290 col(value_col).alias("_val"),
291 ]);
292 let agg_expr = vec![st
293 .sort(SortOptions::default())
294 .first()
295 .struct_()
296 .field_by_name("_val")
297 .alias(format!("min_by({value_col}, {ord_col})"))];
298 let lf = self.lazy_grouped.clone().agg(agg_expr);
299 let mut pl_df = lf.collect()?;
300 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
301 Ok(super::DataFrame::from_polars_with_options(
302 pl_df,
303 self.case_sensitive,
304 ))
305 }
306
307 pub fn covar_pop(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
309 use polars::prelude::DataType;
310 let c1 = col(col1).cast(DataType::Float64);
311 let c2 = col(col2).cast(DataType::Float64);
312 let n = len().cast(DataType::Float64);
313 let sum_ab = (c1.clone() * c2.clone()).sum();
314 let sum_a = col(col1).sum().cast(DataType::Float64);
315 let sum_b = col(col2).sum().cast(DataType::Float64);
316 let cov = (sum_ab - sum_a * sum_b / n.clone()) / n;
317 let agg_expr = vec![cov.alias(format!("covar_pop({col1}, {col2})"))];
318 let lf = self.lazy_grouped.clone().agg(agg_expr);
319 let mut pl_df = lf.collect()?;
320 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
321 Ok(super::DataFrame::from_polars_with_options(
322 pl_df,
323 self.case_sensitive,
324 ))
325 }
326
327 pub fn covar_samp(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
329 use polars::prelude::DataType;
330 let c1 = col(col1).cast(DataType::Float64);
331 let c2 = col(col2).cast(DataType::Float64);
332 let n = len().cast(DataType::Float64);
333 let sum_ab = (c1.clone() * c2.clone()).sum();
334 let sum_a = col(col1).sum().cast(DataType::Float64);
335 let sum_b = col(col2).sum().cast(DataType::Float64);
336 let cov = when(len().gt(lit(1)))
337 .then((sum_ab - sum_a * sum_b / n.clone()) / (len() - lit(1)).cast(DataType::Float64))
338 .otherwise(lit(f64::NAN));
339 let agg_expr = vec![cov.alias(format!("covar_samp({col1}, {col2})"))];
340 let lf = self.lazy_grouped.clone().agg(agg_expr);
341 let mut pl_df = lf.collect()?;
342 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
343 Ok(super::DataFrame::from_polars_with_options(
344 pl_df,
345 self.case_sensitive,
346 ))
347 }
348
349 pub fn corr(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
351 use polars::prelude::DataType;
352 let c1 = col(col1).cast(DataType::Float64);
353 let c2 = col(col2).cast(DataType::Float64);
354 let n = len().cast(DataType::Float64);
355 let n1 = (len() - lit(1)).cast(DataType::Float64);
356 let sum_ab = (c1.clone() * c2.clone()).sum();
357 let sum_a = col(col1).sum().cast(DataType::Float64);
358 let sum_b = col(col2).sum().cast(DataType::Float64);
359 let sum_a2 = (c1.clone() * c1).sum();
360 let sum_b2 = (c2.clone() * c2).sum();
361 let cov_samp = (sum_ab - sum_a.clone() * sum_b.clone() / n.clone()) / n1.clone();
362 let var_a = (sum_a2 - sum_a.clone() * sum_a / n.clone()) / n1.clone();
363 let var_b = (sum_b2 - sum_b.clone() * sum_b / n.clone()) / n1.clone();
364 let std_a = var_a.sqrt();
365 let std_b = var_b.sqrt();
366 let corr_expr = when(len().gt(lit(1)))
367 .then(cov_samp / (std_a * std_b))
368 .otherwise(lit(f64::NAN));
369 let agg_expr = vec![corr_expr.alias(format!("corr({col1}, {col2})"))];
370 let lf = self.lazy_grouped.clone().agg(agg_expr);
371 let mut pl_df = lf.collect()?;
372 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
373 Ok(super::DataFrame::from_polars_with_options(
374 pl_df,
375 self.case_sensitive,
376 ))
377 }
378
379 pub fn regr_count(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
381 let agg_expr = vec![crate::functions::regr_count_expr(y_col, x_col)
382 .alias(format!("regr_count({y_col}, {x_col})"))];
383 let lf = self.lazy_grouped.clone().agg(agg_expr);
384 let mut pl_df = lf.collect()?;
385 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
386 Ok(super::DataFrame::from_polars_with_options(
387 pl_df,
388 self.case_sensitive,
389 ))
390 }
391
392 pub fn regr_avgx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
394 let agg_expr = vec![crate::functions::regr_avgx_expr(y_col, x_col)
395 .alias(format!("regr_avgx({y_col}, {x_col})"))];
396 let lf = self.lazy_grouped.clone().agg(agg_expr);
397 let mut pl_df = lf.collect()?;
398 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
399 Ok(super::DataFrame::from_polars_with_options(
400 pl_df,
401 self.case_sensitive,
402 ))
403 }
404
405 pub fn regr_avgy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
407 let agg_expr = vec![crate::functions::regr_avgy_expr(y_col, x_col)
408 .alias(format!("regr_avgy({y_col}, {x_col})"))];
409 let lf = self.lazy_grouped.clone().agg(agg_expr);
410 let mut pl_df = lf.collect()?;
411 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
412 Ok(super::DataFrame::from_polars_with_options(
413 pl_df,
414 self.case_sensitive,
415 ))
416 }
417
418 pub fn regr_slope(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
420 let agg_expr = vec![crate::functions::regr_slope_expr(y_col, x_col)
421 .alias(format!("regr_slope({y_col}, {x_col})"))];
422 let lf = self.lazy_grouped.clone().agg(agg_expr);
423 let mut pl_df = lf.collect()?;
424 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
425 Ok(super::DataFrame::from_polars_with_options(
426 pl_df,
427 self.case_sensitive,
428 ))
429 }
430
431 pub fn regr_intercept(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
433 let agg_expr = vec![crate::functions::regr_intercept_expr(y_col, x_col)
434 .alias(format!("regr_intercept({y_col}, {x_col})"))];
435 let lf = self.lazy_grouped.clone().agg(agg_expr);
436 let mut pl_df = lf.collect()?;
437 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
438 Ok(super::DataFrame::from_polars_with_options(
439 pl_df,
440 self.case_sensitive,
441 ))
442 }
443
444 pub fn regr_r2(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
446 let agg_expr = vec![crate::functions::regr_r2_expr(y_col, x_col)
447 .alias(format!("regr_r2({y_col}, {x_col})"))];
448 let lf = self.lazy_grouped.clone().agg(agg_expr);
449 let mut pl_df = lf.collect()?;
450 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
451 Ok(super::DataFrame::from_polars_with_options(
452 pl_df,
453 self.case_sensitive,
454 ))
455 }
456
457 pub fn regr_sxx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
459 let agg_expr = vec![crate::functions::regr_sxx_expr(y_col, x_col)
460 .alias(format!("regr_sxx({y_col}, {x_col})"))];
461 let lf = self.lazy_grouped.clone().agg(agg_expr);
462 let mut pl_df = lf.collect()?;
463 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
464 Ok(super::DataFrame::from_polars_with_options(
465 pl_df,
466 self.case_sensitive,
467 ))
468 }
469
470 pub fn regr_syy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
472 let agg_expr = vec![crate::functions::regr_syy_expr(y_col, x_col)
473 .alias(format!("regr_syy({y_col}, {x_col})"))];
474 let lf = self.lazy_grouped.clone().agg(agg_expr);
475 let mut pl_df = lf.collect()?;
476 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
477 Ok(super::DataFrame::from_polars_with_options(
478 pl_df,
479 self.case_sensitive,
480 ))
481 }
482
483 pub fn regr_sxy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
485 let agg_expr = vec![crate::functions::regr_sxy_expr(y_col, x_col)
486 .alias(format!("regr_sxy({y_col}, {x_col})"))];
487 let lf = self.lazy_grouped.clone().agg(agg_expr);
488 let mut pl_df = lf.collect()?;
489 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
490 Ok(super::DataFrame::from_polars_with_options(
491 pl_df,
492 self.case_sensitive,
493 ))
494 }
495
496 pub fn kurtosis(&self, column: &str) -> Result<DataFrame, PolarsError> {
498 use polars::prelude::*;
499 let agg_expr = vec![col(column)
500 .cast(DataType::Float64)
501 .kurtosis(true, true)
502 .alias(format!("kurtosis({column})"))];
503 let lf = self.lazy_grouped.clone().agg(agg_expr);
504 let mut pl_df = lf.collect()?;
505 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
506 Ok(super::DataFrame::from_polars_with_options(
507 pl_df,
508 self.case_sensitive,
509 ))
510 }
511
512 pub fn skewness(&self, column: &str) -> Result<DataFrame, PolarsError> {
514 use polars::prelude::*;
515 let agg_expr = vec![col(column)
516 .cast(DataType::Float64)
517 .skew(true)
518 .alias(format!("skewness({column})"))];
519 let lf = self.lazy_grouped.clone().agg(agg_expr);
520 let mut pl_df = lf.collect()?;
521 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
522 Ok(super::DataFrame::from_polars_with_options(
523 pl_df,
524 self.case_sensitive,
525 ))
526 }
527
528 pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
530 let lf = self.lazy_grouped.clone().agg(aggregations);
531 let mut pl_df = lf.collect()?;
532 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
533 Ok(super::DataFrame::from_polars_with_options(
534 pl_df,
535 self.case_sensitive,
536 ))
537 }
538
539 pub fn grouping_columns(&self) -> &[String] {
541 &self.grouping_cols
542 }
543}
544
545pub struct CubeRollupData {
547 pub(super) df: PlDataFrame,
548 pub(super) grouping_cols: Vec<String>,
549 pub(super) case_sensitive: bool,
550 pub(super) is_cube: bool,
551}
552
553impl CubeRollupData {
554 pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
556 use polars::prelude::*;
557 let subsets: Vec<Vec<String>> = if self.is_cube {
558 let n = self.grouping_cols.len();
560 (0..1 << n)
561 .map(|mask| {
562 self.grouping_cols
563 .iter()
564 .enumerate()
565 .filter(|(i, _)| (mask & (1 << i)) != 0)
566 .map(|(_, c)| c.clone())
567 .collect()
568 })
569 .collect()
570 } else {
571 (0..=self.grouping_cols.len())
573 .map(|len| self.grouping_cols[..len].to_vec())
574 .collect()
575 };
576
577 let schema = self.df.schema();
578 let mut parts: Vec<PlDataFrame> = Vec::with_capacity(subsets.len());
579 for subset in subsets {
580 if subset.is_empty() {
581 let lf = self.df.clone().lazy().select(aggregations.clone());
583 let mut part = lf.collect()?;
584 let n = part.height();
585 for gc in &self.grouping_cols {
586 let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
587 let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
588 part.with_column(null_series)?;
589 }
590 let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
592 for name in part.get_column_names() {
593 if !self.grouping_cols.iter().any(|g| g == name) {
594 order.push(name);
595 }
596 }
597 part = part.select(order)?;
598 parts.push(part);
599 } else {
600 let grouped = self
601 .df
602 .clone()
603 .lazy()
604 .group_by(subset.iter().map(|s| col(s.as_str())).collect::<Vec<_>>());
605 let mut part = grouped.agg(aggregations.clone()).collect()?;
606 part = reorder_groupby_columns(&mut part, &subset)?;
607 let n = part.height();
608 for gc in &self.grouping_cols {
609 if subset.iter().any(|s| s == gc) {
610 continue;
611 }
612 let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
613 let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
614 part.with_column(null_series)?;
615 }
616 let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
617 for name in part.get_column_names() {
618 if !self.grouping_cols.iter().any(|g| g == name) {
619 order.push(name);
620 }
621 }
622 part = part.select(order)?;
623 parts.push(part);
624 }
625 }
626
627 if parts.is_empty() {
628 return Ok(super::DataFrame::from_polars_with_options(
629 PlDataFrame::empty(),
630 self.case_sensitive,
631 ));
632 }
633 let first_schema = parts[0].schema();
634 let order: Vec<&str> = first_schema.iter_names().map(|s| s.as_str()).collect();
635 for p in parts.iter_mut().skip(1) {
636 *p = p.select(order.clone())?;
637 }
638 let lazy_frames: Vec<_> = parts.into_iter().map(|p| p.lazy()).collect();
639 let out = polars::prelude::concat(lazy_frames, UnionArgs::default())?.collect()?;
640 Ok(super::DataFrame::from_polars_with_options(
641 out,
642 self.case_sensitive,
643 ))
644 }
645}
646
647fn null_series_for_dtype(name: &str, n: usize, dtype: &DataType) -> Result<Series, PolarsError> {
648 let name = name.into();
649 let s = match dtype {
650 DataType::Int32 => Series::new(name, vec![None::<i32>; n]),
651 DataType::Int64 => Series::new(name, vec![None::<i64>; n]),
652 DataType::Float32 => Series::new(name, vec![None::<f32>; n]),
653 DataType::Float64 => Series::new(name, vec![None::<f64>; n]),
654 DataType::String => {
655 let v: Vec<Option<String>> = (0..n).map(|_| None).collect();
656 Series::new(name, v)
657 }
658 DataType::Boolean => Series::new(name, vec![None::<bool>; n]),
659 DataType::Date => Series::new(name, vec![None::<i32>; n]).cast(dtype)?,
660 DataType::Datetime(_, _) => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
661 _ => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
662 };
663 Ok(s)
664}
665
666pub(super) fn reorder_groupby_columns(
668 pl_df: &mut PlDataFrame,
669 grouping_cols: &[String],
670) -> Result<PlDataFrame, PolarsError> {
671 let all_cols: Vec<String> = pl_df
672 .get_column_names()
673 .iter()
674 .map(|s| s.to_string())
675 .collect();
676 let mut reordered_cols: Vec<&str> = Vec::new();
677 for gc in grouping_cols {
678 if all_cols.iter().any(|c| c == gc) {
679 reordered_cols.push(gc);
680 }
681 }
682 for col_name in &all_cols {
683 if !grouping_cols.iter().any(|gc| gc == col_name) {
684 reordered_cols.push(col_name);
685 }
686 }
687 if !reordered_cols.is_empty() && reordered_cols.len() == all_cols.len() {
688 pl_df.select(reordered_cols)
689 } else {
690 Ok(pl_df.clone())
691 }
692}
693
694#[cfg(test)]
695mod tests {
696 use crate::{DataFrame, SparkSession};
697
698 fn test_df() -> DataFrame {
699 let spark = SparkSession::builder()
700 .app_name("agg_tests")
701 .get_or_create();
702 let tuples = vec![
703 (1i64, 10i64, "a".to_string()),
704 (1i64, 20i64, "a".to_string()),
705 (2i64, 30i64, "b".to_string()),
706 ];
707 spark
708 .create_dataframe(tuples, vec!["k", "v", "label"])
709 .unwrap()
710 }
711
712 #[test]
713 fn group_by_count_single_group() {
714 let df = test_df();
715 let grouped = df.group_by(vec!["k"]).unwrap();
716 let out = grouped.count().unwrap();
717 assert_eq!(out.count().unwrap(), 2);
718 let cols = out.columns().unwrap();
719 assert!(cols.contains(&"k".to_string()));
720 assert!(cols.contains(&"count".to_string()));
721 }
722
723 #[test]
724 fn group_by_sum() {
725 let df = test_df();
726 let grouped = df.group_by(vec!["k"]).unwrap();
727 let out = grouped.sum("v").unwrap();
728 assert_eq!(out.count().unwrap(), 2);
729 let cols = out.columns().unwrap();
730 assert!(cols.iter().any(|c| c.starts_with("sum(")));
731 }
732
733 #[test]
734 fn group_by_empty_groups() {
735 let spark = SparkSession::builder()
736 .app_name("agg_tests")
737 .get_or_create();
738 let tuples: Vec<(i64, i64, String)> = vec![];
739 let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
740 let grouped = df.group_by(vec!["a"]).unwrap();
741 let out = grouped.count().unwrap();
742 assert_eq!(out.count().unwrap(), 0);
743 }
744
745 #[test]
746 fn group_by_agg_multi() {
747 use polars::prelude::*;
748 let df = test_df();
749 let grouped = df.group_by(vec!["k"]).unwrap();
750 let out = grouped
751 .agg(vec![len().alias("cnt"), col("v").sum().alias("total")])
752 .unwrap();
753 assert_eq!(out.count().unwrap(), 2);
754 let cols = out.columns().unwrap();
755 assert!(cols.contains(&"k".to_string()));
756 assert!(cols.contains(&"cnt".to_string()));
757 assert!(cols.contains(&"total".to_string()));
758 }
759}