use crate::utils::{
bracketing::{Interval, IntervalResult, initial_interval_handling},
convergence_data::ConvergenceData,
enums::{SolverError, TerminationReason},
solver_settings::{DEFAULT_SOLVER_SETTINGS, SolverSettings},
termination::is_vtol_satisfied,
};
fn get_k12(ab: Interval, batol: f64) -> u32 {
((ab.b - ab.a) / batol).log2().ceil() as u32
}
pub fn root_bisection(
f: &impl Fn(f64) -> f64,
mut ab: Interval,
solver_settings: Option<&SolverSettings>,
mut convergence_data: Option<&mut ConvergenceData>,
) -> Result<f64, SolverError> {
let mut termination_reason = TerminationReason::AbsoluteBracketToleranceSatisfied;
let solver_settings: &SolverSettings = solver_settings.unwrap_or(&DEFAULT_SOLVER_SETTINGS);
let batol = solver_settings.batol.unwrap_or(2.0 * f64::EPSILON);
let mut fa: f64;
let mut n_feval: u32 = 0;
match initial_interval_handling(f, ab, solver_settings, convergence_data.as_deref_mut()) {
IntervalResult::Root(root) => return Ok(root),
IntervalResult::SolverError(e) => return Err(e),
IntervalResult::UpdatedInterval(updated_interval) => {
ab = updated_interval.interval;
fa = updated_interval.fa;
n_feval += updated_interval.n_feval;
}
};
let mut a = ab.a;
let mut b = ab.b;
let k_12 = get_k12(Interval::new(a, b), batol);
let mut max_iter = solver_settings.max_iter.unwrap_or(k_12);
if max_iter >= k_12 {
max_iter = k_12;
} else {
termination_reason = TerminationReason::MaxIterationsReached;
}
let mut c = (a + b) / 2.0;
if let Some(max_feval) = solver_settings.max_feval {
let n_feval_remaining = max_feval - n_feval;
if n_feval_remaining < max_iter {
max_iter = n_feval_remaining;
termination_reason = TerminationReason::MaxFunctionEvaluationsReached;
}
}
let mut fc;
for _ in 0..max_iter {
fc = f(c);
n_feval += 1;
if let Some(convergence_data) = convergence_data.as_deref_mut() {
convergence_data.x_all.push(c);
convergence_data.a_all.push(a);
convergence_data.b_all.push(b);
convergence_data.f_all.push(fc);
convergence_data.n_iter += 1;
}
if is_vtol_satisfied(fc, solver_settings, convergence_data.as_deref_mut()) {
break;
}
if fa * fc > 0.0 {
a = c;
fa = fc;
} else {
b = c;
}
c = (a + b) / 2.0;
}
if let Some(convergence_data) = convergence_data {
convergence_data.x_all.push(c);
convergence_data.a_all.push(a);
convergence_data.b_all.push(b);
convergence_data.f_all.push(f64::NAN);
convergence_data.n_feval = n_feval;
if let TerminationReason::NotYetTerminated = convergence_data.termination_reason {
convergence_data.termination_reason = termination_reason;
}
}
Ok(c)
}
pub fn root_bisection_fast(f: &impl Fn(f64) -> f64, ab: Interval) -> f64 {
let n_iter = ((ab.b - ab.a) / (2.0 * f64::EPSILON)).log2().ceil() as u32;
let mut a = ab.a;
let mut b = ab.b;
let mut c = (a + b) / 2.0;
let mut fc;
let mut fa = f(a);
for _ in 0..n_iter {
fc = f(c);
if fa * fc > 0.0 {
a = c;
fa = fc;
} else {
b = c;
}
c = (a + b) / 2.0;
}
c
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::enums::TerminationReason;
use numtest::*;
#[test]
fn test_get_k12() {
assert_eq!(
get_k12(
Interval::new(-f64::EPSILON, f64::EPSILON),
2.0 * f64::EPSILON
),
0
);
assert_eq!(get_k12(Interval::new(-1.0, 1.0), 2.0 * f64::EPSILON), 52);
}
#[allow(clippy::too_many_arguments)]
fn root_bisection_test_helper(
f: &impl Fn(f64) -> f64,
ab: Interval,
solver_settings: Option<&SolverSettings>,
x_exp: f64,
root_tol: Option<f64>,
value_tol: Option<f64>,
n_iter_exp: u32,
n_feval_exp: u32,
n_bracket_iter_exp: u32,
reason_exp: TerminationReason,
) {
let root_tol = root_tol.unwrap_or(f64::EPSILON);
let value_tol = value_tol.unwrap_or(f64::EPSILON);
let mut convergence_data = ConvergenceData::default();
let x = root_bisection(f, ab, solver_settings, Some(&mut convergence_data)).unwrap();
let x_fast = root_bisection_fast(f, ab);
assert_equal_to_atol!(x, x_exp, root_tol);
assert_equal_to_atol!(f(x), 0.0, value_tol);
if solver_settings.is_some() && !solver_settings.unwrap().rebracket.unwrap_or(true) {
assert_eq!(x, x_fast);
}
assert_eq!(convergence_data.n_iter, n_iter_exp);
assert_eq!(convergence_data.n_feval, n_feval_exp);
assert_eq!(convergence_data.n_bracket_iter, n_bracket_iter_exp);
assert_eq!(convergence_data.termination_reason, reason_exp);
}
#[test]
fn test_root_bisection_root_at_midpoint() {
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(0.0, 2.0),
None,
1.0,
None,
Some(2.0 * f64::EPSILON),
52,
54,
0,
TerminationReason::AbsoluteBracketToleranceSatisfied,
);
}
#[test]
fn test_root_bisection_root_at_lower_bound() {
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(1.0, 2.0),
None,
1.0,
None,
Some(2.0 * f64::EPSILON),
0,
2,
0,
TerminationReason::RootAtLowerBound,
);
}
#[test]
fn test_root_bisection_root_at_upper_bound() {
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(0.0, 1.0),
None,
1.0,
None,
Some(2.0 * f64::EPSILON),
0,
2,
0,
TerminationReason::RootAtUpperBound,
);
}
#[test]
fn test_root_bisection_large_initial_interval() {
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(0.0, 9999999.0),
None,
1.0,
None,
None,
75,
77,
0,
TerminationReason::AbsoluteBracketToleranceSatisfied,
);
}
#[test]
fn test_root_bisection_root_within_tolerance() {
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(1.0 - f64::EPSILON, 1.0 + f64::EPSILON),
None,
1.0,
None,
None,
0,
2,
0,
TerminationReason::AbsoluteBracketToleranceSatisfied,
);
}
#[test]
fn test_root_bisection_zero_bracket_width() {
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(1.0, 1.0),
None,
1.0,
None,
None,
0,
2,
0,
TerminationReason::RootAtLowerBound,
);
}
#[test]
fn test_root_bisection_constant_function() {
let solver_settings = SolverSettings::default();
let mut convergence_data = ConvergenceData::default();
let result = root_bisection(
&|_x: f64| 1.0,
Interval::new(0.0, 2.0),
Some(&solver_settings),
Some(&mut convergence_data),
);
assert!(matches!(
result.unwrap_err(),
SolverError::IntervalDoesNotBracketSignChange
));
}
#[test]
fn test_root_bisection_batol() {
let solver_settings = SolverSettings {
batol: Some(0.1),
..Default::default()
};
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(0.75, 1.5),
Some(&solver_settings),
1.0,
Some(solver_settings.batol.unwrap() / 2.0), Some(0.1),
3,
5,
0,
TerminationReason::AbsoluteBracketToleranceSatisfied,
);
}
#[test]
fn test_root_bisection_vtol() {
let solver_settings = SolverSettings {
vtol: Some(0.01),
..Default::default()
};
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(0.75, 1.5),
Some(&solver_settings),
1.0,
Some(solver_settings.vtol.unwrap() / 2.0),
Some(0.01),
6,
8,
0,
TerminationReason::ValueToleranceSatisfied,
);
}
#[test]
fn test_root_bisection_max_iter() {
let solver_settings = SolverSettings {
max_iter: Some(10),
..Default::default()
};
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(0.75, 1.5),
Some(&solver_settings),
1.0,
Some(0.0002),
Some(0.0004),
10,
12,
0,
TerminationReason::MaxIterationsReached,
);
}
#[test]
fn test_root_bisection_max_feval() {
let solver_settings = SolverSettings {
max_feval: Some(10),
..Default::default()
};
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(0.75, 1.5),
Some(&solver_settings),
1.0,
Some(0.001),
Some(0.002),
8,
10,
0,
TerminationReason::MaxFunctionEvaluationsReached,
);
}
#[test]
fn test_root_bisection_rebracket_not_needed() {
let solver_settings = SolverSettings {
rebracket: Some(true),
..Default::default()
};
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(0.0, 2.0),
Some(&solver_settings),
1.0,
None,
Some(2.0 * f64::EPSILON),
52,
54,
0,
TerminationReason::AbsoluteBracketToleranceSatisfied,
);
}
#[test]
fn test_root_bisection_rebracket_successful() {
let solver_settings = SolverSettings {
rebracket: Some(true),
..Default::default()
};
root_bisection_test_helper(
&|x: f64| x.powi(2) - 1.0,
Interval::new(1.5, 2.5),
Some(&solver_settings),
1.0,
None,
Some(2.0 * f64::EPSILON),
53,
59,
2,
TerminationReason::AbsoluteBracketToleranceSatisfied,
);
}
#[test]
fn test_root_bisection_rebracket_failed() {
let solver_settings = SolverSettings {
rebracket: Some(true),
..Default::default()
};
let mut convergence_data = ConvergenceData::default();
let result = root_bisection(
&|x: f64| x.powi(2) + 1.0,
Interval::new(-2.0, 2.0),
Some(&solver_settings),
Some(&mut convergence_data),
);
assert!(matches!(
result.unwrap_err(),
SolverError::BracketingIntervalNotFound
));
}
#[test]
fn test_root_bisection_rebracket_exceeds_max_feval() {
let solver_settings = SolverSettings {
rebracket: Some(true),
max_feval: Some(5),
..Default::default()
};
let result = root_bisection(
&|x: f64| x.powi(2) - 1.0,
Interval::new(1.5, 2.5),
Some(&solver_settings),
None,
);
assert!(matches!(
result.unwrap_err(),
SolverError::BracketingIntervalNotFound
));
}
#[test]
fn test_root_bisection_max_bracket_iter_successful() {
let solver_settings = SolverSettings {
rebracket: Some(true),
max_bracket_iter: Some(5),
..Default::default()
};
root_bisection_test_helper(
&|x: f64| x.powi(3) - 1.0,
Interval::new(1000.0, 1100.0),
Some(&solver_settings),
1.0,
Some(9.0),
Some(99.0),
63,
75,
5,
TerminationReason::AbsoluteBracketToleranceSatisfied,
);
}
#[test]
fn test_root_bisection_max_bracket_iter_failed() {
let solver_settings = SolverSettings {
rebracket: Some(true),
max_bracket_iter: Some(2),
..Default::default()
};
let result = root_bisection(
&|x: f64| x.powi(3) - 1.0,
Interval::new(1000.0, 1100.0),
Some(&solver_settings),
None,
);
assert!(matches!(
result.unwrap_err(),
SolverError::BracketingIntervalNotFound
));
}
#[test]
fn test_root_bisection_iterates() {
let f = |x: f64| x.powi(3) - x - 2.0;
let ab = Interval::new(1.0, 2.0);
let mut convergence_data = ConvergenceData::default();
let solver_settings = SolverSettings {
max_iter: Some(14),
..Default::default()
};
let root =
root_bisection(&f, ab, Some(&solver_settings), Some(&mut convergence_data)).unwrap();
assert_eq!(root, *convergence_data.x_all.last().unwrap());
assert_arrays_equal_to_decimal!(
convergence_data.x_all,
[
1.5,
1.75,
1.625,
1.5625,
1.53125,
1.515625,
1.5234375,
1.51953125,
1.521484375,
1.5205078125,
1.52099609375,
1.521240234375,
1.5213623046875,
1.52142333984375,
1.521392822265625
],
16
);
assert_arrays_equal_to_decimal!(
convergence_data.a_all,
[
1.0,
1.5,
1.5,
1.5,
1.5,
1.5,
1.515625,
1.515625,
1.51953125,
1.51953125,
1.5205078125,
1.52099609375,
1.521240234375,
1.5213623046875,
1.5213623046875
],
16
);
assert_arrays_equal_to_decimal!(
convergence_data.b_all,
[
2.0,
2.0,
1.75,
1.625,
1.5625,
1.53125,
1.53125,
1.5234375,
1.5234375,
1.521484375,
1.521484375,
1.521484375,
1.521484375,
1.521484375,
1.52142333984375
],
16
);
assert_arrays_equal_to_decimal!(
convergence_data.f_all,
[
-0.125,
1.609375,
0.666015625,
0.252197265625,
0.059112548828125,
-0.034053802490234375,
0.012250423431396484,
-0.010971248149871826,
6.221756339073181e-4,
-5.178886465728283e-3,
-2.279443317092955e-3,
-8.289058605441824e-4,
-1.034331235132413e-4,
2.593542519662151e-4,
f64::NAN
],
16
);
}
}