use crate::error::QuadratureError;
use crate::gauss_legendre::legendre_eval;
use crate::rule::QuadratureRule;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
#[cfg(not(feature = "std"))]
use num_traits::Float as _;
#[derive(Debug, Clone)]
pub struct GaussLobatto {
rule: QuadratureRule<f64>,
}
impl GaussLobatto {
pub fn new(n: usize) -> Result<Self, QuadratureError> {
if n == 0 {
return Err(QuadratureError::ZeroOrder);
}
if n < 2 {
return Err(QuadratureError::InvalidInput(
"Gauss-Lobatto requires at least 2 points",
));
}
let (nodes, weights) = compute_lobatto(n);
Ok(Self {
rule: QuadratureRule { nodes, weights },
})
}
}
impl_rule_accessors!(GaussLobatto, nodes_doc: "Returns the nodes on \\[-1, 1\\].");
#[allow(clippy::cast_precision_loss)] fn compute_lobatto(n: usize) -> (Vec<f64>, Vec<f64>) {
let n_f = n as f64;
let nm1 = n - 1;
let mut nodes = vec![0.0_f64; n];
let mut weights = vec![0.0_f64; n];
nodes[0] = -1.0;
nodes[n - 1] = 1.0;
let w_end = 2.0 / (n_f * (n_f - 1.0));
weights[0] = w_end;
weights[n - 1] = w_end;
if n == 2 {
return (nodes, weights);
}
let m = n - 2; for k in 0..m {
let theta = core::f64::consts::PI * (k as f64 + 1.0) / (m as f64 + 1.0);
let mut x = -(theta.cos());
for _ in 0..100 {
let (p_nm1, dp_nm1) = legendre_eval(nm1, x);
let nm1_f = nm1 as f64;
let d2p = (2.0 * x * dp_nm1 - nm1_f * (nm1_f + 1.0) * p_nm1) / (1.0 - x * x);
if d2p.abs() < 1e-300 {
break;
}
let dx = dp_nm1 / d2p;
x -= dx;
x = x.clamp(-1.0 + f64::EPSILON, 1.0 - f64::EPSILON);
if dx.abs() < 1e-15 * (1.0 + x.abs()) {
break;
}
}
nodes[k + 1] = x;
let (p_nm1, _) = legendre_eval(nm1, x);
weights[k + 1] = 2.0 / (n_f * (n_f - 1.0) * p_nm1 * p_nm1);
}
let mut pairs: Vec<_> = nodes.into_iter().zip(weights).collect();
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(core::cmp::Ordering::Equal));
let (nodes, weights) = pairs.into_iter().unzip();
(nodes, weights)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn too_few_points() {
assert!(GaussLobatto::new(0).is_err());
assert!(GaussLobatto::new(1).is_err());
assert!(GaussLobatto::new(2).is_ok());
}
#[test]
fn zero_order_returns_zero_order_error() {
assert_eq!(
GaussLobatto::new(0).unwrap_err(),
QuadratureError::ZeroOrder
);
}
#[test]
fn one_point_returns_invalid_input() {
assert!(matches!(
GaussLobatto::new(1).unwrap_err(),
QuadratureError::InvalidInput(_)
));
}
#[test]
fn two_point_trapezoid() {
let gl = GaussLobatto::new(2).unwrap();
assert_eq!(gl.nodes(), &[-1.0, 1.0]);
assert!((gl.weights()[0] - 1.0).abs() < 1e-14);
assert!((gl.weights()[1] - 1.0).abs() < 1e-14);
}
#[test]
fn weight_sum() {
for n in [3, 5, 10, 20, 50] {
let gl = GaussLobatto::new(n).unwrap();
let sum: f64 = gl.weights().iter().sum();
assert!((sum - 2.0).abs() < 1e-12, "n={n}: sum={sum}");
}
}
#[test]
fn endpoints() {
let gl = GaussLobatto::new(10).unwrap();
assert_eq!(gl.nodes()[0], -1.0);
assert_eq!(*gl.nodes().last().unwrap(), 1.0);
}
#[test]
fn nodes_sorted() {
let gl = GaussLobatto::new(20).unwrap();
for i in 0..gl.order() - 1 {
assert!(gl.nodes()[i] < gl.nodes()[i + 1]);
}
}
#[test]
fn symmetry() {
let gl = GaussLobatto::new(15).unwrap();
let n = gl.order();
for i in 0..n / 2 {
assert!(
(gl.nodes()[i] + gl.nodes()[n - 1 - i]).abs() < 1e-13,
"i={i}: {} vs {}",
gl.nodes()[i],
gl.nodes()[n - 1 - i]
);
assert!(
(gl.weights()[i] - gl.weights()[n - 1 - i]).abs() < 1e-13,
"i={i}: {} vs {}",
gl.weights()[i],
gl.weights()[n - 1 - i]
);
}
}
#[test]
fn polynomial_exactness() {
let n = 10;
let gl = GaussLobatto::new(n).unwrap();
let max_deg = 2 * n - 3;
let result = gl
.rule()
.integrate(-1.0, 1.0, |x: f64| x.powi(max_deg as i32));
assert!(result.abs() < 1e-11, "deg={max_deg}: result={result}");
let deg = max_deg - 1;
let expected = 2.0 / (deg as f64 + 1.0);
let result = gl.rule().integrate(-1.0, 1.0, |x: f64| x.powi(deg as i32));
assert!(
(result - expected).abs() < 1e-11,
"deg={deg}: result={result}, expected={expected}"
);
}
}