use crate::core::error::{Error, Result};
use crate::dataframe::base::DataFrame;
use crate::series::Series;
use std::collections::HashMap;
pub struct DataFrameGroupBy<'a> {
df: &'a DataFrame,
group_columns: Vec<String>,
groups: HashMap<String, Vec<usize>>,
group_keys: HashMap<String, Vec<String>>,
}
impl<'a> DataFrameGroupBy<'a> {
pub fn new(df: &'a DataFrame, by: &[&str]) -> Result<Self> {
if by.is_empty() {
return Err(Error::InvalidValue(
"GroupBy requires at least one column".to_string(),
));
}
for col in by {
if !df.contains_column(col) {
return Err(Error::InvalidValue(format!(
"Column '{}' not found in DataFrame",
col
)));
}
}
let group_columns: Vec<String> = by.iter().map(|s| s.to_string()).collect();
let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
let mut group_keys: HashMap<String, Vec<String>> = HashMap::new();
let row_count = df.row_count();
for row_idx in 0..row_count {
let mut key_parts: Vec<String> = Vec::new();
for col in &group_columns {
let value = if let Ok(values) = df.get_column_string_values(col) {
values.get(row_idx).cloned().unwrap_or_default()
} else if let Ok(values) = df.get_column_numeric_values(col) {
let v = values.get(row_idx).copied().unwrap_or(f64::NAN);
if v.is_nan() {
"NaN".to_string()
} else {
v.to_string()
}
} else {
"".to_string()
};
key_parts.push(value);
}
let key = key_parts.join("|||");
groups
.entry(key.clone())
.or_insert_with(Vec::new)
.push(row_idx);
group_keys.entry(key).or_insert(key_parts);
}
Ok(Self {
df,
group_columns,
groups,
group_keys,
})
}
pub fn ngroups(&self) -> usize {
self.groups.len()
}
pub fn size(&self) -> Result<DataFrame> {
let mut result = DataFrame::new();
let mut group_col_values: Vec<Vec<String>> = vec![Vec::new(); self.group_columns.len()];
let mut sizes: Vec<f64> = Vec::new();
for (key, indices) in &self.groups {
if let Some(key_values) = self.group_keys.get(key) {
for (i, val) in key_values.iter().enumerate() {
group_col_values[i].push(val.clone());
}
}
sizes.push(indices.len() as f64);
}
for (i, col_name) in self.group_columns.iter().enumerate() {
result.add_column(
col_name.clone(),
Series::new(group_col_values[i].clone(), Some(col_name.clone()))?,
)?;
}
result.add_column(
"size".to_string(),
Series::new(sizes, Some("size".to_string()))?,
)?;
Ok(result)
}
pub fn count(&self) -> Result<DataFrame> {
self.size()
}
pub fn sum(&self) -> Result<DataFrame> {
self.aggregate(|values| values.iter().filter(|v| !v.is_nan()).sum())
}
pub fn mean(&self) -> Result<DataFrame> {
self.aggregate(|values| {
let valid: Vec<f64> = values.iter().filter(|v| !v.is_nan()).copied().collect();
if valid.is_empty() {
f64::NAN
} else {
valid.iter().sum::<f64>() / valid.len() as f64
}
})
}
pub fn min(&self) -> Result<DataFrame> {
self.aggregate(|values| {
values
.iter()
.filter(|v| !v.is_nan())
.copied()
.fold(f64::INFINITY, f64::min)
})
}
pub fn max(&self) -> Result<DataFrame> {
self.aggregate(|values| {
values
.iter()
.filter(|v| !v.is_nan())
.copied()
.fold(f64::NEG_INFINITY, f64::max)
})
}
pub fn std(&self) -> Result<DataFrame> {
self.aggregate(|values| {
let valid: Vec<f64> = values.iter().filter(|v| !v.is_nan()).copied().collect();
if valid.len() <= 1 {
f64::NAN
} else {
let mean = valid.iter().sum::<f64>() / valid.len() as f64;
let variance: f64 = valid.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
/ (valid.len() - 1) as f64;
variance.sqrt()
}
})
}
pub fn var(&self) -> Result<DataFrame> {
self.aggregate(|values| {
let valid: Vec<f64> = values.iter().filter(|v| !v.is_nan()).copied().collect();
if valid.len() <= 1 {
f64::NAN
} else {
let mean = valid.iter().sum::<f64>() / valid.len() as f64;
valid.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (valid.len() - 1) as f64
}
})
}
pub fn first(&self) -> Result<DataFrame> {
self.aggregate_first_last(true)
}
pub fn last(&self) -> Result<DataFrame> {
self.aggregate_first_last(false)
}
fn aggregate<F>(&self, agg_fn: F) -> Result<DataFrame>
where
F: Fn(&[f64]) -> f64,
{
let mut result = DataFrame::new();
let numeric_cols: Vec<String> = self
.df
.column_names()
.into_iter()
.filter(|col| {
!self.group_columns.contains(col) && self.df.get_column_numeric_values(col).is_ok()
})
.collect();
let mut group_col_values: Vec<Vec<String>> = vec![Vec::new(); self.group_columns.len()];
let mut agg_values: HashMap<String, Vec<f64>> = HashMap::new();
for col in &numeric_cols {
agg_values.insert(col.clone(), Vec::new());
}
for (key, indices) in &self.groups {
if let Some(key_values) = self.group_keys.get(key) {
for (i, val) in key_values.iter().enumerate() {
group_col_values[i].push(val.clone());
}
}
for col in &numeric_cols {
if let Ok(all_values) = self.df.get_column_numeric_values(col) {
let group_values: Vec<f64> = indices
.iter()
.filter_map(|&i| all_values.get(i).copied())
.collect();
let aggregated = agg_fn(&group_values);
agg_values
.get_mut(col)
.expect("test should succeed")
.push(aggregated);
}
}
}
for (i, col_name) in self.group_columns.iter().enumerate() {
result.add_column(
col_name.clone(),
Series::new(group_col_values[i].clone(), Some(col_name.clone()))?,
)?;
}
for col in &numeric_cols {
if let Some(values) = agg_values.get(col) {
result.add_column(col.clone(), Series::new(values.clone(), Some(col.clone()))?)?;
}
}
Ok(result)
}
fn aggregate_first_last(&self, first: bool) -> Result<DataFrame> {
let mut result = DataFrame::new();
let other_cols: Vec<String> = self
.df
.column_names()
.into_iter()
.filter(|col| !self.group_columns.contains(col))
.collect();
let mut group_col_values: Vec<Vec<String>> = vec![Vec::new(); self.group_columns.len()];
let mut numeric_values: HashMap<String, Vec<f64>> = HashMap::new();
let mut string_values: HashMap<String, Vec<String>> = HashMap::new();
for col in &other_cols {
if self.df.get_column_numeric_values(col).is_ok() {
numeric_values.insert(col.clone(), Vec::new());
} else if self.df.get_column_string_values(col).is_ok() {
string_values.insert(col.clone(), Vec::new());
}
}
for (key, indices) in &self.groups {
if let Some(key_values) = self.group_keys.get(key) {
for (i, val) in key_values.iter().enumerate() {
group_col_values[i].push(val.clone());
}
}
let target_idx = if first {
*indices.first().expect("test should succeed")
} else {
*indices.last().expect("test should succeed")
};
for col in &other_cols {
if let Ok(all_values) = self.df.get_column_numeric_values(col) {
let value = all_values.get(target_idx).copied().unwrap_or(f64::NAN);
numeric_values
.get_mut(col)
.expect("test should succeed")
.push(value);
} else if let Ok(all_values) = self.df.get_column_string_values(col) {
let value = all_values.get(target_idx).cloned().unwrap_or_default();
string_values
.get_mut(col)
.expect("test should succeed")
.push(value);
}
}
}
for (i, col_name) in self.group_columns.iter().enumerate() {
result.add_column(
col_name.clone(),
Series::new(group_col_values[i].clone(), Some(col_name.clone()))?,
)?;
}
for col in &other_cols {
if let Some(values) = numeric_values.get(col) {
result.add_column(col.clone(), Series::new(values.clone(), Some(col.clone()))?)?;
} else if let Some(values) = string_values.get(col) {
result.add_column(col.clone(), Series::new(values.clone(), Some(col.clone()))?)?;
}
}
Ok(result)
}
pub fn agg(&self, aggs: &[(&str, &str)]) -> Result<DataFrame> {
let mut result = DataFrame::new();
let mut group_col_values: Vec<Vec<String>> = vec![Vec::new(); self.group_columns.len()];
for (key, _) in &self.groups {
if let Some(key_values) = self.group_keys.get(key) {
for (i, val) in key_values.iter().enumerate() {
group_col_values[i].push(val.clone());
}
}
}
for (i, col_name) in self.group_columns.iter().enumerate() {
result.add_column(
col_name.clone(),
Series::new(group_col_values[i].clone(), Some(col_name.clone()))?,
)?;
}
for (col, agg_name) in aggs {
if !self.df.contains_column(col) {
continue;
}
if let Ok(all_values) = self.df.get_column_numeric_values(col) {
let mut agg_values: Vec<f64> = Vec::new();
for (_, indices) in &self.groups {
let group_values: Vec<f64> = indices
.iter()
.filter_map(|&i| all_values.get(i).copied())
.collect();
let aggregated = match *agg_name {
"sum" => group_values.iter().filter(|v| !v.is_nan()).sum(),
"mean" => {
let valid: Vec<f64> = group_values
.iter()
.filter(|v| !v.is_nan())
.copied()
.collect();
if valid.is_empty() {
f64::NAN
} else {
valid.iter().sum::<f64>() / valid.len() as f64
}
}
"min" => group_values
.iter()
.filter(|v| !v.is_nan())
.copied()
.fold(f64::INFINITY, f64::min),
"max" => group_values
.iter()
.filter(|v| !v.is_nan())
.copied()
.fold(f64::NEG_INFINITY, f64::max),
"count" => group_values.iter().filter(|v| !v.is_nan()).count() as f64,
"std" => {
let valid: Vec<f64> = group_values
.iter()
.filter(|v| !v.is_nan())
.copied()
.collect();
if valid.len() <= 1 {
f64::NAN
} else {
let mean = valid.iter().sum::<f64>() / valid.len() as f64;
let variance: f64 =
valid.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
/ (valid.len() - 1) as f64;
variance.sqrt()
}
}
"var" => {
let valid: Vec<f64> = group_values
.iter()
.filter(|v| !v.is_nan())
.copied()
.collect();
if valid.len() <= 1 {
f64::NAN
} else {
let mean = valid.iter().sum::<f64>() / valid.len() as f64;
valid.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
/ (valid.len() - 1) as f64
}
}
"first" => group_values.first().copied().unwrap_or(f64::NAN),
"last" => group_values.last().copied().unwrap_or(f64::NAN),
_ => f64::NAN,
};
agg_values.push(aggregated);
}
let result_col_name = format!("{}_{}", col, agg_name);
result.add_column(
result_col_name.clone(),
Series::new(agg_values, Some(result_col_name))?,
)?;
}
}
Ok(result)
}
}
pub trait PandasGroupByExt {
fn groupby_multi(&self, by: &[&str]) -> Result<DataFrameGroupBy>;
}
impl PandasGroupByExt for DataFrame {
fn groupby_multi(&self, by: &[&str]) -> Result<DataFrameGroupBy> {
DataFrameGroupBy::new(self, by)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_df() -> DataFrame {
let mut df = DataFrame::new();
df.add_column(
"category".to_string(),
Series::new(
vec![
"A".to_string(),
"B".to_string(),
"A".to_string(),
"B".to_string(),
"A".to_string(),
],
Some("category".to_string()),
)
.expect("test should succeed"),
)
.expect("test should succeed");
df.add_column(
"value".to_string(),
Series::new(
vec![10.0, 20.0, 30.0, 40.0, 50.0],
Some("value".to_string()),
)
.expect("test should succeed"),
)
.expect("test should succeed");
df.add_column(
"score".to_string(),
Series::new(vec![1.0, 2.0, 3.0, 4.0, 5.0], Some("score".to_string()))
.expect("test should succeed"),
)
.expect("test should succeed");
df
}
#[test]
fn test_groupby_sum() {
let df = create_test_df();
let result = df
.groupby_multi(&["category"])
.expect("test should succeed")
.sum()
.expect("test should succeed");
assert_eq!(result.row_count(), 2);
let cats = result
.get_column_string_values("category")
.expect("test should succeed");
let values = result
.get_column_numeric_values("value")
.expect("test should succeed");
let a_idx = cats
.iter()
.position(|c| c == "A")
.expect("test should succeed");
let b_idx = cats
.iter()
.position(|c| c == "B")
.expect("test should succeed");
assert_eq!(values[a_idx], 90.0);
assert_eq!(values[b_idx], 60.0);
}
#[test]
fn test_groupby_mean() {
let df = create_test_df();
let result = df
.groupby_multi(&["category"])
.expect("test should succeed")
.mean()
.expect("test should succeed");
let cats = result
.get_column_string_values("category")
.expect("test should succeed");
let values = result
.get_column_numeric_values("value")
.expect("test should succeed");
let a_idx = cats
.iter()
.position(|c| c == "A")
.expect("test should succeed");
let b_idx = cats
.iter()
.position(|c| c == "B")
.expect("test should succeed");
assert_eq!(values[a_idx], 30.0);
assert_eq!(values[b_idx], 30.0);
}
#[test]
fn test_groupby_min() {
let df = create_test_df();
let result = df
.groupby_multi(&["category"])
.expect("test should succeed")
.min()
.expect("test should succeed");
let cats = result
.get_column_string_values("category")
.expect("test should succeed");
let values = result
.get_column_numeric_values("value")
.expect("test should succeed");
let a_idx = cats
.iter()
.position(|c| c == "A")
.expect("test should succeed");
let b_idx = cats
.iter()
.position(|c| c == "B")
.expect("test should succeed");
assert_eq!(values[a_idx], 10.0);
assert_eq!(values[b_idx], 20.0);
}
#[test]
fn test_groupby_max() {
let df = create_test_df();
let result = df
.groupby_multi(&["category"])
.expect("test should succeed")
.max()
.expect("test should succeed");
let cats = result
.get_column_string_values("category")
.expect("test should succeed");
let values = result
.get_column_numeric_values("value")
.expect("test should succeed");
let a_idx = cats
.iter()
.position(|c| c == "A")
.expect("test should succeed");
let b_idx = cats
.iter()
.position(|c| c == "B")
.expect("test should succeed");
assert_eq!(values[a_idx], 50.0);
assert_eq!(values[b_idx], 40.0);
}
#[test]
fn test_groupby_count() {
let df = create_test_df();
let result = df
.groupby_multi(&["category"])
.expect("test should succeed")
.count()
.expect("test should succeed");
let cats = result
.get_column_string_values("category")
.expect("test should succeed");
let sizes = result
.get_column_numeric_values("size")
.expect("test should succeed");
let a_idx = cats
.iter()
.position(|c| c == "A")
.expect("test should succeed");
let b_idx = cats
.iter()
.position(|c| c == "B")
.expect("test should succeed");
assert_eq!(sizes[a_idx], 3.0);
assert_eq!(sizes[b_idx], 2.0);
}
#[test]
fn test_groupby_std() {
let df = create_test_df();
let result = df
.groupby_multi(&["category"])
.expect("test should succeed")
.std()
.expect("test should succeed");
let cats = result
.get_column_string_values("category")
.expect("test should succeed");
let values = result
.get_column_numeric_values("value")
.expect("test should succeed");
let a_idx = cats
.iter()
.position(|c| c == "A")
.expect("test should succeed");
assert!((values[a_idx] - 20.0).abs() < 0.001);
}
#[test]
fn test_groupby_first() {
let df = create_test_df();
let result = df
.groupby_multi(&["category"])
.expect("test should succeed")
.first()
.expect("test should succeed");
let cats = result
.get_column_string_values("category")
.expect("test should succeed");
let values = result
.get_column_numeric_values("value")
.expect("test should succeed");
let a_idx = cats
.iter()
.position(|c| c == "A")
.expect("test should succeed");
let b_idx = cats
.iter()
.position(|c| c == "B")
.expect("test should succeed");
assert_eq!(values[a_idx], 10.0);
assert_eq!(values[b_idx], 20.0);
}
#[test]
fn test_groupby_last() {
let df = create_test_df();
let result = df
.groupby_multi(&["category"])
.expect("test should succeed")
.last()
.expect("test should succeed");
let cats = result
.get_column_string_values("category")
.expect("test should succeed");
let values = result
.get_column_numeric_values("value")
.expect("test should succeed");
let a_idx = cats
.iter()
.position(|c| c == "A")
.expect("test should succeed");
let b_idx = cats
.iter()
.position(|c| c == "B")
.expect("test should succeed");
assert_eq!(values[a_idx], 50.0);
assert_eq!(values[b_idx], 40.0);
}
#[test]
fn test_groupby_multiple_columns() {
let mut df = DataFrame::new();
df.add_column(
"cat1".to_string(),
Series::new(
vec![
"A".to_string(),
"A".to_string(),
"B".to_string(),
"B".to_string(),
],
Some("cat1".to_string()),
)
.expect("test should succeed"),
)
.expect("test should succeed");
df.add_column(
"cat2".to_string(),
Series::new(
vec![
"X".to_string(),
"Y".to_string(),
"X".to_string(),
"Y".to_string(),
],
Some("cat2".to_string()),
)
.expect("test should succeed"),
)
.expect("test should succeed");
df.add_column(
"value".to_string(),
Series::new(vec![1.0, 2.0, 3.0, 4.0], Some("value".to_string()))
.expect("test should succeed"),
)
.expect("test should succeed");
let result = df
.groupby_multi(&["cat1", "cat2"])
.expect("test should succeed")
.sum()
.expect("test should succeed");
assert_eq!(result.row_count(), 4);
}
#[test]
fn test_groupby_with_nan() {
let mut df = DataFrame::new();
df.add_column(
"category".to_string(),
Series::new(
vec!["A".to_string(), "A".to_string(), "A".to_string()],
Some("category".to_string()),
)
.expect("test should succeed"),
)
.expect("test should succeed");
df.add_column(
"value".to_string(),
Series::new(vec![10.0, f64::NAN, 30.0], Some("value".to_string()))
.expect("test should succeed"),
)
.expect("test should succeed");
let result = df
.groupby_multi(&["category"])
.expect("test should succeed")
.sum()
.expect("test should succeed");
let values = result
.get_column_numeric_values("value")
.expect("test should succeed");
assert_eq!(values[0], 40.0);
}
#[test]
fn test_groupby_agg() {
let df = create_test_df();
let result = df
.groupby_multi(&["category"])
.expect("test should succeed")
.agg(&[("value", "sum"), ("value", "mean"), ("score", "max")])
.expect("test should succeed");
assert!(result.contains_column("value_sum"));
assert!(result.contains_column("value_mean"));
assert!(result.contains_column("score_max"));
let value_sums = result
.get_column_numeric_values("value_sum")
.expect("test should succeed");
let cats = result
.get_column_string_values("category")
.expect("test should succeed");
let a_idx = cats
.iter()
.position(|c| c == "A")
.expect("test should succeed");
assert_eq!(value_sums[a_idx], 90.0);
}
#[test]
fn test_ngroups() {
let df = create_test_df();
let gb = df
.groupby_multi(&["category"])
.expect("test should succeed");
assert_eq!(gb.ngroups(), 2);
}
}