use std::cell::RefCell;
use crate::{
scalar::{IndexType, Scalar},
NonLinearOp, Vector,
};
use anyhow::Result;
use num_traits::{abs, One, Zero};
pub struct RootFinder<V: Vector> {
t0: RefCell<V::T>,
g0: RefCell<V>,
g1: RefCell<V>,
gmid: RefCell<V>,
}
impl<V: Vector> RootFinder<V> {
pub fn new(n: usize) -> Self {
Self {
t0: RefCell::new(V::T::zero()),
g0: RefCell::new(V::zeros(n)),
g1: RefCell::new(V::zeros(n)),
gmid: RefCell::new(V::zeros(n)),
}
}
pub fn init(&self, root_fn: &impl NonLinearOp<V = V, T = V::T>, y: &V, t: V::T) {
root_fn.call_inplace(y, t, &mut self.g0.borrow_mut());
self.t0.replace(t);
}
pub fn check_root(
&self,
interpolate: &impl Fn(V::T) -> Result<V>,
root_fn: &impl NonLinearOp<V = V, T = V::T>,
y: &V,
t: V::T,
) -> Option<V::T> {
let g1 = &mut *self.g1.borrow_mut();
let g0 = &mut *self.g0.borrow_mut();
let gmid = &mut *self.gmid.borrow_mut();
root_fn.call_inplace(y, t, g1);
let sign_change_fn = |mut acc: (bool, V::T, i32), g0: V::T, g1: V::T, i: IndexType| {
if g1 == V::T::zero() {
acc.0 = true;
} else if g0 * g1 < V::T::zero() {
let gfrac = abs(g1 / (g1 - g0));
if gfrac > acc.1 {
acc.1 = gfrac;
acc.2 = i32::try_from(i).unwrap();
}
}
acc
};
let (rootfnd, _gfracmax, imax) =
(*g0).binary_fold(g1, (false, V::T::zero(), -1), sign_change_fn);
if imax < 0 {
std::mem::swap(g0, g1);
self.t0.replace(t);
return if rootfnd {
Some(t)
} else {
None
};
}
let mut imax = IndexType::try_from(imax).unwrap();
let mut alpha = V::T::one();
let mut sign_change = [false, true];
let mut i = 0;
let mut t1 = t;
let mut t0 = *self.t0.borrow();
let tol = V::T::from(100.0) * V::T::EPSILON * (abs(t1) + abs(t1 - t0));
let half = V::T::from(0.5);
let double = V::T::from(2.0);
let five = V::T::from(5.0);
let pntone = V::T::from(0.1);
while abs(t1 - t0) > tol {
let mut t_mid = t1 - (t1 - t0) * g1[imax] / (g1[imax] - alpha * g0[imax]);
if abs(t_mid - t0) < half * tol {
let fracint = abs(t1 - t0) / tol;
let fracsub = if fracint > five {
pntone
} else {
half / fracint
};
t_mid = t0 + fracsub * (t1 - t0);
}
if abs(t1 - t_mid) < half * tol {
let fracint = abs(t1 - t0) / tol;
let fracsub = if fracint > five {
pntone
} else {
half / fracint
};
t_mid = t1 - fracsub * (t1 - t0);
}
let ymid = interpolate(t_mid).unwrap();
root_fn.call_inplace(&ymid, t_mid, gmid);
let (rootfnd, _gfracmax, imax_i32) =
(*g0).binary_fold(gmid, (false, V::T::zero(), -1), sign_change_fn);
let lower = imax_i32 >= 0;
if lower {
t1 = t_mid;
imax = IndexType::try_from(imax_i32).unwrap();
std::mem::swap(g1, gmid);
} else if rootfnd {
root_fn.call_inplace(y, t, g0);
return Some(t_mid);
} else {
t0 = t_mid;
std::mem::swap(g0, gmid);
}
sign_change[i % 2] = lower;
if i >= 2 {
alpha = if sign_change[0] != sign_change[1] {
V::T::one()
} else if sign_change[0] {
half * alpha
} else {
double * alpha
};
}
i += 1;
}
root_fn.call_inplace(y, t, g0);
Some(t1)
}
}
#[cfg(test)]
mod tests {
use std::rc::Rc;
use crate::{ClosureNoJac, RootFinder, Vector};
use anyhow::Result;
#[test]
fn test_root() {
type V = nalgebra::DVector<f64>;
type M = nalgebra::DMatrix<f64>;
let interpolate = |t: f64| -> Result<V> { Ok(Vector::from_vec(vec![t])) };
let root_fn = ClosureNoJac::<M, _>::new(
|y: &V, _p: &V, _t: f64, g: &mut V| {
g[0] = y[0] - 0.4;
},
1,
1,
Rc::new(V::zeros(0)),
);
let root_finder = RootFinder::new(1);
root_finder.init(&root_fn, &Vector::from_vec(vec![0.0]), 0.0);
let root =
root_finder.check_root(&interpolate, &root_fn, &Vector::from_vec(vec![0.3]), 0.3);
assert_eq!(root, None);
let root_finder = RootFinder::new(1);
root_finder.init(&root_fn, &Vector::from_vec(vec![0.0]), 0.0);
let root =
root_finder.check_root(&interpolate, &root_fn, &Vector::from_vec(vec![1.3]), 1.3);
if let Some(root) = root {
assert!((root - 0.4).abs() < 1e-10);
} else {
unreachable!();
}
}
}