use scirs2_core::ndarray::Array1;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::f64::consts::PI;
use std::fmt;
use crate::error::{IntegrateError, IntegrateResult};
#[derive(Clone, Debug)]
pub struct QuadVecResult<T> {
pub integral: Array1<T>,
pub error: Array1<T>,
pub nfev: usize,
pub nintervals: usize,
pub success: bool,
}
impl<T: fmt::Display> fmt::Display for QuadVecResult<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"QuadVecResult(\n integral=[{:}],\n error=[{:}],\n nfev={},\n nintervals={},\n success={}\n)",
self.integral
.iter()
.map(|v| format!("{v}"))
.collect::<Vec<_>>()
.join(", "),
self.error
.iter()
.map(|v| format!("{v}"))
.collect::<Vec<_>>()
.join(", "),
self.nfev,
self.nintervals,
self.success
)
}
}
#[derive(Clone, Debug)]
pub struct QuadVecOptions {
pub epsabs: f64,
pub epsrel: f64,
pub norm: NormType,
pub limit: usize,
pub rule: QuadRule,
pub points: Option<Vec<f64>>,
}
impl Default for QuadVecOptions {
fn default() -> Self {
Self {
epsabs: 1e-10,
epsrel: 1e-8,
norm: NormType::L2,
limit: 50,
rule: QuadRule::GK21,
points: None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum NormType {
Max,
L2,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum QuadRule {
GK15,
GK21,
Trapezoid,
}
#[derive(Clone, Debug)]
struct Subinterval {
a: f64,
b: f64,
integral: Array1<f64>,
error: Array1<f64>,
error_norm: f64,
}
impl PartialEq for Subinterval {
fn eq(&self, other: &Self) -> bool {
self.error_norm == other.error_norm
}
}
impl Eq for Subinterval {}
impl PartialOrd for Subinterval {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Subinterval {
fn cmp(&self, other: &Self) -> Ordering {
other
.error_norm
.partial_cmp(&self.error_norm)
.unwrap_or(Ordering::Equal)
}
}
#[allow(dead_code)]
fn compute_norm(array: &Array1<f64>, normtype: NormType) -> f64 {
match normtype {
NormType::Max => {
let mut max_abs = 0.0;
for &val in array.iter() {
let abs_val = val.abs();
if abs_val > max_abs {
max_abs = abs_val;
}
}
max_abs
}
NormType::L2 => {
let mut sum_squares: f64 = 0.0;
for &val in array.iter() {
sum_squares += val * val;
}
sum_squares.sqrt()
}
}
}
#[allow(dead_code)]
pub fn quad_vec<F>(
f: F,
a: f64,
b: f64,
options: Option<QuadVecOptions>,
) -> IntegrateResult<QuadVecResult<f64>>
where
F: Fn(f64) -> Array1<f64>,
{
let options = options.unwrap_or_default();
if !a.is_finite() || !b.is_finite() {
return Err(IntegrateError::ValueError(
"Integration limits must be finite".to_string(),
));
}
if (b - a).abs() <= f64::EPSILON * a.abs().max(b.abs()) {
let fval = f((a + b) / 2.0);
let zeros = Array1::zeros(fval.len());
return Ok(QuadVecResult {
integral: zeros.clone(),
error: zeros,
nfev: 1,
nintervals: 0,
success: true,
});
}
let intervals = if let Some(ref points) = options.points {
let mut sorted_points: Vec<f64> = points.clone();
sorted_points.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
let mut filtered_points: Vec<f64> = Vec::new();
for &point in sorted_points.iter() {
if point > a
&& point < b
&& (filtered_points.is_empty()
|| (point - filtered_points.last().expect("Operation failed")).abs()
> f64::EPSILON)
{
filtered_points.push(point);
}
}
let mut intervals: Vec<(f64, f64)> = Vec::new();
if filtered_points.is_empty() {
intervals.push((a, b));
} else {
intervals.push((a, filtered_points[0]));
for i in 0..filtered_points.len() - 1 {
intervals.push((filtered_points[i], filtered_points[i + 1]));
}
intervals.push((*filtered_points.last().expect("Operation failed"), b));
}
intervals
} else {
vec![(a, b)]
};
let fval = f((intervals[0].0 + intervals[0].1) / 2.0);
let output_size = fval.len();
let mut subintervals = BinaryHeap::new();
let mut nfev = 1;
for (a_i, b_i) in intervals {
let (integral, error, evals) = evaluate_interval(&f, a_i, b_i, output_size, options.rule)?;
nfev += evals;
let error_norm = compute_norm(&error, options.norm);
subintervals.push(Subinterval {
a: a_i,
b: b_i,
integral,
error,
error_norm,
});
}
while subintervals.len() < options.limit {
let interval = match subintervals.pop() {
Some(i) => i,
None => break, };
let total_integral = get_total(&subintervals, &interval, |i| &i.integral);
let total_error = get_total(&subintervals, &interval, |i| &i.error);
let error_norm = compute_norm(&total_error, options.norm);
let abs_tol = options.epsabs;
let rel_tol = options.epsrel * compute_norm(&total_integral, options.norm);
if error_norm <= abs_tol || error_norm <= rel_tol {
subintervals.push(interval);
break;
}
let mid = (interval.a + interval.b) / 2.0;
let (left_integral, left_error, left_evals) =
evaluate_interval(&f, interval.a, mid, output_size, options.rule)?;
let (right_integral, right_error, right_evals) =
evaluate_interval(&f, mid, interval.b, output_size, options.rule)?;
nfev += left_evals + right_evals;
let left_error_norm = compute_norm(&left_error, options.norm);
let right_error_norm = compute_norm(&right_error, options.norm);
subintervals.push(Subinterval {
a: interval.a,
b: mid,
integral: left_integral,
error: left_error,
error_norm: left_error_norm,
});
subintervals.push(Subinterval {
a: mid,
b: interval.b,
integral: right_integral,
error: right_error,
error_norm: right_error_norm,
});
}
let interval_vec: Vec<Subinterval> = subintervals.into_vec();
let mut total_integral = Array1::zeros(output_size);
let mut total_error = Array1::zeros(output_size);
for interval in &interval_vec {
for (i, &val) in interval.integral.iter().enumerate() {
total_integral[i] += val;
}
for (i, &val) in interval.error.iter().enumerate() {
total_error[i] += val;
}
}
let error_norm = compute_norm(&total_error, options.norm);
let abs_tol = options.epsabs;
let rel_tol = options.epsrel * compute_norm(&total_integral, options.norm);
let success = error_norm <= abs_tol || error_norm <= rel_tol;
Ok(QuadVecResult {
integral: total_integral,
error: total_error,
nfev,
nintervals: interval_vec.len(),
success,
})
}
#[allow(dead_code)]
fn get_total<F, T>(heap: &BinaryHeap<Subinterval>, extra: &Subinterval, extract: F) -> Array1<T>
where
F: Fn(&Subinterval) -> &Array1<T>,
T: Clone + scirs2_core::numeric::Zero,
{
let mut result = extract(extra).clone();
for interval in heap.iter() {
let property = extract(interval);
for (i, val) in property.iter().enumerate() {
result[i] = result[i].clone() + val.clone();
}
}
result
}
#[allow(dead_code)]
fn evaluate_interval<F>(
f: &F,
a: f64,
b: f64,
output_size: usize,
rule: QuadRule,
) -> IntegrateResult<(Array1<f64>, Array1<f64>, usize)>
where
F: Fn(f64) -> Array1<f64>,
{
match rule {
QuadRule::GK15 => {
let nodes = [
-0.9914553711208126f64,
-0.9491079123427585,
-0.8648644233597691,
-0.7415311855993944,
-0.5860872354676911,
-0.4058451513773972,
-0.2077849550078985,
0.0,
0.2077849550078985,
0.4058451513773972,
0.5860872354676911,
0.7415311855993944,
0.8648644233597691,
0.9491079123427585,
0.9914553711208126,
];
let weights_k = [
0.022935322010529224f64,
0.063_092_092_629_978_56,
0.10479001032225018,
0.14065325971552592,
0.169_004_726_639_267_9,
0.190_350_578_064_785_4,
0.20443294007529889,
0.20948214108472782,
0.20443294007529889,
0.190_350_578_064_785_4,
0.169_004_726_639_267_9,
0.14065325971552592,
0.10479001032225018,
0.063_092_092_629_978_56,
0.022935322010529224,
];
let weights_g = [
0.129_484_966_168_869_7_f64,
0.27970539148927664,
0.381_830_050_505_118_9,
0.417_959_183_673_469_4,
0.381_830_050_505_118_9,
0.27970539148927664,
0.129_484_966_168_869_7,
];
evaluate_rule(f, a, b, output_size, &nodes, &weights_g, &weights_k)
}
QuadRule::GK21 => {
let nodes = [
-0.9956571630258081f64,
-0.9739065285171717,
-0.9301574913557082,
-0.8650633666889845,
-0.7808177265864169,
-0.6794095682990244,
-0.5627571346686047,
-0.4333953941292472,
-0.2943928627014602,
-0.1488743389816312,
0.0,
0.1488743389816312,
0.2943928627014602,
0.4333953941292472,
0.5627571346686047,
0.6794095682990244,
0.7808177265864169,
0.8650633666889845,
0.9301574913557082,
0.9739065285171717,
0.9956571630258081,
];
let weights_k = [
0.011694638867371874f64,
0.032558162307964725,
0.054755896574351995,
0.075_039_674_810_919_96,
0.093_125_454_583_697_6,
0.109_387_158_802_297_64,
0.123_491_976_262_065_84,
0.134_709_217_311_473_34,
0.142_775_938_577_060_09,
0.147_739_104_901_338_49,
0.149_445_554_002_916_9,
0.147_739_104_901_338_49,
0.142_775_938_577_060_09,
0.134_709_217_311_473_34,
0.123_491_976_262_065_84,
0.109_387_158_802_297_64,
0.093_125_454_583_697_6,
0.075_039_674_810_919_96,
0.054755896574351995,
0.032558162307964725,
0.011694638867371874,
];
let weights_g = [
0.066_671_344_308_688_14f64,
0.149_451_349_150_580_6,
0.219_086_362_515_982_04,
0.269_266_719_309_996_36,
0.295_524_224_714_752_9,
0.295_524_224_714_752_9,
0.269_266_719_309_996_36,
0.219_086_362_515_982_04,
0.149_451_349_150_580_6,
0.066_671_344_308_688_14,
];
evaluate_rule(f, a, b, output_size, &nodes, &weights_g, &weights_k)
}
QuadRule::Trapezoid => {
let n = 15;
let mut integral = Array1::zeros(output_size);
let mut error = Array1::zeros(output_size);
let h = (b - a) / (n as f64 - 1.0);
let fa = f(a);
let fb = f(b);
for (i, (&fa_i, &fb_i)) in fa.iter().zip(fb.iter()).enumerate() {
integral[i] = 0.5 * (fa_i + fb_i);
}
for j in 1..n - 1 {
let x = a + (j as f64) * h;
let fx = f(x);
for (i, &fx_i) in fx.iter().enumerate() {
integral[i] += fx_i;
}
}
for i in 0..output_size {
integral[i] *= h;
error[i] = 1e-2 * integral[i].abs();
}
Ok((integral, error, n))
}
}
}
#[allow(dead_code)]
fn evaluate_rule<F>(
f: &F,
a: f64,
b: f64,
output_size: usize,
nodes: &[f64],
weights_g: &[f64],
weights_k: &[f64],
) -> IntegrateResult<(Array1<f64>, Array1<f64>, usize)>
where
F: Fn(f64) -> Array1<f64>,
{
let _n = nodes.len();
let mut integral_k = Array1::zeros(output_size);
let mut integral_g = Array1::zeros(output_size);
let mid = (a + b) / 2.0;
let half_length = (b - a) / 2.0;
let mut nfev = 0;
let mut gauss_idx = 0;
for (i, &node) in nodes.iter().enumerate() {
let x = mid + half_length * node;
let fx = f(x);
nfev += 1;
for (j, &fx_j) in fx.iter().enumerate() {
integral_k[j] += weights_k[i] * fx_j;
}
if i % 2 == 1 && gauss_idx < weights_g.len() {
for (j, &fx_j) in fx.iter().enumerate() {
integral_g[j] += weights_g[gauss_idx] * fx_j;
}
gauss_idx += 1;
}
}
integral_k *= half_length;
integral_g *= half_length;
let mut error = Array1::zeros(output_size);
for i in 0..output_size {
let diff = (integral_k[i] - integral_g[i]).abs();
error[i] = (200.0 * diff).powf(1.5_f64);
}
Ok((integral_k, error, nfev))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::arr1;
#[test]
fn test_simple_integral() {
let f = |x: f64| arr1(&[x, x * x]);
let result = quad_vec(f, 0.0, 1.0, None).expect("Operation failed");
assert_abs_diff_eq!(result.integral[0], 0.5, epsilon = 1e-10);
assert_abs_diff_eq!(result.integral[1], 1.0 / 3.0, epsilon = 1e-10);
assert!(result.success);
}
#[test]
fn test_trig_functions() {
let f = |x: f64| arr1(&[x.sin(), x.cos()]);
let result = quad_vec(f, 0.0, PI, None).expect("Operation failed");
assert_abs_diff_eq!(result.integral[0], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(result.integral[1], 0.0, epsilon = 1e-10);
assert!(result.success);
}
#[test]
fn test_with_breakpoints() {
let f = |x: f64| arr1(&[x, x * x]);
let options = QuadVecOptions {
points: Some(vec![1.0]),
..Default::default()
};
let result = quad_vec(f, 0.0, 2.0, Some(options)).expect("Operation failed");
assert_abs_diff_eq!(result.integral[0], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(result.integral[1], 8.0 / 3.0, epsilon = 1e-10);
assert!(result.success);
}
#[test]
fn test_different_rules() {
let f = |x: f64| arr1(&[x.sin()]);
let options_gk15 = QuadVecOptions {
rule: QuadRule::GK15,
..Default::default()
};
let options_gk21 = QuadVecOptions {
rule: QuadRule::GK21,
..Default::default()
};
let options_trapezoid = QuadVecOptions {
rule: QuadRule::Trapezoid,
..Default::default()
};
let result_gk15 = quad_vec(f, 0.0, PI, Some(options_gk15)).expect("Operation failed");
let result_gk21 = quad_vec(f, 0.0, PI, Some(options_gk21)).expect("Operation failed");
let result_trapezoid =
quad_vec(f, 0.0, PI, Some(options_trapezoid)).expect("Operation failed");
assert_abs_diff_eq!(result_gk15.integral[0], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(result_gk21.integral[0], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(result_trapezoid.integral[0], 2.0, epsilon = 2e-3); }
#[test]
fn test_error_norms() {
let arr = arr1(&[1.0, -2.0, 0.5]);
let max_norm = compute_norm(&arr, NormType::Max);
assert_abs_diff_eq!(max_norm, 2.0, epsilon = 1e-10);
let l2_norm = compute_norm(&arr, NormType::L2);
assert_abs_diff_eq!(
l2_norm,
(1.0f64 * 1.0 + 2.0 * 2.0 + 0.5 * 0.5).sqrt(),
epsilon = 1e-10
);
}
}