stats_ci/
utils.rs

1//!
2//! Groups utility functions and types
3//!
4//! Offers support for the computation of compensated floating point sums (aka. Kahan summation).
5//!
6//! The main type is [`KahanSum`] where you can find further explanations and an example.
7
8use num_traits::Float;
9
10///
11/// Kahan compensated summation register.
12///
13/// This is a register that can be used to sum a sequence of floating point numbers with a better precision than a naive summation.
14///
15/// See <https://en.wikipedia.org/wiki/Kahan_summation_algorithm>
16///
17/// # Examples
18///
19/// ```ignore
20/// let repetitions = 10_000;
21/// let mut naive = 0.0_f32;
22/// let mut sum = KahanSum::new(0.0_f32);
23/// (1..=repetitions).for_each(|_| {
24///     sum += 0.1;
25///     naive += 0.1;
26/// });
27/// assert_eq!(sum.sum(), repetitions as f32 * 0.1);
28/// assert_ne!(naive, repetitions as f32 * 0.1);
29/// ```
30#[derive(Debug, Clone, Copy)]
31pub struct KahanSum<T: Float> {
32    sum: T,
33    compensation: T,
34}
35
36impl<T: Float> KahanSum<T> {
37    ///
38    /// Create a new KahanSum register with the given initial value
39    ///
40    /// # Arguments
41    ///
42    /// * `value` - the initial value
43    ///
44    pub fn new(value: T) -> Self {
45        Self {
46            sum: value,
47            compensation: T::zero(),
48        }
49    }
50
51    ///
52    /// Return the current value of the sum
53    ///
54    pub fn value(&self) -> T {
55        self.sum + self.compensation
56    }
57}
58
59impl<T: Float> Default for KahanSum<T> {
60    fn default() -> Self {
61        Self::new(T::zero())
62    }
63}
64
65impl<T: Float> PartialEq for KahanSum<T> {
66    fn eq(&self, other: &Self) -> bool {
67        self.value() == other.value()
68    }
69}
70
71impl<T: Float + std::fmt::Display> std::fmt::Display for KahanSum<T> {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        self.value().fmt(f)
74    }
75}
76
77impl<T: Float> std::ops::AddAssign<Self> for KahanSum<T> {
78    fn add_assign(&mut self, rhs: Self) {
79        kahan_add(&mut self.sum, rhs.sum, &mut self.compensation);
80        kahan_add(&mut self.sum, rhs.compensation, &mut self.compensation);
81    }
82}
83
84impl<T: Float> std::ops::AddAssign<T> for KahanSum<T> {
85    fn add_assign(&mut self, rhs: T) {
86        kahan_add(&mut self.sum, rhs, &mut self.compensation);
87    }
88}
89
90impl<T: Float, X> std::ops::Add<X> for KahanSum<T>
91where
92    Self: std::ops::AddAssign<X>,
93{
94    type Output = Self;
95
96    fn add(self, rhs: X) -> Self::Output {
97        let mut sum = self;
98        sum += rhs;
99        sum
100    }
101}
102
103impl<T: Float> From<T> for KahanSum<T> {
104    fn from(value: T) -> Self {
105        Self::new(value)
106    }
107}
108
109///
110/// Compensated Kahan summation.
111/// See <https://en.wikipedia.org/wiki/Kahan_summation_algorithm>
112///
113/// The function is meant to be called at each iteration of the summation,
114/// with relevant variables managed externally
115///
116/// # Arguments
117///
118/// * `current_sum` - the current sum
119/// * `x` - the next value to add to the sum
120/// * `compensation` - the compensation term
121///
122#[inline]
123fn kahan_add<T: Float>(current_sum: &mut T, x: T, compensation: &mut T) {
124    let sum = *current_sum;
125    let c = *compensation;
126    let y = x - c;
127    let t = sum + y;
128    *compensation = (t - sum) - y;
129    *current_sum = t;
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use approx::*;
136
137    #[test]
138    fn test_kahan_add() {
139        type Float = f32;
140        let iterations = 50_000_000_usize;
141        let mut normal: Float = 0.;
142        let mut kahan: Float = 0.;
143        let mut kahan_c: Float = 0.;
144        let x = 1.1;
145
146        for _ in 0..iterations {
147            normal += x;
148            kahan_add(&mut kahan, x, &mut kahan_c);
149        }
150        let expected = iterations as Float * x;
151        println!("should be: {}", expected);
152        println!(
153            "normal: {} (diff: {:.0}%)",
154            normal,
155            (normal - expected) / expected * 100.
156        );
157        println!(
158            "kahan: {} (diff: {:.0}%)",
159            kahan,
160            (kahan - expected) / expected * 100.
161        );
162        assert_abs_diff_eq!(expected, kahan, epsilon = 1e-10);
163        assert!((expected - normal).abs() > 500_000.); // normal summation is not accurate for f32
164    }
165
166    #[test]
167    fn test_kahan_sum() {
168        type Float = f32;
169
170        let iterations = 50_000_000_usize;
171        let mut normal: Float = 0.;
172        let mut kahan = KahanSum::<Float>::default();
173        let mut kahan2 = KahanSum::<Float>::default();
174
175        let x = 1.1;
176
177        for i in 0..iterations {
178            normal += x;
179            kahan += x;
180            if i % 2 == 1 {
181                let mut double = KahanSum::<Float>::default();
182                double += x;
183                double += x;
184                kahan2 += double;
185            }
186        }
187        let expected = iterations as Float * x;
188        println!("should be: {}", expected);
189        println!(
190            "normal: {} (diff: {:.0}%)",
191            normal,
192            (normal - expected) / expected * 100.
193        );
194        println!(
195            "kahan: {} (diff: {:.0}%)",
196            kahan,
197            (kahan.value() - expected) / expected * 100.
198        );
199        println!(
200            "kahan2: {} (diff: {:.0}%)",
201            kahan2,
202            (kahan2.value() - expected) / expected * 100.
203        );
204        assert_abs_diff_eq!(expected, kahan.value(), epsilon = 1e-10);
205        assert_abs_diff_eq!(expected, kahan2.value(), epsilon = 1e-10);
206        assert!((expected - normal).abs() > 500_000.); // normal summation is not accurate for f32
207    }
208
209    #[test]
210    fn test_doctest() {
211        let repetitions = 10_000;
212        let mut naive = 0.0_f32;
213        let mut sum = KahanSum::new(0.0_f32);
214        (1..=repetitions).for_each(|_| {
215            sum += 0.1;
216            naive += 0.1;
217        });
218        assert_eq!(sum.value(), repetitions as f32 * 0.1);
219        assert_ne!(naive, repetitions as f32 * 0.1);
220    }
221}