use crate::utils::convergence_data::ConvergenceData;
use crate::utils::enums::SolverError;
use crate::utils::enums::TerminationReason;
use crate::utils::perturb::perturb_real;
use crate::utils::solver_settings::SolverSettings;
use crate::utils::termination::{is_btol_satisfied, is_vtol_satisfied};
use std::fmt;
#[derive(Debug, PartialEq)]
pub(crate) struct UpdatedInterval {
pub(crate) interval: Interval,
pub(crate) fa: f64,
pub(crate) n_feval: u32,
}
impl UpdatedInterval {
pub(crate) fn new(interval: Interval, fa: f64, n_feval: u32) -> UpdatedInterval {
UpdatedInterval {
interval,
fa,
n_feval,
}
}
}
#[derive(Debug)]
pub enum IntervalResult {
UpdatedInterval(UpdatedInterval),
Root(f64),
SolverError(SolverError),
}
#[derive(Debug, PartialEq, Clone, Copy)]
pub struct Interval {
pub a: f64,
pub b: f64,
}
impl Interval {
pub fn new(a: f64, b: f64) -> Self {
if a == b {
Self {
a,
b: perturb_real(a),
}
} else if b > a {
Self { a, b }
} else {
Self { a: b, b: a }
}
}
pub fn from_point(x: f64) -> Self {
Self {
a: x,
b: perturb_real(x),
}
}
}
impl fmt::Display for Interval {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Interval(a: {}, b: {})", self.a, self.b)
}
}
pub fn bracket_sign_change(
f: &impl Fn(f64) -> f64,
ab: Interval,
mut fa: f64,
mut fb: f64,
max_bracket_iter: u32,
) -> Result<(Interval, f64, u32, u32), SolverError> {
if fa * fb < 0.0 {
return Ok((ab, fa, 0, 0));
}
let mut a = ab.a;
let mut b = ab.b;
let c = (a + b) / 2.0;
let mut wh = (b - a) / 2.0;
let mut sign_change = false;
let mut n_bracket_iter: u32 = 0;
let mut n_feval: u32 = 0;
for _ in 0..max_bracket_iter {
n_bracket_iter += 1;
wh *= 2.0;
a = c - wh;
b = c + wh;
fa = f(a);
fb = f(b);
n_feval += 2;
if fa * fb < 0.0 {
sign_change = true;
break;
}
}
match sign_change {
true => Ok((Interval::new(a, b), fa, n_bracket_iter, n_feval)),
false => Err(SolverError::BracketingIntervalNotFound),
}
}
pub fn initial_interval_handling(
f: &impl Fn(f64) -> f64,
ab: Interval,
solver_settings: &SolverSettings,
mut convergence_data: Option<&mut ConvergenceData>,
) -> IntervalResult {
let mut n_feval: u32 = 0;
let rebracket = solver_settings.rebracket.unwrap_or(false);
let mut fa = f(ab.a);
let fb = f(ab.b);
n_feval += 2;
if let Some(convergence_data) = convergence_data.as_deref_mut() {
convergence_data.n_feval += n_feval;
}
let root_at_lower_bound =
is_vtol_satisfied(fa, solver_settings, convergence_data.as_deref_mut()) || fa == 0.0;
let root_at_upper_bound =
is_vtol_satisfied(fb, solver_settings, convergence_data.as_deref_mut()) || fb == 0.0;
if root_at_lower_bound || root_at_upper_bound {
let root = if root_at_lower_bound { ab.a } else { ab.b };
if let Some(convergence_data) = convergence_data.as_deref_mut() {
convergence_data.x_all.push(root);
convergence_data.a_all.push(ab.a);
convergence_data.b_all.push(ab.b);
if root_at_lower_bound {
convergence_data.f_all.push(fa);
if convergence_data.termination_reason != TerminationReason::ValueToleranceSatisfied
{
convergence_data.termination_reason = TerminationReason::RootAtLowerBound;
}
} else {
convergence_data.f_all.push(fb);
if convergence_data.termination_reason != TerminationReason::ValueToleranceSatisfied
{
convergence_data.termination_reason = TerminationReason::RootAtUpperBound;
}
}
}
return IntervalResult::Root(root);
}
if (fa.signum() != fb.signum())
&& is_btol_satisfied(ab.a, ab.b, solver_settings, convergence_data.as_deref_mut())
{
let root = (ab.a + ab.b) / 2.0;
if let Some(convergence_data) = convergence_data.as_deref_mut() {
convergence_data.x_all.push(root);
convergence_data.a_all.push(ab.a);
convergence_data.b_all.push(ab.b);
convergence_data.f_all.push(f64::NAN);
}
return IntervalResult::Root(root);
}
if rebracket {
let mut max_bracket_iter = solver_settings.max_bracket_iter.unwrap_or(200);
let n_feval_remaining = solver_settings
.max_feval
.map(|max_feval| max_feval - n_feval);
if let Some(n_feval_remaining) = n_feval_remaining {
max_bracket_iter = max_bracket_iter.min(n_feval_remaining / 2);
}
match bracket_sign_change(f, ab, fa, fb, max_bracket_iter) {
Ok((ab, fa_, n_bracket_iter, n_feval_rebracket)) => {
fa = fa_;
n_feval += n_feval_rebracket;
if let Some(convergence_data) = convergence_data {
convergence_data.n_bracket_iter = n_bracket_iter;
convergence_data.n_feval += n_feval_rebracket;
}
return IntervalResult::UpdatedInterval(UpdatedInterval::new(ab, fa, n_feval));
}
Err(err) => return IntervalResult::SolverError(err),
}
}
if fa.signum() == fb.signum() {
IntervalResult::SolverError(SolverError::IntervalDoesNotBracketSignChange)
} else {
IntervalResult::UpdatedInterval(UpdatedInterval::new(ab, fa, n_feval))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::solver_settings::DEFAULT_SOLVER_SETTINGS;
use numtest::*;
#[test]
fn test_interval_correct_order() {
let ab = Interval::new(1.0, 2.0);
assert_eq!(ab.a, 1.0);
assert_eq!(ab.b, 2.0);
}
#[test]
fn test_interval_incorrect_order() {
let ab = Interval::new(2.0, 1.0);
assert_eq!(ab.a, 1.0);
assert_eq!(ab.b, 2.0);
}
#[test]
fn test_interval_zero_width() {
let ab = Interval::new(1.0, 1.0);
assert_eq!(ab.a, 1.0);
assert_eq!(ab.b, 1.0000000000000444);
}
#[test]
fn test_interval_print() {
assert_eq!(
format!("{}", Interval::new(1.0, 2.5)),
"Interval(a: 1, b: 2.5)"
);
}
#[test]
fn test_bracket_sign_change_already_bracketed() {
let f = |x: f64| x;
let ab = Interval::new(-1.0, 1.0);
let fa = f(ab.a);
let fb = f(ab.b);
let (ab_new, fa, n_bracket_iter, n_feval) =
bracket_sign_change(&f, ab, fa, fb, 200).unwrap();
assert_eq!(ab_new, ab);
assert_eq!(fa, -1.0);
assert_eq!(n_bracket_iter, 0);
assert_eq!(n_feval, 0);
}
#[test]
fn test_bracket_sign_change_from_close_initial_guess() {
let f = |x: f64| x;
let ab = Interval::from_point(0.0);
let fa = f(ab.a);
let fb = f(ab.b);
let (ab_new, fa, n_bracket_iter, n_feval) =
bracket_sign_change(&f, ab, fa, fb, 200).unwrap();
assert_eq!(
ab_new,
Interval::new(-1.1102230246251565e-14, 3.3306690738754696e-14)
);
assert_eq!(fa, -1.1102230246251565e-14);
assert_eq!(n_bracket_iter, 1);
assert_eq!(n_feval, 2);
}
#[test]
fn test_bracket_sign_change_from_worse_initial_guess() {
let f = |x: f64| x;
let ab = Interval::from_point(10.0);
let fa = f(ab.a);
let fb = f(ab.b);
let (ab_new, fa, n_bracket_iter, n_feval) =
bracket_sign_change(&f, ab, fa, fb, 200).unwrap();
assert_eq!(ab_new, Interval::new(-7.249999999999877, 27.25000000000012));
assert_eq!(fa, -7.249999999999877);
assert_eq!(n_bracket_iter, 47);
assert_eq!(n_feval, 94);
}
#[test]
fn test_bracket_sign_change_not_bracketing() {
let f = |x: f64| x;
let ab = Interval::new(50.0, 100.0);
let fa = f(ab.a);
let fb = f(ab.b);
let (ab_new, fa, n_bracket_iter, n_feval) =
bracket_sign_change(&f, ab, fa, fb, 200).unwrap();
assert_eq!(ab_new, Interval::new(-25.0, 175.0));
assert_eq!(fa, -25.0);
assert_eq!(n_bracket_iter, 2);
assert_eq!(n_feval, 4);
}
#[test]
fn test_bracket_sign_change_no_interval_found() {
let f = |x: f64| x.powi(2) + 1.0;
let ab = Interval::new(-1.0, 1.0);
let fa = f(ab.a);
let fb = f(ab.b);
let result = bracket_sign_change(&f, ab, fa, fb, 200);
assert!(matches!(
result.unwrap_err(),
SolverError::BracketingIntervalNotFound
));
}
#[test]
fn test_bracket_sign_change_max_bracket_iter_reached() {
let f = |x: f64| x;
let ab = Interval::new(10.0, 10.1);
let fa = f(ab.a);
let fb = f(ab.b);
let result = bracket_sign_change(&f, ab, fa, fb, 7);
assert!(matches!(
result.unwrap_err(),
SolverError::BracketingIntervalNotFound
));
}
#[test]
fn test_initial_interval_handling_already_bracketed() {
let f = |x: f64| x;
let ab = Interval::new(-1.0, 1.0);
let solver_settings_1 = SolverSettings::default();
let solver_settings_2 = SolverSettings {
rebracket: Some(true),
..Default::default()
};
for solver_settings in [solver_settings_1, solver_settings_2] {
let mut convergence_data = ConvergenceData::default();
let interval_result =
initial_interval_handling(&f, ab, &solver_settings, Some(&mut convergence_data));
match interval_result {
IntervalResult::UpdatedInterval(updated_interval) => {
assert_eq!(
updated_interval,
UpdatedInterval::new(Interval::new(-1.0, 1.0), -1.0, 2)
);
assert_eq!(convergence_data.x_all, vec![]);
assert_eq!(convergence_data.a_all, vec![]);
assert_eq!(convergence_data.b_all, vec![]);
assert_eq!(convergence_data.f_all, vec![]);
assert_eq!(convergence_data.n_feval, 2);
assert_eq!(convergence_data.n_bracket_iter, 0);
}
_ => panic!("Test failed."),
}
}
}
#[test]
fn test_initial_interval_handling_rebracketing() {
let f = |x: f64| x;
let ab = Interval::new(10.0, 10.1);
let solver_settings = SolverSettings {
rebracket: Some(true),
..Default::default()
};
let mut convergence_data = ConvergenceData::default();
let interval_result =
initial_interval_handling(&f, ab, &solver_settings, Some(&mut convergence_data));
match interval_result {
IntervalResult::UpdatedInterval(updated_interval) => {
assert_eq!(
updated_interval,
UpdatedInterval::new(
Interval::new(-2.749999999999954, 22.849999999999955),
-2.749999999999954,
18
)
);
assert_eq!(convergence_data.x_all, vec![]);
assert_eq!(convergence_data.a_all, vec![]);
assert_eq!(convergence_data.b_all, vec![]);
assert_eq!(convergence_data.f_all, vec![]);
assert_eq!(convergence_data.n_feval, 18);
assert_eq!(convergence_data.n_bracket_iter, 8);
}
_ => panic!("Test failed."),
}
}
#[test]
fn test_initial_interval_handling_root_at_lower_bound_no_vtol() {
let f = |x: f64| x;
let ab = Interval::new(0.0, 1.0);
let solver_settings_1 = SolverSettings::default();
let solver_settings_2 = SolverSettings {
rebracket: Some(true),
..Default::default()
};
for solver_settings in [solver_settings_1, solver_settings_2] {
let mut convergence_data = ConvergenceData::default();
let interval_result =
initial_interval_handling(&f, ab, &solver_settings, Some(&mut convergence_data));
match interval_result {
IntervalResult::Root(root) => {
assert_eq!(root, 0.0);
assert_eq!(convergence_data.x_all, vec![0.0]);
assert_eq!(convergence_data.a_all, vec![0.0]);
assert_eq!(convergence_data.b_all, vec![1.0]);
assert_eq!(convergence_data.f_all, vec![0.0]);
assert_eq!(convergence_data.n_feval, 2);
assert_eq!(convergence_data.n_bracket_iter, 0);
assert_eq!(
convergence_data.termination_reason,
TerminationReason::RootAtLowerBound
);
}
_ => panic!("Test failed."),
}
}
}
#[test]
fn test_initial_interval_handling_root_at_lower_bound_with_vtol() {
let f = |x: f64| x;
let ab = Interval::new(0.1, 1.0);
let solver_settings_1 = SolverSettings {
vtol: Some(0.1),
..Default::default()
};
let solver_settings_2 = SolverSettings {
vtol: Some(0.1),
rebracket: Some(true),
..Default::default()
};
for solver_settings in [solver_settings_1, solver_settings_2] {
let mut convergence_data = ConvergenceData::default();
let interval_result =
initial_interval_handling(&f, ab, &solver_settings, Some(&mut convergence_data));
match interval_result {
IntervalResult::Root(root) => {
assert_eq!(root, 0.1);
assert_eq!(convergence_data.x_all, vec![0.1]);
assert_eq!(convergence_data.a_all, vec![0.1]);
assert_eq!(convergence_data.b_all, vec![1.0]);
assert_eq!(convergence_data.f_all, vec![0.1]);
assert_eq!(convergence_data.n_feval, 2);
assert_eq!(convergence_data.n_bracket_iter, 0);
assert_eq!(
convergence_data.termination_reason,
TerminationReason::ValueToleranceSatisfied
);
}
_ => panic!("Test failed."),
}
}
}
#[test]
fn test_initial_interval_handling_root_at_upper_bound_no_vtol() {
let f = |x: f64| x;
let ab = Interval::new(-1.0, 0.0);
let solver_settings_1 = SolverSettings::default();
let solver_settings_2 = SolverSettings {
rebracket: Some(true),
..Default::default()
};
for solver_settings in [solver_settings_1, solver_settings_2] {
let mut convergence_data = ConvergenceData::default();
let interval_result =
initial_interval_handling(&f, ab, &solver_settings, Some(&mut convergence_data));
match interval_result {
IntervalResult::Root(root) => {
assert_eq!(root, 0.0);
assert_eq!(convergence_data.x_all, vec![0.0]);
assert_eq!(convergence_data.a_all, vec![-1.0]);
assert_eq!(convergence_data.b_all, vec![0.0]);
assert_eq!(convergence_data.f_all, vec![0.0]);
assert_eq!(convergence_data.n_feval, 2);
assert_eq!(convergence_data.n_bracket_iter, 0);
assert_eq!(
convergence_data.termination_reason,
TerminationReason::RootAtUpperBound
);
}
_ => panic!("Test failed."),
}
}
}
#[test]
fn test_initial_interval_handling_root_at_upper_bound_with_vtol() {
let f = |x: f64| x;
let ab = Interval::new(-1.0, -0.1);
let solver_settings_1 = SolverSettings {
vtol: Some(0.1),
..Default::default()
};
let solver_settings_2 = SolverSettings {
vtol: Some(0.1),
rebracket: Some(true),
..Default::default()
};
for solver_settings in [solver_settings_1, solver_settings_2] {
let mut convergence_data = ConvergenceData::default();
let interval_result =
initial_interval_handling(&f, ab, &solver_settings, Some(&mut convergence_data));
match interval_result {
IntervalResult::Root(root) => {
assert_eq!(root, -0.1);
assert_eq!(convergence_data.x_all, vec![-0.1]);
assert_eq!(convergence_data.a_all, vec![-1.0]);
assert_eq!(convergence_data.b_all, vec![-0.1]);
assert_eq!(convergence_data.f_all, vec![-0.1]);
assert_eq!(convergence_data.n_feval, 2);
assert_eq!(convergence_data.n_bracket_iter, 0);
assert_eq!(
convergence_data.termination_reason,
TerminationReason::ValueToleranceSatisfied
);
}
_ => panic!("Test failed."),
}
}
}
#[test]
fn test_initial_interval_handling_root_batol_satisfied_with_sign_change() {
let f = |x: f64| x;
let ab = Interval::new(-0.1, 0.1);
let solver_settings_1 = SolverSettings {
batol: Some(0.2),
..Default::default()
};
let solver_settings_2 = SolverSettings {
batol: Some(0.2),
rebracket: Some(true),
..Default::default()
};
for solver_settings in [solver_settings_1, solver_settings_2] {
let mut convergence_data = ConvergenceData::default();
let interval_result =
initial_interval_handling(&f, ab, &solver_settings, Some(&mut convergence_data));
match interval_result {
IntervalResult::Root(root) => {
assert_eq!(root, 0.0);
assert_eq!(convergence_data.x_all, vec![0.0]);
assert_eq!(convergence_data.a_all, vec![-0.1]);
assert_eq!(convergence_data.b_all, vec![0.1]);
assert_arrays_equal!(convergence_data.f_all, [f64::NAN]);
assert_eq!(convergence_data.n_feval, 2);
assert_eq!(convergence_data.n_bracket_iter, 0);
assert_eq!(
convergence_data.termination_reason,
TerminationReason::AbsoluteBracketToleranceSatisfied
);
}
_ => panic!("Test failed."),
}
}
}
#[test]
fn test_initial_interval_handling_root_batol_satisfied_without_sign_change_without_rebracketing()
{
let f = |x: f64| x;
let ab = Interval::new(0.1, 0.2);
let solver_settings = SolverSettings {
batol: Some(0.2),
..Default::default()
};
let mut convergence_data = ConvergenceData::default();
let interval_result =
initial_interval_handling(&f, ab, &solver_settings, Some(&mut convergence_data));
match interval_result {
IntervalResult::SolverError(err) => {
assert!(matches!(err, SolverError::IntervalDoesNotBracketSignChange));
}
_ => panic!("Test failed."),
}
}
#[test]
fn test_initial_interval_handling_not_rebracketing_no_sign_change() {
let f = |x: f64| x;
let ab = Interval::new(50.0, 100.0);
let solver_settings = &DEFAULT_SOLVER_SETTINGS;
let interval_result = initial_interval_handling(&f, ab, solver_settings, None);
match interval_result {
IntervalResult::SolverError(err) => {
assert!(matches!(err, SolverError::IntervalDoesNotBracketSignChange));
}
_ => panic!("Test failed."),
}
}
#[test]
fn test_initial_interval_handling_rebracketing_no_interval_found() {
let f = |x: f64| x.powi(2) + 1.0;
let ab = Interval::new(-1.0, 1.0);
let solver_settings = SolverSettings {
rebracket: Some(true),
..Default::default()
};
let interval_result = initial_interval_handling(&f, ab, &solver_settings, None);
match interval_result {
IntervalResult::SolverError(err) => {
assert!(matches!(err, SolverError::BracketingIntervalNotFound));
}
_ => panic!("Test failed."),
}
}
#[test]
fn test_initial_interval_handling_even_max_feval_reached_during_rebracketing() {
let f = |x: f64| x;
let ab = Interval::new(10.0, 10.1);
let solver_settings = SolverSettings {
rebracket: Some(true),
max_feval: Some(16),
..Default::default()
};
let interval_result = initial_interval_handling(&f, ab, &solver_settings, None);
match interval_result {
IntervalResult::SolverError(err) => {
assert!(matches!(err, SolverError::BracketingIntervalNotFound));
}
_ => panic!("Test failed."),
}
}
#[test]
fn test_initial_interval_handling_odd_max_feval_reached_during_rebracketing() {
let f = |x: f64| x;
let ab = Interval::new(10.0, 10.1);
let solver_settings = SolverSettings {
rebracket: Some(true),
max_feval: Some(17),
..Default::default()
};
let interval_result = initial_interval_handling(&f, ab, &solver_settings, None);
match interval_result {
IntervalResult::SolverError(err) => {
assert!(matches!(err, SolverError::BracketingIntervalNotFound));
}
_ => panic!("Test failed."),
}
}
}