acme_tensor/stats/
mod.rs

1/*
2    Appellation: stats <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5
6mod impl_stats;
7
8use crate::prelude::{Axis, Scalar};
9#[cfg(not(feature = "std"))]
10use alloc::collections::BTreeMap;
11#[cfg(feature = "std")]
12use std::collections::BTreeMap;
13
14/// A trait describing the behavior of a collection of values that can be used to compute statistics.
15pub trait Statistics<T> {
16    /// Returns the maximum value in the collection.
17    fn max(&self) -> T;
18    /// Returns the mean (average) value of the collection.
19    fn mean(&self) -> T;
20    /// Returns the median value in the collection.
21    fn median(&self) -> T;
22    /// Returns the minimum value in the collection.
23    fn min(&self) -> T;
24    /// Get the mode of the collection.
25    fn mode(&self) -> T;
26
27    fn sum(&self) -> T;
28    /// Compute the standard deviation
29    fn std(&self) -> T;
30    /// Compute the variance
31    fn variance(&self) -> T;
32}
33
34macro_rules! impl_stats {
35    ($container:ty, $size:ident) => {
36        impl<T> Statistics<T> for $container
37        where
38            Self: Clone,
39            T: Ord + Scalar,
40        {
41            fn max(&self) -> T {
42                self.iter().max().unwrap().clone()
43            }
44
45            fn mean(&self) -> T {
46                self.sum() / T::from_usize(self.$size()).unwrap()
47            }
48
49            fn median(&self) -> T {
50                let mut sorted = self.clone();
51                sorted.sort();
52                let mid = sorted.$size() / 2;
53                if sorted.$size() % 2 == 0 {
54                    (sorted[mid - 1] + sorted[mid]) / T::from_usize(2).unwrap()
55                } else {
56                    sorted[mid]
57                }
58            }
59
60            fn min(&self) -> T {
61                self.iter().min().unwrap().clone()
62            }
63
64            fn mode(&self) -> T {
65                let mut freqs = BTreeMap::new();
66                for &val in self.iter() {
67                    *freqs.entry(val).or_insert(0) += 1;
68                }
69                let max_freq = freqs.values().max().unwrap();
70                *freqs.iter().find(|(_, &freq)| freq == *max_freq).unwrap().0
71            }
72
73            fn sum(&self) -> T {
74                self.iter().copied().sum()
75            }
76
77            fn std(&self) -> T {
78                self.variance().sqrt()
79            }
80
81            fn variance(&self) -> T {
82                let sqr = |x| x * x;
83                let mean = self.mean();
84                self.iter().map(|x| sqr(*x - mean)).sum::<T>()
85                    / T::from_usize(self.$size()).unwrap()
86            }
87        }
88    };
89}
90impl_stats!(Vec<T>, len);
91pub trait StatisticsExt<T>: Statistics<T> {
92    /// Compute the mean along the specified axis.
93    fn mean_axis(&self, axis: Axis) -> T;
94}
95
96pub(crate) mod prelude {
97    pub use super::{Statistics, StatisticsExt};
98}
99
100#[cfg(test)]
101mod tests {}