use num_traits::Float;
use super::newton_cotes::Formula;
use crate::integrators::DEFAULT_MAXIMUM_CUT_COUNT;
use crate::DEFAULT_TOL;
use crate::{SolverError, SolverResult};
pub struct AdaptiveNewtonCotes<T, F> {
f: F,
formula: Formula,
maximum_cut_count: usize,
tolerance: T,
}
impl<T, F> AdaptiveNewtonCotes<T, F>
where
T: Float,
F: Fn(T) -> T,
{
pub fn new(f: F) -> Self {
Self {
f,
formula: Formula::SimpsonsOneThird,
maximum_cut_count: DEFAULT_MAXIMUM_CUT_COUNT,
tolerance: T::from(DEFAULT_TOL).unwrap(),
}
}
pub fn with_formula(&mut self, formula: Formula) -> &mut Self {
self.formula = formula;
self
}
pub fn with_maximum_cut_count(&mut self, maximum_cut_count: usize) -> &mut Self {
self.maximum_cut_count = maximum_cut_count;
self
}
pub fn with_tolerance(&mut self, tolerance: T) -> &mut Self {
self.tolerance = tolerance;
self
}
pub fn integrate(&self, from_0: T, to_0: T) -> SolverResult<T> {
let mut intervals = Vec::with_capacity(self.maximum_cut_count);
intervals.push((from_0, to_0));
let mut result = T::zero();
let mut number_of_cuts = 0;
let delta_0 = to_0 - from_0;
type FormulaF<T, F> = fn(&AdaptiveNewtonCotes<T, F>, T, T) -> T;
let (formula, error_scaling): (FormulaF<T, F>, T) = match self.formula {
Formula::Trapezium => (Self::trapezium, T::from(3.).unwrap()),
Formula::SimpsonsOneThird => (Self::simpsons_one_third, T::from(15.).unwrap()),
Formula::SimpsonsThreeEighths => (Self::simpsons_three_eighths, T::from(15.).unwrap()),
};
let scaled_tolerance = error_scaling * self.tolerance;
let half = T::from(0.5).unwrap();
while let Some((from_i, to_i)) = intervals.pop() {
let delta_i = (to_i - from_i) * half;
let mid_i = from_i + delta_i;
let integral_full = formula(self, from_i, to_i);
let integral_split = formula(self, from_i, mid_i) + formula(self, mid_i, to_i);
let error = (integral_full - integral_split).abs();
if error < scaled_tolerance * delta_i / delta_0 {
result = result + integral_split
} else {
intervals.push((from_i, mid_i));
intervals.push((mid_i, to_i));
number_of_cuts += 1;
}
if number_of_cuts > self.maximum_cut_count {
return Err(SolverError::MaxIterReached(number_of_cuts));
}
}
Ok(result)
}
fn trapezium(&self, x0: T, x1: T) -> T {
let half = T::from(0.5).unwrap();
let h = x1 - x0;
let f = &self.f;
half * h * (f(x0) + f(x1))
}
fn simpsons_one_third(&self, x0: T, x2: T) -> T {
let third = T::from(1. / 3.).unwrap();
let four = T::from(4.).unwrap();
let half = T::from(0.5).unwrap();
let h = (x2 - x0) * half;
let x1 = x0 + h;
let f = &self.f;
third * h * (f(x0) + four * f(x1) + f(x2))
}
fn simpsons_three_eighths(&self, x0: T, x3: T) -> T {
let three_eighths = T::from(3. / 8.).unwrap();
let three = T::from(3.0).unwrap();
let third = three.recip();
let h = (x3 - x0) * third;
let x1 = x0 + h;
let x2 = x1 + h;
let f = &self.f;
three_eighths * h * (f(x0) + three * (f(x1) + f(x2)) + f(x3))
}
}