use crate::error::{CoreError, ErrorContext};
pub fn kahan_sum(values: &[f64]) -> Result<f64, CoreError> {
if values.is_empty() {
return Err(CoreError::InvalidInput(ErrorContext::new(
"kahan_sum: empty slice".to_string(),
)));
}
let mut sum = 0.0_f64;
let mut c = 0.0_f64; for &v in values {
let y = v - c;
let t = sum + y;
c = (t - sum) - y;
sum = t;
}
Ok(sum)
}
pub fn neumaier_sum(values: &[f64]) -> Result<f64, CoreError> {
if values.is_empty() {
return Err(CoreError::InvalidInput(ErrorContext::new(
"neumaier_sum: empty slice".to_string(),
)));
}
let mut sum = 0.0_f64;
let mut c = 0.0_f64;
for &v in values {
let t = sum + v;
if sum.abs() >= v.abs() {
c += (sum - t) + v;
} else {
c += (v - t) + sum;
}
sum = t;
}
Ok(sum + c)
}
pub fn ogita_sum(values: &[f64]) -> Result<f64, CoreError> {
if values.is_empty() {
return Err(CoreError::InvalidInput(ErrorContext::new(
"ogita_sum: empty slice".to_string(),
)));
}
let mut p: Vec<f64> = values.to_vec();
let n = p.len();
let mut s = 0.0_f64;
for i in 0..n {
let (sigma, q) = two_sum(s, p[i]);
p[i] = q; s = sigma;
}
let err_sum = neumaier_sum(&p)?;
Ok(s + err_sum)
}
pub fn dot_product_compensated(a: &[f64], b: &[f64]) -> Result<f64, CoreError> {
if a.len() != b.len() {
return Err(CoreError::InvalidInput(ErrorContext::new(format!(
"dot_product_compensated: slice length mismatch ({} vs {})",
a.len(),
b.len()
))));
}
if a.is_empty() {
return Err(CoreError::InvalidInput(ErrorContext::new(
"dot_product_compensated: empty slices".to_string(),
)));
}
let mut p = 0.0_f64;
let mut s = 0.0_f64;
for (&ai, &bi) in a.iter().zip(b.iter()) {
let (h, l) = two_product(ai, bi);
let (p_new, q) = two_sum(p, h);
p = p_new;
s += l + q;
}
Ok(p + s)
}
#[inline]
pub fn two_sum(a: f64, b: f64) -> (f64, f64) {
let s = a + b;
let a_prime = s - b;
let b_prime = s - a_prime;
let delta_a = a - a_prime;
let delta_b = b - b_prime;
let e = delta_a + delta_b;
(s, e)
}
#[inline]
pub fn two_product(a: f64, b: f64) -> (f64, f64) {
let p = a * b;
let (a_hi, a_lo) = split_f64(a);
let (b_hi, b_lo) = split_f64(b);
let e = ((a_hi * b_hi - p) + a_hi * b_lo + a_lo * b_hi) + a_lo * b_lo;
(p, e)
}
#[inline]
fn split_f64(a: f64) -> (f64, f64) {
const SPLITTER: f64 = 134_217_729.0; let c = SPLITTER * a;
let a_hi = c - (c - a);
let a_lo = a - a_hi;
(a_hi, a_lo)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn kahan_sum_basic() {
let data = [1.0_f64, 2.0, 3.0, 4.0];
let s = kahan_sum(&data).expect("valid");
assert!((s - 10.0).abs() < 1e-14);
}
#[test]
fn kahan_sum_empty_error() {
assert!(kahan_sum(&[]).is_err());
}
#[test]
fn neumaier_sum_basic() {
let data = [1.0_f64, 2.0, 3.0];
let s = neumaier_sum(&data).expect("valid");
assert!((s - 6.0).abs() < 1e-14);
}
#[test]
fn neumaier_sum_large_and_small() {
let data = [1e15_f64, 1.0, -1e15_f64];
let s = neumaier_sum(&data).expect("valid");
assert!((s - 1.0).abs() < 1e-10);
}
#[test]
fn ogita_sum_basic() {
let data = [1.0_f64, 2.0, 3.0];
let s = ogita_sum(&data).expect("valid");
assert!((s - 6.0).abs() < 1e-14);
}
#[test]
fn ogita_sum_single() {
let data = [42.0_f64];
let s = ogita_sum(&data).expect("valid");
assert_eq!(s, 42.0);
}
#[test]
fn dot_product_exact_integers() {
let a = [1.0_f64, 2.0, 3.0];
let b = [4.0_f64, 5.0, 6.0];
let d = dot_product_compensated(&a, &b).expect("valid");
assert!((d - 32.0).abs() < 1e-12);
}
#[test]
fn dot_product_length_mismatch() {
let a = [1.0_f64, 2.0];
let b = [3.0_f64];
assert!(dot_product_compensated(&a, &b).is_err());
}
#[test]
fn two_sum_exact_reconstruction() {
let a = 1.0_f64;
let b = f64::EPSILON / 2.0;
let (s, e) = two_sum(a, b);
assert!((s + e - (a + b)).abs() < f64::EPSILON * f64::EPSILON);
}
#[test]
fn two_product_integer_exact() {
let (p, e) = two_product(3.0_f64, 7.0_f64);
assert_eq!(p, 21.0);
assert_eq!(e, 0.0);
}
#[test]
fn two_product_error_reconstruct() {
let a = 1.0_f64 / 3.0;
let b = 3.0_f64;
let (p, e) = two_product(a, b);
let reconstructed = p + e;
assert!((reconstructed - 1.0).abs() < 1e-15);
}
#[test]
fn split_roundtrip() {
let x = 1.23456789_f64;
let (hi, lo) = super::split_f64(x);
assert!((hi + lo - x).abs() < f64::EPSILON * x.abs());
}
}