use crate::error::{IntegrateError, IntegrateResult};
use crate::IntegrateFloat;
use std::f64::consts::PI;
#[inline(always)]
fn to_f<F: IntegrateFloat>(value: f64) -> F {
F::from_f64(value).unwrap_or_else(|| F::zero())
}
#[derive(Debug, Clone)]
pub struct ClenshawCurtisRule<F: IntegrateFloat> {
pub nodes: Vec<F>,
pub weights: Vec<F>,
pub order: usize,
}
impl<F: IntegrateFloat> ClenshawCurtisRule<F> {
pub fn new(n: usize) -> IntegrateResult<Self> {
if n == 0 {
return Err(IntegrateError::ValueError(
"Clenshaw-Curtis order n must be >= 1".into(),
));
}
let n_f64 = n as f64;
let mut nodes = Vec::with_capacity(n + 1);
let mut weights = Vec::with_capacity(n + 1);
for j in 0..=n {
let theta = j as f64 * PI / n_f64;
nodes.push(to_f::<F>(theta.cos()));
}
let half_n = n / 2;
for j in 0..=n {
let c_j: f64 = if j == 0 || j == n { 1.0 } else { 2.0 };
let theta_j = j as f64 * PI / n_f64;
let mut s = 0.0_f64;
for k in 1..=half_n {
let b_k: f64 = if k < half_n || (!n.is_multiple_of(2) && k == half_n) {
2.0
} else {
1.0
};
let denom = (4 * k * k) as f64 - 1.0;
s += b_k / denom * (2.0 * k as f64 * theta_j).cos();
}
let w_j = c_j / n_f64 * (1.0 - s);
weights.push(to_f::<F>(w_j));
}
Ok(Self {
nodes,
weights,
order: n,
})
}
pub fn integrate<Func>(&self, f: &Func, a: F, b: F) -> IntegrateResult<F>
where
Func: Fn(F) -> F,
{
let half = to_f::<F>(0.5);
let mid = (a + b) * half;
let half_len = (b - a) * half;
let mut sum = F::zero();
for (node, &w) in self.nodes.iter().zip(self.weights.iter()) {
let x = mid + half_len * *node;
sum += w * f(x);
}
Ok(sum * half_len)
}
}
#[derive(Debug, Clone)]
pub struct ClenshawCurtisResult<F: IntegrateFloat> {
pub value: F,
pub error: F,
pub n_evals: usize,
pub max_level: usize,
pub converged: bool,
}
#[derive(Debug, Clone)]
pub struct ClenshawCurtisOptions<F: IntegrateFloat> {
pub atol: F,
pub rtol: F,
pub initial_order: usize,
pub max_evals: usize,
pub max_depth: usize,
}
impl<F: IntegrateFloat> Default for ClenshawCurtisOptions<F> {
fn default() -> Self {
Self {
atol: F::zero(),
rtol: to_f::<F>(1e-10),
initial_order: 8,
max_evals: 100_000,
max_depth: 30,
}
}
}
pub fn quad_cc<F, Func>(
f: Func,
a: F,
b: F,
options: Option<ClenshawCurtisOptions<F>>,
) -> IntegrateResult<ClenshawCurtisResult<F>>
where
F: IntegrateFloat,
Func: Fn(F) -> F,
{
let opts = options.unwrap_or_default();
if a >= b {
return Err(IntegrateError::ValueError(
"Lower bound must be strictly less than upper bound".into(),
));
}
let base_order = if opts.initial_order < 2 {
2
} else if !opts.initial_order.is_multiple_of(2) {
opts.initial_order + 1
} else {
opts.initial_order
};
let rule_lo = ClenshawCurtisRule::<F>::new(base_order)?;
let rule_hi = ClenshawCurtisRule::<F>::new(base_order * 2)?;
let mut total_evals: usize = 0;
struct Panel<F: IntegrateFloat> {
a: F,
b: F,
value: F,
error: F,
depth: usize,
}
let evaluate_panel =
|a_p: F, b_p: F, depth: usize, evals: &mut usize| -> IntegrateResult<Panel<F>> {
let val_lo = rule_lo.integrate(&f, a_p, b_p)?;
let val_hi = rule_hi.integrate(&f, a_p, b_p)?;
*evals += (rule_lo.order + 1) + (rule_hi.order + 1);
let err = (val_hi - val_lo).abs();
Ok(Panel {
a: a_p,
b: b_p,
value: val_hi,
error: err,
depth,
})
};
let initial = evaluate_panel(a, b, 0, &mut total_evals)?;
let mut panels: Vec<Panel<F>> = vec![initial];
let mut global_value = panels[0].value;
let mut global_error = panels[0].error;
let mut max_level: usize = 0;
loop {
let tol = opts.atol + opts.rtol * global_value.abs();
if global_error <= tol {
return Ok(ClenshawCurtisResult {
value: global_value,
error: global_error,
n_evals: total_evals,
max_level,
converged: true,
});
}
if total_evals >= opts.max_evals {
return Ok(ClenshawCurtisResult {
value: global_value,
error: global_error,
n_evals: total_evals,
max_level,
converged: false,
});
}
let worst_idx = panels
.iter()
.enumerate()
.max_by(|(_, pa), (_, pb)| {
pa.error
.partial_cmp(&pb.error)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
let worst = panels.swap_remove(worst_idx);
if worst.depth >= opts.max_depth {
global_value = global_value - worst.value + worst.value; panels.push(worst);
return Ok(ClenshawCurtisResult {
value: global_value,
error: global_error,
n_evals: total_evals,
max_level,
converged: false,
});
}
let mid = (worst.a + worst.b) * to_f::<F>(0.5);
let left = evaluate_panel(worst.a, mid, worst.depth + 1, &mut total_evals)?;
let right = evaluate_panel(mid, worst.b, worst.depth + 1, &mut total_evals)?;
global_value = global_value - worst.value + left.value + right.value;
global_error = global_error - worst.error + left.error + right.error;
if left.depth > max_level {
max_level = left.depth;
}
if right.depth > max_level {
max_level = right.depth;
}
panels.push(left);
panels.push(right);
}
}
pub fn quad_cc_tol<F, Func>(f: Func, a: F, b: F, tol: F) -> IntegrateResult<(F, F)>
where
F: IntegrateFloat,
Func: Fn(F) -> F,
{
let opts = ClenshawCurtisOptions {
atol: tol,
rtol: F::zero(),
..Default::default()
};
let res = quad_cc(f, a, b, Some(opts))?;
Ok((res.value, res.error))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cc_rule_constant() {
let rule = ClenshawCurtisRule::<f64>::new(4).expect("rule creation");
let val = rule
.integrate(&|_x: f64| 1.0, -1.0, 1.0)
.expect("integrate");
assert!(
(val - 2.0).abs() < 1e-14,
"constant function integral: got {val}"
);
}
#[test]
fn test_cc_rule_polynomial() {
let rule = ClenshawCurtisRule::<f64>::new(4).expect("rule creation");
let val = rule
.integrate(&|x: f64| x.powi(4), -1.0, 1.0)
.expect("integrate");
assert!(
(val - 0.4).abs() < 1e-12,
"x^4 integral: got {val}, expected 0.4"
);
}
#[test]
fn test_cc_rule_mapped_interval() {
let rule = ClenshawCurtisRule::<f64>::new(6).expect("rule creation");
let val = rule
.integrate(&|x: f64| x * x, 0.0, 1.0)
.expect("integrate");
assert!((val - 1.0 / 3.0).abs() < 1e-12, "x^2 on [0,1]: got {val}");
}
#[test]
fn test_quad_cc_sin() {
let res = quad_cc(|x: f64| x.sin(), 0.0, PI, None).expect("quad_cc");
assert!(res.converged, "should converge");
assert!(
(res.value - 2.0).abs() < 1e-10,
"sin integral: got {}",
res.value
);
}
#[test]
fn test_quad_cc_exp() {
let exact = std::f64::consts::E - 1.0;
let res = quad_cc(|x: f64| x.exp(), 0.0, 1.0, None).expect("quad_cc");
assert!(res.converged, "should converge");
assert!(
(res.value - exact).abs() < 1e-10,
"exp integral: got {}, expected {}",
res.value,
exact
);
}
#[test]
fn test_quad_cc_oscillatory() {
let res = quad_cc(
|x: f64| (50.0 * x).cos(),
0.0,
PI,
Some(ClenshawCurtisOptions {
rtol: to_f(1e-8),
max_evals: 500_000,
..Default::default()
}),
)
.expect("quad_cc");
assert!(
res.value.abs() < 1e-6,
"oscillatory integral should be ~0, got {}",
res.value
);
}
#[test]
fn test_quad_cc_tol_wrapper() {
let (val, _err) = quad_cc_tol(|x: f64| x.powi(3), 0.0, 1.0, 1e-12).expect("quad_cc_tol");
assert!(
(val - 0.25).abs() < 1e-11,
"x^3 integral: got {val}, expected 0.25"
);
}
#[test]
fn test_quad_cc_peaked_function() {
let exact = 2.0 * 5.0_f64.atan() / 5.0;
let res = quad_cc(
|x: f64| 1.0 / (1.0 + 25.0 * x * x),
-1.0,
1.0,
Some(ClenshawCurtisOptions {
rtol: to_f(1e-10),
..Default::default()
}),
)
.expect("quad_cc");
assert!(
(res.value - exact).abs() < 1e-8,
"Runge function integral: got {}, expected {}",
res.value,
exact
);
}
#[test]
fn test_quad_cc_invalid_bounds() {
let res = quad_cc(|x: f64| x, 1.0, 0.0, None);
assert!(res.is_err(), "should fail for a >= b");
}
#[test]
fn test_cc_rule_invalid_order() {
let res = ClenshawCurtisRule::<f64>::new(0);
assert!(res.is_err(), "order 0 should be invalid");
}
}