use crate::error::QuadratureError;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
#[cfg(not(feature = "std"))]
use num_traits::Float as _;
pub(crate) fn golub_welsch(
diag: &[f64],
off_diag_sq: &[f64],
mu0: f64,
) -> Result<(Vec<f64>, Vec<f64>), QuadratureError> {
let n = diag.len();
assert_eq!(off_diag_sq.len(), n.saturating_sub(1));
if n == 0 {
return Ok((vec![], vec![]));
}
if n == 1 {
return Ok((vec![diag[0]], vec![mu0]));
}
let mut d = diag.to_vec();
let mut e: Vec<f64> = off_diag_sq.iter().map(|&b| b.sqrt()).collect();
let mut z = vec![0.0; n];
z[0] = 1.0;
if !symmetric_tridiag_eig(&mut d, &mut e, &mut z) {
return Err(QuadratureError::InvalidInput(
"QL eigenvalue algorithm did not converge",
));
}
let weights: Vec<f64> = z.iter().map(|&zk| mu0 * zk * zk).collect();
let mut pairs: Vec<_> = d.into_iter().zip(weights).collect();
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(core::cmp::Ordering::Equal));
Ok(pairs.into_iter().unzip())
}
pub(crate) fn radau_modify(diag: &mut [f64], off_diag_sq: &[f64], x0: f64) {
let n = diag.len();
assert_eq!(off_diag_sq.len(), n.saturating_sub(1));
if n <= 1 {
diag[0] = x0;
return;
}
let mut r = x0 - diag[0]; for k in 2..n {
r = x0 - diag[k - 1] - off_diag_sq[k - 2] / r;
}
diag[n - 1] = x0 - off_diag_sq[n - 2] / r;
}
#[allow(clippy::many_single_char_names)] fn symmetric_tridiag_eig(d: &mut [f64], e: &mut [f64], z: &mut [f64]) -> bool {
let n = d.len();
if n <= 1 {
return true;
}
let mut e_ext = vec![0.0; n];
e_ext[..n - 1].copy_from_slice(e);
let mut converged = true;
for l in 0..n {
let mut iter_count = 0u32;
loop {
let mut m = l;
while m < n - 1 {
let tst = d[m].abs() + d[m + 1].abs();
if e_ext[m].abs() <= f64::EPSILON * tst {
break;
}
m += 1;
}
if m == l {
break;
}
iter_count += 1;
if iter_count > 200 {
converged = false;
break;
}
let mut g = (d[l + 1] - d[l]) / (2.0 * e_ext[l]);
let r = g.hypot(1.0);
g = d[m] - d[l] + e_ext[l] / (g + r.copysign(g));
let mut s = 1.0;
let mut c = 1.0;
let mut p = 0.0;
let mut deflated = false;
for i in (l..m).rev() {
let f = s * e_ext[i];
let b = c * e_ext[i];
let r = f.hypot(g);
e_ext[i + 1] = r;
if r.abs() < 1e-300 {
d[i + 1] -= p;
e_ext[m] = 0.0;
deflated = true;
break;
}
s = f / r;
c = g / r;
let g_tmp = d[i + 1] - p;
let r2 = (d[i] - g_tmp) * s + 2.0 * c * b;
p = s * r2;
d[i + 1] = g_tmp + p;
g = c * r2 - b;
let fz = z[i + 1];
z[i + 1] = s * z[i] + c * fz;
z[i] = c * z[i] - s * fz;
}
if !deflated {
d[l] -= p;
e_ext[l] = g;
e_ext[m] = 0.0;
}
}
}
converged
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn legendre_recovery() {
let n = 5;
let diag = vec![0.0; n]; let off_diag_sq: Vec<f64> = (1..n)
.map(|k| {
let k = k as f64;
k * k / (4.0 * k * k - 1.0)
})
.collect();
let mu0 = 2.0;
let (nodes, weights) = golub_welsch(&diag, &off_diag_sq, mu0).unwrap();
let sum: f64 = weights.iter().sum();
assert!((sum - 2.0).abs() < 1e-14, "sum={sum}");
assert!(nodes[0] > -1.0);
assert!(*nodes.last().unwrap() < 1.0);
for i in 0..n - 1 {
assert!(nodes[i] < nodes[i + 1]);
}
let r: f64 = nodes
.iter()
.zip(&weights)
.map(|(&x, &w)| x.powi(4) * w)
.sum();
assert!((r - 2.0 / 5.0).abs() < 1e-14, "r={r}");
}
#[test]
fn radau_left_n2() {
let n = 2;
let mut diag = vec![0.0; n];
let off_diag_sq = vec![1.0 / 3.0]; let mu0 = 2.0;
radau_modify(&mut diag, &off_diag_sq, -1.0);
let (nodes, weights) = golub_welsch(&diag, &off_diag_sq, mu0).unwrap();
assert!((nodes[0] - (-1.0)).abs() < 1e-14);
assert!((nodes[1] - 1.0 / 3.0).abs() < 1e-14);
assert!((weights[0] - 0.5).abs() < 1e-14, "w0={}", weights[0]);
assert!((weights[1] - 1.5).abs() < 1e-14, "w1={}", weights[1]);
}
}