use num_traits::Float;
use crate::util::traits::{HalfUlp, Round3};
use crate::util::{round3, two_sum};
pub trait IFastSum: Float + HalfUlp + Round3 {}
impl<F> IFastSum for F where F: Float + HalfUlp + Round3 {}
pub fn i_fast_sum_in_place<F>(xs: &mut [F]) -> F
where
F: IFastSum,
{
let mut n = xs.len();
i_fast_sum_in_place_aux(xs, &mut n, true)
}
fn i_fast_sum_in_place_aux<F>(xs: &mut [F], n: &mut usize, recurse: bool) -> F
where
F: IFastSum,
{
let mut s = F::zero();
debug_assert!(*n <= xs.len());
for i in 0..*n {
let x = unsafe { xs.get_unchecked_mut(i) };
let (a, b) = two_sum(s, *x);
s = a;
*x = b;
}
loop {
let mut count: usize = 0; let mut st = F::zero();
let mut sm = F::zero();
debug_assert!(*n <= xs.len());
for i in 0..*n {
let (a, b) = two_sum(st, unsafe { *xs.get_unchecked(i) });
st = a;
if b != F::zero() {
debug_assert!(count < xs.len());
unsafe {
*xs.get_unchecked_mut(count) = b;
}
debug_assert!(count < usize::MAX);
debug_assert!(count.checked_add(1).is_some());
count += 1;
sm = sm.max(Float::abs(st));
}
}
let em = F::from(count).expect("count not representable as floating point number")
* sm.half_ulp();
let (a, b) = two_sum(s, st);
s = a;
st = b;
debug_assert!(count < xs.len());
unsafe {
*xs.get_unchecked_mut(count) = st;
}
debug_assert!(count < usize::MAX);
debug_assert!(count.checked_add(1).is_some());
*n = count + 1;
if (em == F::zero()) || (em < s.half_ulp()) {
if !recurse {
return s;
}
let (w1, e1) = two_sum(st, em);
let (w2, e2) = two_sum(st, -em);
if (w1 + s != s)
|| (w2 + s != s)
|| (round3(s, w1, e1) != s)
|| (round3(s, w2, e2) != s)
{
let mut s1 = i_fast_sum_in_place_aux(xs, n, false);
let (a, b) = two_sum(s, s1);
s = a;
s1 = b;
let s2 = i_fast_sum_in_place_aux(xs, n, false);
s = round3(s, s1, s2);
}
return s;
}
}
}
#[cfg(test)]
mod test {
use super::i_fast_sum_in_place;
#[test]
fn issue_5() {
let a: f32 = 4194304.0;
let b: f32 = 4194304.5;
let s = i_fast_sum_in_place(&mut [a, b]);
assert_eq!(s, 8388608.0);
}
}