clipivot/
aggfunc.rs

1//! The `aggfunc` module is the central module for computing statistics from a stream of records.
2//!
3//! The central component of this is a trait called `Accumulate` that implements a `new` function on initialization,
4//! an `update` function to add a new record, and a `compute` function to compute the final value of the aggregation.
5//! This trait requires two types, an input type (which is used by the `new` and `update` functions) and an `output` type.
6//!
7//! Internally, all of the structs implementing this trait are used in the main `aggregation` module
8//! with the input type bounded by `FromStr` so the tool can convert from string records to the internal data types
9//! that these aggregation types manipulate. And the output type is bounded by `Display` so the tool can write
10//! the outputs to standard output.
11
12use crate::parsing::DecimalWrapper;
13use rust_decimal::Decimal;
14use std::collections::{BTreeMap, HashMap, HashSet};
15use std::marker::PhantomData;
16
17/// Accumulates records from a stream, in order to allow functions to be optimized for minimal memory usage.
18pub trait Accumulate<I, O> {
19    /// Creates a new object with an initial value (often based on the value of `item`.)
20    ///
21    /// This has a separate function for the initialization because some functions like sample standard deviation
22    /// update differently than they initialize.
23    fn new(item: I) -> Self;
24    /// Adds a new value to the accumulator.
25    fn update(&mut self, item: I);
26    /// Computes the final value. Returns an option value, which is usually guaranteed to be Some(val)
27    /// (with the exception of `StdDev`.)
28    fn compute(&self) -> Option<O>;
29}
30
31/// The total number of records added to the accumulator.
32pub struct Count<I>(usize, PhantomData<I>);
33
34impl<I> Accumulate<I, usize> for Count<I> {
35    fn new(_item: I) -> Count<I> {
36        Count(1, PhantomData)
37    }
38
39    fn update(&mut self, _item: I) {
40        self.0 += 1;
41    }
42
43    fn compute(&self) -> Option<usize> {
44        Some(self.0)
45    }
46}
47
48/// The total number of *unique* records.
49pub struct CountUnique<I>(HashSet<I>);
50
51impl<I> Accumulate<I, usize> for CountUnique<I>
52where
53    I: std::cmp::Eq,
54    I: std::hash::Hash,
55{
56    fn new(item: I) -> CountUnique<I> {
57        let mut vals = HashSet::new();
58        vals.insert(item);
59        CountUnique(vals)
60    }
61
62    fn update(&mut self, item: I) {
63        self.0.insert(item);
64    }
65
66    fn compute(&self) -> Option<usize> {
67        Some(self.0.len())
68    }
69}
70
71/// The largest value (or the value that would appear last in a sorted array)
72pub struct Maximum<I>(I);
73
74impl<I> Accumulate<I, I> for Maximum<I>
75where
76    I: std::cmp::PartialOrd,
77    I: std::clone::Clone,
78{
79    fn new(item: I) -> Maximum<I> {
80        Maximum(item)
81    }
82
83    fn update(&mut self, item: I) {
84        if self.0 < item {
85            self.0 = item;
86        }
87    }
88
89    fn compute(&self) -> Option<I> {
90        Some(self.0.clone())
91    }
92}
93
94/// The mean. This is only implemented for `DecimalWrapper`, though it  could probably be extended for floating point
95/// types.
96pub struct Mean {
97    running_sum: DecimalWrapper,
98    running_count: usize,
99}
100
101impl Accumulate<DecimalWrapper, DecimalWrapper> for Mean {
102    fn new(item: DecimalWrapper) -> Mean {
103        Mean {
104            running_sum: item,
105            running_count: 1,
106        }
107    }
108
109    fn update(&mut self, item: DecimalWrapper) {
110        self.running_sum.item += item.item;
111        self.running_count += 1;
112    }
113
114    fn compute(&self) -> Option<DecimalWrapper> {
115        let decimal_count = Decimal::new(self.running_count as i64, 0);
116        let result = self.running_sum.item / decimal_count;
117        Some(DecimalWrapper { item: result })
118    }
119}
120
121/// The median value. I've stored values in a `BTreeMap` in order to minimize memory usage.
122/// As a result, this is the least performant of all the functions (running at `Nlog(m)`, rather than
123/// the `N` of all the other algorithms (where `m` is the number of *unique* values in the accumulator).
124pub struct Median {
125    values: BTreeMap<DecimalWrapper, usize>,
126    num: usize,
127}
128
129impl Accumulate<DecimalWrapper, DecimalWrapper> for Median {
130    fn new(item: DecimalWrapper) -> Median {
131        let mut mapping = BTreeMap::new();
132        mapping.insert(item, 1);
133        Median {
134            values: mapping,
135            num: 1,
136        }
137    }
138
139    fn update(&mut self, item: DecimalWrapper) {
140        self.values
141            .entry(item)
142            .and_modify(|val| *val += 1)
143            .or_insert(1);
144        self.num += 1;
145    }
146
147    fn compute(&self) -> Option<DecimalWrapper> {
148        let mut cur_count = 0;
149        let mut cur_val = DecimalWrapper {
150            item: Decimal::new(0, 0),
151        };
152        // creating an iter bc we're stopping at N/2
153        let mut iter = self.values.iter();
154        while (cur_count as f64) < (self.num as f64 / 2.) {
155            // should break before iter.next().is_none()
156            let (result, count) = iter.next().unwrap();
157            cur_count += count;
158            cur_val = *result;
159        }
160        // -- take the mean if we have an even number of records and end at *exactly* the midpoint.
161        if (self.num % 2) == 0
162            && ((cur_count as f64) - (self.num as f64 / 2.)).abs() < std::f64::EPSILON
163        {
164            // iter.next() will always be Some(_) because this is always initialized with
165            let median = (cur_val + *iter.next().unwrap().0)
166                / DecimalWrapper {
167                    item: Decimal::new(2, 0),
168                };
169            Some(median)
170        } else {
171            Some(cur_val)
172        }
173    }
174}
175
176/// The minimum value
177pub struct Minimum<I>(I);
178
179impl<I> Accumulate<I, I> for Minimum<I>
180where
181    I: std::cmp::PartialOrd,
182    I: std::clone::Clone,
183{
184    fn new(item: I) -> Minimum<I> {
185        Minimum(item)
186    }
187
188    fn update(&mut self, item: I) {
189        if self.0 > item {
190            self.0 = item;
191        }
192    }
193
194    fn compute(&self) -> Option<I> {
195        Some(self.0.clone())
196    }
197}
198
199/// A combination of the minimum and maximum values, producing a string concatenating
200/// the minimum value and the maximum value together, separated by a hyphen.
201pub struct MinMax<I> {
202    max_val: I,
203    min_val: I,
204}
205
206impl<I> Accumulate<I, String> for MinMax<I>
207where
208    I: std::fmt::Display,
209    I: std::cmp::PartialOrd,
210    I: std::clone::Clone,
211{
212    fn new(item: I) -> MinMax<I> {
213        MinMax {
214            min_val: item.clone(),
215            max_val: item,
216        }
217    }
218
219    fn update(&mut self, item: I) {
220        if self.min_val > item {
221            self.min_val = item;
222        } else if self.max_val < item {
223            self.max_val = item;
224        }
225    }
226
227    fn compute(&self) -> Option<String> {
228        Some(format!("{} - {}", self.min_val, self.max_val))
229    }
230}
231
232/// The most commonly appearing item.
233///
234/// If there is more than one mode, it returns
235/// the item that reached the maximum value first. So in the case of
236/// ["a", "b", "b", "a"], it will return "b" because "b" was the first
237/// value to appear twice.
238pub struct Mode<I> {
239    histogram: HashMap<I, usize>,
240    max_count: usize,
241    max_val: I,
242}
243
244impl<I> Accumulate<I, I> for Mode<I>
245where
246    I: std::cmp::PartialOrd,
247    I: std::cmp::Eq,
248    I: std::hash::Hash,
249    I: std::clone::Clone,
250{
251    fn new(item: I) -> Mode<I> {
252        let mut histogram = HashMap::new();
253        let max_val = item.clone();
254        histogram.insert(item, 1);
255        Mode {
256            histogram,
257            max_count: 1,
258            max_val,
259        }
260    }
261
262    fn update(&mut self, item: I) {
263        // barely adapted from https://docs.rs/indexmap/1.0.2/indexmap/map/struct.IndexMap.html
264        let new_count = *self.histogram.get(&item).unwrap_or(&0) + 1;
265        if new_count > self.max_count {
266            self.max_count = new_count;
267            self.max_val = item.clone();
268        }
269        *self.histogram.entry(item).or_insert(0) += 1;
270    }
271
272    fn compute(&self) -> Option<I> {
273        Some(self.max_val.clone())
274    }
275}
276
277/// The range, or the difference between the minimum and maximum values (where the minimum value is subtracted from the maximum value).
278pub struct Range<I, O> {
279    max_val: I,
280    min_val: I,
281    phantom: PhantomData<O>,
282}
283
284impl<I, O> Accumulate<I, O> for Range<I, O>
285where
286    I: std::cmp::PartialOrd,
287    I: std::ops::Sub<Output = O>,
288    I: std::marker::Copy,
289{
290    #[allow(clippy::clone_on_copy)]
291    fn new(item: I) -> Range<I, O> {
292        Range {
293            min_val: item,
294            max_val: item.clone(),
295            phantom: PhantomData,
296        }
297    }
298
299    fn update(&mut self, item: I) {
300        if self.min_val > item {
301            self.min_val = item;
302        }
303        if self.max_val < item {
304            self.max_val = item;
305        }
306    }
307
308    fn compute(&self) -> Option<O> {
309        Some(self.max_val - self.min_val)
310    }
311}
312
313/// Computes the *sample* variance in a single pass, using
314/// [Welford's algorithm](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm).
315
316/// The attributes in this method refer to the same ones described in
317/// *Accuracy and Stability of Numerical Algorithms* by Higham (2nd Edition, page 11).
318pub struct StdDev {
319    // solution from Nicholas Higham: Accuracy and Stability of Numerical Algorithms
320    // Second Edition, 2002, p. 11
321    q: f64,
322    m: f64,
323    /// The number of records parsed so far
324    num_records: f64,
325}
326
327impl Accumulate<f64, f64> for StdDev {
328    fn new(item: f64) -> Self {
329        StdDev {
330            q: 0.,
331            m: item,
332            num_records: 1.,
333        }
334    }
335
336    fn update(&mut self, item: f64) {
337        self.num_records += 1.;
338        let squared_diff = (item - self.m).powi(2);
339        self.q += ((self.num_records - 1.) * squared_diff) / self.num_records;
340        self.m += (item - self.m) / self.num_records;
341    }
342
343    fn compute(&self) -> Option<f64> {
344        if self.num_records <= 1. {
345            return None;
346        }
347        Some((self.q / (self.num_records - 1.)).sqrt())
348    }
349}
350
351/// The running sum of a stream of values.
352pub struct Sum<I>(I);
353
354impl<I> Accumulate<I, I> for Sum<I>
355where
356    I: std::ops::AddAssign,
357    I: std::fmt::Display,
358    I: std::marker::Copy,
359{
360    fn new(item: I) -> Sum<I> {
361        Sum(item)
362    }
363
364    fn update(&mut self, item: I) {
365        self.0 += item;
366    }
367
368    fn compute(&self) -> Option<I> {
369        Some(self.0)
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use crate::parsing::{self, CustomDateObject, DecimalWrapper};
377    use proptest::prelude::*;
378    use proptest::test_runner::Config;
379
380    #[test]
381    fn test_unique_count() {
382        let update_vals = vec!["apple", "pie", "is", "good"]
383            .into_iter()
384            .map(|v| v.to_string());
385        let mut no_dups = CountUnique::new("really".to_string());
386        let mut dup = CountUnique::new("good".to_string());
387        for val in update_vals {
388            no_dups.update(val.clone());
389            dup.update(val.clone());
390        }
391        assert_eq!(no_dups.compute().unwrap(), 5);
392        assert_eq!(dup.compute().unwrap(), 4);
393    }
394
395    #[test]
396    fn test_max_string() {
397        let update_vals = vec!["2019-02-03", "2020-01-03", "2018-01-02"]
398            .into_iter()
399            .map(|v| v.to_string());
400        let mut max_vals = Maximum::new("2019-12-31".to_string());
401        for val in update_vals {
402            max_vals.update(val);
403        }
404        assert_eq!(max_vals.compute(), Some("2020-01-03".to_string()));
405    }
406
407    #[test]
408    fn test_max_dates() {
409        // there's probably a better way of handling this, but this uses the same
410        // code as parsing does, so shouldn't affect either
411        let _ex = parsing::set_date_format("%Y-%m-%d %H:%M:%S".to_string());
412        let date_updates = vec![
413            "2019-02-03 12:23:10",
414            "2020-01-03 13:45:02",
415            "2018-01-02 12:23:10",
416        ];
417        let cust_date: CustomDateObject = "2019-12-31 01:20:13".parse().unwrap();
418        let mut date_vals = Maximum::new(cust_date);
419        for val in date_updates {
420            let date_parse: CustomDateObject = val.parse().unwrap();
421            date_vals.update(date_parse);
422        }
423        assert_eq!(
424            date_vals.compute().unwrap().to_string(),
425            "2020-01-03 13:45:02".to_string()
426        );
427    }
428
429    #[test]
430    fn test_max_decimals() {
431        let updates = vec!["1.2", "2e-7", "2E3", "10000"];
432        let start_dec: DecimalWrapper = ".278".parse().unwrap();
433        let mut max_dec = Maximum::new(start_dec);
434        for val in updates {
435            max_dec.update(val.parse().unwrap());
436        }
437        assert_eq!(max_dec.compute().unwrap().to_string(), "10000".to_string());
438    }
439
440    #[test]
441    fn test_min_string() {
442        let update_vals = vec!["2019-02-03", "2020-01-03", "2018-01-02"]
443            .into_iter()
444            .map(|v| v.to_string());
445        let mut max_vals = Minimum::new("2019-12-31".to_string());
446        for val in update_vals {
447            max_vals.update(val);
448        }
449        assert_eq!(max_vals.compute(), Some("2018-01-02".to_string()));
450    }
451
452    #[test]
453    fn test_min_dates() {
454        // there's probably a better way of handling this, but this uses the same
455        // code as parsing does, so shouldn't affect either
456        let _ex = parsing::set_date_format("%Y-%m-%d %H:%M:%S".to_string());
457        let date_updates = vec![
458            "2019-02-03 12:23:10",
459            "2020-01-03 13:45:02",
460            "2018-01-02 12:23:10",
461        ];
462        let cust_date: CustomDateObject = "2019-12-31 01:20:13".parse().unwrap();
463        let mut date_vals = Minimum::new(cust_date);
464        for val in date_updates {
465            let date_parse: CustomDateObject = val.parse().unwrap();
466            date_vals.update(date_parse);
467        }
468        assert_eq!(
469            date_vals.compute().unwrap().to_string(),
470            "2018-01-02 12:23:10".to_string()
471        );
472    }
473
474    #[test]
475    fn test_min_decimals() {
476        let updates = vec!["1.2", "2e-7", "2E3", "10000"];
477        let start_dec: DecimalWrapper = ".278".parse().unwrap();
478        let mut max_dec = Minimum::new(start_dec);
479        for val in updates {
480            max_dec.update(val.parse().unwrap());
481        }
482        assert_eq!(
483            max_dec.compute().unwrap().to_string(),
484            "0.0000002".to_string()
485        );
486    }
487
488    #[test]
489    fn test_minmax_string() {
490        let update_vals = vec!["2019-02-03", "2020-01-03", "2018-01-02"]
491            .into_iter()
492            .map(|v| v.to_string());
493        let mut max_vals = MinMax::new("2019-12-31".to_string());
494        for val in update_vals {
495            max_vals.update(val);
496        }
497        assert_eq!(
498            max_vals.compute(),
499            Some("2018-01-02 - 2020-01-03".to_string())
500        );
501    }
502
503    #[test]
504    fn test_minmax_dates() {
505        // there's probably a better way of handling this, but this uses the same
506        // code as parsing does, so shouldn't affect either
507        let _ex = parsing::set_date_format("%Y-%m-%d %H:%M:%S".to_string());
508        let date_updates = vec![
509            "2019-02-03 12:23:10",
510            "2020-01-03 13:45:02",
511            "2018-01-02 12:23:10",
512        ];
513        let cust_date: CustomDateObject = "2019-12-31 01:20:13".parse().unwrap();
514        let mut date_vals = MinMax::new(cust_date);
515        for val in date_updates {
516            let date_parse: CustomDateObject = val.parse().unwrap();
517            date_vals.update(date_parse);
518        }
519        assert_eq!(
520            date_vals.compute().unwrap().to_string(),
521            "2018-01-02 12:23:10 - 2020-01-03 13:45:02".to_string()
522        );
523    }
524
525    #[test]
526    fn test_minmax_decimals() {
527        let updates = vec!["1.2", "2e-7", "2E3", "10000"];
528        let start_dec: DecimalWrapper = ".278".parse().unwrap();
529        let mut max_dec = MinMax::new(start_dec);
530        for val in updates {
531            max_dec.update(val.parse().unwrap());
532        }
533        assert_eq!(
534            max_dec.compute().unwrap().to_string(),
535            "0.0000002 - 10000".to_string()
536        );
537    }
538
539    #[test]
540    fn test_range_dates() {
541        // there's probably a better way of handling this, but this uses the same
542        // code as parsing does, so shouldn't affect either
543        let _ex = parsing::set_date_format("%Y-%m-%d %H:%M:%S".to_string());
544        let date_updates = vec![
545            "2019-02-03 00:00:00",
546            "2020-01-03 12:00:00",
547            "2018-01-02 06:00:00",
548        ];
549        let cust_date: CustomDateObject = "2019-12-31 01:20:13".parse().unwrap();
550        let mut date_vals = Range::new(cust_date);
551        for val in date_updates {
552            let date_parse: CustomDateObject = val.parse().unwrap();
553            date_vals.update(date_parse);
554        }
555        assert_eq!(date_vals.compute().unwrap(), 731.25);
556    }
557
558    #[test]
559    fn test_median() {
560        let dec1: DecimalWrapper = "2".parse().unwrap();
561        let mut dec_vals = Median::new(dec1);
562        assert_eq!(dec_vals.compute().unwrap().to_string(), "2".to_string());
563        let new_vals = vec!["3", "5"];
564        for val in new_vals {
565            let dec: DecimalWrapper = val.parse().unwrap();
566            dec_vals.update(dec);
567        }
568        assert_eq!(dec_vals.compute().unwrap().to_string(), "3".to_string());
569        let next_val: DecimalWrapper = "1".parse().unwrap();
570        dec_vals.update(next_val);
571        assert_eq!(dec_vals.compute().unwrap().to_string(), "2.5".to_string());
572        let mult_middle_vals: DecimalWrapper = "3".parse().unwrap();
573        let mut mult_median = Median::new(mult_middle_vals);
574        for val in vec!["5", "6", "1", "4", "3"] {
575            mult_median.update(val.parse().unwrap());
576        }
577        assert_eq!(
578            mult_median.compute().unwrap().to_string(),
579            "3.5".to_string()
580        );
581    }
582
583    #[test]
584    fn test_range_decimals() {
585        let updates = vec!["1.2", "2E3", "10000"];
586        let start_dec: DecimalWrapper = "19".parse().unwrap();
587        let mut max_dec = Range::new(start_dec);
588        for val in updates {
589            max_dec.update(val.parse().unwrap());
590        }
591        assert_eq!(max_dec.compute().unwrap().to_string(), "9998.8".to_string());
592    }
593
594    #[test]
595    fn test_mode() {
596        let mut mode = Mode::new("a".to_string());
597        assert_eq!(mode.compute().unwrap().to_string(), "a".to_string());
598        let new_vals = vec!["b", "c", "a"].into_iter().map(|v| v.to_string());
599        for val in new_vals {
600            mode.update(val);
601        }
602        assert_eq!(mode.compute().unwrap(), "a".to_string());
603        for _i in 1..=10000 {
604            let mut mode_ordering = Mode::new("a".to_string());
605            for val in vec!["b", "a", "b"].into_iter().map(|v| v.to_string()) {
606                mode_ordering.update(val);
607            }
608            assert_eq!(mode.compute().unwrap(), "a".to_string());
609        }
610    }
611
612    #[test]
613    fn test_sum() {
614        let dec_num: DecimalWrapper = "10".parse().unwrap();
615        let mut summation = Sum::new(dec_num);
616        assert_eq!(summation.compute().unwrap().to_string(), "10".to_string());
617        let addl_vals = vec!["0.3", "100", "3.2"];
618        for val in addl_vals {
619            summation.update(val.parse().unwrap());
620        }
621        assert_eq!(
622            summation.compute().unwrap().to_string(),
623            "113.5".to_string()
624        );
625    }
626
627    proptest! {
628        #![proptest_config(Config::with_cases(100))]
629        #[test]
630        fn test_count_gets_raw_count(mut string_vecs in prop::collection::vec(any::<String>(), 1 .. 50)) {
631            let total_count = string_vecs.len();
632            let count_split = string_vecs.split_off(1);
633            let mut count_obj = Count::new(string_vecs[0].clone());
634            for item in count_split {
635                count_obj.update(item);
636            }
637            assert_eq!(count_obj.compute().unwrap(), total_count);
638        }
639    }
640}