Skip to main content

oxidize_pdf/dashboard/
data_aggregation.rs

1//! Data Aggregation DSL
2//!
3//! Provides a simple, fluent API for common data aggregation operations
4//! used in dashboard reporting.
5
6use std::collections::HashMap;
7
8/// Data aggregation builder for dashboard components
9#[derive(Debug, Clone)]
10pub struct DataAggregator {
11    data: Vec<HashMap<String, String>>,
12}
13
14impl DataAggregator {
15    /// Create a new data aggregator from raw data
16    pub fn new(data: Vec<HashMap<String, String>>) -> Self {
17        Self { data }
18    }
19
20    /// Group data by a field
21    pub fn group_by(&self, field: &str) -> GroupedData {
22        let mut groups: HashMap<String, Vec<HashMap<String, String>>> = HashMap::new();
23
24        for record in &self.data {
25            if let Some(value) = record.get(field) {
26                groups
27                    .entry(value.clone())
28                    .or_insert_with(Vec::new)
29                    .push(record.clone());
30            }
31        }
32
33        GroupedData {
34            groups,
35            group_field: field.to_string(),
36        }
37    }
38
39    /// Sum a numeric field
40    pub fn sum(&self, field: &str) -> f64 {
41        self.data
42            .iter()
43            .filter_map(|record| record.get(field))
44            .filter_map(|value| value.parse::<f64>().ok())
45            .sum()
46    }
47
48    /// Calculate average of a numeric field
49    pub fn avg(&self, field: &str) -> f64 {
50        let values: Vec<f64> = self
51            .data
52            .iter()
53            .filter_map(|record| record.get(field))
54            .filter_map(|value| value.parse::<f64>().ok())
55            .collect();
56
57        if values.is_empty() {
58            0.0
59        } else {
60            values.iter().sum::<f64>() / values.len() as f64
61        }
62    }
63
64    /// Count records
65    pub fn count(&self) -> usize {
66        self.data.len()
67    }
68
69    /// Get minimum value of a numeric field
70    pub fn min(&self, field: &str) -> Option<f64> {
71        self.data
72            .iter()
73            .filter_map(|record| record.get(field))
74            .filter_map(|value| value.parse::<f64>().ok())
75            .filter(|v| !v.is_nan()) // Filter out NaN values before comparison
76            .min_by(|a, b| {
77                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) // NaN values treated as equal (shouldn't happen after filter)
78            })
79    }
80
81    /// Get maximum value of a numeric field
82    pub fn max(&self, field: &str) -> Option<f64> {
83        self.data
84            .iter()
85            .filter_map(|record| record.get(field))
86            .filter_map(|value| value.parse::<f64>().ok())
87            .filter(|v| !v.is_nan()) // Filter out NaN values before comparison
88            .max_by(|a, b| {
89                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) // NaN values treated as equal (shouldn't happen after filter)
90            })
91    }
92
93    /// Filter data by a condition
94    pub fn filter<F>(&self, predicate: F) -> DataAggregator
95    where
96        F: Fn(&HashMap<String, String>) -> bool,
97    {
98        DataAggregator {
99            data: self.data.iter().filter(|r| predicate(r)).cloned().collect(),
100        }
101    }
102}
103
104/// Grouped data for aggregation operations
105#[derive(Debug, Clone)]
106pub struct GroupedData {
107    groups: HashMap<String, Vec<HashMap<String, String>>>,
108    group_field: String,
109}
110
111impl GroupedData {
112    /// Aggregate each group with a function
113    pub fn aggregate<F>(&self, field: &str, func: AggregateFunc, label: F) -> Vec<(String, f64)>
114    where
115        F: Fn(&str) -> String,
116    {
117        self.groups
118            .iter()
119            .map(|(key, records)| {
120                let aggregator = DataAggregator::new(records.clone());
121                let value = match func {
122                    AggregateFunc::Sum => aggregator.sum(field),
123                    AggregateFunc::Avg => aggregator.avg(field),
124                    AggregateFunc::Count => aggregator.count() as f64,
125                    AggregateFunc::Min => aggregator.min(field).unwrap_or(0.0),
126                    AggregateFunc::Max => aggregator.max(field).unwrap_or(0.0),
127                };
128                (label(key), value)
129            })
130            .collect()
131    }
132
133    /// Sum each group
134    pub fn sum(&self, field: &str) -> Vec<(String, f64)> {
135        self.aggregate(field, AggregateFunc::Sum, |k| k.to_string())
136    }
137
138    /// Average each group
139    pub fn avg(&self, field: &str) -> Vec<(String, f64)> {
140        self.aggregate(field, AggregateFunc::Avg, |k| k.to_string())
141    }
142
143    /// Count each group
144    pub fn count(&self) -> Vec<(String, f64)> {
145        self.groups
146            .iter()
147            .map(|(key, records)| (key.clone(), records.len() as f64))
148            .collect()
149    }
150}
151
152/// Aggregate function types
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum AggregateFunc {
155    Sum,
156    Avg,
157    Count,
158    Min,
159    Max,
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    fn sample_data() -> Vec<HashMap<String, String>> {
167        vec![
168            [
169                ("region".to_string(), "North".to_string()),
170                ("amount".to_string(), "100".to_string()),
171            ]
172            .iter()
173            .cloned()
174            .collect(),
175            [
176                ("region".to_string(), "North".to_string()),
177                ("amount".to_string(), "150".to_string()),
178            ]
179            .iter()
180            .cloned()
181            .collect(),
182            [
183                ("region".to_string(), "South".to_string()),
184                ("amount".to_string(), "200".to_string()),
185            ]
186            .iter()
187            .cloned()
188            .collect(),
189        ]
190    }
191
192    #[test]
193    fn test_sum() {
194        let agg = DataAggregator::new(sample_data());
195        assert_eq!(agg.sum("amount"), 450.0);
196    }
197
198    #[test]
199    fn test_avg() {
200        let agg = DataAggregator::new(sample_data());
201        assert_eq!(agg.avg("amount"), 150.0);
202    }
203
204    #[test]
205    fn test_count() {
206        let agg = DataAggregator::new(sample_data());
207        assert_eq!(agg.count(), 3);
208    }
209
210    #[test]
211    fn test_min_max() {
212        let agg = DataAggregator::new(sample_data());
213        assert_eq!(agg.min("amount"), Some(100.0));
214        assert_eq!(agg.max("amount"), Some(200.0));
215    }
216
217    #[test]
218    fn test_group_by_sum() {
219        let agg = DataAggregator::new(sample_data());
220        let grouped = agg.group_by("region").sum("amount");
221
222        assert_eq!(grouped.len(), 2);
223        assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 250.0));
224        assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 200.0));
225    }
226
227    #[test]
228    fn test_group_by_count() {
229        let agg = DataAggregator::new(sample_data());
230        let grouped = agg.group_by("region").count();
231
232        assert_eq!(grouped.len(), 2);
233        assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 2.0));
234        assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 1.0));
235    }
236
237    #[test]
238    fn test_filter() {
239        let agg = DataAggregator::new(sample_data());
240        let filtered = agg.filter(|r| r.get("region") == Some(&"North".to_string()));
241
242        assert_eq!(filtered.count(), 2);
243        assert_eq!(filtered.sum("amount"), 250.0);
244    }
245
246    #[test]
247    fn test_avg_empty_data() {
248        let agg = DataAggregator::new(vec![]);
249        assert_eq!(agg.avg("amount"), 0.0);
250    }
251
252    #[test]
253    fn test_min_max_empty_data() {
254        let agg = DataAggregator::new(vec![]);
255        assert_eq!(agg.min("amount"), None);
256        assert_eq!(agg.max("amount"), None);
257    }
258
259    #[test]
260    fn test_sum_nonexistent_field() {
261        let agg = DataAggregator::new(sample_data());
262        assert_eq!(agg.sum("nonexistent"), 0.0);
263    }
264
265    #[test]
266    fn test_avg_nonexistent_field() {
267        let agg = DataAggregator::new(sample_data());
268        assert_eq!(agg.avg("nonexistent"), 0.0);
269    }
270
271    #[test]
272    fn test_group_by_avg() {
273        let agg = DataAggregator::new(sample_data());
274        let grouped = agg.group_by("region").avg("amount");
275
276        assert_eq!(grouped.len(), 2);
277        assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 125.0));
278        assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 200.0));
279    }
280
281    #[test]
282    fn test_aggregate_with_custom_label() {
283        let agg = DataAggregator::new(sample_data());
284        let grouped = agg
285            .group_by("region")
286            .aggregate("amount", AggregateFunc::Sum, |k| format!("Region: {}", k));
287
288        assert_eq!(grouped.len(), 2);
289        assert!(grouped
290            .iter()
291            .any(|(k, v)| k == "Region: North" && *v == 250.0));
292        assert!(grouped
293            .iter()
294            .any(|(k, v)| k == "Region: South" && *v == 200.0));
295    }
296
297    #[test]
298    fn test_aggregate_with_count() {
299        let agg = DataAggregator::new(sample_data());
300        let grouped = agg
301            .group_by("region")
302            .aggregate("amount", AggregateFunc::Count, |k| k.to_string());
303
304        assert!(grouped.iter().any(|(k, v)| k == "North" && *v == 2.0));
305        assert!(grouped.iter().any(|(k, v)| k == "South" && *v == 1.0));
306    }
307
308    #[test]
309    fn test_aggregate_with_min_max() {
310        let agg = DataAggregator::new(sample_data());
311        let min_grouped = agg
312            .group_by("region")
313            .aggregate("amount", AggregateFunc::Min, |k| k.to_string());
314        let max_grouped = agg
315            .group_by("region")
316            .aggregate("amount", AggregateFunc::Max, |k| k.to_string());
317
318        assert!(min_grouped.iter().any(|(k, v)| k == "North" && *v == 100.0));
319        assert!(max_grouped.iter().any(|(k, v)| k == "North" && *v == 150.0));
320    }
321
322    #[test]
323    fn test_aggregate_func_enum() {
324        assert_eq!(AggregateFunc::Sum, AggregateFunc::Sum);
325        assert_eq!(AggregateFunc::Avg, AggregateFunc::Avg);
326        assert_eq!(AggregateFunc::Count, AggregateFunc::Count);
327        assert_eq!(AggregateFunc::Min, AggregateFunc::Min);
328        assert_eq!(AggregateFunc::Max, AggregateFunc::Max);
329        assert_ne!(AggregateFunc::Sum, AggregateFunc::Avg);
330    }
331
332    #[test]
333    fn test_group_by_missing_field() {
334        let agg = DataAggregator::new(sample_data());
335        let grouped = agg.group_by("nonexistent");
336        // Should result in empty groups
337        assert_eq!(grouped.count().len(), 0);
338    }
339
340    #[test]
341    fn test_filter_all_records() {
342        let agg = DataAggregator::new(sample_data());
343        let filtered = agg.filter(|_| true);
344        assert_eq!(filtered.count(), 3);
345    }
346
347    #[test]
348    fn test_filter_no_records() {
349        let agg = DataAggregator::new(sample_data());
350        let filtered = agg.filter(|_| false);
351        assert_eq!(filtered.count(), 0);
352    }
353}