1use num_traits::Float;
9
10#[derive(Debug, Clone, Copy)]
31pub struct KahanSum<T: Float> {
32 sum: T,
33 compensation: T,
34}
35
36impl<T: Float> KahanSum<T> {
37 pub fn new(value: T) -> Self {
45 Self {
46 sum: value,
47 compensation: T::zero(),
48 }
49 }
50
51 pub fn value(&self) -> T {
55 self.sum + self.compensation
56 }
57}
58
59impl<T: Float> Default for KahanSum<T> {
60 fn default() -> Self {
61 Self::new(T::zero())
62 }
63}
64
65impl<T: Float> PartialEq for KahanSum<T> {
66 fn eq(&self, other: &Self) -> bool {
67 self.value() == other.value()
68 }
69}
70
71impl<T: Float + std::fmt::Display> std::fmt::Display for KahanSum<T> {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 self.value().fmt(f)
74 }
75}
76
77impl<T: Float> std::ops::AddAssign<Self> for KahanSum<T> {
78 fn add_assign(&mut self, rhs: Self) {
79 kahan_add(&mut self.sum, rhs.sum, &mut self.compensation);
80 kahan_add(&mut self.sum, rhs.compensation, &mut self.compensation);
81 }
82}
83
84impl<T: Float> std::ops::AddAssign<T> for KahanSum<T> {
85 fn add_assign(&mut self, rhs: T) {
86 kahan_add(&mut self.sum, rhs, &mut self.compensation);
87 }
88}
89
90impl<T: Float, X> std::ops::Add<X> for KahanSum<T>
91where
92 Self: std::ops::AddAssign<X>,
93{
94 type Output = Self;
95
96 fn add(self, rhs: X) -> Self::Output {
97 let mut sum = self;
98 sum += rhs;
99 sum
100 }
101}
102
103impl<T: Float> From<T> for KahanSum<T> {
104 fn from(value: T) -> Self {
105 Self::new(value)
106 }
107}
108
109#[inline]
123fn kahan_add<T: Float>(current_sum: &mut T, x: T, compensation: &mut T) {
124 let sum = *current_sum;
125 let c = *compensation;
126 let y = x - c;
127 let t = sum + y;
128 *compensation = (t - sum) - y;
129 *current_sum = t;
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use approx::*;
136
137 #[test]
138 fn test_kahan_add() {
139 type Float = f32;
140 let iterations = 50_000_000_usize;
141 let mut normal: Float = 0.;
142 let mut kahan: Float = 0.;
143 let mut kahan_c: Float = 0.;
144 let x = 1.1;
145
146 for _ in 0..iterations {
147 normal += x;
148 kahan_add(&mut kahan, x, &mut kahan_c);
149 }
150 let expected = iterations as Float * x;
151 println!("should be: {}", expected);
152 println!(
153 "normal: {} (diff: {:.0}%)",
154 normal,
155 (normal - expected) / expected * 100.
156 );
157 println!(
158 "kahan: {} (diff: {:.0}%)",
159 kahan,
160 (kahan - expected) / expected * 100.
161 );
162 assert_abs_diff_eq!(expected, kahan, epsilon = 1e-10);
163 assert!((expected - normal).abs() > 500_000.); }
165
166 #[test]
167 fn test_kahan_sum() {
168 type Float = f32;
169
170 let iterations = 50_000_000_usize;
171 let mut normal: Float = 0.;
172 let mut kahan = KahanSum::<Float>::default();
173 let mut kahan2 = KahanSum::<Float>::default();
174
175 let x = 1.1;
176
177 for i in 0..iterations {
178 normal += x;
179 kahan += x;
180 if i % 2 == 1 {
181 let mut double = KahanSum::<Float>::default();
182 double += x;
183 double += x;
184 kahan2 += double;
185 }
186 }
187 let expected = iterations as Float * x;
188 println!("should be: {}", expected);
189 println!(
190 "normal: {} (diff: {:.0}%)",
191 normal,
192 (normal - expected) / expected * 100.
193 );
194 println!(
195 "kahan: {} (diff: {:.0}%)",
196 kahan,
197 (kahan.value() - expected) / expected * 100.
198 );
199 println!(
200 "kahan2: {} (diff: {:.0}%)",
201 kahan2,
202 (kahan2.value() - expected) / expected * 100.
203 );
204 assert_abs_diff_eq!(expected, kahan.value(), epsilon = 1e-10);
205 assert_abs_diff_eq!(expected, kahan2.value(), epsilon = 1e-10);
206 assert!((expected - normal).abs() > 500_000.); }
208
209 #[test]
210 fn test_doctest() {
211 let repetitions = 10_000;
212 let mut naive = 0.0_f32;
213 let mut sum = KahanSum::new(0.0_f32);
214 (1..=repetitions).for_each(|_| {
215 sum += 0.1;
216 naive += 0.1;
217 });
218 assert_eq!(sum.value(), repetitions as f32 * 0.1);
219 assert_ne!(naive, repetitions as f32 * 0.1);
220 }
221}