use std::collections::HashMap;
use serde::Deserialize;
use serde::Serialize;
use crate::numerical::elementary::eval_expr;
use crate::symbolic::core::Expr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuadratureMethod {
Trapezoidal,
Simpson,
Adaptive,
Romberg,
GaussLegendre,
}
pub fn trapezoidal_rule<F>(
f: F,
range: (f64, f64),
n_steps: usize,
) -> f64
where
F: Fn(f64) -> f64,
{
let (a, b) = range;
if n_steps == 0 {
return 0.0;
}
if (a - b).abs() < f64::EPSILON {
return 0.0;
}
let h = (b - a) / (n_steps as f64);
let mut sum = 0.5 * (f(a) + f(b));
for i in 1..n_steps {
let x = (i as f64).mul_add(h, a);
sum += f(x);
}
h * sum
}
pub fn simpson_rule<F>(
f: F,
range: (f64, f64),
n_steps: usize,
) -> Result<f64, String>
where
F: Fn(f64) -> f64,
{
let (a, b) = range;
if n_steps == 0 {
return Ok(0.0);
}
if (a - b).abs() < f64::EPSILON {
return Ok(0.0);
}
let steps = if n_steps.is_multiple_of(2) {
n_steps
} else {
n_steps + 1
};
let h = (b - a) / (steps as f64);
let mut sum = f(a) + f(b);
for i in 1..steps {
let x = (i as f64).mul_add(h, a);
let weight = if i % 2 == 0 { 2.0 } else { 4.0 };
sum += weight * f(x);
}
Ok((h / 3.0) * sum)
}
pub fn adaptive_quadrature<F>(
f: F,
range: (f64, f64),
tolerance: f64,
) -> f64
where
F: Fn(f64) -> f64,
{
fn adaptive_recursive<F>(
f: &F,
a: f64,
b: f64,
eps: f64,
whole_simpson: f64,
limit: usize,
) -> f64
where
F: Fn(f64) -> f64,
{
if limit == 0 {
return whole_simpson;
}
let mid = f64::midpoint(a, b);
let sub_mid_left = f64::midpoint(a, mid);
let sub_mid_right = f64::midpoint(mid, b);
let fa = f(a);
let fb = f(b);
let fm = f(mid);
let fml = f(sub_mid_left);
let fmr = f(sub_mid_right);
let left_simpson = (mid - a) / 6.0 * (4.0f64.mul_add(fml, fa) + fm);
let right_simpson = (b - mid) / 6.0 * (4.0f64.mul_add(fmr, fm) + fb);
let sum_halves = left_simpson + right_simpson;
let error = (sum_halves - whole_simpson).abs() / 15.0;
if error <= eps {
sum_halves + (sum_halves - whole_simpson) / 15.0
} else {
adaptive_recursive(f, a, mid, eps / 2.0, left_simpson, limit - 1)
+ adaptive_recursive(f, mid, b, eps / 2.0, right_simpson, limit - 1)
}
}
let (a, b) = range;
if (a - b).abs() < f64::EPSILON {
return 0.0;
}
let mid = f64::midpoint(a, b);
let fm = f(mid);
let initial_simpson = (b - a) / 6.0 * (4.0f64.mul_add(fm, f(a)) + f(b));
adaptive_recursive(&f, a, b, tolerance, initial_simpson, 100) }
pub fn romberg_integration<F>(
f: F,
range: (f64, f64),
max_steps: usize,
) -> f64
where
F: Fn(f64) -> f64,
{
let (a, b) = range;
if max_steps == 0 {
return 0.0;
}
if (a - b).abs() < f64::EPSILON {
return 0.0;
}
let mut r = vec![vec![0.0; max_steps]; max_steps];
let h = b - a;
r[0][0] = 0.5 * h * (f(a) + f(b));
for i in 1..max_steps {
let steps_prev = 1 << (i - 1);
let h_i = h / f64::from(1 << i);
let mut sum = 0.0;
for k in 1..=steps_prev {
let x = a + f64::from(2 * k - 1) * h_i;
sum += f(x);
}
r[i][0] = 0.5f64.mul_add(r[i - 1][0], h_i * sum);
for j in 1..=i {
let k = 4.0_f64.powi(j as i32);
r[i][j] = k.mul_add(r[i][j - 1], -r[i - 1][j - 1]) / (k - 1.0);
}
}
r[max_steps - 1][max_steps - 1]
}
pub fn gauss_legendre_quadrature<F>(
f: F,
range: (f64, f64),
) -> f64
where
F: Fn(f64) -> f64,
{
let (a, b) = range;
if (a - b).abs() < f64::EPSILON {
return 0.0;
}
let mid = f64::midpoint(a, b);
let half_len = (b - a) / 2.0;
let nodes = [
0.0,
0.538_469_310_105_683_1,
-0.538_469_310_105_683_1,
0.906_179_845_938_664,
-0.906_179_845_938_664,
];
let weights = [
0.568_888_888_888_889,
0.478_628_670_499_366_5,
0.478_628_670_499_366_5,
0.236_926_885_056_189_1,
0.236_926_885_056_189_1,
];
let mut sum = 0.0;
for i in 0..5 {
let x = mid + half_len * nodes[i];
sum += weights[i] * f(x);
}
half_len * sum
}
pub fn quadrature(
f: &Expr,
var: &str,
range: (f64, f64),
n_steps: usize,
method: &QuadratureMethod,
) -> Result<f64, String> {
let func = |x: f64| -> f64 {
let mut vars = HashMap::new();
vars.insert(var.to_string(), x);
eval_expr(f, &vars).unwrap_or(f64::NAN)
};
let result = match method {
| QuadratureMethod::Trapezoidal => trapezoidal_rule(func, range, n_steps),
| QuadratureMethod::Simpson => simpson_rule(func, range, n_steps)?,
| QuadratureMethod::Adaptive => adaptive_quadrature(func, range, 1e-6),
| QuadratureMethod::Romberg => romberg_integration(func, range, 6),
| QuadratureMethod::GaussLegendre => gauss_legendre_quadrature(func, range),
};
if result.is_nan() {
return Err("Integration \
resulted in \
NaN, likely due \
to evaluation \
error."
.to_string());
}
Ok(result)
}