1use std::collections::HashMap;
7
8#[derive(Debug, Clone)]
10pub struct DataAggregator {
11 data: Vec<HashMap<String, String>>,
12}
13
14impl DataAggregator {
15 pub fn new(data: Vec<HashMap<String, String>>) -> Self {
17 Self { data }
18 }
19
20 pub fn group_by(&self, field: &str) -> GroupedData {
22 let mut groups: HashMap<String, Vec<HashMap<String, String>>> = HashMap::new();
23
24 for record in &self.data {
25 if let Some(value) = record.get(field) {
26 groups
27 .entry(value.clone())
28 .or_insert_with(Vec::new)
29 .push(record.clone());
30 }
31 }
32
33 GroupedData {
34 groups,
35 group_field: field.to_string(),
36 }
37 }
38
39 pub fn sum(&self, field: &str) -> f64 {
41 self.data
42 .iter()
43 .filter_map(|record| record.get(field))
44 .filter_map(|value| value.parse::<f64>().ok())
45 .sum()
46 }
47
48 pub fn avg(&self, field: &str) -> f64 {
50 let values: Vec<f64> = self
51 .data
52 .iter()
53 .filter_map(|record| record.get(field))
54 .filter_map(|value| value.parse::<f64>().ok())
55 .collect();
56
57 if values.is_empty() {
58 0.0
59 } else {
60 values.iter().sum::<f64>() / values.len() as f64
61 }
62 }
63
64 pub fn count(&self) -> usize {
66 self.data.len()
67 }
68
69 pub fn min(&self, field: &str) -> Option<f64> {
71 self.data
72 .iter()
73 .filter_map(|record| record.get(field))
74 .filter_map(|value| value.parse::<f64>().ok())
75 .filter(|v| !v.is_nan()) .min_by(|a, b| {
77 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) })
79 }
80
81 pub fn max(&self, field: &str) -> Option<f64> {
83 self.data
84 .iter()
85 .filter_map(|record| record.get(field))
86 .filter_map(|value| value.parse::<f64>().ok())
87 .filter(|v| !v.is_nan()) .max_by(|a, b| {
89 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) })
91 }
92
93 pub fn filter<F>(&self, predicate: F) -> DataAggregator
95 where
96 F: Fn(&HashMap<String, String>) -> bool,
97 {
98 DataAggregator {
99 data: self.data.iter().filter(|r| predicate(r)).cloned().collect(),
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct GroupedData {
107 groups: HashMap<String, Vec<HashMap<String, String>>>,
108 group_field: String,
109}
110
111impl GroupedData {
112 pub fn aggregate<F>(&self, field: &str, func: AggregateFunc, label: F) -> Vec<(String, f64)>
114 where
115 F: Fn(&str) -> String,
116 {
117 self.groups
118 .iter()
119 .map(|(key, records)| {
120 let aggregator = DataAggregator::new(records.clone());
121 let value = match func {
122 AggregateFunc::Sum => aggregator.sum(field),
123 AggregateFunc::Avg => aggregator.avg(field),
124 AggregateFunc::Count => aggregator.count() as f64,
125 AggregateFunc::Min => aggregator.min(field).unwrap_or(0.0),
126 AggregateFunc::Max => aggregator.max(field).unwrap_or(0.0),
127 };
128 (label(key), value)
129 })
130 .collect()
131 }
132
133 pub fn sum(&self, field: &str) -> Vec<(String, f64)> {
135 self.aggregate(field, AggregateFunc::Sum, |k| k.to_string())
136 }
137
138 pub fn avg(&self, field: &str) -> Vec<(String, f64)> {
140 self.aggregate(field, AggregateFunc::Avg, |k| k.to_string())
141 }
142
143 pub fn count(&self) -> Vec<(String, f64)> {
145 self.groups
146 .iter()
147 .map(|(key, records)| (key.clone(), records.len() as f64))
148 .collect()
149 }
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum AggregateFunc {
155 Sum,
156 Avg,
157 Count,
158 Min,
159 Max,
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 fn sample_data() -> Vec<HashMap<String, String>> {
167 vec![
168 [
169 ("region".to_string(), "North".to_string()),
170 ("amount".to_string(), "100".to_string()),
171 ]
172 .iter()
173 .cloned()
174 .collect(),
175 [
176 ("region".to_string(), "North".to_string()),
177 ("amount".to_string(), "150".to_string()),
178 ]
179 .iter()
180 .cloned()
181 .collect(),
182 [
183 ("region".to_string(), "South".to_string()),
184 ("amount".to_string(), "200".to_string()),
185 ]
186 .iter()
187 .cloned()
188 .collect(),
189 ]
190 }
191
192 #[test]
193 fn test_sum() {
194 let agg = DataAggregator::new(sample_data());
195 assert_eq!(agg.sum("amount"), 450.0);
196 }
197
198 #[test]
199 fn test_avg() {
200 let agg = DataAggregator::new(sample_data());
201 assert_eq!(agg.avg("amount"), 150.0);
202 }
203
204 #[test]
205 fn test_count() {
206 let agg = DataAggregator::new(sample_data());
207 assert_eq!(agg.count(), 3);
208 }
209
210 #[test]
211 fn test_min_max() {
212 let agg = DataAggregator::new(sample_data());
213 assert_eq!(agg.min("amount"), Some(100.0));
214 assert_eq!(agg.max("amount"), Some(200.0));
215 }
216
217 #[test]
218 fn test_group_by_sum() {
219 let agg = DataAggregator::new(sample_data());
220 let grouped = agg.group_by("region").sum("amount");
221
222 assert_eq!(grouped.len(), 2);
223 assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 250.0));
224 assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 200.0));
225 }
226
227 #[test]
228 fn test_group_by_count() {
229 let agg = DataAggregator::new(sample_data());
230 let grouped = agg.group_by("region").count();
231
232 assert_eq!(grouped.len(), 2);
233 assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 2.0));
234 assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 1.0));
235 }
236
237 #[test]
238 fn test_filter() {
239 let agg = DataAggregator::new(sample_data());
240 let filtered = agg.filter(|r| r.get("region") == Some(&"North".to_string()));
241
242 assert_eq!(filtered.count(), 2);
243 assert_eq!(filtered.sum("amount"), 250.0);
244 }
245
246 #[test]
247 fn test_avg_empty_data() {
248 let agg = DataAggregator::new(vec![]);
249 assert_eq!(agg.avg("amount"), 0.0);
250 }
251
252 #[test]
253 fn test_min_max_empty_data() {
254 let agg = DataAggregator::new(vec![]);
255 assert_eq!(agg.min("amount"), None);
256 assert_eq!(agg.max("amount"), None);
257 }
258
259 #[test]
260 fn test_sum_nonexistent_field() {
261 let agg = DataAggregator::new(sample_data());
262 assert_eq!(agg.sum("nonexistent"), 0.0);
263 }
264
265 #[test]
266 fn test_avg_nonexistent_field() {
267 let agg = DataAggregator::new(sample_data());
268 assert_eq!(agg.avg("nonexistent"), 0.0);
269 }
270
271 #[test]
272 fn test_group_by_avg() {
273 let agg = DataAggregator::new(sample_data());
274 let grouped = agg.group_by("region").avg("amount");
275
276 assert_eq!(grouped.len(), 2);
277 assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 125.0));
278 assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 200.0));
279 }
280
281 #[test]
282 fn test_aggregate_with_custom_label() {
283 let agg = DataAggregator::new(sample_data());
284 let grouped = agg
285 .group_by("region")
286 .aggregate("amount", AggregateFunc::Sum, |k| format!("Region: {}", k));
287
288 assert_eq!(grouped.len(), 2);
289 assert!(grouped
290 .iter()
291 .any(|(k, v)| k == "Region: North" && *v == 250.0));
292 assert!(grouped
293 .iter()
294 .any(|(k, v)| k == "Region: South" && *v == 200.0));
295 }
296
297 #[test]
298 fn test_aggregate_with_count() {
299 let agg = DataAggregator::new(sample_data());
300 let grouped = agg
301 .group_by("region")
302 .aggregate("amount", AggregateFunc::Count, |k| k.to_string());
303
304 assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 2.0));
305 assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 1.0));
306 }
307
308 #[test]
309 fn test_aggregate_with_min_max() {
310 let agg = DataAggregator::new(sample_data());
311 let min_grouped = agg
312 .group_by("region")
313 .aggregate("amount", AggregateFunc::Min, |k| k.to_string());
314 let max_grouped = agg
315 .group_by("region")
316 .aggregate("amount", AggregateFunc::Max, |k| k.to_string());
317
318 assert!(min_grouped.iter().any(|(k, v)| k == "North" && *v == 100.0));
319 assert!(max_grouped.iter().any(|(k, v)| k == "North" && *v == 150.0));
320 }
321
322 #[test]
323 fn test_aggregate_func_enum() {
324 assert_eq!(AggregateFunc::Sum, AggregateFunc::Sum);
325 assert_eq!(AggregateFunc::Avg, AggregateFunc::Avg);
326 assert_eq!(AggregateFunc::Count, AggregateFunc::Count);
327 assert_eq!(AggregateFunc::Min, AggregateFunc::Min);
328 assert_eq!(AggregateFunc::Max, AggregateFunc::Max);
329 assert_ne!(AggregateFunc::Sum, AggregateFunc::Avg);
330 }
331
332 #[test]
333 fn test_group_by_missing_field() {
334 let agg = DataAggregator::new(sample_data());
335 let grouped = agg.group_by("nonexistent");
336 assert_eq!(grouped.count().len(), 0);
338 }
339
340 #[test]
341 fn test_filter_all_records() {
342 let agg = DataAggregator::new(sample_data());
343 let filtered = agg.filter(|_| true);
344 assert_eq!(filtered.count(), 3);
345 }
346
347 #[test]
348 fn test_filter_no_records() {
349 let agg = DataAggregator::new(sample_data());
350 let filtered = agg.filter(|_| false);
351 assert_eq!(filtered.count(), 0);
352 }
353}