use crate::error::{IntegrateError, IntegrateResult};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PolynomialFamily {
Hermite,
Legendre,
Laguerre,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct PceConfig {
pub n_inputs: usize,
pub order: usize,
pub polynomial: PolynomialFamily,
pub n_quadrature: usize,
pub use_sparse_grid: bool,
}
impl Default for PceConfig {
fn default() -> Self {
Self {
n_inputs: 2,
order: 3,
polynomial: PolynomialFamily::Hermite,
n_quadrature: 5,
use_sparse_grid: false,
}
}
}
#[derive(Debug, Clone)]
pub struct PceResult {
pub coefficients: Vec<f64>,
pub multi_indices: Vec<Vec<usize>>,
pub mean: f64,
pub variance: f64,
pub sobol_indices: Vec<f64>,
pub total_sobol: Vec<f64>,
pub polynomial: PolynomialFamily,
}
pub struct PolynomialChaos {
config: PceConfig,
}
impl PolynomialChaos {
pub fn new(config: PceConfig) -> Self {
Self { config }
}
pub fn fit<F: Fn(&[f64]) -> f64>(&self, f: F) -> IntegrateResult<PceResult> {
let cfg = &self.config;
if cfg.n_inputs == 0 {
return Err(IntegrateError::InvalidInput(
"n_inputs must be > 0".to_string(),
));
}
if cfg.n_quadrature == 0 {
return Err(IntegrateError::InvalidInput(
"n_quadrature must be > 0".to_string(),
));
}
let multi_indices = Self::generate_multi_indices(cfg.n_inputs, cfg.order);
let n_terms = multi_indices.len();
let (pts_1d, wts_1d) = self.get_quadrature_1d(cfg.n_quadrature)?;
let mut coefficients = vec![0.0; n_terms];
if cfg.use_sparse_grid {
self.fit_sparse_grid(&f, &multi_indices, &pts_1d, &wts_1d, &mut coefficients)?;
} else {
self.fit_tensor_product(&f, &multi_indices, &pts_1d, &wts_1d, &mut coefficients)?;
}
let mean = coefficients.first().copied().unwrap_or(0.0);
let variance: f64 = coefficients[1..].iter().map(|c| c * c).sum();
let (sobol_indices, total_sobol) =
Self::compute_sobol_indices(&coefficients, &multi_indices, cfg.n_inputs, variance);
Ok(PceResult {
coefficients,
multi_indices,
mean,
variance,
sobol_indices,
total_sobol,
polynomial: cfg.polynomial,
})
}
pub fn evaluate(&self, result: &PceResult, xi: &[f64]) -> f64 {
result
.coefficients
.iter()
.zip(result.multi_indices.iter())
.map(|(c, alpha)| {
let psi = self.eval_basis_function(alpha, xi, result.polynomial);
c * psi
})
.sum()
}
fn eval_basis_function(&self, alpha: &[usize], xi: &[f64], family: PolynomialFamily) -> f64 {
alpha
.iter()
.zip(xi.iter())
.map(|(°, &x)| Self::eval_1d_poly(deg, x, family))
.product()
}
fn eval_1d_poly(degree: usize, x: f64, family: PolynomialFamily) -> f64 {
match family {
PolynomialFamily::Hermite => Self::hermite_poly_normalized(degree, x),
PolynomialFamily::Legendre => Self::legendre_poly_normalized(degree, x),
PolynomialFamily::Laguerre => Self::laguerre_poly_normalized(degree, x),
}
}
fn fit_tensor_product(
&self,
f: &dyn Fn(&[f64]) -> f64,
multi_indices: &[Vec<usize>],
pts_1d: &[f64],
wts_1d: &[f64],
coefficients: &mut [f64],
) -> IntegrateResult<()> {
let n = self.config.n_inputs;
let nq = pts_1d.len();
let n_terms = multi_indices.len();
let n_total = nq.pow(n as u32);
for qidx in 0..n_total {
let mut remainder = qidx;
let mut xi = vec![0.0; n];
let mut w = 1.0;
for d in 0..n {
let k = remainder % nq;
remainder /= nq;
xi[d] = pts_1d[k];
w *= wts_1d[k];
}
let fval = f(&xi);
for (tidx, alpha) in multi_indices.iter().enumerate() {
let psi = self.eval_basis_function(alpha, &xi, self.config.polynomial);
coefficients[tidx] += fval * psi * w;
}
}
let _ = n_terms;
Ok(())
}
fn fit_sparse_grid(
&self,
f: &dyn Fn(&[f64]) -> f64,
multi_indices: &[Vec<usize>],
pts_1d: &[f64],
wts_1d: &[f64],
coefficients: &mut [f64],
) -> IntegrateResult<()> {
let n = self.config.n_inputs;
let nq = pts_1d.len();
let smolyak_level = (self.config.order + 1).min(nq);
let pts_sparse = &pts_1d[..smolyak_level];
let wts_sparse = &wts_1d[..smolyak_level];
let w_sum: f64 = wts_sparse.iter().sum();
let wts_norm: Vec<f64> = if w_sum > 1e-14 {
wts_sparse.iter().map(|w| w / w_sum).collect()
} else {
vec![1.0 / smolyak_level as f64; smolyak_level]
};
let n_total = smolyak_level.pow(n as u32);
for qidx in 0..n_total {
let mut remainder = qidx;
let mut xi = vec![0.0; n];
let mut w = 1.0;
for d in 0..n {
let k = remainder % smolyak_level;
remainder /= smolyak_level;
xi[d] = pts_sparse[k];
w *= wts_norm[k];
}
let fval = f(&xi);
for (tidx, alpha) in multi_indices.iter().enumerate() {
let psi = self.eval_basis_function(alpha, &xi, self.config.polynomial);
coefficients[tidx] += fval * psi * w;
}
}
Ok(())
}
pub fn generate_multi_indices(n_vars: usize, order: usize) -> Vec<Vec<usize>> {
let mut indices = Vec::new();
let mut current = vec![0usize; n_vars];
Self::generate_indices_recursive(n_vars, order, 0, 0, &mut current, &mut indices);
indices
}
fn generate_indices_recursive(
n_vars: usize,
max_order: usize,
current_dim: usize,
current_sum: usize,
current: &mut Vec<usize>,
result: &mut Vec<Vec<usize>>,
) {
if current_dim == n_vars {
result.push(current.clone());
return;
}
let remaining = max_order.saturating_sub(current_sum);
for deg in 0..=remaining {
current[current_dim] = deg;
Self::generate_indices_recursive(
n_vars,
max_order,
current_dim + 1,
current_sum + deg,
current,
result,
);
}
}
pub fn compute_sobol_indices(
coefficients: &[f64],
multi_indices: &[Vec<usize>],
n_vars: usize,
variance: f64,
) -> (Vec<f64>, Vec<f64>) {
let mut first_order = vec![0.0; n_vars];
let mut total_order = vec![0.0; n_vars];
if variance < f64::EPSILON {
return (first_order, total_order);
}
for (alpha, c) in multi_indices.iter().zip(coefficients.iter()).skip(1) {
let c_sq = c * c;
let active: Vec<usize> = alpha
.iter()
.enumerate()
.filter(|(_, &d)| d > 0)
.map(|(i, _)| i)
.collect();
if active.is_empty() {
continue;
}
if active.len() == 1 {
let i = active[0];
if i < n_vars {
first_order[i] += c_sq;
}
}
for &i in &active {
if i < n_vars {
total_order[i] += c_sq;
}
}
}
for i in 0..n_vars {
first_order[i] /= variance;
total_order[i] /= variance;
}
(first_order, total_order)
}
fn get_quadrature_1d(&self, n: usize) -> IntegrateResult<(Vec<f64>, Vec<f64>)> {
match self.config.polynomial {
PolynomialFamily::Hermite => Ok(Self::gauss_hermite(n)),
PolynomialFamily::Legendre => Ok(Self::gauss_legendre(n)),
PolynomialFamily::Laguerre => Ok(Self::gauss_laguerre(n)),
}
}
pub fn gauss_hermite(n: usize) -> (Vec<f64>, Vec<f64>) {
gauss_hermite_newton(n)
}
pub fn gauss_legendre(n: usize) -> (Vec<f64>, Vec<f64>) {
gauss_legendre_newton(n)
}
pub fn gauss_laguerre(n: usize) -> (Vec<f64>, Vec<f64>) {
gauss_laguerre_newton(n)
}
pub fn hermite_poly(n: usize, x: f64) -> f64 {
if n == 0 {
return 1.0;
}
let mut h_prev = 1.0_f64;
let mut h_curr = x;
for k in 1..n {
let h_next = x * h_curr - k as f64 * h_prev;
h_prev = h_curr;
h_curr = h_next;
}
h_curr
}
fn hermite_poly_normalized(n: usize, x: f64) -> f64 {
let raw = Self::hermite_poly(n, x);
let norm = factorial_sqrt(n);
raw / norm
}
pub fn legendre_poly(n: usize, x: f64) -> f64 {
if n == 0 {
return 1.0;
}
if n == 1 {
return x;
}
let mut p_prev = 1.0_f64;
let mut p_curr = x;
for k in 1..n {
let kf = k as f64;
let p_next = ((2.0 * kf + 1.0) * x * p_curr - kf * p_prev) / (kf + 1.0);
p_prev = p_curr;
p_curr = p_next;
}
p_curr
}
fn legendre_poly_normalized(n: usize, x: f64) -> f64 {
let raw = Self::legendre_poly(n, x);
let norm = ((2.0 * n as f64 + 1.0) / 2.0).sqrt();
raw * norm
}
pub fn laguerre_poly(n: usize, x: f64) -> f64 {
if n == 0 {
return 1.0;
}
if n == 1 {
return 1.0 - x;
}
let mut l_prev = 1.0_f64;
let mut l_curr = 1.0 - x;
for k in 1..n {
let kf = k as f64;
let l_next = ((2.0 * kf + 1.0 - x) * l_curr - kf * l_prev) / (kf + 1.0);
l_prev = l_curr;
l_curr = l_next;
}
l_curr
}
fn laguerre_poly_normalized(n: usize, x: f64) -> f64 {
Self::laguerre_poly(n, x) }
}
fn factorial_sqrt(n: usize) -> f64 {
if n == 0 {
return 1.0;
}
let log_n_factorial: f64 = (1..=n).map(|k| (k as f64).ln()).sum();
(log_n_factorial / 2.0).exp()
}
fn gauss_legendre_newton(n: usize) -> (Vec<f64>, Vec<f64>) {
if n == 0 {
return (vec![], vec![]);
}
if n == 1 {
return (vec![0.0], vec![2.0]);
}
let mut pts = vec![0.0_f64; n];
let mut wts = vec![0.0_f64; n];
let m = n.div_ceil(2);
let pi = std::f64::consts::PI;
for i in 0..m {
let mut x = (pi * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
for _ in 0..100 {
let mut p0 = 1.0_f64;
let mut p1 = x;
for j in 1..n {
let jf = j as f64;
let p2 = ((2.0 * jf + 1.0) * x * p1 - jf * p0) / (jf + 1.0);
p0 = p1;
p1 = p2;
}
let dp = n as f64 * (x * p1 - p0) / (x * x - 1.0);
let dx = p1 / dp;
x -= dx;
if dx.abs() < f64::EPSILON * x.abs() + 1e-15 {
break;
}
}
let mut p0 = 1.0_f64;
let mut p1 = x;
for j in 1..n {
let jf = j as f64;
let p2 = ((2.0 * jf + 1.0) * x * p1 - jf * p0) / (jf + 1.0);
p0 = p1;
p1 = p2;
}
let dp = n as f64 * (x * p1 - p0) / (x * x - 1.0);
let w = 2.0 / ((1.0 - x * x) * dp * dp);
pts[i] = -x;
pts[n - 1 - i] = x;
wts[i] = w;
wts[n - 1 - i] = w;
}
(pts, wts)
}
fn gauss_hermite_newton(n: usize) -> (Vec<f64>, Vec<f64>) {
if n == 0 {
return (vec![], vec![]);
}
if n == 1 {
return (vec![0.0], vec![1.0]);
}
if n == 2 {
return (vec![-1.0, 1.0], vec![0.5, 0.5]);
}
let mut pts = vec![0.0_f64; n];
let mut wts = vec![0.0_f64; n];
let m = n.div_ceil(2);
for i in 0..m {
let nf = n as f64;
let mut x = (2.0_f64 * nf + 1.0).sqrt()
* (std::f64::consts::PI * (4.0 * i as f64 + 3.0) / (4.0 * nf + 2.0)).cos();
for _ in 0..200 {
let (h_n, h_nm1) = hermite_eval_pair(n, x);
let dp = nf * h_nm1;
if dp.abs() < f64::MIN_POSITIVE {
break;
}
let dx = h_n / dp;
x -= dx;
if dx.abs() < f64::EPSILON * (x.abs() + 1.0) {
break;
}
}
let (_, h_nm1) = hermite_eval_pair(n, x);
let w_raw = if h_nm1.abs() > f64::MIN_POSITIVE {
1.0 / (nf * h_nm1 * h_nm1)
} else {
0.0
};
pts[i] = -x.abs();
pts[n - 1 - i] = x.abs();
wts[i] = w_raw;
wts[n - 1 - i] = w_raw;
}
if n % 2 == 1 {
let mid = n / 2;
pts[mid] = 0.0;
let (_, h_nm1) = hermite_eval_pair(n, 0.0);
let nf = n as f64;
wts[mid] = if h_nm1.abs() > f64::MIN_POSITIVE {
1.0 / (nf * h_nm1 * h_nm1)
} else {
1.0 / nf
};
}
let w_sum: f64 = wts.iter().sum();
if w_sum > f64::MIN_POSITIVE {
for w in &mut wts {
*w /= w_sum;
}
}
let mut pairs: Vec<(f64, f64)> = pts.iter().zip(wts.iter()).map(|(&p, &w)| (p, w)).collect();
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let pts: Vec<f64> = pairs.iter().map(|(p, _)| *p).collect();
let wts: Vec<f64> = pairs.iter().map(|(_, w)| *w).collect();
(pts, wts)
}
fn hermite_eval_pair(n: usize, x: f64) -> (f64, f64) {
if n == 0 {
return (1.0, 0.0);
}
if n == 1 {
return (x, 1.0);
}
let mut h_prev = 1.0_f64;
let mut h_curr = x;
for k in 1..n {
let h_next = x * h_curr - k as f64 * h_prev;
h_prev = h_curr;
h_curr = h_next;
}
(h_curr, h_prev)
}
fn gauss_laguerre_newton(n: usize) -> (Vec<f64>, Vec<f64>) {
if n == 0 {
return (vec![], vec![]);
}
if n == 1 {
return (vec![1.0], vec![1.0]);
}
let nf = n as f64;
let mut pts = vec![0.0_f64; n];
let mut wts = vec![0.0_f64; n];
for i in 0..n {
let mut x = if i == 0 {
3.0 / (1.0 + 2.4 / nf)
} else if i == 1 {
7.5 / (1.0 + 2.5 / nf)
} else {
let ai = (i as f64 - 0.5) * std::f64::consts::PI / (nf + 0.5);
let x0 = (1.0 + 0.2 * ai * ai).sqrt() * ai;
x0 * x0
};
for _ in 0..200 {
let (l_n, l_nm1) = laguerre_eval_pair(n, x);
let dp = if x.abs() > f64::EPSILON {
nf * (l_n - l_nm1) / x
} else {
-nf
};
if dp.abs() < f64::MIN_POSITIVE {
break;
}
let dx = l_n / dp;
x -= dx;
x = x.max(f64::MIN_POSITIVE); if dx.abs() < f64::EPSILON * (x + 1.0) {
break;
}
}
pts[i] = x;
let (_, l_nm1) = laguerre_eval_pair(n, x);
let l_np1 = laguerre_eval_n(n + 1, x);
let denom = (nf + 1.0) * l_np1;
wts[i] = if denom.abs() > f64::MIN_POSITIVE {
x / (denom * denom)
} else {
let _ = l_nm1;
1.0 / nf
};
}
let w_sum: f64 = wts.iter().sum();
if w_sum > f64::MIN_POSITIVE {
for w in &mut wts {
*w /= w_sum;
}
}
let mut pairs: Vec<(f64, f64)> = pts.iter().zip(wts.iter()).map(|(&p, &w)| (p, w)).collect();
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let pts: Vec<f64> = pairs.iter().map(|(p, _)| *p).collect();
let wts: Vec<f64> = pairs.iter().map(|(_, w)| *w).collect();
(pts, wts)
}
fn laguerre_eval_pair(n: usize, x: f64) -> (f64, f64) {
if n == 0 {
return (1.0, 0.0);
}
if n == 1 {
return (1.0 - x, 1.0);
}
let mut l_prev = 1.0_f64;
let mut l_curr = 1.0 - x;
for k in 1..n {
let kf = k as f64;
let l_next = ((2.0 * kf + 1.0 - x) * l_curr - kf * l_prev) / (kf + 1.0);
l_prev = l_curr;
l_curr = l_next;
}
(l_curr, l_prev)
}
fn laguerre_eval_n(n: usize, x: f64) -> f64 {
laguerre_eval_pair(n, x).0
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1e-10;
const LOOSE_TOL: f64 = 1e-6;
#[test]
fn test_pce_config_default() {
let cfg = PceConfig::default();
assert_eq!(cfg.n_inputs, 2);
assert_eq!(cfg.order, 3);
assert_eq!(cfg.polynomial, PolynomialFamily::Hermite);
assert_eq!(cfg.n_quadrature, 5);
assert!(!cfg.use_sparse_grid);
}
#[test]
fn test_hermite_polynomial_recurrence() {
assert!((PolynomialChaos::hermite_poly(0, 2.0) - 1.0).abs() < TOL);
assert!((PolynomialChaos::hermite_poly(1, 2.0) - 2.0).abs() < TOL);
assert!((PolynomialChaos::hermite_poly(2, 2.0) - (4.0 - 1.0)).abs() < TOL);
assert!((PolynomialChaos::hermite_poly(3, 2.0) - (8.0 - 6.0)).abs() < TOL);
}
#[test]
fn test_legendre_polynomials() {
let x = 0.5;
assert!((PolynomialChaos::legendre_poly(0, x) - 1.0).abs() < TOL);
assert!((PolynomialChaos::legendre_poly(1, x) - x).abs() < TOL);
let p2_exact = (3.0 * x * x - 1.0) / 2.0;
assert!(
(PolynomialChaos::legendre_poly(2, x) - p2_exact).abs() < TOL,
"P_2(0.5): got {}, expected {}",
PolynomialChaos::legendre_poly(2, x),
p2_exact
);
assert!((PolynomialChaos::legendre_poly(0, 1.0) - 1.0).abs() < TOL);
assert!((PolynomialChaos::legendre_poly(1, 1.0) - 1.0).abs() < TOL);
assert!((PolynomialChaos::legendre_poly(2, 1.0) - 1.0).abs() < TOL);
}
#[test]
fn test_legendre_poly_at_zero() {
assert!((PolynomialChaos::legendre_poly(0, 0.0) - 1.0).abs() < TOL);
assert!((PolynomialChaos::legendre_poly(1, 0.0) - 0.0).abs() < TOL);
assert!((PolynomialChaos::legendre_poly(2, 0.0) - (-0.5)).abs() < TOL);
}
#[test]
fn test_multi_indices_count() {
let idx = PolynomialChaos::generate_multi_indices(2, 2);
assert_eq!(idx.len(), 6, "C(4,2)=6, got {}", idx.len());
let idx = PolynomialChaos::generate_multi_indices(2, 3);
assert_eq!(idx.len(), 10, "C(5,3)=10, got {}", idx.len());
let idx = PolynomialChaos::generate_multi_indices(1, 3);
assert_eq!(idx.len(), 4, "1D order-3: 4 terms, got {}", idx.len());
let idx = PolynomialChaos::generate_multi_indices(3, 2);
assert_eq!(idx.len(), 10, "C(5,2)=10, got {}", idx.len());
}
#[test]
fn test_multi_indices_first_is_zero() {
let idx = PolynomialChaos::generate_multi_indices(3, 4);
assert_eq!(idx[0], vec![0, 0, 0]);
}
#[test]
fn test_multi_indices_total_degree_constraint() {
let order = 3;
let n = 2;
let idx = PolynomialChaos::generate_multi_indices(n, order);
for alpha in &idx {
let total: usize = alpha.iter().sum();
assert!(
total <= order,
"Multi-index {:?} exceeds order {}",
alpha,
order
);
}
}
#[test]
fn test_gauss_legendre_1pt() {
let (pts, wts) = PolynomialChaos::gauss_legendre(1);
assert_eq!(pts.len(), 1);
assert!((pts[0] - 0.0).abs() < TOL);
assert!((wts[0] - 2.0).abs() < 1e-6);
}
#[test]
fn test_gauss_legendre_integrates_polynomials() {
let (pts, wts) = PolynomialChaos::gauss_legendre(3);
let integral: f64 = pts.iter().zip(wts.iter()).map(|(x, w)| x * x * w).sum();
assert!(
(integral - 2.0 / 3.0).abs() < LOOSE_TOL,
"∫x² dx = 2/3: got {integral}"
);
}
#[test]
fn test_gauss_hermite_weights_sum_to_one() {
let n = 5;
let (pts, wts) = PolynomialChaos::gauss_hermite(n);
let w_sum: f64 = wts.iter().sum();
assert!(
(w_sum - 1.0).abs() < 1e-6,
"Hermite weights sum = {w_sum}, expected 1.0"
);
assert_eq!(pts.len(), n);
}
#[test]
fn test_pce_linear_function_hermite() {
let config = PceConfig {
n_inputs: 2,
order: 2,
polynomial: PolynomialFamily::Hermite,
n_quadrature: 6,
use_sparse_grid: false,
};
let pce = PolynomialChaos::new(config);
let result = pce.fit(|xi: &[f64]| xi[0]).expect("PCE fit should succeed");
assert!(
result.mean.abs() < 1e-3,
"Mean should be ~0 for u=ξ₁: got {}",
result.mean
);
assert!(
(result.variance - 1.0).abs() < 0.1,
"Variance should be ~1 for u=ξ₁: got {}",
result.variance
);
assert!(
result.sobol_indices[0] > 0.7,
"S_1 should be ~1 for u=ξ₁: got {}",
result.sobol_indices[0]
);
}
#[test]
fn test_pce_constant_function() {
let config = PceConfig {
n_inputs: 2,
order: 2,
polynomial: PolynomialFamily::Hermite,
n_quadrature: 5,
use_sparse_grid: false,
};
let pce = PolynomialChaos::new(config);
let result = pce.fit(|_| 5.0).expect("PCE fit constant");
assert!(
(result.mean - 5.0).abs() < 0.1,
"Mean of constant 5.0 should be ~5: got {}",
result.mean
);
assert!(
result.variance < 0.1,
"Variance of constant should be ~0: got {}",
result.variance
);
}
#[test]
fn test_pce_evaluate_at_origin() {
let config = PceConfig {
n_inputs: 1,
order: 3,
polynomial: PolynomialFamily::Legendre,
n_quadrature: 5,
use_sparse_grid: false,
};
let pce = PolynomialChaos::new(config);
let result = pce.fit(|xi: &[f64]| xi[0] * xi[0]).expect("PCE fit");
let val = pce.evaluate(&result, &[0.0]);
assert!(
val.abs() < 0.5,
"PCE(0) ≈ u(0) = 0 for quadratic: got {val}"
);
}
#[test]
fn test_pce_sobol_indices_sum_leq_one() {
let config = PceConfig {
n_inputs: 3,
order: 2,
polynomial: PolynomialFamily::Hermite,
n_quadrature: 4,
use_sparse_grid: false,
};
let pce = PolynomialChaos::new(config);
let result = pce.fit(|xi: &[f64]| xi[0] + xi[1]).expect("PCE fit");
if result.variance > 1e-10 {
let s_sum: f64 = result.sobol_indices.iter().sum();
assert!(
s_sum <= 1.0 + 1e-6,
"Sum of first-order Sobol indices should be <= 1: {}",
s_sum
);
}
}
#[test]
fn test_pce_sparse_grid_runs() {
let config = PceConfig {
n_inputs: 2,
order: 2,
polynomial: PolynomialFamily::Legendre,
n_quadrature: 4,
use_sparse_grid: true,
};
let pce = PolynomialChaos::new(config);
let result = pce.fit(|xi: &[f64]| xi[0] + xi[1]);
assert!(result.is_ok(), "Sparse grid PCE should succeed");
}
#[test]
fn test_pce_invalid_config_n_inputs_zero() {
let config = PceConfig {
n_inputs: 0,
order: 2,
..Default::default()
};
let pce = PolynomialChaos::new(config);
let result = pce.fit(|_| 1.0);
assert!(result.is_err(), "n_inputs=0 should fail");
}
#[test]
fn test_laguerre_polynomial() {
assert!((PolynomialChaos::laguerre_poly(0, 1.0) - 1.0).abs() < TOL);
assert!((PolynomialChaos::laguerre_poly(1, 1.0) - 0.0).abs() < TOL);
let l2_exact = (1.0 - 4.0 + 2.0) / 2.0;
assert!(
(PolynomialChaos::laguerre_poly(2, 1.0) - l2_exact).abs() < TOL,
"L_2(1) = {}",
PolynomialChaos::laguerre_poly(2, 1.0)
);
}
}