use crate::DType;
use numr::error::{Error, Result};
use numr::ops::TensorOps;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::integrate::{QuadOptions, QuadResult};
pub fn quad_impl<R, C, F>(
client: &C,
f: F,
a: f64,
b: f64,
options: &QuadOptions,
) -> Result<QuadResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<Tensor<R>>,
{
if a >= b {
return Err(Error::InvalidArgument {
arg: "a/b",
reason: format!("quad: invalid interval [{}, {}]", a, b),
});
}
if options.limit == 0 {
return Err(Error::InvalidArgument {
arg: "limit",
reason: "quad: limit must be at least 1".to_string(),
});
}
let mut intervals: Vec<(f64, f64, f64, f64)> = Vec::new();
let mut neval = 0;
let (integral, error, evals) = gauss_kronrod_15(client, &f, a, b)?;
neval += evals;
intervals.push((a, b, integral, error));
let mut total_integral = integral;
let mut total_error = error;
let mut subdivisions = 0;
while subdivisions < options.limit {
let tolerance = options.atol + options.rtol * total_integral.abs();
if total_error <= tolerance {
return Ok(QuadResult {
integral: Tensor::<R>::from_slice(&[total_integral], &[], client.device()),
error: total_error,
neval,
converged: true,
});
}
let max_idx = match intervals
.iter()
.enumerate()
.max_by(|a, b| {
a.1.3
.partial_cmp(&b.1.3)
.unwrap_or(std::cmp::Ordering::Less)
})
.map(|(i, _)| i)
{
Some(idx) => idx,
None => {
return Err(Error::InvalidArgument {
arg: "intervals",
reason: "quad: internal error - no intervals available".to_string(),
});
}
};
let (ia, ib, old_integral, old_error) = intervals.swap_remove(max_idx);
let mid = (ia + ib) / 2.0;
let (int1, err1, evals1) = gauss_kronrod_15(client, &f, ia, mid)?;
let (int2, err2, evals2) = gauss_kronrod_15(client, &f, mid, ib)?;
neval += evals1 + evals2;
total_integral = total_integral - old_integral + int1 + int2;
total_error = total_error - old_error + err1 + err2;
intervals.push((ia, mid, int1, err1));
intervals.push((mid, ib, int2, err2));
subdivisions += 1;
}
Ok(QuadResult {
integral: Tensor::<R>::from_slice(&[total_integral], &[], client.device()),
error: total_error,
neval,
converged: false,
})
}
fn gauss_kronrod_15<R, C, F>(client: &C, f: &F, a: f64, b: f64) -> Result<(f64, f64, usize)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<Tensor<R>>,
{
const XGK: [f64; 15] = [
-0.9914553711208126,
-0.9491079123427585,
-0.8648644233597691,
-0.7415311855993944,
-0.5860872354676911,
-0.4058451513773972,
-0.2077849550078985,
0.0,
0.2077849550078985,
0.4058451513773972,
0.5860872354676911,
0.7415311855993944,
0.8648644233597691,
0.9491079123427585,
0.9914553711208126,
];
const WGK: [f64; 15] = [
0.022935322010529224,
0.063_092_092_629_978_56,
0.10479001032225018,
0.14065325971552592,
0.169_004_726_639_267_9,
0.190_350_578_064_785_4,
0.20443294007529889,
0.20948214108472782,
0.20443294007529889,
0.190_350_578_064_785_4,
0.169_004_726_639_267_9,
0.14065325971552592,
0.10479001032225018,
0.063_092_092_629_978_56,
0.022935322010529224,
];
const WG: [f64; 7] = [
0.129_484_966_168_869_7,
0.27970539148927664,
0.381_830_050_505_118_9,
0.417_959_183_673_469_4,
0.381_830_050_505_118_9,
0.27970539148927664,
0.129_484_966_168_869_7,
];
let mid = (a + b) / 2.0;
let half_width = (b - a) / 2.0;
let eval_points: Vec<f64> = XGK.iter().map(|&x| mid + half_width * x).collect();
let x_tensor = Tensor::<R>::from_slice(&eval_points, &[15], client.device());
let f_values = f(&x_tensor)?;
let fvals: Vec<f64> = f_values.to_vec();
let mut result_kronrod = 0.0;
for (i, &fval) in fvals.iter().enumerate() {
result_kronrod += WGK[i] * fval;
}
result_kronrod *= half_width;
let mut result_gauss = 0.0;
for (i, &w) in WG.iter().enumerate() {
result_gauss += w * fvals[2 * i + 1];
}
result_gauss *= half_width;
let error = (result_kronrod - result_gauss).abs();
Ok((result_kronrod, error, 15))
}