use crate::error::QuadratureError;
use crate::golub_welsch::{golub_welsch, radau_modify};
use crate::rule::QuadratureRule;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
#[derive(Debug, Clone)]
pub struct GaussRadau {
rule: QuadratureRule<f64>,
}
impl GaussRadau {
pub fn left(n: usize) -> Result<Self, QuadratureError> {
if n == 0 {
return Err(QuadratureError::ZeroOrder);
}
let (nodes, weights) = compute_radau_left(n)?;
Ok(Self {
rule: QuadratureRule { nodes, weights },
})
}
pub fn right(n: usize) -> Result<Self, QuadratureError> {
if n == 0 {
return Err(QuadratureError::ZeroOrder);
}
let (nodes_left, weights_left) = compute_radau_left(n)?;
let nodes: Vec<f64> = nodes_left.iter().rev().map(|&x| -x).collect();
let weights: Vec<f64> = weights_left.iter().rev().copied().collect();
Ok(Self {
rule: QuadratureRule { nodes, weights },
})
}
}
impl_rule_accessors!(GaussRadau, nodes_doc: "Returns the nodes on \\[-1, 1\\].");
#[allow(clippy::cast_precision_loss)] fn compute_radau_left(n: usize) -> Result<(Vec<f64>, Vec<f64>), QuadratureError> {
if n == 1 {
return Ok((vec![-1.0], vec![2.0]));
}
let mut diag = vec![0.0; n];
let off_diag_sq: Vec<f64> = (1..n)
.map(|k| {
let k = k as f64;
k * k / (4.0 * k * k - 1.0)
})
.collect();
radau_modify(&mut diag, &off_diag_sq, -1.0);
let mu0 = 2.0; golub_welsch(&diag, &off_diag_sq, mu0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_order() {
assert!(GaussRadau::left(0).is_err());
assert!(GaussRadau::right(0).is_err());
}
#[test]
fn single_point() {
let gr = GaussRadau::left(1).unwrap();
assert_eq!(gr.nodes(), &[-1.0]);
assert!((gr.weights()[0] - 2.0).abs() < 1e-14);
}
#[test]
fn left_endpoint() {
let gr = GaussRadau::left(10).unwrap();
assert!((gr.nodes()[0] - (-1.0)).abs() < 1e-14);
assert!(*gr.nodes().last().unwrap() < 1.0);
}
#[test]
fn right_endpoint() {
let gr = GaussRadau::right(10).unwrap();
assert!(gr.nodes()[0] > -1.0);
assert!((*gr.nodes().last().unwrap() - 1.0).abs() < 1e-14);
}
#[test]
fn weight_sum() {
for n in [2, 5, 10, 20] {
let gl = GaussRadau::left(n).unwrap();
let sum: f64 = gl.weights().iter().sum();
assert!((sum - 2.0).abs() < 1e-12, "n={n}: sum={sum}");
let gr = GaussRadau::right(n).unwrap();
let sum: f64 = gr.weights().iter().sum();
assert!((sum - 2.0).abs() < 1e-12, "right n={n}: sum={sum}");
}
}
#[test]
fn nodes_sorted() {
let gr = GaussRadau::left(20).unwrap();
for i in 0..gr.order() - 1 {
assert!(
gr.nodes()[i] < gr.nodes()[i + 1],
"i={i}: {} >= {}",
gr.nodes()[i],
gr.nodes()[i + 1]
);
}
}
#[test]
fn polynomial_exactness() {
let n = 10;
let gr = GaussRadau::left(n).unwrap();
let max_deg = 2 * n - 2;
let expected = 2.0 / (max_deg as f64 + 1.0);
let result = gr
.rule()
.integrate(-1.0, 1.0, |x: f64| x.powi(max_deg as i32));
assert!(
(result - expected).abs() < 1e-10,
"deg={max_deg}: result={result}, expected={expected}"
);
}
}