use crate::adaptive;
use crate::error::QuadratureError;
use crate::gauss_legendre::GaussLegendre;
use crate::result::QuadratureResult;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, vec, vec::Vec};
#[cfg(not(feature = "std"))]
use num_traits::Float as _;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OscillatoryKernel {
Sine,
Cosine,
}
#[derive(Debug, Clone)]
pub struct OscillatoryIntegrator {
kernel: OscillatoryKernel,
omega: f64,
order: usize,
abs_tol: f64,
rel_tol: f64,
}
impl OscillatoryIntegrator {
#[must_use]
pub fn new(kernel: OscillatoryKernel, omega: f64) -> Self {
Self {
kernel,
omega,
order: 32,
abs_tol: 1.49e-8,
rel_tol: 1.49e-8,
}
}
#[must_use]
pub fn with_order(mut self, n: usize) -> Self {
self.order = n;
self
}
#[must_use]
pub fn with_abs_tol(mut self, tol: f64) -> Self {
self.abs_tol = tol;
self
}
#[must_use]
pub fn with_rel_tol(mut self, tol: f64) -> Self {
self.rel_tol = tol;
self
}
#[allow(clippy::many_single_char_names)] #[allow(clippy::similar_names)] #[allow(clippy::cast_precision_loss)] pub fn integrate<G>(
&self,
a: f64,
b: f64,
f: G,
) -> Result<QuadratureResult<f64>, QuadratureError>
where
G: Fn(f64) -> f64,
{
if !a.is_finite() || !b.is_finite() || !self.omega.is_finite() {
return Err(QuadratureError::DegenerateInterval);
}
if (b - a).abs() < f64::EPSILON {
return Ok(QuadratureResult {
value: 0.0,
error_estimate: 0.0,
num_evals: 0,
converged: true,
});
}
let half = 0.5 * (b - a);
let mid = 0.5 * (a + b);
let theta = self.omega * half;
if theta.abs() < 2.0 {
return self.adaptive_fallback(a, b, &f);
}
let n = self.order.max(4);
let f_vals: Vec<f64> = (0..=n)
.map(|k| {
let t = (k as f64 * core::f64::consts::PI / n as f64).cos();
f(mid + half * t)
})
.collect();
let num_evals = n + 1;
let cheb_coeffs = chebyshev_coefficients(&f_vals, n);
let (moments_cos, moments_sin) = modified_chebyshev_moments(theta, n);
let sum_c: f64 = cheb_coeffs
.iter()
.zip(moments_cos.iter())
.map(|(c, m)| c * m)
.sum();
let sum_s: f64 = cheb_coeffs
.iter()
.zip(moments_sin.iter())
.map(|(c, m)| c * m)
.sum();
let omega_mid = self.omega * mid;
let cos_mid = omega_mid.cos();
let sin_mid = omega_mid.sin();
let value = match self.kernel {
OscillatoryKernel::Cosine => half * (cos_mid * sum_c - sin_mid * sum_s),
OscillatoryKernel::Sine => half * (sin_mid * sum_c + cos_mid * sum_s),
};
let n_half = n / 2;
let sum_c_half: f64 = cheb_coeffs
.iter()
.take(n_half + 1)
.zip(moments_cos.iter())
.map(|(c, m)| c * m)
.sum();
let sum_s_half: f64 = cheb_coeffs
.iter()
.take(n_half + 1)
.zip(moments_sin.iter())
.map(|(c, m)| c * m)
.sum();
let value_half = match self.kernel {
OscillatoryKernel::Cosine => half * (cos_mid * sum_c_half - sin_mid * sum_s_half),
OscillatoryKernel::Sine => half * (sin_mid * sum_c_half + cos_mid * sum_s_half),
};
let error = (value - value_half).abs();
let converged = error <= self.abs_tol.max(self.rel_tol * value.abs());
Ok(QuadratureResult {
value,
error_estimate: error,
num_evals,
converged,
})
}
fn adaptive_fallback<G>(
&self,
a: f64,
b: f64,
f: &G,
) -> Result<QuadratureResult<f64>, QuadratureError>
where
G: Fn(f64) -> f64,
{
let omega = self.omega;
let integrand = match self.kernel {
OscillatoryKernel::Sine => {
Box::new(move |x: f64| f(x) * (omega * x).sin()) as Box<dyn Fn(f64) -> f64>
}
OscillatoryKernel::Cosine => Box::new(move |x: f64| f(x) * (omega * x).cos()),
};
adaptive::adaptive_integrate(&*integrand, a, b, self.abs_tol)
}
}
#[allow(clippy::cast_precision_loss)] fn chebyshev_coefficients(f_vals: &[f64], n: usize) -> Vec<f64> {
let mut coeffs = vec![0.0; n + 1];
let pi_n = core::f64::consts::PI / n as f64;
for (j, cj) in coeffs.iter_mut().enumerate() {
let mut sum = 0.0;
for (k, fk) in f_vals.iter().enumerate() {
let factor = if k == 0 || k == n { 0.5 } else { 1.0 };
sum += factor * fk * (j as f64 * k as f64 * pi_n).cos();
}
*cj = 2.0 * sum / n as f64;
}
coeffs[0] *= 0.5;
coeffs[n] *= 0.5;
coeffs
}
#[allow(clippy::cast_precision_loss)] #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] fn modified_chebyshev_moments(theta: f64, n: usize) -> (Vec<f64>, Vec<f64>) {
let mut mc = vec![0.0; n + 1];
let mut ms = vec![0.0; n + 1];
if theta.abs() < 1e-15 {
for (j, mcj) in mc.iter_mut().enumerate() {
if j == 0 {
*mcj = 2.0;
} else if j % 2 == 0 {
*mcj = 2.0 / (1.0 - (j as f64).powi(2));
}
}
return (mc, ms);
}
let m = (n + (theta.abs().ceil() as usize) + 32).max(64);
let gl = GaussLegendre::new(m).unwrap();
let rule = gl.rule();
for (node, weight) in rule.nodes.iter().zip(rule.weights.iter()) {
let x = *node;
let cos_tx = (theta * x).cos();
let sin_tx = (theta * x).sin();
let mut t_prev = 1.0;
let mut t_curr = x;
mc[0] += weight * cos_tx;
ms[0] += weight * sin_tx;
if n >= 1 {
mc[1] += weight * x * cos_tx;
ms[1] += weight * x * sin_tx;
}
for j in 2..=n {
let t_next = 2.0 * x * t_curr - t_prev;
mc[j] += weight * t_next * cos_tx;
ms[j] += weight * t_next * sin_tx;
t_prev = t_curr;
t_curr = t_next;
}
}
(mc, ms)
}
pub fn integrate_oscillatory_sin<G>(
f: G,
a: f64,
b: f64,
omega: f64,
tol: f64,
) -> Result<QuadratureResult<f64>, QuadratureError>
where
G: Fn(f64) -> f64,
{
OscillatoryIntegrator::new(OscillatoryKernel::Sine, omega)
.with_abs_tol(tol)
.with_rel_tol(tol)
.integrate(a, b, f)
}
pub fn integrate_oscillatory_cos<G>(
f: G,
a: f64,
b: f64,
omega: f64,
tol: f64,
) -> Result<QuadratureResult<f64>, QuadratureError>
where
G: Fn(f64) -> f64,
{
OscillatoryIntegrator::new(OscillatoryKernel::Cosine, omega)
.with_abs_tol(tol)
.with_rel_tol(tol)
.integrate(a, b, f)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sin_trivial() {
let result =
integrate_oscillatory_sin(|_| 1.0, 0.0, core::f64::consts::PI, 1.0, 1e-10).unwrap();
assert!((result.value - 2.0).abs() < 1e-6, "value={}", result.value);
}
#[test]
fn cos_moderate_omega() {
let exact = 10.0_f64.sin() / 10.0;
let result = integrate_oscillatory_cos(|_| 1.0, 0.0, 1.0, 10.0, 1e-10).unwrap();
assert!(
(result.value - exact).abs() < 1e-8,
"value={}, exact={exact}",
result.value
);
}
#[test]
fn sin_high_omega() {
let exact = (1.0 - 100.0_f64.cos()) / 100.0;
let result = integrate_oscillatory_sin(|_| 1.0, 0.0, 1.0, 100.0, 1e-10).unwrap();
assert!(
(result.value - exact).abs() < 1e-8,
"value={}, exact={exact}",
result.value
);
}
#[test]
fn cos_with_linear_f() {
let result =
integrate_oscillatory_cos(|x| x, 0.0, core::f64::consts::PI, 1.0, 1e-8).unwrap();
assert!(
(result.value - (-2.0)).abs() < 1e-4,
"value={}",
result.value
);
}
#[test]
fn sin_with_exp_f() {
let omega: f64 = 50.0;
let r = 1.0;
let i = omega;
let e = core::f64::consts::E;
let re_num = e * omega.cos() - 1.0;
let im_num = e * omega.sin();
let denom = r * r + i * i; let exact = (im_num * r - re_num * i) / denom;
let result = integrate_oscillatory_sin(f64::exp, 0.0, 1.0, omega, 1e-8).unwrap();
assert!(
(result.value - exact).abs() < 1e-4,
"value={}, exact={exact}",
result.value
);
}
#[test]
fn small_omega_fallback() {
let exact = 0.5_f64.sin() / 0.5;
let result = integrate_oscillatory_cos(|_| 1.0, 0.0, 1.0, 0.5, 1e-10).unwrap();
assert!(
(result.value - exact).abs() < 1e-8,
"value={}, exact={exact}",
result.value
);
}
#[test]
fn zero_interval() {
let result = integrate_oscillatory_sin(|_| 1.0, 1.0, 1.0, 10.0, 1e-10).unwrap();
assert_eq!(result.value, 0.0);
}
#[test]
fn nan_input() {
assert!(integrate_oscillatory_sin(|_| 1.0, f64::NAN, 1.0, 10.0, 1e-10).is_err());
}
#[test]
fn inf_inputs_rejected() {
assert!(integrate_oscillatory_sin(|_| 1.0, f64::INFINITY, 1.0, 10.0, 1e-10).is_err());
assert!(integrate_oscillatory_cos(|_| 1.0, 0.0, f64::NEG_INFINITY, 10.0, 1e-10).is_err());
assert!(integrate_oscillatory_sin(|_| 1.0, 0.0, 1.0, f64::INFINITY, 1e-10).is_err());
}
#[test]
fn very_high_omega() {
let exact = (1.0 - 1000.0_f64.cos()) / 1000.0;
let result = integrate_oscillatory_sin(|_| 1.0, 0.0, 1.0, 1000.0, 1e-8).unwrap();
assert!(
(result.value - exact).abs() < 1e-6,
"value={}, exact={exact}",
result.value
);
}
}