use crate::error::QuadratureError;
use crate::gauss_jacobi::ln_gamma;
use crate::golub_welsch::golub_welsch;
use crate::rule::QuadratureRule;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[cfg(not(feature = "std"))]
use num_traits::Float as _;
#[derive(Debug, Clone)]
pub struct GaussLaguerre {
rule: QuadratureRule<f64>,
alpha: f64,
}
impl GaussLaguerre {
pub fn new(n: usize, alpha: f64) -> Result<Self, QuadratureError> {
if n == 0 {
return Err(QuadratureError::ZeroOrder);
}
if !alpha.is_finite() || alpha <= -1.0 {
return Err(QuadratureError::InvalidInput(
"Laguerre parameter alpha must be finite and satisfy alpha > -1",
));
}
let (nodes, weights) = compute_laguerre(n, alpha)?;
Ok(Self {
rule: QuadratureRule { nodes, weights },
alpha,
})
}
#[must_use]
pub fn alpha(&self) -> f64 {
self.alpha
}
}
impl_rule_accessors!(GaussLaguerre, nodes_doc: "Returns the nodes on \\[0, ∞).");
#[allow(clippy::cast_precision_loss)] fn compute_laguerre(n: usize, alpha: f64) -> Result<(Vec<f64>, Vec<f64>), QuadratureError> {
let diag: Vec<f64> = (0..n).map(|k| 2.0 * k as f64 + 1.0 + alpha).collect();
let off_diag_sq: Vec<f64> = (1..n)
.map(|k| {
let k = k as f64;
k * (k + alpha)
})
.collect();
let mu0 = ln_gamma(alpha + 1.0).exp();
golub_welsch(&diag, &off_diag_sq, mu0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_order() {
assert!(GaussLaguerre::new(0, 0.0).is_err());
}
#[test]
fn invalid_alpha() {
assert!(GaussLaguerre::new(5, -1.0).is_err());
assert!(GaussLaguerre::new(5, -2.0).is_err());
assert!(GaussLaguerre::new(5, f64::NAN).is_err());
}
#[test]
fn standard_weight_sum() {
let gl = GaussLaguerre::new(20, 0.0).unwrap();
let sum: f64 = gl.weights().iter().sum();
assert!((sum - 1.0).abs() < 1e-12, "sum={sum}");
}
#[test]
fn generalised_weight_sum() {
for alpha in [0.5, 1.0, 2.0, 3.5] {
let gl = GaussLaguerre::new(20, alpha).unwrap();
let sum: f64 = gl.weights().iter().sum();
let expected = ln_gamma(alpha + 1.0).exp();
assert!(
(sum - expected).abs() < 1e-10,
"alpha={alpha}: sum={sum}, expected={expected}"
);
}
}
#[test]
fn nodes_positive_and_sorted() {
let gl = GaussLaguerre::new(20, 0.0).unwrap();
for &x in gl.nodes() {
assert!(x > 0.0, "node={x} is not positive");
}
for i in 0..gl.order() - 1 {
assert!(gl.nodes()[i] < gl.nodes()[i + 1]);
}
}
#[test]
fn infinite_alpha_rejected() {
assert!(GaussLaguerre::new(5, f64::INFINITY).is_err());
assert!(GaussLaguerre::new(5, f64::NAN).is_err());
}
#[test]
fn polynomial_exactness() {
let n = 10;
let gl = GaussLaguerre::new(n, 0.0).unwrap();
for k in 0..5 {
let numerical: f64 = gl
.nodes()
.iter()
.zip(gl.weights())
.map(|(&x, &w)| w * x.powi(k))
.sum();
let expected = ln_gamma(k as f64 + 1.0).exp();
assert!(
(numerical - expected).abs() < 1e-10,
"k={k}: got={numerical}, expected={expected}"
);
}
}
}