1use super::DataFrame;
4use crate::column::Column;
5use polars::prelude::{
6 DataFrame as PlDataFrame, DataType, Expr, LazyFrame, LazyGroupBy, NamedFrom, PolarsError,
7 SchemaNamesAndDtypes, Series, col, len, lit, when,
8};
9use std::collections::HashMap;
10
11pub(crate) fn disambiguate_agg_output_names(aggregations: Vec<Expr>) -> Vec<Expr> {
15 let mut name_count: HashMap<String, u32> = HashMap::new();
16 aggregations
17 .into_iter()
18 .map(|e| {
19 let base_name = polars_plan::utils::expr_output_name(&e)
20 .map(|s| s.to_string())
21 .unwrap_or_else(|_| "_".to_string());
22 let count = name_count.entry(base_name.clone()).or_insert(0);
23 *count += 1;
24 let final_name = if *count == 1 {
25 base_name
26 } else {
27 format!("{}_{}", base_name, *count - 1)
28 };
29 if *count == 1 {
30 e
31 } else {
32 e.alias(final_name.as_str())
33 }
34 })
35 .collect()
36}
37
38pub struct GroupedData {
41 pub(crate) lf: LazyFrame,
42 pub(crate) lazy_grouped: LazyGroupBy,
43 pub(crate) grouping_cols: Vec<String>,
44 pub(crate) case_sensitive: bool,
45}
46
47impl GroupedData {
48 fn resolve_column(&self, name: &str) -> Result<String, PolarsError> {
50 let schema = self.lf.clone().collect_schema()?;
51 let names: Vec<String> = schema
52 .iter_names_and_dtypes()
53 .map(|(n, _)| n.to_string())
54 .collect();
55 if self.case_sensitive {
56 if names.iter().any(|n| n == name) {
57 return Ok(name.to_string());
58 }
59 } else {
60 let name_lower = name.to_lowercase();
61 for n in &names {
62 if n.to_lowercase() == name_lower {
63 return Ok(n.clone());
64 }
65 }
66 }
67 let available = names.join(", ");
68 Err(PolarsError::ColumnNotFound(
69 format!(
70 "Column '{}' not found in grouped DataFrame. Available: [{}].",
71 name, available
72 )
73 .into(),
74 ))
75 }
76
77 pub fn count(&self) -> Result<DataFrame, PolarsError> {
79 use polars::prelude::*;
80 let agg_expr = vec![len().alias("count")];
81 let lf = self.lazy_grouped.clone().agg(agg_expr);
82 let mut pl_df = lf.collect()?;
83 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
84 Ok(super::DataFrame::from_polars_with_options(
85 pl_df,
86 self.case_sensitive,
87 ))
88 }
89
90 pub fn sum(&self, column: &str) -> Result<DataFrame, PolarsError> {
92 use polars::prelude::*;
93 let c = self.resolve_column(column)?;
94 let agg_expr = vec![col(c.as_str()).sum().alias(format!("sum({column})"))];
95 let lf = self.lazy_grouped.clone().agg(agg_expr);
96 let mut pl_df = lf.collect()?;
97 let all_cols: Vec<String> = pl_df
98 .get_column_names()
99 .iter()
100 .map(|s| s.to_string())
101 .collect();
102 let grouping_cols: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
103 let mut reordered_cols: Vec<&str> = Vec::new();
104 for gc in &grouping_cols {
105 if all_cols.iter().any(|c| c == gc) {
106 reordered_cols.push(gc);
107 }
108 }
109 for col_name in &all_cols {
110 if !grouping_cols.iter().any(|gc| *gc == col_name) {
111 reordered_cols.push(col_name);
112 }
113 }
114 if !reordered_cols.is_empty() {
115 pl_df = pl_df.select(reordered_cols)?;
116 }
117 Ok(super::DataFrame::from_polars_with_options(
118 pl_df,
119 self.case_sensitive,
120 ))
121 }
122
123 pub fn avg(&self, columns: &[&str]) -> Result<DataFrame, PolarsError> {
125 if columns.is_empty() {
126 return Err(PolarsError::ComputeError(
127 "avg requires at least one column".into(),
128 ));
129 }
130 use polars::prelude::*;
131 let agg_expr: Vec<Expr> = columns
132 .iter()
133 .map(|c| {
134 let resolved = self.resolve_column(c)?;
135 Ok(col(resolved.as_str()).mean().alias(format!("avg({c})")))
136 })
137 .collect::<Result<Vec<_>, PolarsError>>()?;
138 let lf = self.lazy_grouped.clone().agg(agg_expr);
139 let mut pl_df = lf.collect()?;
140 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
141 Ok(super::DataFrame::from_polars_with_options(
142 pl_df,
143 self.case_sensitive,
144 ))
145 }
146
147 pub fn min(&self, column: &str) -> Result<DataFrame, PolarsError> {
149 use polars::prelude::*;
150 let c = self.resolve_column(column)?;
151 let agg_expr = vec![col(c.as_str()).min().alias(format!("min({column})"))];
152 let lf = self.lazy_grouped.clone().agg(agg_expr);
153 let mut pl_df = lf.collect()?;
154 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
155 Ok(super::DataFrame::from_polars_with_options(
156 pl_df,
157 self.case_sensitive,
158 ))
159 }
160
161 pub fn max(&self, column: &str) -> Result<DataFrame, PolarsError> {
163 use polars::prelude::*;
164 let c = self.resolve_column(column)?;
165 let agg_expr = vec![col(c.as_str()).max().alias(format!("max({column})"))];
166 let lf = self.lazy_grouped.clone().agg(agg_expr);
167 let mut pl_df = lf.collect()?;
168 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
169 Ok(super::DataFrame::from_polars_with_options(
170 pl_df,
171 self.case_sensitive,
172 ))
173 }
174
175 pub fn first(&self, column: &str) -> Result<DataFrame, PolarsError> {
177 use polars::prelude::*;
178 let c = self.resolve_column(column)?;
179 let agg_expr = vec![col(c.as_str()).first().alias(format!("first({column})"))];
180 let lf = self.lazy_grouped.clone().agg(agg_expr);
181 let mut pl_df = lf.collect()?;
182 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
183 Ok(super::DataFrame::from_polars_with_options(
184 pl_df,
185 self.case_sensitive,
186 ))
187 }
188
189 pub fn last(&self, column: &str) -> Result<DataFrame, PolarsError> {
191 use polars::prelude::*;
192 let c = self.resolve_column(column)?;
193 let agg_expr = vec![col(c.as_str()).last().alias(format!("last({column})"))];
194 let lf = self.lazy_grouped.clone().agg(agg_expr);
195 let mut pl_df = lf.collect()?;
196 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
197 Ok(super::DataFrame::from_polars_with_options(
198 pl_df,
199 self.case_sensitive,
200 ))
201 }
202
203 pub fn approx_count_distinct(&self, column: &str) -> Result<DataFrame, PolarsError> {
205 use polars::prelude::{DataType, col};
206 let c = self.resolve_column(column)?;
207 let agg_expr = vec![
208 col(c.as_str())
209 .n_unique()
210 .cast(DataType::Int64)
211 .alias(format!("approx_count_distinct({column})")),
212 ];
213 let lf = self.lazy_grouped.clone().agg(agg_expr);
214 let mut pl_df = lf.collect()?;
215 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
216 Ok(super::DataFrame::from_polars_with_options(
217 pl_df,
218 self.case_sensitive,
219 ))
220 }
221
222 pub fn any_value(&self, column: &str) -> Result<DataFrame, PolarsError> {
224 use polars::prelude::*;
225 let c = self.resolve_column(column)?;
226 let agg_expr = vec![
227 col(c.as_str())
228 .first()
229 .alias(format!("any_value({column})")),
230 ];
231 let lf = self.lazy_grouped.clone().agg(agg_expr);
232 let mut pl_df = lf.collect()?;
233 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
234 Ok(super::DataFrame::from_polars_with_options(
235 pl_df,
236 self.case_sensitive,
237 ))
238 }
239
240 pub fn bool_and(&self, column: &str) -> Result<DataFrame, PolarsError> {
242 use polars::prelude::*;
243 let c = self.resolve_column(column)?;
244 let agg_expr = vec![
245 col(c.as_str())
246 .all(true)
247 .alias(format!("bool_and({column})")),
248 ];
249 let lf = self.lazy_grouped.clone().agg(agg_expr);
250 let mut pl_df = lf.collect()?;
251 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
252 Ok(super::DataFrame::from_polars_with_options(
253 pl_df,
254 self.case_sensitive,
255 ))
256 }
257
258 pub fn bool_or(&self, column: &str) -> Result<DataFrame, PolarsError> {
260 use polars::prelude::*;
261 let c = self.resolve_column(column)?;
262 let agg_expr = vec![
263 col(c.as_str())
264 .any(true)
265 .alias(format!("bool_or({column})")),
266 ];
267 let lf = self.lazy_grouped.clone().agg(agg_expr);
268 let mut pl_df = lf.collect()?;
269 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
270 Ok(super::DataFrame::from_polars_with_options(
271 pl_df,
272 self.case_sensitive,
273 ))
274 }
275
276 pub fn product(&self, column: &str) -> Result<DataFrame, PolarsError> {
278 use polars::prelude::*;
279 let c = self.resolve_column(column)?;
280 let agg_expr = vec![
281 col(c.as_str())
282 .product()
283 .alias(format!("product({column})")),
284 ];
285 let lf = self.lazy_grouped.clone().agg(agg_expr);
286 let mut pl_df = lf.collect()?;
287 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
288 Ok(super::DataFrame::from_polars_with_options(
289 pl_df,
290 self.case_sensitive,
291 ))
292 }
293
294 pub fn collect_list(&self, column: &str) -> Result<DataFrame, PolarsError> {
296 use polars::prelude::*;
297 let c = self.resolve_column(column)?;
298 let agg_expr = vec![
299 col(c.as_str())
300 .implode()
301 .alias(format!("collect_list({column})")),
302 ];
303 let lf = self.lazy_grouped.clone().agg(agg_expr);
304 let mut pl_df = lf.collect()?;
305 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
306 Ok(super::DataFrame::from_polars_with_options(
307 pl_df,
308 self.case_sensitive,
309 ))
310 }
311
312 pub fn collect_set(&self, column: &str) -> Result<DataFrame, PolarsError> {
314 use polars::prelude::*;
315 let c = self.resolve_column(column)?;
316 let agg_expr = vec![
317 col(c.as_str())
318 .unique()
319 .implode()
320 .alias(format!("collect_set({column})")),
321 ];
322 let lf = self.lazy_grouped.clone().agg(agg_expr);
323 let mut pl_df = lf.collect()?;
324 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
325 Ok(super::DataFrame::from_polars_with_options(
326 pl_df,
327 self.case_sensitive,
328 ))
329 }
330
331 pub fn count_if(&self, column: &str) -> Result<DataFrame, PolarsError> {
333 use polars::prelude::*;
334 let c = self.resolve_column(column)?;
335 let agg_expr = vec![
336 col(c.as_str())
337 .cast(DataType::Int64)
338 .sum()
339 .alias(format!("count_if({column})")),
340 ];
341 let lf = self.lazy_grouped.clone().agg(agg_expr);
342 let mut pl_df = lf.collect()?;
343 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
344 Ok(super::DataFrame::from_polars_with_options(
345 pl_df,
346 self.case_sensitive,
347 ))
348 }
349
350 pub fn percentile(&self, column: &str, p: f64) -> Result<DataFrame, PolarsError> {
352 use polars::prelude::*;
353 let c = self.resolve_column(column)?;
354 let agg_expr = vec![
355 col(c.as_str())
356 .quantile(lit(p), QuantileMethod::Linear)
357 .alias(format!("percentile({column}, {p})")),
358 ];
359 let lf = self.lazy_grouped.clone().agg(agg_expr);
360 let mut pl_df = lf.collect()?;
361 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
362 Ok(super::DataFrame::from_polars_with_options(
363 pl_df,
364 self.case_sensitive,
365 ))
366 }
367
368 pub fn max_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
370 use polars::prelude::*;
371 let vc = self.resolve_column(value_col)?;
372 let oc = self.resolve_column(ord_col)?;
373 let st = as_struct(vec![
374 col(oc.as_str()).alias("_ord"),
375 col(vc.as_str()).alias("_val"),
376 ]);
377 let agg_expr = vec![
378 st.sort(SortOptions::default().with_order_descending(true))
379 .first()
380 .struct_()
381 .field_by_name("_val")
382 .alias(format!("max_by({value_col}, {ord_col})")),
383 ];
384 let lf = self.lazy_grouped.clone().agg(agg_expr);
385 let mut pl_df = lf.collect()?;
386 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
387 Ok(super::DataFrame::from_polars_with_options(
388 pl_df,
389 self.case_sensitive,
390 ))
391 }
392
393 pub fn min_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
395 use polars::prelude::*;
396 let vc = self.resolve_column(value_col)?;
397 let oc = self.resolve_column(ord_col)?;
398 let st = as_struct(vec![
399 col(oc.as_str()).alias("_ord"),
400 col(vc.as_str()).alias("_val"),
401 ]);
402 let agg_expr = vec![
403 st.sort(SortOptions::default())
404 .first()
405 .struct_()
406 .field_by_name("_val")
407 .alias(format!("min_by({value_col}, {ord_col})")),
408 ];
409 let lf = self.lazy_grouped.clone().agg(agg_expr);
410 let mut pl_df = lf.collect()?;
411 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
412 Ok(super::DataFrame::from_polars_with_options(
413 pl_df,
414 self.case_sensitive,
415 ))
416 }
417
418 pub fn covar_pop(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
420 use polars::prelude::DataType;
421 let c1_res = self.resolve_column(col1)?;
422 let c2_res = self.resolve_column(col2)?;
423 let c1 = col(c1_res.as_str()).cast(DataType::Float64);
424 let c2 = col(c2_res.as_str()).cast(DataType::Float64);
425 let n = len().cast(DataType::Float64);
426 let sum_ab = (c1.clone() * c2.clone()).sum();
427 let sum_a = col(c1_res.as_str()).sum().cast(DataType::Float64);
428 let sum_b = col(c2_res.as_str()).sum().cast(DataType::Float64);
429 let cov = (sum_ab - sum_a * sum_b / n.clone()) / n;
430 let agg_expr = vec![cov.alias(format!("covar_pop({col1}, {col2})"))];
431 let lf = self.lazy_grouped.clone().agg(agg_expr);
432 let mut pl_df = lf.collect()?;
433 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
434 Ok(super::DataFrame::from_polars_with_options(
435 pl_df,
436 self.case_sensitive,
437 ))
438 }
439
440 pub fn covar_samp(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
442 use polars::prelude::DataType;
443 let c1_res = self.resolve_column(col1)?;
444 let c2_res = self.resolve_column(col2)?;
445 let c1 = col(c1_res.as_str()).cast(DataType::Float64);
446 let c2 = col(c2_res.as_str()).cast(DataType::Float64);
447 let n = len().cast(DataType::Float64);
448 let sum_ab = (c1.clone() * c2.clone()).sum();
449 let sum_a = col(c1_res.as_str()).sum().cast(DataType::Float64);
450 let sum_b = col(c2_res.as_str()).sum().cast(DataType::Float64);
451 let cov = when(len().gt(lit(1)))
452 .then((sum_ab - sum_a * sum_b / n.clone()) / (len() - lit(1)).cast(DataType::Float64))
453 .otherwise(lit(f64::NAN));
454 let agg_expr = vec![cov.alias(format!("covar_samp({col1}, {col2})"))];
455 let lf = self.lazy_grouped.clone().agg(agg_expr);
456 let mut pl_df = lf.collect()?;
457 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
458 Ok(super::DataFrame::from_polars_with_options(
459 pl_df,
460 self.case_sensitive,
461 ))
462 }
463
464 pub fn corr(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
466 use polars::prelude::DataType;
467 let c1_res = self.resolve_column(col1)?;
468 let c2_res = self.resolve_column(col2)?;
469 let c1 = col(c1_res.as_str()).cast(DataType::Float64);
470 let c2 = col(c2_res.as_str()).cast(DataType::Float64);
471 let n = len().cast(DataType::Float64);
472 let n1 = (len() - lit(1)).cast(DataType::Float64);
473 let sum_ab = (c1.clone() * c2.clone()).sum();
474 let sum_a = col(c1_res.as_str()).sum().cast(DataType::Float64);
475 let sum_b = col(c2_res.as_str()).sum().cast(DataType::Float64);
476 let sum_a2 = (c1.clone() * c1).sum();
477 let sum_b2 = (c2.clone() * c2).sum();
478 let cov_samp = (sum_ab - sum_a.clone() * sum_b.clone() / n.clone()) / n1.clone();
479 let var_a = (sum_a2 - sum_a.clone() * sum_a / n.clone()) / n1.clone();
480 let var_b = (sum_b2 - sum_b.clone() * sum_b / n.clone()) / n1.clone();
481 let std_a = var_a.sqrt();
482 let std_b = var_b.sqrt();
483 let corr_expr = when(len().gt(lit(1)))
484 .then(cov_samp / (std_a * std_b))
485 .otherwise(lit(f64::NAN));
486 let agg_expr = vec![corr_expr.alias(format!("corr({col1}, {col2})"))];
487 let lf = self.lazy_grouped.clone().agg(agg_expr);
488 let mut pl_df = lf.collect()?;
489 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
490 Ok(super::DataFrame::from_polars_with_options(
491 pl_df,
492 self.case_sensitive,
493 ))
494 }
495
496 pub fn regr_count(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
498 let yc = self.resolve_column(y_col)?;
499 let xc = self.resolve_column(x_col)?;
500 let agg_expr = vec![
501 crate::functions::regr_count_expr(yc.as_str(), xc.as_str())
502 .alias(format!("regr_count({y_col}, {x_col})")),
503 ];
504 let lf = self.lazy_grouped.clone().agg(agg_expr);
505 let mut pl_df = lf.collect()?;
506 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
507 Ok(super::DataFrame::from_polars_with_options(
508 pl_df,
509 self.case_sensitive,
510 ))
511 }
512
513 pub fn regr_avgx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
515 let yc = self.resolve_column(y_col)?;
516 let xc = self.resolve_column(x_col)?;
517 let agg_expr = vec![
518 crate::functions::regr_avgx_expr(yc.as_str(), xc.as_str())
519 .alias(format!("regr_avgx({y_col}, {x_col})")),
520 ];
521 let lf = self.lazy_grouped.clone().agg(agg_expr);
522 let mut pl_df = lf.collect()?;
523 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
524 Ok(super::DataFrame::from_polars_with_options(
525 pl_df,
526 self.case_sensitive,
527 ))
528 }
529
530 pub fn regr_avgy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
532 let yc = self.resolve_column(y_col)?;
533 let xc = self.resolve_column(x_col)?;
534 let agg_expr = vec![
535 crate::functions::regr_avgy_expr(yc.as_str(), xc.as_str())
536 .alias(format!("regr_avgy({y_col}, {x_col})")),
537 ];
538 let lf = self.lazy_grouped.clone().agg(agg_expr);
539 let mut pl_df = lf.collect()?;
540 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
541 Ok(super::DataFrame::from_polars_with_options(
542 pl_df,
543 self.case_sensitive,
544 ))
545 }
546
547 pub fn regr_slope(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
549 let yc = self.resolve_column(y_col)?;
550 let xc = self.resolve_column(x_col)?;
551 let agg_expr = vec![
552 crate::functions::regr_slope_expr(yc.as_str(), xc.as_str())
553 .alias(format!("regr_slope({y_col}, {x_col})")),
554 ];
555 let lf = self.lazy_grouped.clone().agg(agg_expr);
556 let mut pl_df = lf.collect()?;
557 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
558 Ok(super::DataFrame::from_polars_with_options(
559 pl_df,
560 self.case_sensitive,
561 ))
562 }
563
564 pub fn regr_intercept(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
566 let yc = self.resolve_column(y_col)?;
567 let xc = self.resolve_column(x_col)?;
568 let agg_expr = vec![
569 crate::functions::regr_intercept_expr(yc.as_str(), xc.as_str())
570 .alias(format!("regr_intercept({y_col}, {x_col})")),
571 ];
572 let lf = self.lazy_grouped.clone().agg(agg_expr);
573 let mut pl_df = lf.collect()?;
574 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
575 Ok(super::DataFrame::from_polars_with_options(
576 pl_df,
577 self.case_sensitive,
578 ))
579 }
580
581 pub fn regr_r2(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
583 let yc = self.resolve_column(y_col)?;
584 let xc = self.resolve_column(x_col)?;
585 let agg_expr = vec![
586 crate::functions::regr_r2_expr(yc.as_str(), xc.as_str())
587 .alias(format!("regr_r2({y_col}, {x_col})")),
588 ];
589 let lf = self.lazy_grouped.clone().agg(agg_expr);
590 let mut pl_df = lf.collect()?;
591 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
592 Ok(super::DataFrame::from_polars_with_options(
593 pl_df,
594 self.case_sensitive,
595 ))
596 }
597
598 pub fn regr_sxx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
600 let yc = self.resolve_column(y_col)?;
601 let xc = self.resolve_column(x_col)?;
602 let agg_expr = vec![
603 crate::functions::regr_sxx_expr(yc.as_str(), xc.as_str())
604 .alias(format!("regr_sxx({y_col}, {x_col})")),
605 ];
606 let lf = self.lazy_grouped.clone().agg(agg_expr);
607 let mut pl_df = lf.collect()?;
608 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
609 Ok(super::DataFrame::from_polars_with_options(
610 pl_df,
611 self.case_sensitive,
612 ))
613 }
614
615 pub fn regr_syy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
617 let yc = self.resolve_column(y_col)?;
618 let xc = self.resolve_column(x_col)?;
619 let agg_expr = vec![
620 crate::functions::regr_syy_expr(yc.as_str(), xc.as_str())
621 .alias(format!("regr_syy({y_col}, {x_col})")),
622 ];
623 let lf = self.lazy_grouped.clone().agg(agg_expr);
624 let mut pl_df = lf.collect()?;
625 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
626 Ok(super::DataFrame::from_polars_with_options(
627 pl_df,
628 self.case_sensitive,
629 ))
630 }
631
632 pub fn regr_sxy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
634 let yc = self.resolve_column(y_col)?;
635 let xc = self.resolve_column(x_col)?;
636 let agg_expr = vec![
637 crate::functions::regr_sxy_expr(yc.as_str(), xc.as_str())
638 .alias(format!("regr_sxy({y_col}, {x_col})")),
639 ];
640 let lf = self.lazy_grouped.clone().agg(agg_expr);
641 let mut pl_df = lf.collect()?;
642 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
643 Ok(super::DataFrame::from_polars_with_options(
644 pl_df,
645 self.case_sensitive,
646 ))
647 }
648
649 pub fn kurtosis(&self, column: &str) -> Result<DataFrame, PolarsError> {
651 use polars::prelude::*;
652 let c = self.resolve_column(column)?;
653 let agg_expr = vec![
654 col(c.as_str())
655 .cast(DataType::Float64)
656 .kurtosis(true, true)
657 .alias(format!("kurtosis({column})")),
658 ];
659 let lf = self.lazy_grouped.clone().agg(agg_expr);
660 let mut pl_df = lf.collect()?;
661 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
662 Ok(super::DataFrame::from_polars_with_options(
663 pl_df,
664 self.case_sensitive,
665 ))
666 }
667
668 pub fn skewness(&self, column: &str) -> Result<DataFrame, PolarsError> {
670 use polars::prelude::*;
671 let c = self.resolve_column(column)?;
672 let agg_expr = vec![
673 col(c.as_str())
674 .cast(DataType::Float64)
675 .skew(true)
676 .alias(format!("skewness({column})")),
677 ];
678 let lf = self.lazy_grouped.clone().agg(agg_expr);
679 let mut pl_df = lf.collect()?;
680 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
681 Ok(super::DataFrame::from_polars_with_options(
682 pl_df,
683 self.case_sensitive,
684 ))
685 }
686
687 pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
690 let disambiguated = disambiguate_agg_output_names(aggregations);
691 let lf = self.lazy_grouped.clone().agg(disambiguated);
692 let mut pl_df = lf.collect()?;
693 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
694 Ok(super::DataFrame::from_polars_with_options(
695 pl_df,
696 self.case_sensitive,
697 ))
698 }
699
700 pub fn agg_columns(&self, aggregations: Vec<Column>) -> Result<DataFrame, PolarsError> {
704 let exprs: Vec<Expr> = aggregations.into_iter().map(|c| c.into_expr()).collect();
705 self.agg(exprs)
706 }
707
708 pub fn grouping_columns(&self) -> &[String] {
710 &self.grouping_cols
711 }
712
713 pub fn pivot(&self, pivot_col: &str, values: Option<Vec<String>>) -> PivotedGroupedData {
716 PivotedGroupedData {
717 lf: self.lf.clone(),
718 grouping_cols: self.grouping_cols.clone(),
719 pivot_col: pivot_col.to_string(),
720 values,
721 case_sensitive: self.case_sensitive,
722 }
723 }
724}
725
726pub struct PivotedGroupedData {
728 pub(crate) lf: LazyFrame,
729 pub(crate) grouping_cols: Vec<String>,
730 pub(crate) pivot_col: String,
731 pub(crate) values: Option<Vec<String>>,
732 pub(crate) case_sensitive: bool,
733}
734
735fn pivot_value_to_column_name(av: polars::prelude::AnyValue<'_>) -> String {
737 use polars::prelude::AnyValue;
738 match av {
739 AnyValue::Null => "null".to_string(),
740 AnyValue::String(s) => s.to_string(),
741 _ => av.to_string(),
742 }
743}
744
745fn pivot_values_from_lf(lf: &LazyFrame, pivot_col: &str) -> Result<Vec<String>, PolarsError> {
746 use polars::prelude::*;
747 let pl_df = lf
748 .clone()
749 .select([col(pivot_col)])
750 .unique(None, Default::default())
751 .collect()?;
752 let s = pl_df.column(pivot_col)?;
753 let mut out = Vec::with_capacity(s.len());
754 for i in 0..s.len() {
755 let av = s.get(i)?;
756 out.push(pivot_value_to_column_name(av));
757 }
758 out.sort();
760 Ok(out)
761}
762
763impl PivotedGroupedData {
764 fn resolve_column(&self, name: &str) -> Result<String, PolarsError> {
765 let schema = self.lf.clone().collect_schema()?;
766 let names: Vec<String> = schema
767 .iter_names_and_dtypes()
768 .map(|(n, _)| n.to_string())
769 .collect();
770 if self.case_sensitive {
771 if names.iter().any(|n| n == name) {
772 return Ok(name.to_string());
773 }
774 } else {
775 let name_lower = name.to_lowercase();
776 for n in &names {
777 if n.to_lowercase() == name_lower {
778 return Ok(n.clone());
779 }
780 }
781 }
782 let available = names.join(", ");
783 Err(PolarsError::ColumnNotFound(
784 format!(
785 "Column '{}' not found in pivot DataFrame. Available: [{}].",
786 name, available
787 )
788 .into(),
789 ))
790 }
791
792 fn pivot_values(&self) -> Result<Vec<String>, PolarsError> {
793 if let Some(ref v) = self.values {
794 return Ok(v.clone());
795 }
796 let resolved = self.resolve_column(&self.pivot_col)?;
797 pivot_values_from_lf(&self.lf, &resolved)
798 }
799
800 fn pivot_agg(
801 &self,
802 value_col: &str,
803 agg_fn: fn(Expr) -> Expr,
804 ) -> Result<DataFrame, PolarsError> {
805 use polars::prelude::*;
806 let pivot_resolved = self.resolve_column(&self.pivot_col)?;
807 let value_resolved = self.resolve_column(value_col)?;
808 let pivot_vals = self.pivot_values()?;
809 if pivot_vals.is_empty() {
810 let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
811 let lf = self.lf.clone().group_by(by).agg(vec![]);
812 let pl_df = lf.collect()?;
813 return Ok(super::DataFrame::from_polars_with_options(
814 pl_df,
815 self.case_sensitive,
816 ));
817 }
818 let mut agg_exprs: Vec<Expr> = Vec::with_capacity(pivot_vals.len());
819 use polars::prelude::DataType;
820 for v in &pivot_vals {
821 let pred = if v == "null" {
823 col(pivot_resolved.as_str()).is_null()
824 } else {
825 col(pivot_resolved.as_str())
826 .cast(DataType::String)
827 .eq(lit(v.as_str()))
828 };
829 let then_expr = col(value_resolved.as_str());
830 let expr = when(pred).then(then_expr).otherwise(lit(NULL));
831 let has_any = expr
833 .clone()
834 .is_not_null()
835 .cast(DataType::UInt32)
836 .sum()
837 .gt(lit(0));
838 let agg_expr = when(has_any)
839 .then(agg_fn(expr))
840 .otherwise(lit(NULL))
841 .alias(v.as_str());
842 agg_exprs.push(agg_expr);
843 }
844 let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
845 let lf = self.lf.clone().group_by(by).agg(agg_exprs);
846 let mut pl_df = lf.collect()?;
847 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
848 Ok(super::DataFrame::from_polars_with_options(
849 pl_df,
850 self.case_sensitive,
851 ))
852 }
853
854 pub fn sum(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
856 self.pivot_agg(value_col, polars::prelude::Expr::sum)
857 }
858
859 pub fn avg(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
861 self.pivot_agg(value_col, polars::prelude::Expr::mean)
862 }
863
864 pub fn min(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
866 self.pivot_agg(value_col, polars::prelude::Expr::min)
867 }
868
869 pub fn max(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
871 self.pivot_agg(value_col, polars::prelude::Expr::max)
872 }
873
874 pub fn count(&self) -> Result<DataFrame, PolarsError> {
876 use polars::prelude::*;
877 let pivot_vals = self.pivot_values()?;
878 if pivot_vals.is_empty() {
879 let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
880 let lf = self.lf.clone().group_by(by).agg(vec![]);
881 let pl_df = lf.collect()?;
882 return Ok(super::DataFrame::from_polars_with_options(
883 pl_df,
884 self.case_sensitive,
885 ));
886 }
887 let mut agg_exprs: Vec<Expr> = Vec::with_capacity(pivot_vals.len());
888 use polars::prelude::DataType;
889 let pivot_resolved = self.resolve_column(&self.pivot_col)?;
890 for v in &pivot_vals {
891 let pred = if v == "null" {
892 col(pivot_resolved.as_str()).is_null()
893 } else {
894 col(pivot_resolved.as_str())
895 .cast(DataType::String)
896 .eq(lit(v.as_str()))
897 };
898 let expr = when(pred).then(lit(1)).otherwise(lit(NULL));
899 let has_any = expr
900 .clone()
901 .is_not_null()
902 .cast(DataType::UInt32)
903 .sum()
904 .gt(lit(0));
905 let agg_expr = when(has_any)
906 .then(expr.sum())
907 .otherwise(lit(NULL))
908 .alias(v.as_str());
909 agg_exprs.push(agg_expr);
910 }
911 let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
912 let lf = self.lf.clone().group_by(by).agg(agg_exprs);
913 let mut pl_df = lf.collect()?;
914 pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
915 Ok(super::DataFrame::from_polars_with_options(
916 pl_df,
917 self.case_sensitive,
918 ))
919 }
920}
921
922pub struct CubeRollupData {
924 pub(super) lf: LazyFrame,
925 pub(super) grouping_cols: Vec<String>,
926 pub(super) case_sensitive: bool,
927 pub(super) is_cube: bool,
928}
929
930impl CubeRollupData {
931 pub fn count(&self) -> Result<DataFrame, PolarsError> {
933 use polars::prelude::*;
934 self.agg(vec![len().alias("count")])
935 }
936
937 pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
940 use polars::prelude::*;
941 let aggregations = disambiguate_agg_output_names(aggregations);
942 let subsets: Vec<Vec<String>> = if self.is_cube {
943 let n = self.grouping_cols.len();
945 (0..1 << n)
946 .map(|mask| {
947 self.grouping_cols
948 .iter()
949 .enumerate()
950 .filter(|(i, _)| (mask & (1 << i)) != 0)
951 .map(|(_, c)| c.clone())
952 .collect()
953 })
954 .collect()
955 } else {
956 (0..=self.grouping_cols.len())
958 .map(|len| self.grouping_cols[..len].to_vec())
959 .collect()
960 };
961
962 let schema = self.lf.clone().collect_schema()?;
963 let mut parts: Vec<PlDataFrame> = Vec::with_capacity(subsets.len());
964 for subset in subsets {
965 if subset.is_empty() {
966 let lf = self.lf.clone().select(&aggregations);
968 let mut part = lf.collect()?;
969 let n = part.height();
970 for gc in &self.grouping_cols {
971 let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
972 let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
973 part.with_column(null_series.into())?;
974 }
975 let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
977 for name in part.get_column_names() {
978 if !self.grouping_cols.iter().any(|g| g == name) {
979 order.push(name);
980 }
981 }
982 part = part.select(order)?;
983 parts.push(part);
984 } else {
985 let grouped = self
986 .lf
987 .clone()
988 .group_by(subset.iter().map(|s| col(s.as_str())).collect::<Vec<_>>());
989 let mut part = grouped.agg(aggregations.clone()).collect()?;
990 part = reorder_groupby_columns(&mut part, &subset)?;
991 let n = part.height();
992 for gc in &self.grouping_cols {
993 if subset.iter().any(|s| s == gc) {
994 continue;
995 }
996 let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
997 let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
998 part.with_column(null_series.into())?;
999 }
1000 let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
1001 for name in part.get_column_names() {
1002 if !self.grouping_cols.iter().any(|g| g == name) {
1003 order.push(name);
1004 }
1005 }
1006 part = part.select(order)?;
1007 parts.push(part);
1008 }
1009 }
1010
1011 if parts.is_empty() {
1012 return Ok(super::DataFrame::from_polars_with_options(
1013 PlDataFrame::empty(),
1014 self.case_sensitive,
1015 ));
1016 }
1017 let order: Vec<String> = parts[0]
1018 .schema()
1019 .iter_names()
1020 .map(|s| s.to_string())
1021 .collect();
1022 for p in parts.iter_mut().skip(1) {
1023 *p = p.select(order.as_slice())?;
1024 }
1025 let lazy_frames: Vec<_> = parts.into_iter().map(|p| p.lazy()).collect();
1026 let out = polars::prelude::concat(lazy_frames, UnionArgs::default())?.collect()?;
1027 Ok(super::DataFrame::from_polars_with_options(
1028 out,
1029 self.case_sensitive,
1030 ))
1031 }
1032}
1033
1034fn null_series_for_dtype(name: &str, n: usize, dtype: &DataType) -> Result<Series, PolarsError> {
1035 let name = name.into();
1036 let s = match dtype {
1037 DataType::Int32 => Series::new(name, vec![None::<i32>; n]),
1038 DataType::Int64 => Series::new(name, vec![None::<i64>; n]),
1039 DataType::Float32 => Series::new(name, vec![None::<f32>; n]),
1040 DataType::Float64 => Series::new(name, vec![None::<f64>; n]),
1041 DataType::String => {
1042 let v: Vec<Option<String>> = (0..n).map(|_| None).collect();
1043 Series::new(name, v)
1044 }
1045 DataType::Boolean => Series::new(name, vec![None::<bool>; n]),
1046 DataType::Date => Series::new(name, vec![None::<i32>; n]).cast(dtype)?,
1047 DataType::Datetime(_, _) => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
1048 _ => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
1049 };
1050 Ok(s)
1051}
1052
1053pub(super) fn reorder_groupby_columns(
1055 pl_df: &mut PlDataFrame,
1056 grouping_cols: &[String],
1057) -> Result<PlDataFrame, PolarsError> {
1058 let all_cols: Vec<String> = pl_df
1059 .get_column_names()
1060 .iter()
1061 .map(|s| s.to_string())
1062 .collect();
1063 let mut reordered_cols: Vec<&str> = Vec::new();
1064 for gc in grouping_cols {
1065 if all_cols.iter().any(|c| c == gc) {
1066 reordered_cols.push(gc);
1067 }
1068 }
1069 for col_name in &all_cols {
1070 if !grouping_cols.iter().any(|gc| gc == col_name) {
1071 reordered_cols.push(col_name);
1072 }
1073 }
1074 if !reordered_cols.is_empty() && reordered_cols.len() == all_cols.len() {
1075 pl_df.select(reordered_cols)
1076 } else {
1077 Ok(pl_df.clone())
1078 }
1079}
1080
1081#[cfg(test)]
1082mod tests {
1083 use crate::{DataFrame, SparkSession, functions};
1084
1085 fn test_df() -> DataFrame {
1086 let spark = SparkSession::builder()
1087 .app_name("agg_tests")
1088 .get_or_create();
1089 let tuples = vec![
1090 (1i64, 10i64, "a".to_string()),
1091 (1i64, 20i64, "a".to_string()),
1092 (2i64, 30i64, "b".to_string()),
1093 ];
1094 spark
1095 .create_dataframe(tuples, vec!["k", "v", "label"])
1096 .unwrap()
1097 }
1098
1099 #[test]
1100 fn group_by_count_single_group() {
1101 let df = test_df();
1102 let grouped = df.group_by(vec!["k"]).unwrap();
1103 let out = grouped.count().unwrap();
1104 assert_eq!(out.count().unwrap(), 2);
1105 let cols = out.columns().unwrap();
1106 assert!(cols.contains(&"k".to_string()));
1107 assert!(cols.contains(&"count".to_string()));
1108 }
1109
1110 #[test]
1111 fn group_by_sum() {
1112 let df = test_df();
1113 let grouped = df.group_by(vec!["k"]).unwrap();
1114 let out = grouped.sum("v").unwrap();
1115 assert_eq!(out.count().unwrap(), 2);
1116 let cols = out.columns().unwrap();
1117 assert!(cols.iter().any(|c| c.starts_with("sum(")));
1118 }
1119
1120 #[test]
1121 fn group_by_empty_groups() {
1122 let spark = SparkSession::builder()
1123 .app_name("agg_tests")
1124 .get_or_create();
1125 let tuples: Vec<(i64, i64, String)> = vec![];
1126 let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
1127 let grouped = df.group_by(vec!["a"]).unwrap();
1128 let out = grouped.count().unwrap();
1129 assert_eq!(out.count().unwrap(), 0);
1130 }
1131
1132 #[test]
1133 fn group_by_agg_multi() {
1134 let df = test_df();
1135 let grouped = df.group_by(vec!["k"]).unwrap();
1136 let out = grouped
1137 .agg(vec![
1138 polars::prelude::len().alias("cnt"),
1139 polars::prelude::col("v").sum().alias("total"),
1140 ])
1141 .unwrap();
1142 assert_eq!(out.count().unwrap(), 2);
1143 let cols = out.columns().unwrap();
1144 assert!(cols.contains(&"k".to_string()));
1145 assert!(cols.contains(&"cnt".to_string()));
1146 assert!(cols.contains(&"total".to_string()));
1147 }
1148
1149 #[test]
1150 fn group_by_agg_columns_multi() {
1151 let df = test_df();
1152 let grouped = df.group_by(vec!["k"]).unwrap();
1153 let v_col = functions::col("v");
1154 let aggs = vec![functions::count(&v_col), functions::sum(&v_col)];
1155 let out = grouped.agg_columns(aggs).unwrap();
1156 assert_eq!(out.count().unwrap(), 2);
1157 let cols = out.columns().unwrap();
1158 assert!(cols.contains(&"k".to_string()));
1159 assert_eq!(cols.len(), 3);
1160 }
1161}