use crate::rootfinder::{Rootfinder, RootfinderData};
pub struct Bisection<F>
where
F: Fn(f64) -> f64,
{
function: F,
guess: f64,
data: RootfinderData,
}
impl<F> Bisection<F>
where
F: Fn(f64) -> f64,
{
pub fn new(function: F, guess: f64, data: RootfinderData) -> Self {
Self {
function,
guess,
data,
}
}
}
impl<F> Rootfinder<F> for Bisection<F>
where
F: Fn(f64) -> f64,
{
fn value(&self, x: f64) -> f64 {
(self.function)(x)
}
fn derivative(&self, _: f64) -> f64 {
0.0
}
fn solve_impl(&mut self) -> f64 {
let mut dx: f64;
let mut x_mid: f64;
let mut f_mid: f64;
if self.data.y_min < 0.0 {
dx = self.data.x_max - self.data.x_min;
self.data.root = self.data.x_min;
} else {
dx = self.data.x_min - self.data.x_max;
self.data.root = self.data.x_max;
}
while self.data.iteration_count <= Self::MAX_ITERATIONS {
dx /= 2.0;
x_mid = self.data.root + dx;
f_mid = self.value(x_mid);
self.data.increment_evaluation_count();
if f_mid <= 0.0 {
self.data.root = x_mid;
}
if dx.abs() < self.data.accuracy || RootfinderData::close(f_mid, 0.0) {
return self.data.root;
}
}
0.0
}
fn solve(&mut self) -> f64 {
assert!(self.data.accuracy > 0., "accuracy must be positive");
self.data.accuracy = f64::max(self.data.accuracy, f64::EPSILON);
let growth_factor = 1.6;
let mut flipflop = -1;
self.data.root = self.guess;
self.data.y_max = self.value(self.data.root);
if RootfinderData::close(self.data.y_max, 0.0) {
return self.data.root;
} else if self.data.y_max > 0.0 {
self.data.x_min = self
.data
.enforce_bounds(self.data.root - self.data.stepsize);
self.data.y_min = self.value(self.data.x_min);
self.data.x_max = self.data.root;
} else {
self.data.x_min = self.data.root;
self.data.y_min = self.data.y_max;
self.data.x_max = self
.data
.enforce_bounds(self.data.root + self.data.stepsize);
self.data.y_max = self.value(self.data.x_max);
}
self.data.iteration_count = 2;
while self.data.iteration_count <= Self::MAX_ITERATIONS {
if self.data.y_min * self.data.y_max <= 0.0 {
if RootfinderData::close(self.data.y_min, 0.0) {
return self.data.x_min;
}
if RootfinderData::close(self.data.y_max, 0.0) {
return self.data.x_max;
}
self.data.root = 0.5 * (self.data.x_max + self.data.x_min);
return self.solve_impl();
}
if self.data.y_min.abs() < self.data.y_max.abs() {
self.data.x_min = self.data.enforce_bounds(
self.data.x_min + growth_factor * (self.data.x_min - self.data.x_max),
);
self.data.y_min = self.value(self.data.x_min);
} else if self.data.y_min.abs() > self.data.y_max.abs() {
self.data.x_max = self.data.enforce_bounds(
self.data.x_max + growth_factor * (self.data.x_max - self.data.x_min),
);
self.data.y_max = self.value(self.data.x_max);
} else if flipflop == -1 {
self.data.x_min = self.data.enforce_bounds(
self.data.x_min + growth_factor * (self.data.x_min - self.data.x_max),
);
self.data.y_min = self.value(self.data.x_min);
self.data.increment_evaluation_count();
flipflop = 1;
} else if flipflop == 1 {
self.data.x_max = self.data.enforce_bounds(
self.data.x_max + growth_factor * (self.data.x_max - self.data.x_min),
);
self.data.y_max = self.value(self.data.x_max);
flipflop = -1;
}
self.data.increment_evaluation_count();
}
0.0
}
}
#[cfg(test)]
mod TESTS_bisection_solver {
use super::*;
use std::f64::consts::SQRT_2;
#[test]
fn test_bisection_solver() {
let f = |x: f64| x.powi(2) - 2.0;
let data = RootfinderData::new(1e-15, 1e-5, 0.0, 2.0, true);
let mut solver = Bisection::new(f, 1.0, data);
let root = solver.solve();
assert!((root - SQRT_2) < 1e-15);
}
}