use std::ops::{Add, AddAssign};
#[cfg(doc)]
use embed_doc_image::embed_doc_image;
use num_traits::Float;
use super::i_fast_sum_in_place;
use super::traits::{IFastSum, SumAccumulator};
use crate::util::traits::{FloatFormat, RawExponent};
use crate::util::two_sum;
#[cfg_attr(doc, embed_doc_image("OnlineExactSum", "images/OnlineExactSum.svg"))]
#[derive(Clone, Debug)]
pub struct OnlineExactSum<F> {
i: usize,
a1: Box<[F]>,
a2: Box<[F]>,
}
impl<F> OnlineExactSum<F>
where
F: Float + FloatFormat + RawExponent,
{
fn new() -> Self {
OnlineExactSum {
i: 0,
a1: vec![F::zero(); F::base_pow_exponent_digits()].into_boxed_slice(),
a2: vec![F::zero(); F::base_pow_exponent_digits()].into_boxed_slice(),
}
}
#[inline(never)]
fn compact(&mut self) {
let mut b1v = vec![F::zero(); F::base_pow_exponent_digits()].into_boxed_slice();
let mut b2v = vec![F::zero(); F::base_pow_exponent_digits()].into_boxed_slice();
for &y in self.a1.iter().chain(self.a2.iter()) {
let j = y.raw_exponent();
debug_assert_eq!(b1v.len(), F::base_pow_exponent_digits());
debug_assert_eq!(b2v.len(), F::base_pow_exponent_digits());
debug_assert!(j < F::base_pow_exponent_digits());
let b1 = unsafe { b1v.get_unchecked_mut(j) };
let b2 = unsafe { b2v.get_unchecked_mut(j) };
let (b, e) = two_sum(*b1, y);
*b1 = b;
*b2 = *b2 + e;
}
self.a1 = b1v;
self.a2 = b2v;
self.i = 2 * F::base_pow_exponent_digits();
}
}
impl<F> SumAccumulator<F> for OnlineExactSum<F>
where
F: Float + IFastSum + FloatFormat + RawExponent,
{
fn zero() -> Self {
Self::new()
}
#[inline]
fn sum(self) -> F {
let mut a = self.a1.into_vec();
let mut b = self.a2.into_vec();
a.append(&mut b);
a.retain(|&x| x != F::zero());
i_fast_sum_in_place(&mut a[..])
}
}
impl<F> Add<F> for OnlineExactSum<F>
where
OnlineExactSum<F>: AddAssign<F>,
{
type Output = Self;
#[inline]
fn add(mut self, rhs: F) -> Self::Output {
self += rhs;
self
}
}
impl<F> From<F> for OnlineExactSum<F>
where
F: Float + FloatFormat + RawExponent,
{
fn from(x: F) -> Self {
Self::new() + x
}
}
impl<F> Add for OnlineExactSum<F>
where
F: Float + IFastSum + FloatFormat + RawExponent,
{
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self::Output {
self.absorb(rhs.a1.iter().cloned().chain(rhs.a2.iter().cloned()))
}
}
unsafe impl<F> Send for OnlineExactSum<F> where F: Send {}
impl<F> AddAssign<F> for OnlineExactSum<F>
where
F: Float + FloatFormat + RawExponent,
{
#[inline]
fn add_assign(&mut self, rhs: F) {
{
let j = rhs.raw_exponent();
debug_assert_eq!(self.a1.len(), F::base_pow_exponent_digits());
debug_assert_eq!(self.a2.len(), F::base_pow_exponent_digits());
debug_assert!(j < F::base_pow_exponent_digits());
let a1 = unsafe { self.a1.get_unchecked_mut(j) };
let a2 = unsafe { self.a2.get_unchecked_mut(j) };
let (a, e) = two_sum(*a1, rhs);
*a1 = a;
*a2 = *a2 + e;
}
debug_assert!(self.i < F::base_pow_significand_digits_half());
debug_assert!(F::base_pow_significand_digits_half() < usize::MAX);
debug_assert!(self.i.checked_add(1).is_some());
self.i += 1;
if self.i >= F::base_pow_significand_digits_half() {
self.compact();
}
}
}