use crate::error::QuadratureError;
use crate::golub_welsch::golub_welsch;
use crate::rule::QuadratureRule;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
#[cfg(not(feature = "std"))]
use num_traits::Float as _;
#[derive(Debug, Clone)]
pub struct GaussHermite {
rule: QuadratureRule<f64>,
}
impl GaussHermite {
pub fn new(n: usize) -> Result<Self, QuadratureError> {
if n == 0 {
return Err(QuadratureError::ZeroOrder);
}
let (nodes, weights) = compute_hermite(n)?;
Ok(Self {
rule: QuadratureRule { nodes, weights },
})
}
}
impl_rule_accessors!(GaussHermite, nodes_doc: "Returns the nodes on (-∞, ∞).");
#[allow(clippy::cast_precision_loss)] fn compute_hermite(n: usize) -> Result<(Vec<f64>, Vec<f64>), QuadratureError> {
let diag = vec![0.0; n];
let off_diag_sq: Vec<f64> = (1..n).map(|k| k as f64 / 2.0).collect();
let mu0 = core::f64::consts::PI.sqrt();
golub_welsch(&diag, &off_diag_sq, mu0)
}
#[cfg(test)]
mod tests {
use super::*;
use core::f64::consts::PI;
#[test]
fn zero_order() {
assert!(GaussHermite::new(0).is_err());
}
#[test]
fn weight_sum() {
let gh = GaussHermite::new(20).unwrap();
let sum: f64 = gh.weights().iter().sum();
assert!((sum - PI.sqrt()).abs() < 1e-12, "sum={sum}");
}
#[test]
fn node_symmetry() {
let gh = GaussHermite::new(21).unwrap();
let n = gh.order();
for i in 0..n / 2 {
assert!(
(gh.nodes()[i] + gh.nodes()[n - 1 - i]).abs() < 1e-12,
"i={i}: {} vs {}",
gh.nodes()[i],
gh.nodes()[n - 1 - i]
);
}
if n % 2 == 1 {
assert!(gh.nodes()[n / 2].abs() < 1e-14);
}
}
#[test]
fn weight_symmetry() {
let gh = GaussHermite::new(20).unwrap();
let n = gh.order();
for i in 0..n / 2 {
assert!(
(gh.weights()[i] - gh.weights()[n - 1 - i]).abs() < 1e-12,
"i={i}: {} vs {}",
gh.weights()[i],
gh.weights()[n - 1 - i]
);
}
}
#[test]
fn nodes_sorted() {
let gh = GaussHermite::new(20).unwrap();
for i in 0..gh.order() - 1 {
assert!(
gh.nodes()[i] < gh.nodes()[i + 1],
"i={i}: {} >= {}",
gh.nodes()[i],
gh.nodes()[i + 1]
);
}
}
#[test]
fn polynomial_exactness() {
let gh = GaussHermite::new(10).unwrap();
let r0: f64 = gh.weights().iter().sum();
assert!((r0 - PI.sqrt()).abs() < 1e-12);
let r1: f64 = gh
.nodes()
.iter()
.zip(gh.weights())
.map(|(&x, &w)| x * x * w)
.sum();
assert!((r1 - PI.sqrt() / 2.0).abs() < 1e-12, "r1={r1}");
let r2: f64 = gh
.nodes()
.iter()
.zip(gh.weights())
.map(|(&x, &w)| x.powi(4) * w)
.sum();
assert!((r2 - 3.0 * PI.sqrt() / 4.0).abs() < 1e-11, "r2={r2}");
let odd: f64 = gh
.nodes()
.iter()
.zip(gh.weights())
.map(|(&x, &w)| x.powi(3) * w)
.sum();
assert!(odd.abs() < 1e-12, "odd={odd}");
}
}