use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use crate::dataframe::DataFrame;
use crate::error::{PandRSError, Result};
use crate::series::Series;
#[derive(Debug)]
pub struct GroupBy<'a, K, T>
where
K: Debug + Eq + Hash + Clone,
T: Debug + Clone,
{
#[allow(dead_code)]
keys: Vec<K>,
groups: HashMap<K, Vec<usize>>,
source: &'a Series<T>,
#[allow(dead_code)]
name: Option<String>,
}
impl<'a, K, T> GroupBy<'a, K, T>
where
K: Debug + Eq + Hash + Clone,
T: Debug + Clone,
{
pub fn new(keys: Vec<K>, source: &'a Series<T>, name: Option<String>) -> Result<Self> {
if keys.len() != source.len() {
return Err(PandRSError::Consistency(format!(
"Length of keys ({}) and source ({}) do not match",
keys.len(),
source.len()
)));
}
let mut groups = HashMap::new();
for (i, key) in keys.iter().enumerate() {
groups.entry(key.clone()).or_insert_with(Vec::new).push(i);
}
Ok(GroupBy {
keys,
groups,
source,
name,
})
}
pub fn group_count(&self) -> usize {
self.groups.len()
}
pub fn size(&self) -> HashMap<K, usize> {
self.groups
.iter()
.map(|(k, indices)| (k.clone(), indices.len()))
.collect()
}
pub fn sum(&self) -> Result<HashMap<K, T>>
where
T: Copy + std::iter::Sum,
{
let mut results = HashMap::new();
for (key, indices) in &self.groups {
let values: Vec<T> = indices
.iter()
.filter_map(|&i| self.source.get(i).cloned())
.collect();
if !values.is_empty() {
results.insert(key.clone(), values.into_iter().sum());
}
}
Ok(results)
}
pub fn mean(&self) -> Result<HashMap<K, f64>>
where
T: Copy + Into<f64>,
{
let mut results = HashMap::new();
for (key, indices) in &self.groups {
let values: Vec<f64> = indices
.iter()
.filter_map(|&i| self.source.get(i).map(|&v| v.into()))
.collect();
if !values.is_empty() {
let sum: f64 = values.iter().sum();
let mean = sum / values.len() as f64;
results.insert(key.clone(), mean);
}
}
Ok(results)
}
}
pub struct DataFrameGroupBy<'a, K>
where
K: Debug + Eq + Hash + Clone,
{
#[allow(dead_code)]
keys: Vec<K>,
groups: HashMap<K, Vec<usize>>,
#[allow(dead_code)]
source: &'a DataFrame,
#[allow(dead_code)]
by: String,
}
impl<'a, K> DataFrameGroupBy<'a, K>
where
K: Debug + Eq + Hash + Clone,
{
pub fn new(keys: Vec<K>, source: &'a DataFrame, by: String) -> Result<Self> {
if keys.len() != source.row_count() {
return Err(PandRSError::Consistency(format!(
"Length of keys ({}) and DataFrame row count ({}) do not match",
keys.len(),
source.row_count()
)));
}
let mut groups = HashMap::new();
for (i, key) in keys.iter().enumerate() {
groups.entry(key.clone()).or_insert_with(Vec::new).push(i);
}
Ok(DataFrameGroupBy {
keys,
groups,
source,
by,
})
}
pub fn group_count(&self) -> usize {
self.groups.len()
}
pub fn size(&self) -> HashMap<K, usize> {
self.groups
.iter()
.map(|(k, indices)| (k.clone(), indices.len()))
.collect()
}
pub fn size_as_df(&self) -> Result<DataFrame> {
let mut result = DataFrame::new();
let mut keys = Vec::new();
let mut sizes = Vec::new();
for (key, indices) in &self.groups {
keys.push(format!("{:?}", key)); sizes.push(indices.len().to_string()); }
let key_column = Series::new(keys, Some("group_key".to_string()))?;
result.add_column("group_key".to_string(), key_column)?;
let size_column = Series::new(sizes, Some("size".to_string()))?;
result.add_column("size".to_string(), size_column)?;
Ok(result)
}
pub fn aggregate(&self, column_name: &str, func_name: &str) -> Result<DataFrame> {
if !self.source.contains_column(column_name) {
return Err(PandRSError::Column(format!(
"Column '{}' not found",
column_name
)));
}
let mut result = DataFrame::new();
let mut keys = Vec::new();
let mut aggregated_values = Vec::new();
let column_data: Vec<f64> = Vec::new();
for (key, indices) in &self.groups {
keys.push(format!("{:?}", key));
let group_data: Vec<f64> = indices
.iter()
.filter_map(|&idx| {
if idx < column_data.len() {
Some(column_data[idx])
} else {
None
}
})
.collect();
let result_value = if group_data.is_empty() {
"0.0".to_string()
} else {
match func_name {
"sum" => group_data.iter().sum::<f64>().to_string(),
"mean" => {
(group_data.iter().sum::<f64>() / group_data.len() as f64).to_string()
}
"min" => group_data
.iter()
.fold(f64::INFINITY, |a, &b| a.min(b))
.to_string(),
"max" => group_data
.iter()
.fold(f64::NEG_INFINITY, |a, &b| a.max(b))
.to_string(),
"count" => group_data.len().to_string(),
_ => "0.0".to_string(),
}
};
aggregated_values.push(result_value);
}
let key_column = Series::new(keys, Some("group_key".to_string()))?;
result.add_column("group_key".to_string(), key_column)?;
let result_column_name = format!("{}_{}", column_name, func_name);
let value_column = Series::new(aggregated_values, Some(result_column_name.clone()))?;
result.add_column(result_column_name, value_column)?;
Ok(result)
}
}