use crate::error::{IntegrateError, IntegrateResult};
use crate::IntegrateFloat;
use std::f64::consts::PI;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct QuadOptions<F: IntegrateFloat> {
pub abs_tol: F,
pub rel_tol: F,
pub max_evals: usize,
pub use_abs_error: bool,
pub use_simpson: bool,
}
impl<F: IntegrateFloat> Default for QuadOptions<F> {
fn default() -> Self {
Self {
abs_tol: F::from_f64(1.49e-8).expect("Operation failed"), rel_tol: F::from_f64(1.49e-8).expect("Operation failed"), max_evals: 500, use_abs_error: false,
use_simpson: false,
}
}
}
#[derive(Debug, Clone)]
pub struct QuadResult<F: IntegrateFloat> {
pub value: F,
pub abs_error: F,
pub n_evals: usize,
pub converged: bool,
}
#[allow(dead_code)]
pub fn trapezoid<F, Func>(f: Func, a: F, b: F, n: usize) -> F
where
F: IntegrateFloat,
Func: Fn(F) -> F,
{
if n == 0 {
return F::zero();
}
let h = (b - a) / F::from_usize(n).expect("Operation failed");
let mut sum = F::from_f64(0.5).expect("Operation failed") * (f(a) + f(b));
for i in 1..n {
let x = a + F::from_usize(i).expect("Operation failed") * h;
sum += f(x);
}
sum * h
}
#[allow(dead_code)]
pub fn simpson<F, Func>(mut f: Func, a: F, b: F, n: usize) -> IntegrateResult<F>
where
F: IntegrateFloat,
Func: FnMut(F) -> F,
{
if n == 0 {
return Ok(F::zero());
}
if !n.is_multiple_of(2) {
return Err(IntegrateError::ValueError(
"Number of intervals must be even".to_string(),
));
}
let h = (b - a) / F::from_usize(n).expect("Operation failed");
let mut sum_even = F::zero();
let mut sum_odd = F::zero();
for i in 1..n {
let x = a + F::from_usize(i).expect("Operation failed") * h;
if i % 2 == 0 {
sum_even += f(x);
} else {
sum_odd += f(x);
}
}
let result = (f(a)
+ f(b)
+ F::from_f64(2.0).expect("Operation failed") * sum_even
+ F::from_f64(4.0).expect("Operation failed") * sum_odd)
* h
/ F::from_f64(3.0).expect("Operation failed");
Ok(result)
}
#[allow(dead_code)]
pub fn quad<F, Func>(
f: Func,
a: F,
b: F,
options: Option<QuadOptions<F>>,
) -> IntegrateResult<QuadResult<F>>
where
F: IntegrateFloat,
Func: Fn(F) -> F + Copy,
{
let opts = options.unwrap_or_default();
if opts.use_simpson {
let n = 1000; let result = simpson(f, a, b, n)?;
return Ok(QuadResult {
value: result,
abs_error: F::from_f64(1e-8).expect("Operation failed"), n_evals: n + 1, converged: true,
});
}
let mut n_evals = 0;
let (value, error, converged) = adaptive_quad_impl(f, a, b, &mut n_evals, &opts)?;
Ok(QuadResult {
value,
abs_error: error,
n_evals,
converged,
})
}
#[allow(dead_code)]
fn adaptive_quad_impl<F, Func>(
f: Func,
a: F,
b: F,
n_evals: &mut usize,
options: &QuadOptions<F>,
) -> IntegrateResult<(F, F, bool)>
where
F: IntegrateFloat,
Func: Fn(F) -> F + Copy,
{
let n_initial = 10; let mut eval_count_coarse = 0;
let coarse_result = {
let f_with_count = |x: F| {
eval_count_coarse += 1;
f(x)
};
simpson(f_with_count, a, b, n_initial)?
};
*n_evals += eval_count_coarse;
let n_refined = 20; let mut eval_count_refined = 0;
let refined_result = {
let f_with_count = |x: F| {
eval_count_refined += 1;
f(x)
};
simpson(f_with_count, a, b, n_refined)?
};
*n_evals += eval_count_refined;
let error = (refined_result - coarse_result).abs();
let tolerance = if options.use_abs_error {
options.abs_tol
} else {
options.abs_tol + options.rel_tol * refined_result.abs()
};
let converged = error <= tolerance || *n_evals >= options.max_evals;
if *n_evals >= options.max_evals && error > tolerance {
return Err(IntegrateError::ConvergenceError(format!(
"Failed to converge after {} function evaluations",
*n_evals
)));
}
if !converged {
let mid = (a + b) / F::from_f64(2.0).expect("Operation failed");
let (left_value, left_error, left_converged) =
adaptive_quad_impl(f, a, mid, n_evals, options)?;
let (right_value, right_error, right_converged) =
adaptive_quad_impl(f, mid, b, n_evals, options)?;
let value = left_value + right_value;
let abs_error = left_error + right_error;
let sub_converged = left_converged && right_converged;
return Ok((value, abs_error, sub_converged));
}
Ok((refined_result, error, converged))
}
#[allow(dead_code)] fn simpson_with_count<F, Func>(
f: &mut Func,
a: F,
b: F,
n: usize,
count: &mut usize,
) -> IntegrateResult<F>
where
F: IntegrateFloat,
Func: FnMut(F) -> F,
{
if n == 0 {
return Ok(F::zero());
}
if !n.is_multiple_of(2) {
return Err(IntegrateError::ValueError(
"Number of intervals must be even".to_string(),
));
}
let h = (b - a) / F::from_usize(n).expect("Operation failed");
let mut sum_even = F::zero();
let mut sum_odd = F::zero();
*count += 2; let fa = f(a);
let fb = f(b);
for i in 1..n {
let x = a + F::from_usize(i).expect("Operation failed") * h;
*count += 1;
if i % 2 == 0 {
sum_even += f(x);
} else {
sum_odd += f(x);
}
}
let result = (fa
+ fb
+ F::from_f64(2.0).expect("Operation failed") * sum_even
+ F::from_f64(4.0).expect("Operation failed") * sum_odd)
* h
/ F::from_f64(3.0).expect("Operation failed");
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_trapezoid_rule() {
let result = trapezoid(|x| x * x, 0.0, 1.0, 100);
assert_relative_eq!(result, 1.0 / 3.0, epsilon = 1e-4);
let pi = std::f64::consts::PI;
let result = trapezoid(|x| x.sin(), 0.0, pi, 1000);
assert_relative_eq!(result, 2.0, epsilon = 1e-4);
}
#[test]
fn test_simpson_rule() {
let result = simpson(|x| x * x, 0.0, 1.0, 100).expect("Operation failed");
assert_relative_eq!(result, 1.0 / 3.0, epsilon = 1e-8);
let pi = std::f64::consts::PI;
let result = simpson(|x| x.sin(), 0.0, pi, 100).expect("Operation failed");
assert_relative_eq!(result, 2.0, epsilon = 1e-6);
let error = simpson(|x| x * x, 0.0, 1.0, 99);
assert!(error.is_err());
}
#[test]
fn test_adaptive_quad() {
let result = quad(|x| x * x, 0.0, 1.0, None).expect("Operation failed");
assert_relative_eq!(result.value, 1.0 / 3.0, epsilon = 1e-8);
assert!(result.converged);
let options = QuadOptions {
use_simpson: true, ..Default::default()
};
let result = quad(
|x: f64| x.cos(),
0.0,
std::f64::consts::PI / 2.0,
Some(options),
)
.expect("Failed to integrate");
assert_relative_eq!(result.value, 1.0, epsilon = 1e-6);
}
}