use super::DataFrame;
use crate::column::Column;
use crate::dataframe::resolve_column_with_schema as resolve_column_with_schema_df;
use polars::prelude::{
DataFrame as PlDataFrame, DataType, Expr, LazyFrame, LazyGroupBy, NULL, NamedFrom, PlSmallStr,
PolarsError, Schema, SchemaNamesAndDtypes, Series, col, len, lit, when,
};
use polars_plan::dsl::AggExpr;
use std::collections::HashMap;
fn pyspark_style_cast_agg_name(expr: &Expr) -> Option<(String, String)> {
let mut top: &Expr = expr;
while let Expr::Alias(e, _) = top {
top = e.as_ref();
}
let Expr::Cast {
expr: cast_inner,
dtype,
..
} = top
else {
return None;
};
let inner_name = parse_pivot_agg_expr(cast_inner.as_ref())
.map(|(alias, _, _)| alias)
.unwrap_or_else(|| {
polars_plan::utils::expr_output_name(cast_inner.as_ref())
.map(|s| s.to_string())
.unwrap_or_else(|_| "?".to_string())
});
let type_str = match dtype.as_literal() {
Some(DataType::String) => "STRING",
Some(DataType::Int32) => "INT",
Some(DataType::Int64) => "LONG",
Some(DataType::Float32) => "FLOAT",
Some(DataType::Float64) => "DOUBLE",
Some(DataType::Boolean) => "BOOLEAN",
_ => return None,
};
let cast_name = format!("CAST({inner_name} AS {type_str})");
Some((cast_name, inner_name))
}
pub(crate) fn disambiguate_agg_output_names(aggregations: Vec<Expr>) -> Vec<Expr> {
let mut name_count: HashMap<String, u32> = HashMap::new();
aggregations
.into_iter()
.map(|e| {
let (base_name, used_cast_name): (String, bool) = if let Expr::Alias(_, name) = &e {
let user_name = name.to_string();
match pyspark_style_cast_agg_name(&e) {
Some((cast_name, inner_name)) if user_name == inner_name => (cast_name, true),
_ => (user_name, false),
}
} else {
let current_name = polars_plan::utils::expr_output_name(&e)
.map(|s| s.to_string())
.unwrap_or_else(|_| "_".to_string());
match pyspark_style_cast_agg_name(&e) {
Some((cast_name, inner_name)) => {
if current_name == inner_name {
(cast_name, true)
} else {
(current_name, false)
}
}
None => (current_name, false),
}
};
let count = name_count.entry(base_name.clone()).or_insert(0);
*count += 1;
let final_name = if *count == 1 {
base_name.clone()
} else {
format!("{}_{}", base_name, *count - 1)
};
let needs_alias = used_cast_name || *count > 1;
if needs_alias {
e.alias(final_name.as_str())
} else {
e
}
})
.collect()
}
pub struct GroupedData {
pub(crate) lf: LazyFrame,
pub(crate) lazy_grouped: LazyGroupBy,
pub(crate) grouping_cols: Vec<String>,
pub(crate) case_sensitive: bool,
}
impl GroupedData {
fn resolve_column(&self, name: &str) -> Result<String, PolarsError> {
let schema = self.lf.clone().collect_schema()?;
resolve_column_with_schema_df(
name,
&schema,
self.case_sensitive,
Some(&self.grouping_cols),
)
}
fn resolve_column_with_schema(
&self,
name: &str,
schema: &Schema,
) -> Result<String, PolarsError> {
resolve_column_with_schema_df(name, schema, self.case_sensitive, Some(&self.grouping_cols))
}
#[allow(dead_code)]
fn resolve_expr_column_names(&self, expr: Expr) -> Result<Expr, PolarsError> {
let schema = self.lf.clone().collect_schema()?;
self.resolve_expr_column_names_with_schema(expr, &schema)
}
fn resolve_expr_column_names_with_schema(
&self,
expr: Expr,
schema: &Schema,
) -> Result<Expr, PolarsError> {
use std::collections::HashSet;
let mut alias_output_names: HashSet<String> = HashSet::new();
let _ = expr.clone().try_map_expr(|e| {
if let Expr::Alias(_, name) = &e {
alias_output_names.insert(name.as_str().to_string());
}
Ok(e)
})?;
let gd = self;
let schema = schema.clone();
expr.try_map_expr(move |e| {
if let Expr::Column(name) = &e {
let name_str = name.as_str();
if alias_output_names.contains(name_str) {
return Ok(e);
}
if name_str.is_empty() {
return Ok(e);
}
if name_str.contains('.') {
let parts: Vec<&str> = name_str.split('.').collect();
let first = parts[0];
let rest = &parts[1..];
if rest.is_empty() {
return Err(PolarsError::ColumnNotFound(
format!(
"cannot resolve: Column '{}': trailing dot not allowed",
name_str
)
.into(),
));
}
let resolved = gd.resolve_column_with_schema(first, &schema)?;
let mut expr = col(PlSmallStr::from(resolved.as_str()));
for field in rest {
expr = expr.struct_().field_by_name(field);
}
return Ok(expr);
}
let resolved = gd.resolve_column_with_schema(name_str, &schema)?;
return Ok(Expr::Column(PlSmallStr::from(resolved.as_str())));
}
Ok(e)
})
}
pub fn count(&self) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let agg_expr = vec![len().cast(DataType::Int64).alias("count")];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn sum(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![col(c.as_str()).sum().alias(format!("sum({column})"))];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
let all_cols: Vec<String> = pl_df
.get_column_names()
.iter()
.map(|s| s.to_string())
.collect();
let grouping_cols: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let mut reordered_cols: Vec<&str> = Vec::new();
for gc in &grouping_cols {
if all_cols.iter().any(|c| c == gc) {
reordered_cols.push(gc);
}
}
for col_name in &all_cols {
if !grouping_cols.iter().any(|gc| *gc == col_name) {
reordered_cols.push(col_name);
}
}
if !reordered_cols.is_empty() {
pl_df = pl_df.select(reordered_cols)?;
}
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn avg(&self, columns: &[&str]) -> Result<DataFrame, PolarsError> {
if columns.is_empty() {
return Err(PolarsError::ComputeError(
"avg requires at least one column".into(),
));
}
use polars::prelude::*;
let agg_expr: Vec<Expr> = columns
.iter()
.map(|c| {
let resolved = self.resolve_column(c)?;
Ok(col(resolved.as_str()).mean().alias(format!("avg({c})")))
})
.collect::<Result<Vec<_>, PolarsError>>()?;
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn min(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![col(c.as_str()).min().alias(format!("min({column})"))];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn max(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![col(c.as_str()).max().alias(format!("max({column})"))];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn first(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![col(c.as_str()).first().alias(format!("first({column})"))];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn last(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![col(c.as_str()).last().alias(format!("last({column})"))];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn approx_count_distinct(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::{DataType, col};
let c = self.resolve_column(column)?;
let agg_expr = vec![
col(c.as_str())
.n_unique()
.cast(DataType::Int64)
.alias(format!("approx_count_distinct({column})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn any_value(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![
col(c.as_str())
.first()
.alias(format!("any_value({column})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn bool_and(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![
col(c.as_str())
.all(true)
.alias(format!("bool_and({column})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn bool_or(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![
col(c.as_str())
.any(true)
.alias(format!("bool_or({column})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn product(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![
col(c.as_str())
.product()
.alias(format!("product({column})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn collect_list(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![
col(c.as_str())
.implode()
.alias(format!("collect_list({column})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn collect_set(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![
col(c.as_str())
.unique()
.implode()
.alias(format!("collect_set({column})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn count_if(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![
col(c.as_str())
.cast(DataType::Int64)
.sum()
.alias(format!("count_if({column})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn percentile(&self, column: &str, p: f64) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![
col(c.as_str())
.quantile(lit(p), QuantileMethod::Linear)
.alias(format!("percentile({column}, {p})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn max_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let vc = self.resolve_column(value_col)?;
let oc = self.resolve_column(ord_col)?;
let st = as_struct(vec![
col(oc.as_str()).alias("_ord"),
col(vc.as_str()).alias("_val"),
]);
let agg_expr = vec![
st.sort(SortOptions::default().with_order_descending(true))
.first()
.struct_()
.field_by_name("_val")
.alias(format!("max_by({value_col}, {ord_col})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn min_by(&self, value_col: &str, ord_col: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let vc = self.resolve_column(value_col)?;
let oc = self.resolve_column(ord_col)?;
let st = as_struct(vec![
col(oc.as_str()).alias("_ord"),
col(vc.as_str()).alias("_val"),
]);
let agg_expr = vec![
st.sort(SortOptions::default())
.first()
.struct_()
.field_by_name("_val")
.alias(format!("min_by({value_col}, {ord_col})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn covar_pop(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::DataType;
let c1_res = self.resolve_column(col1)?;
let c2_res = self.resolve_column(col2)?;
let c1 = col(c1_res.as_str()).cast(DataType::Float64);
let c2 = col(c2_res.as_str()).cast(DataType::Float64);
let n = len().cast(DataType::Float64);
let sum_ab = (c1.clone() * c2.clone()).sum();
let sum_a = col(c1_res.as_str()).sum().cast(DataType::Float64);
let sum_b = col(c2_res.as_str()).sum().cast(DataType::Float64);
let cov = (sum_ab - sum_a * sum_b / n.clone()) / n;
let agg_expr = vec![cov.alias(format!("covar_pop({col1}, {col2})"))];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn covar_samp(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::DataType;
let c1_res = self.resolve_column(col1)?;
let c2_res = self.resolve_column(col2)?;
let c1 = col(c1_res.as_str()).cast(DataType::Float64);
let c2 = col(c2_res.as_str()).cast(DataType::Float64);
let n = len().cast(DataType::Float64);
let sum_ab = (c1.clone() * c2.clone()).sum();
let sum_a = col(c1_res.as_str()).sum().cast(DataType::Float64);
let sum_b = col(c2_res.as_str()).sum().cast(DataType::Float64);
let cov = when(len().gt(lit(1)))
.then((sum_ab - sum_a * sum_b / n.clone()) / (len() - lit(1)).cast(DataType::Float64))
.otherwise(lit(f64::NAN));
let agg_expr = vec![cov.alias(format!("covar_samp({col1}, {col2})"))];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn corr(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::DataType;
let c1_res = self.resolve_column(col1)?;
let c2_res = self.resolve_column(col2)?;
let c1 = col(c1_res.as_str()).cast(DataType::Float64);
let c2 = col(c2_res.as_str()).cast(DataType::Float64);
let n = len().cast(DataType::Float64);
let n1 = (len() - lit(1)).cast(DataType::Float64);
let sum_ab = (c1.clone() * c2.clone()).sum();
let sum_a = col(c1_res.as_str()).sum().cast(DataType::Float64);
let sum_b = col(c2_res.as_str()).sum().cast(DataType::Float64);
let sum_a2 = (c1.clone() * c1).sum();
let sum_b2 = (c2.clone() * c2).sum();
let cov_samp = (sum_ab - sum_a.clone() * sum_b.clone() / n.clone()) / n1.clone();
let var_a = (sum_a2 - sum_a.clone() * sum_a / n.clone()) / n1.clone();
let var_b = (sum_b2 - sum_b.clone() * sum_b / n.clone()) / n1.clone();
let std_a = var_a.sqrt();
let std_b = var_b.sqrt();
let corr_expr = when(len().gt(lit(1)))
.then(cov_samp / (std_a * std_b))
.otherwise(lit(f64::NAN));
let agg_expr = vec![corr_expr.alias(format!("corr({col1}, {col2})"))];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn regr_count(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
let yc = self.resolve_column(y_col)?;
let xc = self.resolve_column(x_col)?;
let agg_expr = vec![
crate::functions::regr_count_expr(yc.as_str(), xc.as_str())
.alias(format!("regr_count({y_col}, {x_col})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn regr_avgx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
let yc = self.resolve_column(y_col)?;
let xc = self.resolve_column(x_col)?;
let agg_expr = vec![
crate::functions::regr_avgx_expr(yc.as_str(), xc.as_str())
.alias(format!("regr_avgx({y_col}, {x_col})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn regr_avgy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
let yc = self.resolve_column(y_col)?;
let xc = self.resolve_column(x_col)?;
let agg_expr = vec![
crate::functions::regr_avgy_expr(yc.as_str(), xc.as_str())
.alias(format!("regr_avgy({y_col}, {x_col})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn regr_slope(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
let yc = self.resolve_column(y_col)?;
let xc = self.resolve_column(x_col)?;
let agg_expr = vec![
crate::functions::regr_slope_expr(yc.as_str(), xc.as_str())
.alias(format!("regr_slope({y_col}, {x_col})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn regr_intercept(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
let yc = self.resolve_column(y_col)?;
let xc = self.resolve_column(x_col)?;
let agg_expr = vec![
crate::functions::regr_intercept_expr(yc.as_str(), xc.as_str())
.alias(format!("regr_intercept({y_col}, {x_col})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn regr_r2(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
let yc = self.resolve_column(y_col)?;
let xc = self.resolve_column(x_col)?;
let agg_expr = vec![
crate::functions::regr_r2_expr(yc.as_str(), xc.as_str())
.alias(format!("regr_r2({y_col}, {x_col})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn regr_sxx(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
let yc = self.resolve_column(y_col)?;
let xc = self.resolve_column(x_col)?;
let agg_expr = vec![
crate::functions::regr_sxx_expr(yc.as_str(), xc.as_str())
.alias(format!("regr_sxx({y_col}, {x_col})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn regr_syy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
let yc = self.resolve_column(y_col)?;
let xc = self.resolve_column(x_col)?;
let agg_expr = vec![
crate::functions::regr_syy_expr(yc.as_str(), xc.as_str())
.alias(format!("regr_syy({y_col}, {x_col})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn regr_sxy(&self, y_col: &str, x_col: &str) -> Result<DataFrame, PolarsError> {
let yc = self.resolve_column(y_col)?;
let xc = self.resolve_column(x_col)?;
let agg_expr = vec![
crate::functions::regr_sxy_expr(yc.as_str(), xc.as_str())
.alias(format!("regr_sxy({y_col}, {x_col})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn kurtosis(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![
col(c.as_str())
.cast(DataType::Float64)
.kurtosis(true, true)
.alias(format!("kurtosis({column})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn skewness(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let c = self.resolve_column(column)?;
let agg_expr = vec![
col(c.as_str())
.cast(DataType::Float64)
.skew(true)
.alias(format!("skewness({column})")),
];
let lf = self.lazy_grouped.clone().agg(agg_expr);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
let schema = self.lf.clone().collect_schema()?;
let resolved: Vec<Expr> = aggregations
.into_iter()
.map(|e| self.resolve_expr_column_names_with_schema(e, &schema))
.collect::<Result<Vec<_>, _>>()?;
let disambiguated = disambiguate_agg_output_names(resolved);
use polars::prelude::*;
let mut lf = self.lazy_grouped.clone().agg(disambiguated);
if !self.grouping_cols.is_empty() {
let sort_exprs: Vec<Expr> =
self.grouping_cols.iter().map(|g| col(g.as_str())).collect();
let descending = vec![false; sort_exprs.len()];
let nulls_last = vec![false; sort_exprs.len()];
let opts = SortMultipleOptions::new()
.with_order_descending_multi(descending)
.with_nulls_last_multi(nulls_last);
lf = lf.sort_by_exprs(sort_exprs, opts);
}
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn agg_columns(&self, aggregations: Vec<Column>) -> Result<DataFrame, PolarsError> {
let exprs: Vec<Expr> = aggregations.into_iter().map(|c| c.into_expr()).collect();
self.agg(exprs)
}
pub fn grouping_columns(&self) -> &[String] {
&self.grouping_cols
}
pub fn pivot(&self, pivot_col: &str, values: Option<Vec<String>>) -> PivotedGroupedData {
PivotedGroupedData {
lf: self.lf.clone(),
grouping_cols: self.grouping_cols.clone(),
pivot_col: pivot_col.to_string(),
values,
case_sensitive: self.case_sensitive,
}
}
}
pub struct PivotedGroupedData {
pub(crate) lf: LazyFrame,
pub(crate) grouping_cols: Vec<String>,
pub(crate) pivot_col: String,
pub(crate) values: Option<Vec<String>>,
pub(crate) case_sensitive: bool,
}
fn pivot_agg_input_column_name(input_expr: &Expr) -> Option<String> {
match input_expr {
Expr::Column(name) => Some(name.as_str().to_string()),
Expr::Cast { expr: inner, .. } => pivot_agg_input_column_name(inner.as_ref()),
Expr::Function { input, .. } => {
let first = input.first()?;
pivot_agg_input_column_name(first)
}
_ => None,
}
}
fn parse_pivot_agg_expr(expr: &Expr) -> Option<(String, &'static str, String)> {
let mut inner = expr;
let mut alias_opt: Option<String> = None;
while let Expr::Alias(e, name) = inner {
alias_opt = Some(name.as_str().to_string());
inner = e.as_ref();
}
let inner = match inner {
Expr::Cast { expr: e, .. } => e.as_ref(),
other => other,
};
let (agg_kind, input_expr) = match inner {
Expr::Agg(agg) => match agg {
AggExpr::Sum(e) => ("sum", e.as_ref()),
AggExpr::Mean(e) => ("avg", e.as_ref()),
AggExpr::Min { input: e, .. } => ("min", e.as_ref()),
AggExpr::Max { input: e, .. } => ("max", e.as_ref()),
AggExpr::NUnique(e) => ("count_distinct", e.as_ref()),
AggExpr::First(e) => ("first", e.as_ref()),
AggExpr::Last(e) => ("last", e.as_ref()),
AggExpr::Implode(e) => {
let kind = if alias_opt
.as_ref()
.is_some_and(|a| a.contains("collect_set"))
{
"collect_set"
} else {
"collect_list"
};
(kind, e.as_ref())
}
AggExpr::Std(e, _) => ("stddev", e.as_ref()),
AggExpr::Var(e, _) => ("variance", e.as_ref()),
_ => return None,
},
_ => return None,
};
let value_col = pivot_agg_input_column_name(input_expr)?;
let alias = alias_opt.unwrap_or_else(|| format!("{}({})", agg_kind, value_col));
Some((alias, agg_kind, value_col))
}
fn pivot_value_to_column_name(av: polars::prelude::AnyValue<'_>) -> String {
use polars::prelude::AnyValue;
match av {
AnyValue::Null => "null".to_string(),
AnyValue::String(s) => s.to_string(),
_ => av.to_string(),
}
}
fn pivot_values_from_lf(lf: &LazyFrame, pivot_col: &str) -> Result<Vec<String>, PolarsError> {
use polars::prelude::*;
let pl_df = lf
.clone()
.select([col(pivot_col)])
.unique(None, Default::default())
.collect()?;
let s = pl_df.column(pivot_col)?;
let mut out = Vec::with_capacity(s.len());
for i in 0..s.len() {
let av = s.get(i)?;
out.push(pivot_value_to_column_name(av));
}
out.sort();
Ok(out)
}
impl PivotedGroupedData {
fn resolve_column(&self, name: &str) -> Result<String, PolarsError> {
let schema = self.lf.clone().collect_schema()?;
let names: Vec<String> = schema
.iter_names_and_dtypes()
.map(|(n, _)| n.to_string())
.collect();
if self.case_sensitive {
if names.iter().any(|n| n == name) {
return Ok(name.to_string());
}
} else {
let name_lower = name.to_lowercase();
for n in &names {
if n.to_lowercase() == name_lower {
return Ok(n.clone());
}
}
}
let available = names.join(", ");
Err(PolarsError::ColumnNotFound(
format!(
"cannot resolve: column '{}' not found in pivot DataFrame. Available: [{}].",
name, available
)
.into(),
))
}
fn pivot_values(&self) -> Result<Vec<String>, PolarsError> {
if let Some(ref v) = self.values {
return Ok(v.clone());
}
let resolved = self.resolve_column(&self.pivot_col)?;
pivot_values_from_lf(&self.lf, &resolved)
}
fn pivot_agg<F>(&self, value_col: &str, agg_fn: F) -> Result<DataFrame, PolarsError>
where
F: Fn(Expr) -> Expr,
{
use polars::prelude::*;
let pivot_resolved = self.resolve_column(&self.pivot_col)?;
let value_resolved = self.resolve_column(value_col)?;
let pivot_vals = self.pivot_values()?;
if pivot_vals.is_empty() {
let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let lf = self.lf.clone().group_by(by).agg(vec![]);
let pl_df = lf.collect()?;
return Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
));
}
let mut agg_exprs: Vec<Expr> = Vec::with_capacity(pivot_vals.len());
use polars::prelude::DataType;
for v in &pivot_vals {
let pred = if v == "null" {
col(pivot_resolved.as_str()).is_null()
} else {
col(pivot_resolved.as_str())
.cast(DataType::String)
.eq(lit(v.as_str()))
};
let then_expr = col(value_resolved.as_str());
let expr = when(pred).then(then_expr).otherwise(lit(NULL));
let has_any = expr
.clone()
.is_not_null()
.cast(DataType::UInt32)
.sum()
.gt(lit(0));
let agg_expr = when(has_any)
.then(agg_fn(expr))
.otherwise(lit(NULL))
.alias(v.as_str());
agg_exprs.push(agg_expr);
}
let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let lf = self.lf.clone().group_by(by).agg(agg_exprs);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
pl_df = reorder_pivot_columns(&pl_df, &self.grouping_cols, &pivot_vals)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn sum(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
self.pivot_agg(value_col, polars::prelude::Expr::sum)
}
pub fn avg(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
self.pivot_agg(value_col, polars::prelude::Expr::mean)
}
pub fn min(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
self.pivot_agg(value_col, polars::prelude::Expr::min)
}
pub fn max(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
self.pivot_agg(value_col, polars::prelude::Expr::max)
}
pub fn count(&self) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let pivot_vals = self.pivot_values()?;
if pivot_vals.is_empty() {
let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let lf = self.lf.clone().group_by(by).agg(vec![]);
let pl_df = lf.collect()?;
return Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
));
}
let mut agg_exprs: Vec<Expr> = Vec::with_capacity(pivot_vals.len());
use polars::prelude::DataType;
let pivot_resolved = self.resolve_column(&self.pivot_col)?;
for v in &pivot_vals {
let pred = if v == "null" {
col(pivot_resolved.as_str()).is_null()
} else {
col(pivot_resolved.as_str())
.cast(DataType::String)
.eq(lit(v.as_str()))
};
let expr = when(pred).then(lit(1)).otherwise(lit(NULL));
let has_any = expr
.clone()
.is_not_null()
.cast(DataType::UInt32)
.sum()
.gt(lit(0));
let agg_expr = when(has_any)
.then(expr.sum())
.otherwise(lit(NULL))
.alias(v.as_str());
agg_exprs.push(agg_expr);
}
let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let lf = self.lf.clone().group_by(by).agg(agg_exprs);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
pl_df = reorder_pivot_columns(&pl_df, &self.grouping_cols, &pivot_vals)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn _count_distinct(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
self.pivot_agg(value_col, Expr::n_unique)
}
pub fn _first(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
self.pivot_agg(value_col, Expr::first)
}
pub fn _last(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
self.pivot_agg(value_col, Expr::last)
}
pub fn _stddev(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
self.pivot_agg(value_col, |e| e.std(1))
}
pub fn _variance(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
self.pivot_agg(value_col, |e| e.var(1))
}
pub fn mean(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
self.avg(value_col)
}
pub fn agg(&self, exprs: Vec<Expr>) -> Result<DataFrame, PolarsError> {
let mut parsed = Vec::with_capacity(exprs.len());
for e in &exprs {
match parse_pivot_agg_expr(e) {
Some(t) => parsed.push(t),
None => {
return Err(PolarsError::ComputeError(
"pivot.agg expects expressions like F.sum(\"col\").alias(\"name\"); got unsupported expr"
.to_string()
.into(),
));
}
}
}
let pivot_resolved = self.resolve_column(&self.pivot_col)?;
let pivot_vals = self.pivot_values()?;
if parsed.len() == 1 {
let (alias, agg_kind, value_col) = &parsed[0];
match *agg_kind {
"count_distinct" => return self._count_distinct(value_col),
"first" => return self._first(value_col),
"last" => return self._last(value_col),
"collect_list" => return self._collect_list(value_col),
"collect_set" => return self._collect_set(value_col),
"stddev" => return self._stddev(value_col),
"variance" => return self._variance(value_col),
_ => {}
}
let value_resolved = self.resolve_column(value_col)?;
let col_expr = col(value_resolved.as_str());
let agg_expr = match *agg_kind {
"sum" => col_expr.sum().alias(alias.as_str()),
"avg" => col_expr.mean().alias(alias.as_str()),
"min" => col_expr.min().alias(alias.as_str()),
"max" => col_expr.max().alias(alias.as_str()),
_ => {
return Err(PolarsError::ComputeError(
format!("pivot.agg unsupported agg: {}", agg_kind).into(),
));
}
};
let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let lf = self.lf.clone().group_by(by).agg(vec![agg_expr]);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
return Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
));
}
if pivot_vals.is_empty() {
let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let lf = self.lf.clone().group_by(by).agg(vec![]);
let pl_df = lf.collect()?;
return Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
));
}
let mut agg_exprs: Vec<Expr> = Vec::with_capacity(pivot_vals.len() * parsed.len());
use polars::prelude::DataType;
for v in &pivot_vals {
let pred = if v == "null" {
col(pivot_resolved.as_str()).is_null()
} else {
col(pivot_resolved.as_str())
.cast(DataType::String)
.eq(lit(v.as_str()))
};
for (alias, agg_kind, value_col) in &parsed {
let value_resolved = self.resolve_column(value_col)?;
let then_expr = col(value_resolved.as_str());
let expr = when(pred.clone()).then(then_expr).otherwise(lit(NULL));
let has_any = expr
.clone()
.is_not_null()
.cast(DataType::UInt32)
.sum()
.gt(lit(0));
let aggregated = match *agg_kind {
"sum" => expr.sum(),
"avg" => expr.mean(),
"min" => expr.min(),
"max" => expr.max(),
_ => {
return Err(PolarsError::ComputeError(
format!("pivot.agg unsupported agg: {}", agg_kind).into(),
));
}
};
let agg_expr = when(has_any)
.then(aggregated)
.otherwise(lit(NULL))
.alias(format!("{}_{}", v, alias));
agg_exprs.push(agg_expr);
}
}
let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let lf = self.lf.clone().group_by(by).agg(agg_exprs);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn _collect_list(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let pivot_resolved = self.resolve_column(&self.pivot_col)?;
let value_resolved = self.resolve_column(value_col)?;
let pivot_vals = self.pivot_values()?;
if pivot_vals.is_empty() {
let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let lf = self.lf.clone().group_by(by).agg(vec![]);
let pl_df = lf.collect()?;
return Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
));
}
let mut agg_exprs: Vec<Expr> = Vec::with_capacity(pivot_vals.len());
use polars::prelude::DataType;
for v in &pivot_vals {
let pred = if v == "null" {
col(pivot_resolved.as_str()).is_null()
} else {
col(pivot_resolved.as_str())
.cast(DataType::String)
.eq(lit(v.as_str()))
};
let filtered = col(value_resolved.as_str()).filter(pred.clone());
let has_any = filtered.clone().count().gt(lit(0));
let agg_expr = when(has_any)
.then(filtered.implode())
.otherwise(lit(NULL))
.alias(v.as_str());
agg_exprs.push(agg_expr);
}
let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let lf = self.lf.clone().group_by(by).agg(agg_exprs);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
pub fn _collect_set(&self, value_col: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let pivot_resolved = self.resolve_column(&self.pivot_col)?;
let value_resolved = self.resolve_column(value_col)?;
let pivot_vals = self.pivot_values()?;
if pivot_vals.is_empty() {
let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let lf = self.lf.clone().group_by(by).agg(vec![]);
let pl_df = lf.collect()?;
return Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
));
}
let mut agg_exprs: Vec<Expr> = Vec::with_capacity(pivot_vals.len());
use polars::prelude::DataType;
for v in &pivot_vals {
let pred = if v == "null" {
col(pivot_resolved.as_str()).is_null()
} else {
col(pivot_resolved.as_str())
.cast(DataType::String)
.eq(lit(v.as_str()))
};
let filtered = col(value_resolved.as_str()).filter(pred.clone());
let has_any = filtered.clone().count().gt(lit(0));
let agg_expr = when(has_any)
.then(filtered.unique().implode())
.otherwise(lit(NULL))
.alias(v.as_str());
agg_exprs.push(agg_expr);
}
let by: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
let lf = self.lf.clone().group_by(by).agg(agg_exprs);
let mut pl_df = lf.collect()?;
pl_df = reorder_groupby_columns(&mut pl_df, &self.grouping_cols)?;
Ok(super::DataFrame::from_polars_with_options(
pl_df,
self.case_sensitive,
))
}
}
pub struct CubeRollupData {
pub(super) lf: LazyFrame,
pub(super) grouping_cols: Vec<String>,
pub(super) case_sensitive: bool,
pub(super) is_cube: bool,
}
impl CubeRollupData {
pub fn count(&self) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
self.agg(vec![len().cast(DataType::Int64).alias("count")])
}
pub fn sum(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
self.agg(vec![col(column).sum().alias(format!("sum({column})"))])
}
pub fn avg(&self, columns: &[&str]) -> Result<DataFrame, PolarsError> {
if columns.is_empty() {
return Err(PolarsError::ComputeError(
"avg requires at least one column".into(),
));
}
use polars::prelude::*;
let agg_exprs: Vec<Expr> = columns
.iter()
.map(|c| col(*c).mean().alias(format!("avg({c})")))
.collect();
self.agg(agg_exprs)
}
pub fn min(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
self.agg(vec![col(column).min().alias(format!("min({column})"))])
}
pub fn max(&self, column: &str) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
self.agg(vec![col(column).max().alias(format!("max({column})"))])
}
pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
use polars::prelude::*;
let df = super::DataFrame::from_lazy_with_options(self.lf.clone(), self.case_sensitive);
let aggregations: Vec<Expr> = aggregations
.into_iter()
.map(|e| df.resolve_expr_column_names(e))
.collect::<Result<Vec<_>, _>>()?;
let aggregations = disambiguate_agg_output_names(aggregations);
let subsets: Vec<Vec<String>> = if self.is_cube {
let n = self.grouping_cols.len();
(0..1 << n)
.map(|mask| {
self.grouping_cols
.iter()
.enumerate()
.filter(|(i, _)| (mask & (1 << i)) != 0)
.map(|(_, c)| c.clone())
.collect()
})
.collect()
} else {
(0..=self.grouping_cols.len())
.map(|len| self.grouping_cols[..len].to_vec())
.collect()
};
let schema = self.lf.clone().collect_schema()?;
let mut parts: Vec<PlDataFrame> = Vec::with_capacity(subsets.len());
for subset in subsets {
if subset.is_empty() {
let lf = self.lf.clone().select(&aggregations);
let mut part = lf.collect()?;
let n = part.height();
for gc in &self.grouping_cols {
let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
part.with_column(null_series.into())?;
}
let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
for name in part.get_column_names() {
if !self.grouping_cols.iter().any(|g| g == name) {
order.push(name);
}
}
part = part.select(order)?;
parts.push(part);
} else {
let grouped = self
.lf
.clone()
.group_by(subset.iter().map(|s| col(s.as_str())).collect::<Vec<_>>());
let mut part = grouped.agg(aggregations.clone()).collect()?;
part = reorder_groupby_columns(&mut part, &subset)?;
let n = part.height();
for gc in &self.grouping_cols {
if subset.iter().any(|s| s == gc) {
continue;
}
let dtype = schema.get(gc).cloned().unwrap_or(DataType::Null);
let null_series = null_series_for_dtype(gc.as_str(), n, &dtype)?;
part.with_column(null_series.into())?;
}
let mut order: Vec<&str> = self.grouping_cols.iter().map(|s| s.as_str()).collect();
for name in part.get_column_names() {
if !self.grouping_cols.iter().any(|g| g == name) {
order.push(name);
}
}
part = part.select(order)?;
parts.push(part);
}
}
if parts.is_empty() {
return Ok(super::DataFrame::from_polars_with_options(
PlDataFrame::empty(),
self.case_sensitive,
));
}
let order: Vec<String> = parts[0]
.schema()
.iter_names()
.map(|s| s.to_string())
.collect();
for p in parts.iter_mut().skip(1) {
*p = p.select(order.as_slice())?;
}
let lazy_frames: Vec<_> = parts.into_iter().map(|p| p.lazy()).collect();
let out = polars::prelude::concat(lazy_frames, UnionArgs::default())?.collect()?;
Ok(super::DataFrame::from_polars_with_options(
out,
self.case_sensitive,
))
}
}
fn null_series_for_dtype(name: &str, n: usize, dtype: &DataType) -> Result<Series, PolarsError> {
let name = name.into();
let s = match dtype {
DataType::Int32 => Series::new(name, vec![None::<i32>; n]),
DataType::Int64 => Series::new(name, vec![None::<i64>; n]),
DataType::Float32 => Series::new(name, vec![None::<f32>; n]),
DataType::Float64 => Series::new(name, vec![None::<f64>; n]),
DataType::String => {
let v: Vec<Option<String>> = (0..n).map(|_| None).collect();
Series::new(name, v)
}
DataType::Boolean => Series::new(name, vec![None::<bool>; n]),
DataType::Date => Series::new(name, vec![None::<i32>; n]).cast(dtype)?,
DataType::Datetime(_, _) => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
_ => Series::new(name, vec![None::<i64>; n]).cast(dtype)?,
};
Ok(s)
}
fn reorder_pivot_columns(
pl_df: &PlDataFrame,
grouping_cols: &[String],
pivot_vals: &[String],
) -> Result<PlDataFrame, PolarsError> {
let all_cols: std::collections::HashSet<String> = pl_df
.get_column_names()
.iter()
.map(|s| s.to_string())
.collect();
let mut order: Vec<&str> = Vec::new();
for gc in grouping_cols {
if all_cols.contains(gc) {
order.push(gc);
}
}
for pv in pivot_vals {
if all_cols.contains(pv) {
order.push(pv);
}
}
if order.len() == all_cols.len() {
pl_df.select(order)
} else {
Ok(pl_df.clone())
}
}
pub(super) fn reorder_groupby_columns(
pl_df: &mut PlDataFrame,
grouping_cols: &[String],
) -> Result<PlDataFrame, PolarsError> {
let all_cols: Vec<String> = pl_df
.get_column_names()
.iter()
.map(|s| s.to_string())
.collect();
let mut reordered_cols: Vec<&str> = Vec::new();
for gc in grouping_cols {
if all_cols.iter().any(|c| c == gc) {
reordered_cols.push(gc);
}
}
for col_name in &all_cols {
if !grouping_cols.iter().any(|gc| gc == col_name) {
reordered_cols.push(col_name);
}
}
if !reordered_cols.is_empty() && reordered_cols.len() == all_cols.len() {
pl_df.select(reordered_cols)
} else {
Ok(pl_df.clone())
}
}
#[cfg(test)]
mod tests {
use crate::{DataFrame, SparkSession, functions};
fn test_df() -> DataFrame {
let spark = SparkSession::builder()
.app_name("agg_tests")
.get_or_create();
let tuples = vec![
(1i64, 10i64, "a".to_string()),
(1i64, 20i64, "a".to_string()),
(2i64, 30i64, "b".to_string()),
];
spark
.create_dataframe(tuples, vec!["k", "v", "label"])
.unwrap()
}
#[test]
fn group_by_count_single_group() {
let df = test_df();
let grouped = df.group_by(vec!["k"]).unwrap();
let out = grouped.count().unwrap();
assert_eq!(out.count().unwrap(), 2);
let cols = out.columns().unwrap();
assert!(cols.contains(&"k".to_string()));
assert!(cols.contains(&"count".to_string()));
}
#[test]
fn group_by_sum() {
let df = test_df();
let grouped = df.group_by(vec!["k"]).unwrap();
let out = grouped.sum("v").unwrap();
assert_eq!(out.count().unwrap(), 2);
let cols = out.columns().unwrap();
assert!(cols.iter().any(|c| c.starts_with("sum(")));
}
#[test]
fn group_by_empty_groups() {
let spark = SparkSession::builder()
.app_name("agg_tests")
.get_or_create();
let tuples: Vec<(i64, i64, String)> = vec![];
let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
let grouped = df.group_by(vec!["a"]).unwrap();
let out = grouped.count().unwrap();
assert_eq!(out.count().unwrap(), 0);
}
#[test]
fn group_by_agg_multi() {
let df = test_df();
let grouped = df.group_by(vec!["k"]).unwrap();
let out = grouped
.agg(vec![
polars::prelude::len().alias("cnt"),
polars::prelude::col("v").sum().alias("total"),
])
.unwrap();
assert_eq!(out.count().unwrap(), 2);
let cols = out.columns().unwrap();
assert!(cols.contains(&"k".to_string()));
assert!(cols.contains(&"cnt".to_string()));
assert!(cols.contains(&"total".to_string()));
}
#[test]
fn group_by_agg_columns_multi() {
let df = test_df();
let grouped = df.group_by(vec!["k"]).unwrap();
let v_col = functions::col("v");
let aggs = vec![functions::count(&v_col), functions::sum(&v_col)];
let out = grouped.agg_columns(aggs).unwrap();
assert_eq!(out.count().unwrap(), 2);
let cols = out.columns().unwrap();
assert!(cols.contains(&"k".to_string()));
assert_eq!(cols.len(), 3);
}
}