gauss-quad 0.3.1

Integrate functions with Gaussian quadrature
Documentation
// Copyright 2019-2024 Dominique Dresen
// Copyright 2023-2026 Johanna Sörngård
// SPDX-License-Identifier: MIT OR Apache-2.0

//! Numerical integration using the Gauss-Hermite quadrature rule.
//!
//! This rule can integrate integrands of the form  
//! e^(-x^2) * f(x)  
//! over the domain (-∞, ∞).
//!
//! # Example
//!
//! Integrate x^2 * e^(-x^2):
//!
//! ```
//! use gauss_quad::hermite::GaussHermite;
//! use approx::assert_abs_diff_eq;
//!
//! let quad = GaussHermite::new(10.try_into().unwrap());
//!
//! let integral = quad.integrate(|x| x.powi(2));
//!
//! assert_abs_diff_eq!(integral, core::f64::consts::PI.sqrt() / 2.0, epsilon = 1e-14);
//! ```

#[cfg(feature = "rayon")]
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};

use crate::{__impl_node_weight_rule, Node, Weight, golub_welsch::golub_welsch, math::sqrt};

use alloc::boxed::Box;
use core::{f64::consts::PI, num::NonZeroUsize};

/// A Gauss-Hermite quadrature scheme.
///
/// These rules can integrate integrands of the form e^(-x^2) * f(x) over the domain (-∞, ∞).
///
/// # Example
///
/// Integrate e^(-x^2) * cos(x):
///
/// ```
/// # use gauss_quad::hermite::GaussHermite;
/// # use approx::assert_abs_diff_eq;
/// # use core::f64::consts::{E, PI};
/// // initialize a Gauss-Hermite rule with 20 nodes
/// let quad = GaussHermite::new(20.try_into().unwrap());
///
/// // numerically integrate a function over (-∞, ∞) using the Gauss-Hermite rule
/// let integral = quad.integrate(|x| x.cos());
///
/// assert_abs_diff_eq!(integral, PI.sqrt() / E.powf(0.25), epsilon = 1e-14);
/// ```
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
    feature = "rkyv",
    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
#[cfg_attr(feature = "zerocopy", derive(zerocopy::KnownLayout))]
pub struct GaussHermite {
    node_weight_pairs: Box<[(Node, Weight)]>,
}

impl GaussHermite {
    /// Initializes Gauss-Hermite quadrature rule of the given degree by computing the needed nodes and weights.
    ///
    /// A rule of degree n can integrate polynomials of degree 2n-1 exactly.
    ///
    /// Uses the Golub-Welsch algorithm.
    pub fn new(deg: NonZeroUsize) -> Self {
        let node_weight_pairs =
            golub_welsch(deg, |_| 0.0, |idx| sqrt((1.0 + idx as f64) * 0.5), sqrt(PI));

        GaussHermite { node_weight_pairs }
    }

    /// Perform quadrature of e^(-x^2) * `integrand`(x) over the domain (-∞, ∞).
    pub fn integrate<F>(&self, mut integrand: F) -> f64
    where
        F: FnMut(f64) -> f64,
    {
        let result: f64 = self
            .node_weight_pairs
            .iter()
            .map(|(x_val, w_val)| integrand(*x_val) * w_val)
            .sum();
        result
    }

    #[cfg(feature = "rayon")]
    /// Same as [`integrate`](GaussHermite::integrate) but runs in parallel.
    pub fn par_integrate<F>(&self, integrand: F) -> f64
    where
        F: Fn(f64) -> f64 + Sync,
    {
        let result: f64 = self
            .node_weight_pairs
            .par_iter()
            .map(|(x_val, w_val)| integrand(*x_val) * w_val)
            .sum();
        result
    }
}

__impl_node_weight_rule! {GaussHermite, GaussHermiteNodes, GaussHermiteWeights, GaussHermiteIter, GaussHermiteIntoIter}

#[cfg(test)]
mod tests {
    use approx::assert_abs_diff_eq;

    use super::*;

    #[test]
    fn check_sorted() {
        for deg in (2..100).step_by(10) {
            let rule = GaussHermite::new(deg.try_into().unwrap());
            assert!(rule.as_node_weight_pairs().is_sorted());
        }
    }

    #[test]
    fn check_degree_1() {
        let rule = GaussHermite::new(1.try_into().unwrap());
        assert_eq!(rule.as_node_weight_pairs(), &[(0.0, sqrt(PI))]);
        for constant in (1..100).step_by(10) {
            assert_abs_diff_eq!(rule.integrate(|x| f64::from(constant) * x), 0.0);
        }
    }

    #[test]
    fn golub_welsch_3() {
        let rule = GaussHermite::new(3.try_into().unwrap());
        let x_should = [-1.224_744_871_391_589, 0.0, 1.224_744_871_391_589];
        let w_should = [
            0.295_408_975_150_919_35,
            1.181_635_900_603_677_4,
            0.295_408_975_150_919_35,
        ];
        for ((correct_node, correct_weight), (computed_node, computed_weight)) in
            x_should.into_iter().zip(w_should).zip(rule)
        {
            assert_abs_diff_eq!(correct_node, computed_node, epsilon = 1e-15);
            assert_abs_diff_eq!(correct_weight, computed_weight, epsilon = 1e-15);
        }
    }

    #[test]
    fn check_derives() {
        let quad = GaussHermite::new(10.try_into().unwrap());
        let quad_clone = quad.clone();
        assert_eq!(quad, quad_clone);
        let other_quad = GaussHermite::new(3.try_into().unwrap());
        assert_ne!(quad, other_quad);
    }

    #[test]
    fn check_iterators() {
        let rule = GaussHermite::new(3.try_into().unwrap());
        let ans = sqrt(PI) / 2.0;

        assert_abs_diff_eq!(
            ans,
            rule.iter().fold(0.0, |tot, (n, w)| tot + n * n * w),
            epsilon = 1e-14
        );

        assert_abs_diff_eq!(
            ans,
            rule.nodes()
                .zip(rule.weights())
                .fold(0.0, |tot, (n, w)| tot + n * n * w),
            epsilon = 1e-14
        );

        assert_abs_diff_eq!(
            ans,
            rule.into_iter().fold(0.0, |tot, (n, w)| tot + n * n * w),
            epsilon = 1e-14
        );
    }

    #[test]
    fn integrate_one() {
        let quad = GaussHermite::new(5.try_into().unwrap());
        let integral = quad.integrate(|_x| 1.0);
        assert_abs_diff_eq!(integral, sqrt(PI), epsilon = 1e-14);
    }

    #[cfg(feature = "rayon")]
    #[test]
    fn par_integrate_one() {
        let quad = GaussHermite::new(5.try_into().unwrap());
        let integral = quad.par_integrate(|_x| 1.0);
        assert_abs_diff_eq!(integral, sqrt(PI), epsilon = 1e-15);
    }
}