#[cfg(feature = "rayon")]
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
use crate::{__impl_node_weight_rule, Node, Weight, golub_welsch::golub_welsch, math::sqrt};
use alloc::boxed::Box;
use core::{f64::consts::PI, num::NonZeroUsize};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
#[cfg_attr(feature = "zerocopy", derive(zerocopy::KnownLayout))]
pub struct GaussHermite {
node_weight_pairs: Box<[(Node, Weight)]>,
}
impl GaussHermite {
pub fn new(deg: NonZeroUsize) -> Self {
let node_weight_pairs =
golub_welsch(deg, |_| 0.0, |idx| sqrt((1.0 + idx as f64) * 0.5), sqrt(PI));
GaussHermite { node_weight_pairs }
}
pub fn integrate<F>(&self, mut integrand: F) -> f64
where
F: FnMut(f64) -> f64,
{
let result: f64 = self
.node_weight_pairs
.iter()
.map(|(x_val, w_val)| integrand(*x_val) * w_val)
.sum();
result
}
#[cfg(feature = "rayon")]
pub fn par_integrate<F>(&self, integrand: F) -> f64
where
F: Fn(f64) -> f64 + Sync,
{
let result: f64 = self
.node_weight_pairs
.par_iter()
.map(|(x_val, w_val)| integrand(*x_val) * w_val)
.sum();
result
}
}
__impl_node_weight_rule! {GaussHermite, GaussHermiteNodes, GaussHermiteWeights, GaussHermiteIter, GaussHermiteIntoIter}
#[cfg(test)]
mod tests {
use approx::assert_abs_diff_eq;
use super::*;
#[test]
fn check_sorted() {
for deg in (2..100).step_by(10) {
let rule = GaussHermite::new(deg.try_into().unwrap());
assert!(rule.as_node_weight_pairs().is_sorted());
}
}
#[test]
fn check_degree_1() {
let rule = GaussHermite::new(1.try_into().unwrap());
assert_eq!(rule.as_node_weight_pairs(), &[(0.0, sqrt(PI))]);
for constant in (1..100).step_by(10) {
assert_abs_diff_eq!(rule.integrate(|x| f64::from(constant) * x), 0.0);
}
}
#[test]
fn golub_welsch_3() {
let rule = GaussHermite::new(3.try_into().unwrap());
let x_should = [-1.224_744_871_391_589, 0.0, 1.224_744_871_391_589];
let w_should = [
0.295_408_975_150_919_35,
1.181_635_900_603_677_4,
0.295_408_975_150_919_35,
];
for ((correct_node, correct_weight), (computed_node, computed_weight)) in
x_should.into_iter().zip(w_should).zip(rule)
{
assert_abs_diff_eq!(correct_node, computed_node, epsilon = 1e-15);
assert_abs_diff_eq!(correct_weight, computed_weight, epsilon = 1e-15);
}
}
#[test]
fn check_derives() {
let quad = GaussHermite::new(10.try_into().unwrap());
let quad_clone = quad.clone();
assert_eq!(quad, quad_clone);
let other_quad = GaussHermite::new(3.try_into().unwrap());
assert_ne!(quad, other_quad);
}
#[test]
fn check_iterators() {
let rule = GaussHermite::new(3.try_into().unwrap());
let ans = sqrt(PI) / 2.0;
assert_abs_diff_eq!(
ans,
rule.iter().fold(0.0, |tot, (n, w)| tot + n * n * w),
epsilon = 1e-14
);
assert_abs_diff_eq!(
ans,
rule.nodes()
.zip(rule.weights())
.fold(0.0, |tot, (n, w)| tot + n * n * w),
epsilon = 1e-14
);
assert_abs_diff_eq!(
ans,
rule.into_iter().fold(0.0, |tot, (n, w)| tot + n * n * w),
epsilon = 1e-14
);
}
#[test]
fn integrate_one() {
let quad = GaussHermite::new(5.try_into().unwrap());
let integral = quad.integrate(|_x| 1.0);
assert_abs_diff_eq!(integral, sqrt(PI), epsilon = 1e-14);
}
#[cfg(feature = "rayon")]
#[test]
fn par_integrate_one() {
let quad = GaussHermite::new(5.try_into().unwrap());
let integral = quad.par_integrate(|_x| 1.0);
assert_abs_diff_eq!(integral, sqrt(PI), epsilon = 1e-15);
}
}