use num_traits::Float;
#[derive(Debug, Clone)]
pub struct ConvergenceParams<F> {
pub max_iter: usize,
pub grad_tol: F,
pub step_tol: F,
pub func_tol: F,
}
impl Default for ConvergenceParams<f64> {
fn default() -> Self {
ConvergenceParams {
max_iter: 100,
grad_tol: 1e-8,
step_tol: 1e-12,
func_tol: 0.0,
}
}
}
impl Default for ConvergenceParams<f32> {
fn default() -> Self {
ConvergenceParams {
max_iter: 100,
grad_tol: 1e-5,
step_tol: 1e-7,
func_tol: 0.0,
}
}
}
pub fn norm<F: Float>(v: &[F]) -> F {
kahan_sum(v.iter().map(|&x| x * x)).sqrt()
}
pub fn dot<F: Float>(a: &[F], b: &[F]) -> F {
debug_assert_eq!(a.len(), b.len());
kahan_sum(a.iter().zip(b.iter()).map(|(&x, &y)| x * y))
}
const KAHAN_THRESHOLD: usize = 64;
#[inline]
fn kahan_sum<F: Float, I: Iterator<Item = F>>(iter: I) -> F {
let mut it = iter;
let mut s = F::zero();
let mut c = F::zero();
let mut n = 0usize;
let mut prefix: [F; KAHAN_THRESHOLD] = [F::zero(); KAHAN_THRESHOLD];
for slot in prefix.iter_mut() {
if let Some(x) = it.next() {
*slot = x;
n += 1;
} else {
break;
}
}
if n < KAHAN_THRESHOLD {
for &x in prefix.iter().take(n) {
s = s + x;
}
return s;
}
for &x in prefix.iter() {
let t = s + x;
if s.abs() >= x.abs() {
c = c + ((s - t) + x);
} else {
c = c + ((x - t) + s);
}
s = t;
}
for x in it {
let t = s + x;
if s.abs() >= x.abs() {
c = c + ((s - t) + x);
} else {
c = c + ((x - t) + s);
}
s = t;
}
s + c
}