use super::dzror::{ZrorAction, ZrorConfig, ZrorState};
#[derive(Debug, Clone, Copy)]
pub(crate) struct InvrConfig {
pub small: f64,
pub big: f64,
pub abs_step: f64,
pub rel_step: f64,
pub stp_mul: f64,
pub abs_tol: f64,
pub rel_tol: f64,
}
#[derive(Debug, Clone, Copy)]
enum Stage {
Start,
AwaitFsmall,
AwaitFbig,
AwaitInitial,
AwaitUpper,
AwaitLower,
InZror,
}
#[derive(Debug, Clone, Copy)]
pub(crate) enum InvrAction {
NeedEval(f64),
Converged(f64),
#[allow(dead_code)]
Failed {
qleft: bool,
qhi: bool,
x: f64,
},
}
#[derive(Debug)]
pub(crate) struct InvrState {
cfg: InvrConfig,
stage: Stage,
xsave: f64,
fsmall: f64,
fbig: f64,
qincr: bool,
step: f64,
xlb: f64,
xub: f64,
zror: Option<ZrorState>,
}
impl InvrState {
#[inline]
pub(crate) fn new(cfg: InvrConfig, x_initial: f64) -> Self {
Self {
cfg,
stage: Stage::Start,
xsave: x_initial,
fsmall: 0.0,
fbig: 0.0,
qincr: false,
step: 0.0,
xlb: 0.0,
xub: 0.0,
zror: None,
}
}
#[inline]
pub(crate) fn step(&mut self, fx: f64) -> InvrAction {
match self.stage {
Stage::Start => {
self.stage = Stage::AwaitFsmall;
InvrAction::NeedEval(self.cfg.small)
}
Stage::AwaitFsmall => {
self.fsmall = fx;
self.stage = Stage::AwaitFbig;
InvrAction::NeedEval(self.cfg.big)
}
Stage::AwaitFbig => {
self.fbig = fx;
self.qincr = self.fbig > self.fsmall;
if self.fsmall <= self.fbig {
if self.fsmall > 0.0 {
return InvrAction::Failed {
qleft: true,
qhi: true,
x: self.cfg.small,
};
}
if self.fbig < 0.0 {
return InvrAction::Failed {
qleft: false,
qhi: false,
x: self.cfg.big,
};
}
} else {
if self.fsmall < 0.0 {
return InvrAction::Failed {
qleft: true,
qhi: false,
x: self.cfg.small,
};
}
if self.fbig > 0.0 {
return InvrAction::Failed {
qleft: false,
qhi: true,
x: self.cfg.big,
};
}
}
let x = self.xsave;
self.step = self.cfg.abs_step.max(self.cfg.rel_step * x.abs());
self.stage = Stage::AwaitInitial;
InvrAction::NeedEval(x)
}
Stage::AwaitInitial => {
let yy = fx;
if yy == 0.0 {
return InvrAction::Converged(self.xsave);
}
let qup = (self.qincr && yy < 0.0) || (!self.qincr && yy > 0.0);
if qup {
self.xlb = self.xsave;
self.xub = (self.xlb + self.step).min(self.cfg.big);
self.stage = Stage::AwaitUpper;
InvrAction::NeedEval(self.xub)
} else {
self.xub = self.xsave;
self.xlb = (self.xub - self.step).max(self.cfg.small);
self.stage = Stage::AwaitLower;
InvrAction::NeedEval(self.xlb)
}
}
Stage::AwaitUpper => {
let yy = fx;
let qbdd = (self.qincr && yy >= 0.0) || (!self.qincr && yy <= 0.0);
let qlim = self.xub >= self.cfg.big;
if qbdd {
return self.start_zror();
}
if qlim {
return InvrAction::Failed {
qleft: false,
qhi: !self.qincr,
x: self.cfg.big,
};
}
self.step *= self.cfg.stp_mul;
self.xlb = self.xub;
self.xub = (self.xlb + self.step).min(self.cfg.big);
InvrAction::NeedEval(self.xub)
}
Stage::AwaitLower => {
let yy = fx;
let qbdd = (self.qincr && yy <= 0.0) || (!self.qincr && yy >= 0.0);
let qlim = self.xlb <= self.cfg.small;
if qbdd {
return self.start_zror();
}
if qlim {
return InvrAction::Failed {
qleft: true,
qhi: self.qincr,
x: self.cfg.small,
};
}
self.step *= self.cfg.stp_mul;
self.xub = self.xlb;
self.xlb = (self.xub - self.step).max(self.cfg.small);
InvrAction::NeedEval(self.xlb)
}
Stage::InZror => {
let z = self.zror.as_mut().expect("zror");
match z.step(fx) {
ZrorAction::NeedEval(x) => InvrAction::NeedEval(x),
ZrorAction::Converged { xlo, .. } => InvrAction::Converged(xlo),
ZrorAction::Failed { xlo, .. } => InvrAction::Converged(xlo),
}
}
}
}
#[inline]
fn start_zror(&mut self) -> InvrAction {
let mut z = ZrorState::new(ZrorConfig {
xlo: self.xlb,
xhi: self.xub,
abstol: self.cfg.abs_tol,
reltol: self.cfg.rel_tol,
});
let first = z.step(0.0);
self.zror = Some(z);
self.stage = Stage::InZror;
match first {
ZrorAction::NeedEval(x) => InvrAction::NeedEval(x),
ZrorAction::Converged { xlo, .. } => InvrAction::Converged(xlo),
ZrorAction::Failed { xlo, .. } => InvrAction::Converged(xlo),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg() -> InvrConfig {
InvrConfig {
small: 0.0,
big: 1.0,
abs_step: 0.5,
rel_step: 0.5,
stp_mul: 5.0,
abs_tol: 1.0e-50,
rel_tol: 1.0e-8,
}
}
#[test]
fn rejects_increasing_range_when_small_is_already_positive() {
let mut state = InvrState::new(cfg(), 0.5);
assert!(matches!(state.step(0.0), InvrAction::NeedEval(0.0)));
assert!(matches!(state.step(1.0), InvrAction::NeedEval(1.0)));
assert!(matches!(
state.step(2.0),
InvrAction::Failed {
qleft: true,
qhi: true,
x: 0.0
}
));
}
#[test]
fn rejects_increasing_range_when_big_is_still_negative() {
let mut state = InvrState::new(cfg(), 0.5);
assert!(matches!(state.step(0.0), InvrAction::NeedEval(0.0)));
assert!(matches!(state.step(-2.0), InvrAction::NeedEval(1.0)));
assert!(matches!(
state.step(-1.0),
InvrAction::Failed {
qleft: false,
qhi: false,
x: 1.0
}
));
}
#[test]
fn rejects_decreasing_range_when_small_is_already_negative() {
let mut state = InvrState::new(cfg(), 0.5);
assert!(matches!(state.step(0.0), InvrAction::NeedEval(0.0)));
assert!(matches!(state.step(-1.0), InvrAction::NeedEval(1.0)));
assert!(matches!(
state.step(-2.0),
InvrAction::Failed {
qleft: true,
qhi: false,
x: 0.0
}
));
}
#[test]
fn rejects_decreasing_range_when_big_is_still_positive() {
let mut state = InvrState::new(cfg(), 0.5);
assert!(matches!(state.step(0.0), InvrAction::NeedEval(0.0)));
assert!(matches!(state.step(2.0), InvrAction::NeedEval(1.0)));
assert!(matches!(
state.step(1.0),
InvrAction::Failed {
qleft: false,
qhi: true,
x: 1.0
}
));
}
#[test]
fn reports_upper_bound_failure_when_search_runs_out_of_room() {
let mut state = InvrState::new(cfg(), 0.9);
assert!(matches!(state.step(0.0), InvrAction::NeedEval(0.0)));
assert!(matches!(state.step(-1.0), InvrAction::NeedEval(1.0)));
assert!(matches!(state.step(1.0), InvrAction::NeedEval(0.9)));
assert!(matches!(state.step(-0.1), InvrAction::NeedEval(1.0)));
assert!(matches!(
state.step(-0.05),
InvrAction::Failed {
qleft: false,
qhi: false,
x: 1.0
}
));
}
#[test]
fn reports_lower_bound_failure_when_search_runs_out_of_room() {
let mut state = InvrState::new(cfg(), 0.1);
assert!(matches!(state.step(0.0), InvrAction::NeedEval(0.0)));
assert!(matches!(state.step(-1.0), InvrAction::NeedEval(1.0)));
assert!(matches!(state.step(1.0), InvrAction::NeedEval(0.1)));
assert!(matches!(state.step(0.1), InvrAction::NeedEval(0.0)));
assert!(matches!(
state.step(0.05),
InvrAction::Failed {
qleft: true,
qhi: true,
x: 0.0
}
));
}
}