extern crate num_traits;
use num_traits::float::Float;
use std::borrow::Borrow;
use std::mem::swap;
use std::ops::{Add, AddAssign};
#[derive(Debug, Clone)]
pub struct KahanSum<T: Float> {
sum: T,
err: T,
}
impl<T: Float> Default for KahanSum<T> {
fn default() -> Self {
KahanSum {
sum: T::zero(),
err: T::zero(),
}
}
}
impl<T: Float> KahanSum<T> {
pub fn new() -> Self {
KahanSum::default()
}
pub fn new_with_value(initial: T) -> Self {
KahanSum {
sum: initial,
err: T::zero(),
}
}
pub fn sum(&self) -> T {
self.sum
}
pub fn err(&self) -> T {
self.err
}
}
impl<T: Float> AddAssign<T> for KahanSum<T> {
fn add_assign(&mut self, rhs: T) {
let mut rhs = rhs;
if self.sum.abs() < rhs.abs() {
swap(&mut self.sum, &mut rhs);
}
let y = rhs - self.err;
let sum = self.sum + y;
let err = (sum - self.sum) - y;
self.sum = sum;
self.err = err;
}
}
impl<T: Float> Add<T> for KahanSum<T> {
type Output = Self;
fn add(self, rhs: T) -> Self::Output {
let mut rv = self;
rv += rhs;
rv
}
}
impl<T: Float> AddAssign<&KahanSum<T>> for KahanSum<T> {
fn add_assign(&mut self, rhs: &KahanSum<T>) {
let mut rhs = rhs.clone();
if self.sum.abs() < rhs.sum.abs() {
swap(self, &mut rhs);
}
let combined_errors = rhs.err + self.err;
let y = rhs.sum - combined_errors;
let sum = self.sum + y;
let err = (sum - self.sum) - y;
self.sum = sum;
self.err = err;
}
}
impl<T: Float> Add<&KahanSum<T>> for KahanSum<T> {
type Output = Self;
fn add(self, rhs: &KahanSum<T>) -> Self::Output {
let mut rv = self;
rv += rhs;
rv
}
}
impl<T: Float> AddAssign<KahanSum<T>> for KahanSum<T> {
fn add_assign(&mut self, rhs: KahanSum<T>) {
*self += &rhs;
}
}
impl<T: Float> Add<KahanSum<T>> for KahanSum<T> {
type Output = Self;
fn add(self, rhs: KahanSum<T>) -> Self::Output {
let mut rv = self;
rv += rhs;
rv
}
}
pub trait KahanSummator<T: Float> {
fn kahan_sum(self) -> KahanSum<T>;
}
impl<T, U, V> KahanSummator<T> for U
where
U: Iterator<Item = V>,
V: Borrow<T>,
T: Float,
{
fn kahan_sum(self) -> KahanSum<T> {
self.fold(KahanSum::new(), |sum, item| sum + *item.borrow())
}
}
#[cfg(test)]
mod tests {
use KahanSum;
use KahanSummator;
#[test]
fn it_works() {
let summands = [
10000.0f32, 3.14159f32, 2.71828f32, 3.14159f32, 2.71828f32, 3.14159f32, 2.71828f32,
];
assert_eq!(10017.58f32, summands.iter().kahan_sum().sum());
}
#[test]
fn associativity_holds() {
let summands = [
123.14159f32,
-2.71828f32,
-3.14159f32,
2.71828f32,
3.14159f32,
2.71828f32,
];
let more_summands = [-3.14159f32, 2.71828f32, 3.14159f32, 2.71828f32, 3.14159f32];
let first: KahanSum<f32> = summands.iter().kahan_sum();
let second: KahanSum<f32> = more_summands.iter().kahan_sum();
let summed = first + second;
let proper: KahanSum<f32> = summands.iter().chain(more_summands.iter()).kahan_sum();
assert_eq!(summed.sum(), proper.sum());
}
}