gauss-quad 0.3.0

Library for applying Gaussian quadrature to integrate a function
Documentation
// Copyright 2019-2024 Dominique Dresen
// Copyright 2023-2026 Johanna Sörngård
// SPDX-License-Identifier: MIT OR Apache-2.0

//! Numerical integration using the Gauss-Hermite quadrature rule.
//!
//! This rule can integrate integrands of the form  
//! e^(-x^2) * f(x)  
//! over the domain (-∞, ∞).
//!
//! # Example
//!
//! Integrate x^2 * e^(-x^2):
//!
//! ```
//! use gauss_quad::hermite::GaussHermite;
//! use approx::assert_abs_diff_eq;
//!
//! let quad = GaussHermite::new(10.try_into().unwrap());
//!
//! let integral = quad.integrate(|x| x.powi(2));
//!
//! assert_abs_diff_eq!(integral, core::f64::consts::PI.sqrt() / 2.0, epsilon = 1e-14);
//! ```

#[cfg(feature = "rayon")]
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};

use crate::{__impl_node_weight_rule, DMatrixf64, Node, Weight, math::sqrt};

use alloc::boxed::Box;
use core::{f64::consts::PI, num::NonZeroUsize};

/// A Gauss-Hermite quadrature scheme.
///
/// These rules can integrate integrands of the form e^(-x^2) * f(x) over the domain (-∞, ∞).
///
/// # Example
///
/// Integrate e^(-x^2) * cos(x):
///
/// ```
/// # use gauss_quad::hermite::GaussHermite;
/// # use approx::assert_abs_diff_eq;
/// # use core::f64::consts::{E, PI};
/// // initialize a Gauss-Hermite rule with 20 nodes
/// let quad = GaussHermite::new(20.try_into().unwrap());
///
/// // numerically integrate a function over (-∞, ∞) using the Gauss-Hermite rule
/// let integral = quad.integrate(|x| x.cos());
///
/// assert_abs_diff_eq!(integral, PI.sqrt() / E.powf(0.25), epsilon = 1e-14);
/// ```
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct GaussHermite {
    node_weight_pairs: Box<[(Node, Weight)]>,
}

impl GaussHermite {
    /// Initializes Gauss-Hermite quadrature rule of the given degree by computing the needed nodes and weights.
    ///
    /// A rule of degree n can integrate polynomials of degree 2n-1 exactly.
    ///
    /// Applies the Golub-Welsch algorithm to determine Gauss-Hermite nodes & weights.
    /// Constructs the companion matrix A for the Hermite Polynomial using the relation:
    /// 1/2 H_{n+1} + n H_{n-1} = x H_n
    /// A similar matrix that is symmetrized is constructed by D A D^{-1}
    /// Resulting in a symmetric tridiagonal matrix with
    /// 0 on the diagonal & sqrt(n/2) on the off-diagonal
    /// root & weight finding are equivalent to eigenvalue problem
    /// see Gil, Segura, Temme - Numerical Methods for Special Functions
    pub fn new(deg: NonZeroUsize) -> Self {
        let mut companion_matrix = DMatrixf64::from_element(deg.get(), deg.get(), 0.0);

        // Initialize symmetric companion matrix
        for idx in 0..deg.get() - 1 {
            let idx_f64 = 1.0 + idx as f64;
            let element = sqrt(idx_f64 * 0.5);
            companion_matrix[(idx, idx + 1)] = element;
            companion_matrix[(idx + 1, idx)] = element;
        }
        // calculate eigenvalues & vectors
        let eigen = companion_matrix.symmetric_eigen();

        // zip together the iterator over nodes with the one over weights and collect into a Box<[(f64, f64)]>
        let mut node_weight_pairs: Box<[(Node, Weight)]> = eigen
            .eigenvalues
            .iter()
            .copied()
            .zip(
                eigen
                    .eigenvectors
                    .row(0)
                    .map(|x| x * x * sqrt(PI))
                    .iter()
                    .copied(),
            )
            .collect();

        // sort the nodes and weights by the nodes
        node_weight_pairs.sort_unstable_by(|(node1, _), (node2, _)| node1.total_cmp(node2));

        GaussHermite { node_weight_pairs }
    }

    /// Perform quadrature of e^(-x^2) * `integrand`(x) over the domain (-∞, ∞).
    pub fn integrate<F>(&self, mut integrand: F) -> f64
    where
        F: FnMut(f64) -> f64,
    {
        let result: f64 = self
            .node_weight_pairs
            .iter()
            .map(|(x_val, w_val)| integrand(*x_val) * w_val)
            .sum();
        result
    }

    #[cfg(feature = "rayon")]
    /// Same as [`integrate`](GaussHermite::integrate) but runs in parallel.
    pub fn par_integrate<F>(&self, integrand: F) -> f64
    where
        F: Fn(f64) -> f64 + Sync,
    {
        let result: f64 = self
            .node_weight_pairs
            .par_iter()
            .map(|(x_val, w_val)| integrand(*x_val) * w_val)
            .sum();
        result
    }
}

__impl_node_weight_rule! {GaussHermite, GaussHermiteNodes, GaussHermiteWeights, GaussHermiteIter, GaussHermiteIntoIter}

#[cfg(test)]
mod tests {
    use approx::assert_abs_diff_eq;

    use super::*;

    #[test]
    fn check_sorted() {
        for deg in (2..100).step_by(10) {
            let rule = GaussHermite::new(deg.try_into().unwrap());
            assert!(rule.as_node_weight_pairs().is_sorted());
        }
    }

    #[test]
    fn check_degree_1() {
        let rule = GaussHermite::new(1.try_into().unwrap());
        assert_eq!(rule.as_node_weight_pairs(), &[(0.0, sqrt(PI))]);
        for constant in (1..100).step_by(10) {
            assert_abs_diff_eq!(rule.integrate(|x| f64::from(constant) * x), 0.0);
        }
    }

    #[test]
    fn golub_welsch_3() {
        let rule = GaussHermite::new(3.try_into().unwrap());
        let x_should = [-1.224_744_871_391_589, 0.0, 1.224_744_871_391_589];
        let w_should = [
            0.295_408_975_150_919_35,
            1.181_635_900_603_677_4,
            0.295_408_975_150_919_35,
        ];
        for ((correct_node, correct_weight), (computed_node, computed_weight)) in
            x_should.into_iter().zip(w_should).zip(rule)
        {
            assert_abs_diff_eq!(correct_node, computed_node, epsilon = 1e-15);
            assert_abs_diff_eq!(correct_weight, computed_weight, epsilon = 1e-15);
        }
    }

    #[test]
    fn check_derives() {
        let quad = GaussHermite::new(10.try_into().unwrap());
        let quad_clone = quad.clone();
        assert_eq!(quad, quad_clone);
        let other_quad = GaussHermite::new(3.try_into().unwrap());
        assert_ne!(quad, other_quad);
    }

    #[test]
    fn check_iterators() {
        let rule = GaussHermite::new(3.try_into().unwrap());
        let ans = sqrt(PI) / 2.0;

        assert_abs_diff_eq!(
            ans,
            rule.iter().fold(0.0, |tot, (n, w)| tot + n * n * w),
            epsilon = 1e-14
        );

        assert_abs_diff_eq!(
            ans,
            rule.nodes()
                .zip(rule.weights())
                .fold(0.0, |tot, (n, w)| tot + n * n * w),
            epsilon = 1e-14
        );

        assert_abs_diff_eq!(
            ans,
            rule.into_iter().fold(0.0, |tot, (n, w)| tot + n * n * w),
            epsilon = 1e-14
        );
    }

    #[test]
    fn integrate_one() {
        let quad = GaussHermite::new(5.try_into().unwrap());
        let integral = quad.integrate(|_x| 1.0);
        assert_abs_diff_eq!(integral, sqrt(PI), epsilon = 1e-14);
    }

    #[cfg(feature = "rayon")]
    #[test]
    fn par_integrate_one() {
        let quad = GaussHermite::new(5.try_into().unwrap());
        let integral = quad.par_integrate(|_x| 1.0);
        assert_abs_diff_eq!(integral, sqrt(PI), epsilon = 1e-15);
    }
}