#[derive(Debug, Clone, Copy)]
pub(crate) struct ZrorConfig {
pub xlo: f64,
pub xhi: f64,
pub abstol: f64,
pub reltol: f64,
}
#[derive(Debug, Clone, Copy)]
enum Stage {
Start,
AwaitFb,
AwaitFa,
AwaitFbStep,
}
#[derive(Debug, Clone, Copy)]
pub(crate) enum ZrorAction {
NeedEval(f64),
Converged {
xlo: f64,
#[allow(dead_code)]
xhi: f64,
},
#[allow(dead_code)]
Failed { xlo: f64, qleft: bool, qhi: bool },
}
#[derive(Debug)]
pub(crate) struct ZrorState {
cfg: ZrorConfig,
stage: Stage,
xlo: f64,
xhi: f64,
a: f64,
b: f64,
c: f64,
d: f64,
fa: f64,
fb: f64,
fc: f64,
fd: f64,
w: f64,
mb: f64,
ext: i32,
first: bool,
}
impl ZrorState {
#[inline]
pub(crate) fn new(cfg: ZrorConfig) -> Self {
Self {
cfg,
stage: Stage::Start,
xlo: 0.0,
xhi: 0.0,
a: 0.0,
b: 0.0,
c: 0.0,
d: 0.0,
fa: 0.0,
fb: 0.0,
fc: 0.0,
fd: 0.0,
w: 0.0,
mb: 0.0,
ext: 0,
first: true,
}
}
#[inline]
pub(crate) fn step(&mut self, fx: f64) -> ZrorAction {
loop {
match self.stage {
Stage::Start => {
self.xlo = self.cfg.xlo;
self.xhi = self.cfg.xhi;
self.b = self.xlo;
self.stage = Stage::AwaitFb;
return ZrorAction::NeedEval(self.b);
}
Stage::AwaitFb => {
self.fb = fx;
self.xlo = self.xhi;
self.a = self.xlo;
self.stage = Stage::AwaitFa;
return ZrorAction::NeedEval(self.a);
}
Stage::AwaitFa => {
if self.fb < 0.0 {
if fx < 0.0 {
return ZrorAction::Failed {
xlo: self.xlo,
qleft: fx < self.fb,
qhi: false,
};
}
} else if self.fb > 0.0 && fx > 0.0 {
return ZrorAction::Failed {
xlo: self.xlo,
qleft: fx > self.fb,
qhi: true,
};
}
self.fa = fx;
self.first = true;
self.restart_c_from_a();
if let Some(action) = self.refine_iteration() {
return action;
}
}
Stage::AwaitFbStep => {
self.fb = fx;
if self.fc * self.fb >= 0.0 {
self.restart_c_from_a();
} else if self.w == self.mb {
self.ext = 0;
} else {
self.ext += 1;
}
if let Some(action) = self.refine_iteration() {
return action;
}
}
}
}
}
#[inline]
fn restart_c_from_a(&mut self) {
self.c = self.a;
self.fc = self.fa;
self.ext = 0;
}
#[inline]
fn refine_iteration(&mut self) -> Option<ZrorAction> {
if self.fc.abs() < self.fb.abs() {
if self.c == self.a {
self.d = self.a;
self.fd = self.fa;
}
self.a = self.b;
self.fa = self.fb;
self.xlo = self.c;
self.b = self.xlo;
self.fb = self.fc;
self.c = self.a;
self.fc = self.fa;
}
let tol = 0.5 * self.cfg.abstol.max(self.cfg.reltol * self.xlo.abs());
let m = 0.5 * (self.c + self.b);
let mb = m - self.b;
self.mb = mb;
if mb.abs() <= tol {
self.xhi = self.c;
let qrzero = (self.fc >= 0.0 && self.fb <= 0.0) || (self.fc < 0.0 && self.fb >= 0.0);
if qrzero {
return Some(ZrorAction::Converged {
xlo: self.xlo,
xhi: self.xhi,
});
}
return Some(ZrorAction::Failed {
xlo: self.xlo,
qleft: false,
qhi: false,
});
}
let w;
if self.ext > 3 {
w = mb;
} else {
let tol_signed = tol.copysign(mb);
let mut p = (self.b - self.a) * self.fb;
let q;
if self.first {
q = self.fa - self.fb;
self.first = false;
} else {
let fdb = if self.d == self.b {
1.0
} else {
(self.fd - self.fb) / (self.d - self.b)
};
let fda = if self.d == self.a {
1.0
} else {
(self.fd - self.fa) / (self.d - self.a)
};
p *= fda;
q = fdb * self.fa - fda * self.fb;
}
let (mut p, q) = if p < 0.0 { (-p, -q) } else { (p, q) };
if self.ext == 3 {
p *= 2.0;
}
if p == 0.0 || p <= q * tol_signed {
w = tol_signed;
} else if p < mb * q {
w = p / q;
} else {
w = mb;
}
}
self.w = w;
self.d = self.a;
self.fd = self.fa;
self.a = self.b;
self.fa = self.fb;
self.b += w;
self.xlo = self.b;
self.stage = Stage::AwaitFbStep;
Some(ZrorAction::NeedEval(self.b))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg() -> ZrorConfig {
ZrorConfig {
xlo: 0.0,
xhi: 1.0,
abstol: 1.0e-50,
reltol: 1.0e-8,
}
}
#[test]
fn swap_branch_preserves_history_when_c_equals_a() {
let mut z = ZrorState {
cfg: cfg(),
stage: Stage::AwaitFbStep,
xlo: 0.0,
xhi: 0.0,
a: 1.0,
b: 2.0,
c: 1.0,
d: 99.0,
fa: 3.0,
fb: 2.0,
fc: 1.0,
fd: 77.0,
w: 0.0,
mb: 0.0,
ext: 0,
first: false,
};
let action = z.refine_iteration();
match action {
Some(ZrorAction::NeedEval(x)) => assert!((x - 4.0 / 3.0).abs() < 1e-15),
other => panic!("unexpected action: {other:?}"),
}
}
#[test]
fn guarded_divided_difference_when_d_equals_a() {
let mut z = ZrorState {
cfg: cfg(),
stage: Stage::AwaitFbStep,
xlo: 2.0,
xhi: 4.0,
a: 1.0,
b: 2.0,
c: 4.0,
d: 1.0,
fa: 4.0,
fb: -1.0,
fc: 5.0,
fd: 5.0,
w: 0.0,
mb: 0.0,
ext: 0,
first: false,
};
let action = z.refine_iteration();
match action {
Some(ZrorAction::NeedEval(x)) => assert!((x - (2.0 + 1.0 / 23.0)).abs() < 1e-15),
other => panic!("unexpected action: {other:?}"),
}
}
#[test]
fn guarded_divided_difference_when_d_equals_b() {
let mut z = ZrorState {
cfg: cfg(),
stage: Stage::AwaitFbStep,
xlo: 2.0,
xhi: 4.0,
a: 1.0,
b: 2.0,
c: 4.0,
d: 2.0,
fa: 4.0,
fb: -1.0,
fc: 5.0,
fd: 5.0,
w: 0.0,
mb: 0.0,
ext: 0,
first: false,
};
let action = z.refine_iteration();
match action {
Some(ZrorAction::NeedEval(x)) => assert_eq!(x, 3.0),
other => panic!("unexpected action: {other:?}"),
}
}
#[test]
fn fails_when_both_initial_values_are_negative() {
let mut z = ZrorState::new(cfg());
assert!(matches!(z.step(0.0), ZrorAction::NeedEval(0.0)));
assert!(matches!(z.step(-2.0), ZrorAction::NeedEval(1.0)));
assert!(matches!(
z.step(-1.0),
ZrorAction::Failed {
qleft: false,
qhi: false,
..
}
));
}
#[test]
fn fails_when_both_initial_values_are_positive() {
let mut z = ZrorState::new(cfg());
assert!(matches!(z.step(0.0), ZrorAction::NeedEval(0.0)));
assert!(matches!(z.step(1.0), ZrorAction::NeedEval(1.0)));
assert!(matches!(
z.step(2.0),
ZrorAction::Failed {
qleft: true,
qhi: true,
..
}
));
}
#[test]
fn reports_failed_convergence_if_interval_no_longer_straddles_zero() {
let mut z = ZrorState {
cfg: cfg(),
stage: Stage::AwaitFbStep,
xlo: 1.0,
xhi: 1.0,
a: 1.0,
b: 1.0,
c: 1.0,
d: 0.0,
fa: 1.0,
fb: 1.0,
fc: 1.0,
fd: 0.0,
w: 0.0,
mb: 0.0,
ext: 0,
first: false,
};
assert!(matches!(
z.refine_iteration(),
Some(ZrorAction::Failed {
qleft: false,
qhi: false,
..
})
));
}
}