1use std::collections::HashMap;
7
8#[derive(Debug, Clone)]
10pub struct DataAggregator {
11 data: Vec<HashMap<String, String>>,
12}
13
14impl DataAggregator {
15 pub fn new(data: Vec<HashMap<String, String>>) -> Self {
17 Self { data }
18 }
19
20 pub fn group_by(&self, field: &str) -> GroupedData {
22 let mut groups: HashMap<String, Vec<HashMap<String, String>>> = HashMap::new();
23
24 for record in &self.data {
25 if let Some(value) = record.get(field) {
26 groups
27 .entry(value.clone())
28 .or_insert_with(Vec::new)
29 .push(record.clone());
30 }
31 }
32
33 GroupedData {
34 groups,
35 group_field: field.to_string(),
36 }
37 }
38
39 pub fn sum(&self, field: &str) -> f64 {
41 self.data
42 .iter()
43 .filter_map(|record| record.get(field))
44 .filter_map(|value| value.parse::<f64>().ok())
45 .sum()
46 }
47
48 pub fn avg(&self, field: &str) -> f64 {
50 let values: Vec<f64> = self
51 .data
52 .iter()
53 .filter_map(|record| record.get(field))
54 .filter_map(|value| value.parse::<f64>().ok())
55 .collect();
56
57 if values.is_empty() {
58 0.0
59 } else {
60 values.iter().sum::<f64>() / values.len() as f64
61 }
62 }
63
64 pub fn count(&self) -> usize {
66 self.data.len()
67 }
68
69 pub fn min(&self, field: &str) -> Option<f64> {
71 self.data
72 .iter()
73 .filter_map(|record| record.get(field))
74 .filter_map(|value| value.parse::<f64>().ok())
75 .filter(|v| !v.is_nan()) .min_by(|a, b| a.total_cmp(b))
77 }
78
79 pub fn max(&self, field: &str) -> Option<f64> {
81 self.data
82 .iter()
83 .filter_map(|record| record.get(field))
84 .filter_map(|value| value.parse::<f64>().ok())
85 .filter(|v| !v.is_nan()) .max_by(|a, b| a.total_cmp(b))
87 }
88
89 pub fn filter<F>(&self, predicate: F) -> DataAggregator
91 where
92 F: Fn(&HashMap<String, String>) -> bool,
93 {
94 DataAggregator {
95 data: self.data.iter().filter(|r| predicate(r)).cloned().collect(),
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct GroupedData {
103 groups: HashMap<String, Vec<HashMap<String, String>>>,
104 group_field: String,
105}
106
107impl GroupedData {
108 pub fn aggregate<F>(&self, field: &str, func: AggregateFunc, label: F) -> Vec<(String, f64)>
110 where
111 F: Fn(&str) -> String,
112 {
113 self.groups
114 .iter()
115 .map(|(key, records)| {
116 let aggregator = DataAggregator::new(records.clone());
117 let value = match func {
118 AggregateFunc::Sum => aggregator.sum(field),
119 AggregateFunc::Avg => aggregator.avg(field),
120 AggregateFunc::Count => aggregator.count() as f64,
121 AggregateFunc::Min => aggregator.min(field).unwrap_or(0.0),
122 AggregateFunc::Max => aggregator.max(field).unwrap_or(0.0),
123 };
124 (label(key), value)
125 })
126 .collect()
127 }
128
129 pub fn sum(&self, field: &str) -> Vec<(String, f64)> {
131 self.aggregate(field, AggregateFunc::Sum, |k| k.to_string())
132 }
133
134 pub fn avg(&self, field: &str) -> Vec<(String, f64)> {
136 self.aggregate(field, AggregateFunc::Avg, |k| k.to_string())
137 }
138
139 pub fn count(&self) -> Vec<(String, f64)> {
141 self.groups
142 .iter()
143 .map(|(key, records)| (key.clone(), records.len() as f64))
144 .collect()
145 }
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150pub enum AggregateFunc {
151 Sum,
152 Avg,
153 Count,
154 Min,
155 Max,
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 fn sample_data() -> Vec<HashMap<String, String>> {
163 vec![
164 [
165 ("region".to_string(), "North".to_string()),
166 ("amount".to_string(), "100".to_string()),
167 ]
168 .iter()
169 .cloned()
170 .collect(),
171 [
172 ("region".to_string(), "North".to_string()),
173 ("amount".to_string(), "150".to_string()),
174 ]
175 .iter()
176 .cloned()
177 .collect(),
178 [
179 ("region".to_string(), "South".to_string()),
180 ("amount".to_string(), "200".to_string()),
181 ]
182 .iter()
183 .cloned()
184 .collect(),
185 ]
186 }
187
188 #[test]
189 fn test_sum() {
190 let agg = DataAggregator::new(sample_data());
191 assert_eq!(agg.sum("amount"), 450.0);
192 }
193
194 #[test]
195 fn test_avg() {
196 let agg = DataAggregator::new(sample_data());
197 assert_eq!(agg.avg("amount"), 150.0);
198 }
199
200 #[test]
201 fn test_count() {
202 let agg = DataAggregator::new(sample_data());
203 assert_eq!(agg.count(), 3);
204 }
205
206 #[test]
207 fn test_min_max() {
208 let agg = DataAggregator::new(sample_data());
209 assert_eq!(agg.min("amount"), Some(100.0));
210 assert_eq!(agg.max("amount"), Some(200.0));
211 }
212
213 #[test]
214 fn test_group_by_sum() {
215 let agg = DataAggregator::new(sample_data());
216 let grouped = agg.group_by("region").sum("amount");
217
218 assert_eq!(grouped.len(), 2);
219 assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 250.0));
220 assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 200.0));
221 }
222
223 #[test]
224 fn test_group_by_count() {
225 let agg = DataAggregator::new(sample_data());
226 let grouped = agg.group_by("region").count();
227
228 assert_eq!(grouped.len(), 2);
229 assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 2.0));
230 assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 1.0));
231 }
232
233 #[test]
234 fn test_filter() {
235 let agg = DataAggregator::new(sample_data());
236 let filtered = agg.filter(|r| r.get("region") == Some(&"North".to_string()));
237
238 assert_eq!(filtered.count(), 2);
239 assert_eq!(filtered.sum("amount"), 250.0);
240 }
241
242 #[test]
243 fn test_avg_empty_data() {
244 let agg = DataAggregator::new(vec![]);
245 assert_eq!(agg.avg("amount"), 0.0);
246 }
247
248 #[test]
249 fn test_min_max_empty_data() {
250 let agg = DataAggregator::new(vec![]);
251 assert_eq!(agg.min("amount"), None);
252 assert_eq!(agg.max("amount"), None);
253 }
254
255 #[test]
256 fn test_sum_nonexistent_field() {
257 let agg = DataAggregator::new(sample_data());
258 assert_eq!(agg.sum("nonexistent"), 0.0);
259 }
260
261 #[test]
262 fn test_avg_nonexistent_field() {
263 let agg = DataAggregator::new(sample_data());
264 assert_eq!(agg.avg("nonexistent"), 0.0);
265 }
266
267 #[test]
268 fn test_group_by_avg() {
269 let agg = DataAggregator::new(sample_data());
270 let grouped = agg.group_by("region").avg("amount");
271
272 assert_eq!(grouped.len(), 2);
273 assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 125.0));
274 assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 200.0));
275 }
276
277 #[test]
278 fn test_aggregate_with_custom_label() {
279 let agg = DataAggregator::new(sample_data());
280 let grouped = agg
281 .group_by("region")
282 .aggregate("amount", AggregateFunc::Sum, |k| format!("Region: {}", k));
283
284 assert_eq!(grouped.len(), 2);
285 assert!(grouped
286 .iter()
287 .any(|(k, v)| k == "Region: North" && *v == 250.0));
288 assert!(grouped
289 .iter()
290 .any(|(k, v)| k == "Region: South" && *v == 200.0));
291 }
292
293 #[test]
294 fn test_aggregate_with_count() {
295 let agg = DataAggregator::new(sample_data());
296 let grouped = agg
297 .group_by("region")
298 .aggregate("amount", AggregateFunc::Count, |k| k.to_string());
299
300 assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 2.0));
301 assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 1.0));
302 }
303
304 #[test]
305 fn test_aggregate_with_min_max() {
306 let agg = DataAggregator::new(sample_data());
307 let min_grouped = agg
308 .group_by("region")
309 .aggregate("amount", AggregateFunc::Min, |k| k.to_string());
310 let max_grouped = agg
311 .group_by("region")
312 .aggregate("amount", AggregateFunc::Max, |k| k.to_string());
313
314 assert!(min_grouped.iter().any(|(k, v)| k == "North" && *v == 100.0));
315 assert!(max_grouped.iter().any(|(k, v)| k == "North" && *v == 150.0));
316 }
317
318 #[test]
319 fn test_aggregate_func_enum() {
320 assert_eq!(AggregateFunc::Sum, AggregateFunc::Sum);
321 assert_eq!(AggregateFunc::Avg, AggregateFunc::Avg);
322 assert_eq!(AggregateFunc::Count, AggregateFunc::Count);
323 assert_eq!(AggregateFunc::Min, AggregateFunc::Min);
324 assert_eq!(AggregateFunc::Max, AggregateFunc::Max);
325 assert_ne!(AggregateFunc::Sum, AggregateFunc::Avg);
326 }
327
328 #[test]
329 fn test_group_by_missing_field() {
330 let agg = DataAggregator::new(sample_data());
331 let grouped = agg.group_by("nonexistent");
332 assert_eq!(grouped.count().len(), 0);
334 }
335
336 #[test]
337 fn test_filter_all_records() {
338 let agg = DataAggregator::new(sample_data());
339 let filtered = agg.filter(|_| true);
340 assert_eq!(filtered.count(), 3);
341 }
342
343 #[test]
344 fn test_filter_no_records() {
345 let agg = DataAggregator::new(sample_data());
346 let filtered = agg.filter(|_| false);
347 assert_eq!(filtered.count(), 0);
348 }
349}