#[cfg(feature = "rayon")]
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
use crate::{
__impl_node_weight_rule, FiniteAboveNegOneF64, Node, Weight,
math::{gamma, sqrt},
};
use crate::golub_welsch::golub_welsch;
use alloc::boxed::Box;
use core::num::NonZeroUsize;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
#[cfg_attr(feature = "zerocopy", derive(zerocopy::KnownLayout))]
pub struct GaussLaguerre {
node_weight_pairs: Box<[(Node, Weight)]>,
alpha: FiniteAboveNegOneF64,
}
impl GaussLaguerre {
pub fn new(degree: NonZeroUsize, alpha: FiniteAboveNegOneF64) -> Self {
let node_weight_pairs = golub_welsch(
degree,
|idx| alpha.get() + 1.0 + 2.0 * (idx as f64),
|idx| {
let idx_f64_p1 = 1.0 + (idx as f64);
sqrt(idx_f64_p1 * (idx_f64_p1 + alpha.get()))
},
gamma(alpha.get() + 1.0),
);
GaussLaguerre {
node_weight_pairs,
alpha,
}
}
pub fn integrate<F>(&self, mut integrand: F) -> f64
where
F: FnMut(f64) -> f64,
{
let result: f64 = self
.node_weight_pairs
.iter()
.map(|(x_val, w_val)| integrand(*x_val) * w_val)
.sum();
result
}
#[cfg(feature = "rayon")]
pub fn par_integrate<F>(&self, integrand: F) -> f64
where
F: Fn(f64) -> f64 + Sync,
{
let result: f64 = self
.node_weight_pairs
.par_iter()
.map(|(x_val, w_val)| integrand(*x_val) * w_val)
.sum();
result
}
#[inline]
pub const fn alpha(&self) -> FiniteAboveNegOneF64 {
self.alpha
}
}
__impl_node_weight_rule! {GaussLaguerre, GaussLaguerreNodes, GaussLaguerreWeights, GaussLaguerreIter, GaussLaguerreIntoIter}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use core::f64::consts::PI;
#[test]
fn check_sorted() {
for deg in (2..100).step_by(10) {
for alpha in [-0.9, -0.5, 0.0, 0.5] {
let rule = GaussLaguerre::new(deg.try_into().unwrap(), alpha.try_into().unwrap());
assert!(rule.as_node_weight_pairs().is_sorted());
}
}
}
#[test]
fn check_degree_1() {
let rule = GaussLaguerre::new(1.try_into().unwrap(), 0.0.try_into().unwrap());
for constant in (0..100).step_by(10).map(f64::from) {
assert_abs_diff_eq!(rule.integrate(|x| constant * x), constant, epsilon = 1e-13);
}
}
#[test]
fn sanity_check_alpha_accessor() {
let alpha = FiniteAboveNegOneF64::new(1.0).unwrap();
let rule = GaussLaguerre::new(10.try_into().unwrap(), alpha);
assert_eq!(rule.alpha(), alpha);
}
#[test]
fn golub_welsch_2_alpha_5() {
let rule = GaussLaguerre::new(2.try_into().unwrap(), 5.0.try_into().unwrap());
let x_should = [4.354_248_688_935_409, 9.645_751_311_064_59];
let w_should = [82.677_868_380_553_63, 37.322_131_619_446_37];
for ((correct_node, correct_weight), (computed_node, computed_weight)) in
x_should.into_iter().zip(w_should).zip(rule)
{
assert_abs_diff_eq!(correct_node, computed_node, epsilon = 1e-12);
assert_abs_diff_eq!(correct_weight, computed_weight, epsilon = 1e-12);
}
}
#[test]
fn golub_welsch_3_alpha_0() {
let rule = GaussLaguerre::new(3.try_into().unwrap(), 0.0.try_into().unwrap());
let x_should = [
0.415_774_556_783_479_1,
2.294_280_360_279_042,
6.289_945_082_937_479_4,
];
let w_should = [
0.711_093_009_929_173,
0.278_517_733_569_240_87,
0.010_389_256_501_586_135,
];
for ((correct_node, correct_weight), (computed_node, computed_weight)) in
x_should.into_iter().zip(w_should).zip(rule)
{
assert_abs_diff_eq!(correct_node, computed_node, epsilon = 1e-14);
assert_abs_diff_eq!(correct_weight, computed_weight, epsilon = 1e-14);
}
}
#[test]
fn golub_welsch_3_alpha_1_5() {
let rule = GaussLaguerre::new(3.try_into().unwrap(), 1.5.try_into().unwrap());
let x_should = [
1.220_402_317_558_883_8,
3.808_880_721_467_068,
8.470_716_960_974_048,
];
let w_should = [
0.730_637_894_350_016,
0.566_249_100_686_605_7,
0.032_453_393_142_515_25,
];
for ((correct_node, correct_weight), (computed_node, computed_weight)) in
x_should.into_iter().zip(w_should).zip(rule)
{
assert_abs_diff_eq!(correct_node, computed_node, epsilon = 1e-14);
assert_abs_diff_eq!(correct_weight, computed_weight, epsilon = 1e-14);
}
}
#[test]
fn golub_welsch_5_alpha_negative() {
let rule = GaussLaguerre::new(5.try_into().unwrap(), (-0.9).try_into().unwrap());
let x_should = [
0.020_777_151_319_288_104,
0.808_997_536_134_602_1,
2.674_900_020_624_07,
5.869_026_089_963_398,
11.126_299_201_958_641,
];
let w_should = [
8.738_289_241_242_436,
0.702_782_353_089_744_5,
0.070_111_720_632_849_48,
0.002_312_760_116_115_564,
1.162_358_758_613_074_8E-5,
];
for ((correct_node, correct_weight), (computed_node, computed_weight)) in
x_should.into_iter().zip(w_should).zip(rule)
{
assert_abs_diff_eq!(correct_node, computed_node, epsilon = 1e-14);
assert_abs_diff_eq!(correct_weight, computed_weight, epsilon = 1e-14);
}
}
#[test]
fn check_derives() {
let quad = GaussLaguerre::new(10.try_into().unwrap(), 1.0.try_into().unwrap());
let quad_clone = quad.clone();
assert_eq!(quad, quad_clone);
let other_quad = GaussLaguerre::new(10.try_into().unwrap(), 2.0.try_into().unwrap());
assert_ne!(quad, other_quad);
}
#[test]
fn check_iterators() {
let rule = GaussLaguerre::new(3.try_into().unwrap(), 0.5.try_into().unwrap());
let ans = 15.0 / 8.0 * core::f64::consts::PI.sqrt();
assert_abs_diff_eq!(
rule.iter().fold(0.0, |tot, (n, w)| tot + n * n * w),
ans,
epsilon = 1e-14
);
assert_abs_diff_eq!(
rule.nodes()
.zip(rule.weights())
.fold(0.0, |tot, (n, w)| tot + n * n * w),
ans,
epsilon = 1e-14
);
assert_abs_diff_eq!(
rule.into_iter().fold(0.0, |tot, (n, w)| tot + n * n * w),
ans,
epsilon = 1e-14
);
}
#[test]
fn check_some_integrals() {
let rule = GaussLaguerre::new(10.try_into().unwrap(), (-0.5).try_into().unwrap());
assert_abs_diff_eq!(
rule.integrate(|x| x * x),
3.0 * PI.sqrt() / 4.0,
epsilon = 1e-14
);
assert_abs_diff_eq!(
rule.integrate(|x| x.sin()),
(PI.sqrt() * (PI / 8.0).sin()) / (2.0_f64.powf(0.25)),
epsilon = 1e-7,
);
}
#[cfg(feature = "rayon")]
#[test]
fn par_check_some_integrals() {
let rule = GaussLaguerre::new(10.try_into().unwrap(), (-0.5).try_into().unwrap());
assert_abs_diff_eq!(
rule.par_integrate(|x| x * x),
3.0 * PI.sqrt() / 4.0,
epsilon = 1e-14
);
assert_abs_diff_eq!(
rule.par_integrate(|x| x.sin()),
(PI.sqrt() * (PI / 8.0).sin()) / (2.0_f64.powf(0.25)),
epsilon = 1e-7,
);
}
}