use crate::{SolverError, SolverResult};
use num_traits::Float;
pub const DEFAULT_SUBDIVISIONS: usize = 1000;
pub enum Formula {
Trapezium,
SimpsonsOneThird,
SimpsonsThreeEighths,
}
pub struct NewtonCotes<T, F> {
f: F,
formula: fn(&Self, &[T], T) -> T,
subdivisions: usize,
}
impl<T, F> NewtonCotes<T, F>
where
F: Fn(T) -> T,
T: Float,
{
pub fn new(f: F) -> Self {
Self {
f,
subdivisions: DEFAULT_SUBDIVISIONS,
formula: Self::simpsons_one_third,
}
}
pub fn with_formula(&mut self, formula: Formula) -> &mut Self {
self.formula = match formula {
Formula::Trapezium => Self::trapezium,
Formula::SimpsonsOneThird => Self::simpsons_one_third,
Formula::SimpsonsThreeEighths => Self::simpsons_three_eighths,
};
self
}
pub fn with_subdivisions(&mut self, subdivisions: usize) -> &mut Self {
self.subdivisions = subdivisions;
self
}
pub fn integrate(&self, from: T, to: T) -> SolverResult<T> {
if from > to {
return Err(SolverError::IncorrectInput {
details: "the input to-value should be at least as big as the from-value",
});
}
let mut subdivision_values = vec![T::zero(); self.subdivisions + 1];
let Some(subdivisions_as_t) = T::from(self.subdivisions) else {
return Err(SolverError::TypeConversionError);
};
let delta = (to - from) / subdivisions_as_t;
for (i, item) in subdivision_values
.iter_mut()
.enumerate()
.take(self.subdivisions + 1)
{
*item = (self.f)(from + T::from(i).unwrap() * delta);
}
Ok((self.formula)(self, &subdivision_values, delta))
}
fn trapezium(&self, subdivision_values: &[T], delta: T) -> T {
let half = T::from(0.5).unwrap();
half * delta
* subdivision_values
.windows(2)
.map(|f| f[0] + f[1])
.fold(T::zero(), T::add)
}
fn simpsons_one_third(&self, subdivision_values: &[T], delta: T) -> T {
let one_third = T::from(1. / 3.).unwrap();
let four = T::from(4.0).unwrap();
one_third
* delta
* subdivision_values
.windows(3)
.step_by(2)
.map(|f| f[0] + four * f[1] + f[2])
.fold(T::zero(), T::add)
}
fn simpsons_three_eighths(&self, subdivision_values: &[T], delta: T) -> T {
let one_third = T::from(3. / 8.).unwrap();
let three = T::from(3.).unwrap();
one_third
* delta
* subdivision_values
.windows(4)
.step_by(3)
.map(|f| f[0] + three * (f[1] + f[2]) + f[3])
.fold(T::zero(), T::add)
}
}