use crate::error::{SpecialError, SpecialResult};
pub fn leggauss(n: usize) -> SpecialResult<(Vec<f64>, Vec<f64>)> {
if n == 0 {
return Err(SpecialError::ValueError(
"Number of quadrature points must be >= 1".to_string(),
));
}
if n == 1 {
return Ok((vec![0.0], vec![2.0]));
}
let mut diag = vec![0.0f64; n];
let mut sub_diag = vec![0.0f64; n - 1];
for i in 0..n - 1 {
let ip1 = (i + 1) as f64;
sub_diag[i] = ip1 / (4.0 * ip1 * ip1 - 1.0).sqrt();
}
let (eigenvalues, eigenvectors) = symmetric_tridiag_eigensystem(&diag, &sub_diag)?;
let mut nodes = eigenvalues;
let mut weights: Vec<f64> = eigenvectors
.iter()
.map(|v| 2.0 * v[0] * v[0])
.collect();
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| nodes[a].partial_cmp(&nodes[b]).expect("NaN in nodes"));
let sorted_nodes: Vec<f64> = indices.iter().map(|&i| nodes[i]).collect();
let sorted_weights: Vec<f64> = indices.iter().map(|&i| weights[i]).collect();
nodes = sorted_nodes;
weights = sorted_weights;
Ok((nodes, weights))
}
pub fn hermgauss(n: usize) -> SpecialResult<(Vec<f64>, Vec<f64>)> {
if n == 0 {
return Err(SpecialError::ValueError(
"Number of quadrature points must be >= 1".to_string(),
));
}
if n == 1 {
return Ok((vec![0.0], vec![std::f64::consts::PI.sqrt()]));
}
let diag = vec![0.0f64; n];
let mut sub_diag = vec![0.0f64; n - 1];
for i in 0..n - 1 {
sub_diag[i] = ((i + 1) as f64 / 2.0).sqrt();
}
let (eigenvalues, eigenvectors) = symmetric_tridiag_eigensystem(&diag, &sub_diag)?;
let mu0 = std::f64::consts::PI.sqrt();
let mut nodes = eigenvalues;
let mut weights: Vec<f64> = eigenvectors
.iter()
.map(|v| mu0 * v[0] * v[0])
.collect();
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| nodes[a].partial_cmp(&nodes[b]).expect("NaN in nodes"));
let sorted_nodes: Vec<f64> = indices.iter().map(|&i| nodes[i]).collect();
let sorted_weights: Vec<f64> = indices.iter().map(|&i| weights[i]).collect();
nodes = sorted_nodes;
weights = sorted_weights;
Ok((nodes, weights))
}
pub fn laggauss(n: usize, alpha: f64) -> SpecialResult<(Vec<f64>, Vec<f64>)> {
if n == 0 {
return Err(SpecialError::ValueError(
"Number of quadrature points must be >= 1".to_string(),
));
}
if alpha <= -1.0 {
return Err(SpecialError::DomainError(
"alpha must be > -1 for Gauss-Laguerre quadrature".to_string(),
));
}
let mut diag = vec![0.0f64; n];
let mut sub_diag = vec![0.0f64; n.saturating_sub(1)];
for i in 0..n {
diag[i] = 2.0 * (i as f64) + alpha + 1.0;
}
for i in 0..n.saturating_sub(1) {
let ip1 = (i + 1) as f64;
sub_diag[i] = (ip1 * (ip1 + alpha)).sqrt();
}
if n == 1 {
return Ok((vec![diag[0]], vec![gamma_fn(alpha + 1.0)]));
}
let (eigenvalues, eigenvectors) = symmetric_tridiag_eigensystem(&diag, &sub_diag)?;
let mu0 = gamma_fn(alpha + 1.0);
let mut nodes = eigenvalues;
let mut weights: Vec<f64> = eigenvectors
.iter()
.map(|v| mu0 * v[0] * v[0])
.collect();
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| nodes[a].partial_cmp(&nodes[b]).expect("NaN in nodes"));
let sorted_nodes: Vec<f64> = indices.iter().map(|&i| nodes[i]).collect();
let sorted_weights: Vec<f64> = indices.iter().map(|&i| weights[i]).collect();
nodes = sorted_nodes;
weights = sorted_weights;
Ok((nodes, weights))
}
pub fn chebgauss(n: usize) -> SpecialResult<(Vec<f64>, Vec<f64>)> {
if n == 0 {
return Err(SpecialError::ValueError(
"Number of quadrature points must be >= 1".to_string(),
));
}
let w = std::f64::consts::PI / (n as f64);
let nodes: Vec<f64> = (1..=n)
.map(|i| ((2 * i - 1) as f64 * std::f64::consts::PI / (2.0 * n as f64)).cos())
.collect();
let weights = vec![w; n];
Ok((nodes, weights))
}
pub fn jacgauss(n: usize, alpha: f64, beta: f64) -> SpecialResult<(Vec<f64>, Vec<f64>)> {
if n == 0 {
return Err(SpecialError::ValueError(
"Number of quadrature points must be >= 1".to_string(),
));
}
if alpha <= -1.0 || beta <= -1.0 {
return Err(SpecialError::DomainError(
"alpha and beta must be > -1 for Gauss-Jacobi quadrature".to_string(),
));
}
let mut diag = vec![0.0f64; n];
let mut sub_diag = vec![0.0f64; n.saturating_sub(1)];
for i in 0..n {
let i_f = i as f64;
let denom = (2.0 * i_f + alpha + beta) * (2.0 * i_f + alpha + beta + 2.0);
if denom.abs() < 1e-300 {
diag[i] = 0.0;
} else {
diag[i] = (beta * beta - alpha * alpha) / denom;
}
}
for i in 0..n.saturating_sub(1) {
let ip1 = (i + 1) as f64;
let numer = 4.0 * ip1 * (ip1 + alpha) * (ip1 + beta) * (ip1 + alpha + beta);
let denom_base = 2.0 * ip1 + alpha + beta;
let denom = denom_base * denom_base * (denom_base + 1.0) * (denom_base - 1.0);
if denom.abs() < 1e-300 {
sub_diag[i] = 0.0;
} else {
sub_diag[i] = (numer / denom).sqrt();
}
}
if n == 1 {
let mu0 = jacobi_mu0(alpha, beta);
return Ok((vec![diag[0]], vec![mu0]));
}
let (eigenvalues, eigenvectors) = symmetric_tridiag_eigensystem(&diag, &sub_diag)?;
let mu0 = jacobi_mu0(alpha, beta);
let mut nodes = eigenvalues;
let mut weights: Vec<f64> = eigenvectors
.iter()
.map(|v| mu0 * v[0] * v[0])
.collect();
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| nodes[a].partial_cmp(&nodes[b]).expect("NaN in nodes"));
let sorted_nodes: Vec<f64> = indices.iter().map(|&i| nodes[i]).collect();
let sorted_weights: Vec<f64> = indices.iter().map(|&i| weights[i]).collect();
nodes = sorted_nodes;
weights = sorted_weights;
Ok((nodes, weights))
}
fn jacobi_mu0(alpha: f64, beta: f64) -> f64 {
let log_mu0 = (alpha + beta + 1.0) * 2.0_f64.ln()
+ lgamma_fn(alpha + 1.0)
+ lgamma_fn(beta + 1.0)
- lgamma_fn(alpha + beta + 2.0);
log_mu0.exp()
}
fn gamma_fn(x: f64) -> f64 {
if x <= 0.0 && x.fract() == 0.0 {
return f64::INFINITY;
}
if x < 0.5 {
return std::f64::consts::PI
/ ((std::f64::consts::PI * x).sin() * gamma_fn(1.0 - x));
}
let p = [
676.520_368_121_885_1,
-1259.139_216_722_402_8,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507_343_278_686_905,
-0.138_571_095_265_720_12,
9.984_369_578_019_572e-6,
1.505_632_735_149_311_6e-7,
];
let z = x - 1.0;
let mut result = 0.999_999_999_999_809_9_f64;
for (i, &p_val) in p.iter().enumerate() {
result += p_val / (z + (i as f64) + 1.0);
}
let t = z + (p.len() as f64) - 0.5;
let sqrt_2pi = 2.506_628_274_631_000_7;
sqrt_2pi * t.powf(z + 0.5) * (-t).exp() * result
}
fn lgamma_fn(x: f64) -> f64 {
if x <= 0.0 {
return f64::INFINITY;
}
gamma_fn(x).ln()
}
fn symmetric_tridiag_eigensystem(
diag: &[f64],
sub_diag: &[f64],
) -> SpecialResult<(Vec<f64>, Vec<Vec<f64>>)> {
let n = diag.len();
if n == 0 {
return Ok((vec![], vec![]));
}
if n == 1 {
return Ok((vec![diag[0]], vec![vec![1.0]]));
}
let mut d = diag.to_vec();
let mut e = sub_diag.to_vec();
e.push(0.0);
let mut z = vec![vec![0.0f64; n]; n];
for i in 0..n {
z[i][i] = 1.0;
}
let max_iter = 30 * n;
let mut iter_count = 0;
let mut m = n;
while m > 1 {
let mut l = m - 1;
while l > 0 {
let thresh = 1e-15 * (d[l - 1].abs() + d[l].abs());
if e[l - 1].abs() <= thresh {
break;
}
l -= 1;
}
if l == m - 1 {
m -= 1;
continue;
}
iter_count += 1;
if iter_count > max_iter {
return Err(SpecialError::ConvergenceError(
"Tridiagonal eigenvalue computation did not converge".to_string(),
));
}
let dd = (d[m - 2] - d[m - 1]) / (2.0 * e[m - 2]);
let r = (dd * dd + 1.0).sqrt();
let shift = d[m - 1] - e[m - 2] / (dd + dd.signum() * r);
let mut f = d[l] - shift;
let mut g = e[l];
for i in l..m - 1 {
let (cos, sin, _r) = givens_rotation(f, g);
if i > l {
e[i - 1] = _r;
}
f = cos * d[i] + sin * e[i];
e[i] = cos * e[i] - sin * d[i];
g = sin * d[i + 1];
d[i + 1] = cos * d[i + 1];
for k in 0..n {
let t = cos * z[i][k] + sin * z[i + 1][k];
z[i + 1][k] = -sin * z[i][k] + cos * z[i + 1][k];
z[i][k] = t;
}
d[i] = cos * f + sin * g;
if i + 1 < m - 1 {
f = e[i + 1] * cos; g = e[i + 1] * sin;
}
if i < m - 2 {
f = cos * e[i] + sin * d[i + 1];
}
}
break;
}
let mut d = diag.to_vec();
let mut e = sub_diag.to_vec();
e.push(0.0);
let mut z_mat = vec![vec![0.0f64; n]; n];
for i in 0..n {
z_mat[i][i] = 1.0;
}
tqli_algorithm(&mut d, &mut e, &mut z_mat)?;
let eigenvectors: Vec<Vec<f64>> = (0..n)
.map(|i| (0..n).map(|j| z_mat[j][i]).collect())
.collect();
Ok((d, eigenvectors))
}
fn tqli_algorithm(
d: &mut [f64],
e: &mut [f64],
z: &mut [Vec<f64>],
) -> SpecialResult<()> {
let n = d.len();
if n <= 1 {
return Ok(());
}
for i in 1..n {
e[i - 1] = e[i - 1]; }
e[n - 1] = 0.0;
for l_outer in 0..n {
let mut iter_count = 0;
let max_iter = 100;
loop {
let mut m = l_outer;
while m < n - 1 {
let dd = d[m].abs() + d[m + 1].abs();
if e[m].abs() + dd == dd {
break;
}
m += 1;
}
if m == l_outer {
break; }
iter_count += 1;
if iter_count > max_iter {
return Err(SpecialError::ConvergenceError(
"QL algorithm did not converge".to_string(),
));
}
let g = (d[l_outer + 1] - d[l_outer]) / (2.0 * e[l_outer]);
let r = (g * g + 1.0).sqrt();
let g_shift = d[m] - d[l_outer] + e[l_outer] / (g + g.signum() * r);
let mut s = 1.0;
let mut c = 1.0;
let mut p = 0.0;
for i in (l_outer..m).rev() {
let f = s * e[i];
let b = c * e[i];
let r2 = (f * f + g_shift * g_shift).sqrt();
e[i + 1] = r2;
if r2 == 0.0 {
d[i + 1] -= p;
e[m] = 0.0;
break;
}
s = f / r2;
c = g_shift / r2;
let g_new = d[i + 1] - p;
let r3 = (d[i] - g_new) * s + 2.0 * c * b;
p = s * r3;
d[i + 1] = g_new + p;
let g_shift_new = c * r3 - b;
for k in 0..n {
let t = z[k][i + 1];
z[k][i + 1] = s * z[k][i] + c * t;
z[k][i] = c * z[k][i] - s * t;
}
let _ = g_shift_new;
}
d[l_outer] -= p;
e[l_outer] = g_shift;
e[m] = 0.0;
}
}
Ok(())
}
fn givens_rotation(a: f64, b: f64) -> (f64, f64, f64) {
if b == 0.0 {
(1.0, 0.0, a)
} else if b.abs() > a.abs() {
let tau = -a / b;
let sin = 1.0 / (1.0 + tau * tau).sqrt();
let cos = sin * tau;
let r = b / sin;
(cos, sin, r)
} else {
let tau = -b / a;
let cos = 1.0 / (1.0 + tau * tau).sqrt();
let sin = cos * tau;
let r = a / cos;
(cos, sin, r)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_leggauss_1_point() {
let (nodes, weights) = leggauss(1).expect("leggauss(1) failed");
assert_eq!(nodes.len(), 1);
assert!((nodes[0] - 0.0).abs() < 1e-14);
assert!((weights[0] - 2.0).abs() < 1e-14);
}
#[test]
fn test_leggauss_2_points() {
let (nodes, weights) = leggauss(2).expect("leggauss(2) failed");
assert_eq!(nodes.len(), 2);
let s = 1.0 / 3.0_f64.sqrt();
assert!((nodes[0].abs() - s).abs() < 1e-12);
assert!((nodes[1].abs() - s).abs() < 1e-12);
assert!((weights[0] - 1.0).abs() < 1e-12);
assert!((weights[1] - 1.0).abs() < 1e-12);
}
#[test]
fn test_leggauss_integrates_x_squared() {
let (nodes, weights) = leggauss(5).expect("leggauss(5) failed");
let integral: f64 = nodes.iter().zip(&weights).map(|(x, w)| w * x * x).sum();
assert!(
(integral - 2.0 / 3.0).abs() < 1e-12,
"integral of x^2 = {integral}, expected 2/3"
);
}
#[test]
fn test_leggauss_integrates_constant() {
let (_, weights) = leggauss(4).expect("leggauss(4) failed");
let integral: f64 = weights.iter().sum();
assert!(
(integral - 2.0).abs() < 1e-12,
"integral of 1 = {integral}, expected 2"
);
}
#[test]
fn test_leggauss_zero_points_error() {
let result = leggauss(0);
assert!(result.is_err());
}
#[test]
fn test_leggauss_symmetry() {
let (nodes, weights) = leggauss(6).expect("leggauss(6) failed");
for i in 0..3 {
assert!(
(nodes[i] + nodes[5 - i]).abs() < 1e-12,
"nodes not symmetric: {} + {} != 0",
nodes[i],
nodes[5 - i]
);
assert!(
(weights[i] - weights[5 - i]).abs() < 1e-12,
"weights not symmetric"
);
}
}
#[test]
fn test_hermgauss_1_point() {
let (nodes, weights) = hermgauss(1).expect("hermgauss(1) failed");
assert_eq!(nodes.len(), 1);
assert!((nodes[0] - 0.0).abs() < 1e-14);
assert!(
(weights[0] - std::f64::consts::PI.sqrt()).abs() < 1e-14
);
}
#[test]
fn test_hermgauss_integrates_one() {
let (_, weights) = hermgauss(5).expect("hermgauss(5) failed");
let integral: f64 = weights.iter().sum();
assert!(
(integral - std::f64::consts::PI.sqrt()).abs() < 1e-12,
"integral of 1 with Hermite weight = {integral}"
);
}
#[test]
fn test_hermgauss_integrates_x_squared() {
let (nodes, weights) = hermgauss(5).expect("hermgauss(5) failed");
let integral: f64 = nodes.iter().zip(&weights).map(|(x, w)| w * x * x).sum();
let expected = std::f64::consts::PI.sqrt() / 2.0;
assert!(
(integral - expected).abs() < 1e-12,
"integral of x^2 * exp(-x^2) = {integral}, expected {expected}"
);
}
#[test]
fn test_hermgauss_symmetry() {
let (nodes, weights) = hermgauss(4).expect("hermgauss(4) failed");
for i in 0..2 {
assert!(
(nodes[i] + nodes[3 - i]).abs() < 1e-12,
"Hermite nodes not symmetric"
);
}
let _ = weights;
}
#[test]
fn test_hermgauss_zero_error() {
assert!(hermgauss(0).is_err());
}
#[test]
fn test_chebgauss_weights() {
let (nodes, weights) = chebgauss(4).expect("chebgauss(4) failed");
assert_eq!(nodes.len(), 4);
let expected_w = std::f64::consts::PI / 4.0;
for w in &weights {
assert!((*w - expected_w).abs() < 1e-14);
}
}
#[test]
fn test_chebgauss_nodes_in_range() {
let (nodes, _) = chebgauss(10).expect("chebgauss(10) failed");
for x in &nodes {
assert!(*x >= -1.0 && *x <= 1.0, "node out of range: {x}");
}
}
#[test]
fn test_chebgauss_integrates_one() {
let (_, weights) = chebgauss(5).expect("chebgauss(5) failed");
let integral: f64 = weights.iter().sum();
assert!(
(integral - std::f64::consts::PI).abs() < 1e-12,
"integral = {integral}, expected pi"
);
}
#[test]
fn test_chebgauss_zero_error() {
assert!(chebgauss(0).is_err());
}
#[test]
fn test_chebgauss_symmetry() {
let (nodes, _) = chebgauss(6).expect("chebgauss(6) failed");
for i in 0..3 {
assert!(
(nodes[i] + nodes[5 - i]).abs() < 1e-12,
"Chebyshev nodes not symmetric"
);
}
}
#[test]
fn test_laggauss_integrates_one() {
let (_, weights) = laggauss(5, 0.0).expect("laggauss(5,0) failed");
let integral: f64 = weights.iter().sum();
assert!(
(integral - 1.0).abs() < 1e-10,
"integral = {integral}, expected 1"
);
}
#[test]
fn test_laggauss_positive_nodes() {
let (nodes, _) = laggauss(5, 0.0).expect("laggauss(5,0) failed");
for x in &nodes {
assert!(*x > 0.0, "Laguerre node should be positive: {x}");
}
}
#[test]
fn test_laggauss_alpha_error() {
assert!(laggauss(5, -1.0).is_err());
assert!(laggauss(5, -2.0).is_err());
}
#[test]
fn test_laggauss_positive_weights() {
let (_, weights) = laggauss(5, 0.0).expect("laggauss(5,0) failed");
for w in &weights {
assert!(*w > 0.0, "Laguerre weight should be positive: {w}");
}
}
#[test]
fn test_laggauss_zero_error() {
assert!(laggauss(0, 0.0).is_err());
}
#[test]
fn test_jacgauss_legendre_case() {
let (_, weights) = jacgauss(5, 0.0, 0.0).expect("jacgauss(5,0,0) failed");
let integral: f64 = weights.iter().sum();
assert!(
(integral - 2.0).abs() < 1e-10,
"Jacobi(0,0) total weight = {integral}, expected 2"
);
}
#[test]
fn test_jacgauss_nodes_in_range() {
let (nodes, _) = jacgauss(5, 1.0, 2.0).expect("jacgauss(5,1,2) failed");
for x in &nodes {
assert!(
*x >= -1.0 - 1e-10 && *x <= 1.0 + 1e-10,
"node out of range: {x}"
);
}
}
#[test]
fn test_jacgauss_positive_weights() {
let (_, weights) = jacgauss(5, 0.5, 0.5).expect("jacgauss(5,0.5,0.5) failed");
for w in &weights {
assert!(*w > 0.0, "Jacobi weight should be positive: {w}");
}
}
#[test]
fn test_jacgauss_error_negative_alpha() {
assert!(jacgauss(5, -1.0, 0.0).is_err());
}
#[test]
fn test_jacgauss_zero_error() {
assert!(jacgauss(0, 0.0, 0.0).is_err());
}
}