use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DataAggregator {
data: Vec<HashMap<String, String>>,
}
impl DataAggregator {
pub fn new(data: Vec<HashMap<String, String>>) -> Self {
Self { data }
}
pub fn group_by(&self, field: &str) -> GroupedData {
let mut groups: HashMap<String, Vec<HashMap<String, String>>> = HashMap::new();
for record in &self.data {
if let Some(value) = record.get(field) {
groups
.entry(value.clone())
.or_insert_with(Vec::new)
.push(record.clone());
}
}
GroupedData {
groups,
group_field: field.to_string(),
}
}
pub fn sum(&self, field: &str) -> f64 {
self.data
.iter()
.filter_map(|record| record.get(field))
.filter_map(|value| value.parse::<f64>().ok())
.sum()
}
pub fn avg(&self, field: &str) -> f64 {
let values: Vec<f64> = self
.data
.iter()
.filter_map(|record| record.get(field))
.filter_map(|value| value.parse::<f64>().ok())
.collect();
if values.is_empty() {
0.0
} else {
values.iter().sum::<f64>() / values.len() as f64
}
}
pub fn count(&self) -> usize {
self.data.len()
}
pub fn min(&self, field: &str) -> Option<f64> {
self.data
.iter()
.filter_map(|record| record.get(field))
.filter_map(|value| value.parse::<f64>().ok())
.filter(|v| !v.is_nan()) .min_by(|a, b| a.total_cmp(b))
}
pub fn max(&self, field: &str) -> Option<f64> {
self.data
.iter()
.filter_map(|record| record.get(field))
.filter_map(|value| value.parse::<f64>().ok())
.filter(|v| !v.is_nan()) .max_by(|a, b| a.total_cmp(b))
}
pub fn filter<F>(&self, predicate: F) -> DataAggregator
where
F: Fn(&HashMap<String, String>) -> bool,
{
DataAggregator {
data: self.data.iter().filter(|r| predicate(r)).cloned().collect(),
}
}
}
#[derive(Debug, Clone)]
pub struct GroupedData {
groups: HashMap<String, Vec<HashMap<String, String>>>,
group_field: String,
}
impl GroupedData {
pub fn aggregate<F>(&self, field: &str, func: AggregateFunc, label: F) -> Vec<(String, f64)>
where
F: Fn(&str) -> String,
{
self.groups
.iter()
.map(|(key, records)| {
let aggregator = DataAggregator::new(records.clone());
let value = match func {
AggregateFunc::Sum => aggregator.sum(field),
AggregateFunc::Avg => aggregator.avg(field),
AggregateFunc::Count => aggregator.count() as f64,
AggregateFunc::Min => aggregator.min(field).unwrap_or(0.0),
AggregateFunc::Max => aggregator.max(field).unwrap_or(0.0),
};
(label(key), value)
})
.collect()
}
pub fn sum(&self, field: &str) -> Vec<(String, f64)> {
self.aggregate(field, AggregateFunc::Sum, |k| k.to_string())
}
pub fn avg(&self, field: &str) -> Vec<(String, f64)> {
self.aggregate(field, AggregateFunc::Avg, |k| k.to_string())
}
pub fn count(&self) -> Vec<(String, f64)> {
self.groups
.iter()
.map(|(key, records)| (key.clone(), records.len() as f64))
.collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggregateFunc {
Sum,
Avg,
Count,
Min,
Max,
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_data() -> Vec<HashMap<String, String>> {
vec![
[
("region".to_string(), "North".to_string()),
("amount".to_string(), "100".to_string()),
]
.iter()
.cloned()
.collect(),
[
("region".to_string(), "North".to_string()),
("amount".to_string(), "150".to_string()),
]
.iter()
.cloned()
.collect(),
[
("region".to_string(), "South".to_string()),
("amount".to_string(), "200".to_string()),
]
.iter()
.cloned()
.collect(),
]
}
#[test]
fn test_sum() {
let agg = DataAggregator::new(sample_data());
assert_eq!(agg.sum("amount"), 450.0);
}
#[test]
fn test_avg() {
let agg = DataAggregator::new(sample_data());
assert_eq!(agg.avg("amount"), 150.0);
}
#[test]
fn test_count() {
let agg = DataAggregator::new(sample_data());
assert_eq!(agg.count(), 3);
}
#[test]
fn test_min_max() {
let agg = DataAggregator::new(sample_data());
assert_eq!(agg.min("amount"), Some(100.0));
assert_eq!(agg.max("amount"), Some(200.0));
}
#[test]
fn test_group_by_sum() {
let agg = DataAggregator::new(sample_data());
let grouped = agg.group_by("region").sum("amount");
assert_eq!(grouped.len(), 2);
assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 250.0));
assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 200.0));
}
#[test]
fn test_group_by_count() {
let agg = DataAggregator::new(sample_data());
let grouped = agg.group_by("region").count();
assert_eq!(grouped.len(), 2);
assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 2.0));
assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 1.0));
}
#[test]
fn test_filter() {
let agg = DataAggregator::new(sample_data());
let filtered = agg.filter(|r| r.get("region") == Some(&"North".to_string()));
assert_eq!(filtered.count(), 2);
assert_eq!(filtered.sum("amount"), 250.0);
}
#[test]
fn test_avg_empty_data() {
let agg = DataAggregator::new(vec![]);
assert_eq!(agg.avg("amount"), 0.0);
}
#[test]
fn test_min_max_empty_data() {
let agg = DataAggregator::new(vec![]);
assert_eq!(agg.min("amount"), None);
assert_eq!(agg.max("amount"), None);
}
#[test]
fn test_sum_nonexistent_field() {
let agg = DataAggregator::new(sample_data());
assert_eq!(agg.sum("nonexistent"), 0.0);
}
#[test]
fn test_avg_nonexistent_field() {
let agg = DataAggregator::new(sample_data());
assert_eq!(agg.avg("nonexistent"), 0.0);
}
#[test]
fn test_group_by_avg() {
let agg = DataAggregator::new(sample_data());
let grouped = agg.group_by("region").avg("amount");
assert_eq!(grouped.len(), 2);
assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 125.0));
assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 200.0));
}
#[test]
fn test_aggregate_with_custom_label() {
let agg = DataAggregator::new(sample_data());
let grouped = agg
.group_by("region")
.aggregate("amount", AggregateFunc::Sum, |k| format!("Region: {}", k));
assert_eq!(grouped.len(), 2);
assert!(grouped
.iter()
.any(|(k, v)| k == "Region: North" && *v == 250.0));
assert!(grouped
.iter()
.any(|(k, v)| k == "Region: South" && *v == 200.0));
}
#[test]
fn test_aggregate_with_count() {
let agg = DataAggregator::new(sample_data());
let grouped = agg
.group_by("region")
.aggregate("amount", AggregateFunc::Count, |k| k.to_string());
assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 2.0));
assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 1.0));
}
#[test]
fn test_aggregate_with_min_max() {
let agg = DataAggregator::new(sample_data());
let min_grouped = agg
.group_by("region")
.aggregate("amount", AggregateFunc::Min, |k| k.to_string());
let max_grouped = agg
.group_by("region")
.aggregate("amount", AggregateFunc::Max, |k| k.to_string());
assert!(min_grouped.iter().any(|(k, v)| k == "North" && *v == 100.0));
assert!(max_grouped.iter().any(|(k, v)| k == "North" && *v == 150.0));
}
#[test]
fn test_aggregate_func_enum() {
assert_eq!(AggregateFunc::Sum, AggregateFunc::Sum);
assert_eq!(AggregateFunc::Avg, AggregateFunc::Avg);
assert_eq!(AggregateFunc::Count, AggregateFunc::Count);
assert_eq!(AggregateFunc::Min, AggregateFunc::Min);
assert_eq!(AggregateFunc::Max, AggregateFunc::Max);
assert_ne!(AggregateFunc::Sum, AggregateFunc::Avg);
}
#[test]
fn test_group_by_missing_field() {
let agg = DataAggregator::new(sample_data());
let grouped = agg.group_by("nonexistent");
assert_eq!(grouped.count().len(), 0);
}
#[test]
fn test_filter_all_records() {
let agg = DataAggregator::new(sample_data());
let filtered = agg.filter(|_| true);
assert_eq!(filtered.count(), 3);
}
#[test]
fn test_filter_no_records() {
let agg = DataAggregator::new(sample_data());
let filtered = agg.filter(|_| false);
assert_eq!(filtered.count(), 0);
}
}