use crate::error::InterpolateError;
use scirs2_core::ndarray::Array2;
use std::collections::HashMap;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SmolyakRule {
ClenshawCurtis,
GaussLegendre,
GaussPatterson,
}
#[derive(Debug, Clone)]
pub struct SmolyakConfig {
pub dim: usize,
pub level: usize,
pub rule: SmolyakRule,
}
impl Default for SmolyakConfig {
fn default() -> Self {
Self {
dim: 2,
level: 3,
rule: SmolyakRule::ClenshawCurtis,
}
}
}
#[derive(Debug, Clone)]
pub struct SmolyakGrid {
pub points: Array2<f64>,
pub weights: Vec<f64>,
}
pub fn smolyak_grid(config: &SmolyakConfig) -> Result<SmolyakGrid, InterpolateError> {
if config.dim == 0 {
return Err(InterpolateError::InvalidInput {
message: "Smolyak: dim must be >= 1".into(),
});
}
let d = config.dim;
let q = config.level;
let mut point_map: HashMap<PointKey, f64> = HashMap::new();
let lo = (q + 1).saturating_sub(d - 1); for multi_idx in gen_multi_indices(d, q + 1, q + d) {
let coeff = smolyak_coefficient(d, q, &multi_idx);
if coeff == 0 {
continue;
}
let rules_1d: Vec<(Vec<f64>, Vec<f64>)> = multi_idx
.iter()
.map(|&level| rule_1d(level, &config.rule))
.collect::<Result<Vec<_>, _>>()?;
for (pt, wt) in tensor_product_points_weights(&rules_1d) {
let key = PointKey::from_slice(&pt);
let entry = point_map.entry(key).or_insert(0.0);
*entry += coeff as f64 * wt;
}
}
let _ = lo;
let mut points_list: Vec<(PointKey, f64)> = point_map.into_iter().collect();
points_list.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let n = points_list.len();
if n == 0 {
let mid = vec![0.0f64; d];
let points = Array2::from_shape_vec((1, d), mid)
.map_err(|e| InterpolateError::ComputationError(format!("Smolyak grid shape: {e}")))?;
return Ok(SmolyakGrid {
points,
weights: vec![2.0f64.powi(d as i32)],
});
}
let mut pts_flat = Vec::with_capacity(n * d);
let mut weights = Vec::with_capacity(n);
for (key, w) in &points_list {
pts_flat.extend_from_slice(&key.coords);
weights.push(*w);
}
let points = Array2::from_shape_vec((n, d), pts_flat).map_err(|e| {
InterpolateError::ComputationError(format!("Smolyak grid Array2 shape: {e}"))
})?;
Ok(SmolyakGrid { points, weights })
}
pub fn smolyak_quadrature<F>(
f: F,
dim: usize,
level: usize,
rule: SmolyakRule,
) -> Result<f64, InterpolateError>
where
F: Fn(&[f64]) -> f64,
{
let config = SmolyakConfig { dim, level, rule };
let grid = smolyak_grid(&config)?;
let n = grid.weights.len();
let mut sum = 0.0f64;
for i in 0..n {
let pt: Vec<f64> = grid.points.row(i).iter().copied().collect();
sum += grid.weights[i] * f(&pt);
}
Ok(sum)
}
pub fn smolyak_interpolant<F>(
f: F,
config: &SmolyakConfig,
) -> Result<impl Fn(&[f64]) -> f64, InterpolateError>
where
F: Fn(&[f64]) -> f64,
{
let grid = smolyak_grid(config)?;
let n = grid.weights.len();
let d = config.dim;
let fvals: Vec<f64> = (0..n)
.map(|i| {
let pt: Vec<f64> = grid.points.row(i).iter().copied().collect();
f(&pt)
})
.collect();
let pts_flat: Vec<f64> = grid.points.iter().copied().collect();
Ok(move |x: &[f64]| -> f64 {
let mut num = 0.0f64;
let mut den = 0.0f64;
let eps = 1e-14;
for i in 0..n {
let pt = &pts_flat[i * d..(i + 1) * d];
let dist2: f64 = pt
.iter()
.zip(x.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
if dist2 < eps * eps {
return fvals[i];
}
let w = 1.0 / dist2;
num += w * fvals[i];
den += w;
}
if den.abs() < 1e-300 {
0.0
} else {
num / den
}
})
}
fn rule_1d(level: usize, rule: &SmolyakRule) -> Result<(Vec<f64>, Vec<f64>), InterpolateError> {
if level == 0 {
return Err(InterpolateError::InvalidInput {
message: "rule_1d: level must be >= 1".into(),
});
}
match rule {
SmolyakRule::ClenshawCurtis => clenshaw_curtis(level),
SmolyakRule::GaussLegendre => gauss_legendre(level),
SmolyakRule::GaussPatterson => gauss_patterson(level),
}
}
fn cc_npoints(level: usize) -> usize {
if level == 1 {
1
} else {
(1usize << (level - 1)) + 1
}
}
fn clenshaw_curtis(level: usize) -> Result<(Vec<f64>, Vec<f64>), InterpolateError> {
let n = cc_npoints(level);
if n == 1 {
return Ok((vec![0.0], vec![2.0]));
}
let pts: Vec<f64> = (0..n)
.map(|j| -(std::f64::consts::PI * j as f64 / (n - 1) as f64).cos())
.collect();
let wts = cc_weights(n)?;
Ok((pts, wts))
}
fn cc_weights(n: usize) -> Result<Vec<f64>, InterpolateError> {
if n == 1 {
return Ok(vec![2.0]);
}
let pi = std::f64::consts::PI;
let mut w = vec![0.0f64; n];
for j in 0..n {
let theta_j = pi * j as f64 / (n - 1) as f64;
let mut s = 0.0f64;
let m = (n - 1) / 2;
for k in 1..=m {
let bk = if 2 * k == n - 1 { 1.0 } else { 2.0 };
s += bk / (1.0 - 4.0 * (k as f64).powi(2)) * (2.0 * k as f64 * theta_j).cos();
}
let w0 = 1.0 + s;
w[j] = if j == 0 || j == n - 1 {
w0 / (n - 1) as f64
} else {
2.0 * w0 / (n - 1) as f64
};
}
Ok(w)
}
fn gauss_legendre(level: usize) -> Result<(Vec<f64>, Vec<f64>), InterpolateError> {
let n = level;
if n == 0 {
return Err(InterpolateError::InvalidInput {
message: "gauss_legendre: level must be >= 1".into(),
});
}
if n == 1 {
return Ok((vec![0.0], vec![2.0]));
}
let mut pts = Vec::with_capacity(n);
let mut wts = Vec::with_capacity(n);
let pi = std::f64::consts::PI;
for i in 0..((n + 1) / 2) {
let mut x = (pi * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
for _ in 0..100 {
let (p0, dp) = legendre_pn_dpn(n, x);
let dx = p0 / dp;
x -= dx;
if dx.abs() < 1e-15 {
break;
}
}
let (_, dp) = legendre_pn_dpn(n, x);
let w = 2.0 / ((1.0 - x * x) * dp * dp);
pts.push(-x);
wts.push(w);
if 2 * i + 1 != n {
pts.push(x);
wts.push(w);
}
}
let mut pw: Vec<(f64, f64)> = pts.into_iter().zip(wts.into_iter()).collect();
pw.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let pts: Vec<f64> = pw.iter().map(|p| p.0).collect();
let wts: Vec<f64> = pw.iter().map(|p| p.1).collect();
Ok((pts, wts))
}
fn legendre_pn_dpn(n: usize, x: f64) -> (f64, f64) {
if n == 0 {
return (1.0, 0.0);
}
if n == 1 {
return (x, 1.0);
}
let mut p_prev = 1.0f64;
let mut p_curr = x;
for k in 2..=n {
let p_next = ((2 * k - 1) as f64 * x * p_curr - (k - 1) as f64 * p_prev) / k as f64;
p_prev = p_curr;
p_curr = p_next;
}
let dp = n as f64 * (x * p_curr - p_prev) / (x * x - 1.0);
(p_curr, dp)
}
fn gauss_patterson(level: usize) -> Result<(Vec<f64>, Vec<f64>), InterpolateError> {
match level {
1 => Ok((vec![0.0], vec![2.0])),
2 => {
let s = 1.0f64 / 3.0f64.sqrt();
Ok((vec![-s, 0.0, s], vec![5.0 / 9.0, 8.0 / 9.0, 5.0 / 9.0]))
}
3 => {
let pts = vec![
-0.96049126870802028,
-0.77459666924148338,
-0.43424374934680255,
0.0,
0.43424374934680255,
0.77459666924148338,
0.96049126870802028,
];
let wts = vec![
0.10465622602646727,
0.26848808986833345,
0.40139741477596222,
0.45091653865847415,
0.40139741477596222,
0.26848808986833345,
0.10465622602646727,
];
Ok((pts, wts))
}
4 => {
gauss_legendre(15)
}
_ => {
let n = 1usize << (level - 1); gauss_legendre(n)
}
}
}
fn gen_multi_indices(d: usize, lo: usize, hi: usize) -> Vec<Vec<usize>> {
if d == 0 {
return vec![];
}
let mut result = Vec::new();
let mut current = vec![1usize; d];
gen_mi_rec(d, lo, hi, 0, &mut current, &mut result);
result
}
fn gen_mi_rec(
d: usize,
lo: usize,
hi: usize,
dim: usize,
current: &mut Vec<usize>,
result: &mut Vec<Vec<usize>>,
) {
if dim == d {
let sum: usize = current.iter().sum();
if sum >= lo && sum <= hi {
result.push(current.clone());
}
return;
}
let sum_so_far: usize = current[..dim].iter().sum();
let remaining = d - dim - 1;
for v in 1..=(hi.saturating_sub(sum_so_far + remaining)) {
current[dim] = v;
gen_mi_rec(d, lo, hi, dim + 1, current, result);
}
current[dim] = 1; }
fn smolyak_coefficient(d: usize, q: usize, idx: &[usize]) -> i64 {
let sum: usize = idx.iter().sum();
let n = q + d;
if sum < q + 1 || sum > n {
return 0;
}
let k = n - sum; let sign: i64 = if k % 2 == 0 { 1 } else { -1 };
sign * binom(d - 1, k) as i64
}
fn binom(n: usize, k: usize) -> usize {
if k > n {
return 0;
}
if k == 0 || k == n {
return 1;
}
let k = k.min(n - k);
let mut result = 1usize;
for i in 0..k {
result = result.saturating_mul(n - i) / (i + 1);
}
result
}
fn tensor_product_points_weights(rules: &[(Vec<f64>, Vec<f64>)]) -> Vec<(Vec<f64>, f64)> {
let mut result: Vec<(Vec<f64>, f64)> = vec![(vec![], 1.0)];
for (pts, wts) in rules {
let mut new_result = Vec::with_capacity(result.len() * pts.len());
for (prev_pt, prev_w) in &result {
for (p, w) in pts.iter().zip(wts.iter()) {
let mut pt = prev_pt.clone();
pt.push(*p);
new_result.push((pt, prev_w * w));
}
}
result = new_result;
}
result
}
#[derive(Debug, Clone)]
struct PointKey {
bits: Vec<u64>,
coords: Vec<f64>,
}
impl PartialEq for PointKey {
fn eq(&self, other: &Self) -> bool {
self.bits == other.bits
}
}
impl Eq for PointKey {}
impl std::hash::Hash for PointKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.bits.hash(state);
}
}
impl PointKey {
fn from_slice(pts: &[f64]) -> Self {
let bits: Vec<u64> = pts
.iter()
.map(|&x| {
let x = if x.abs() < 1e-14 { 0.0 } else { x };
let rounded = (x * 1e12).round() * 1e-12;
rounded.to_bits()
})
.collect();
let coords: Vec<f64> = pts
.iter()
.map(|&x| {
if x.abs() < 1e-14 {
0.0
} else {
(x * 1e12).round() * 1e-12
}
})
.collect();
Self { bits, coords }
}
}
impl PartialOrd for PointKey {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PointKey {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.coords
.iter()
.zip(other.coords.iter())
.map(|(a, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.find(|&o| o != std::cmp::Ordering::Equal)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1e-10;
fn integral_const(dim: usize, level: usize, rule: SmolyakRule) -> f64 {
smolyak_quadrature(|_x| 1.0, dim, level, rule).expect("quadrature ok")
}
#[test]
fn test_smolyak_1d_cc() {
let config = SmolyakConfig {
dim: 1,
level: 2,
rule: SmolyakRule::ClenshawCurtis,
};
let grid = smolyak_grid(&config).expect("grid ok");
assert!(
grid.points.nrows() >= 1,
"1D CC level-2: expected >= 1 points, got {}",
grid.points.nrows()
);
let wsum: f64 = grid.weights.iter().sum();
assert!((wsum - 2.0).abs() < 1e-8, "weight sum = {wsum}");
}
#[test]
fn test_smolyak_2d_cc_points() {
let config = SmolyakConfig {
dim: 2,
level: 2,
rule: SmolyakRule::ClenshawCurtis,
};
let grid = smolyak_grid(&config).expect("grid ok");
assert!(
grid.points.nrows() >= 5,
"2D CC level-2: expected >= 5 points, got {}",
grid.points.nrows()
);
}
#[test]
fn test_smolyak_quadrature_const_1d() {
let val = integral_const(1, 3, SmolyakRule::ClenshawCurtis);
assert!((val - 2.0).abs() < TOL, "1D const integral = {val}");
}
#[test]
fn test_smolyak_quadrature_const_2d() {
let val = integral_const(2, 3, SmolyakRule::ClenshawCurtis);
assert!((val - 4.0).abs() < 1e-8, "2D const integral = {val}");
}
#[test]
fn test_smolyak_quadrature_const_3d() {
let val = integral_const(3, 3, SmolyakRule::ClenshawCurtis);
assert!((val - 8.0).abs() < 1e-6, "3D const integral = {val}");
}
#[test]
fn test_smolyak_quadrature_linear_1d() {
let val =
smolyak_quadrature(|x| x[0], 1, 2, SmolyakRule::ClenshawCurtis).expect("quadrature ok");
assert!(val.abs() < TOL, "∫ x dx = {val}");
}
#[test]
fn test_smolyak_quadrature_linear_2d() {
let val = smolyak_quadrature(|x| x[0] + x[1], 2, 3, SmolyakRule::ClenshawCurtis)
.expect("quadrature ok");
assert!(val.abs() < 1e-8, "∫∫ (x+y) = {val}");
}
#[test]
fn test_smolyak_quadrature_x_squared() {
let val = smolyak_quadrature(|x| x[0] * x[0], 1, 2, SmolyakRule::ClenshawCurtis)
.expect("quadrature ok");
assert!((val - 2.0 / 3.0).abs() < TOL, "∫ x^2 dx = {val}");
}
#[test]
fn test_smolyak_quadrature_polynomial_degree4() {
let val = smolyak_quadrature(|x| x[0].powi(4), 1, 3, SmolyakRule::ClenshawCurtis)
.expect("quadrature ok");
assert!((val - 2.0 / 5.0).abs() < TOL, "∫ x^4 dx = {val}");
}
#[test]
fn test_smolyak_quadrature_gauss_legendre() {
let val = smolyak_quadrature(|x| x[0] * x[0], 1, 3, SmolyakRule::GaussLegendre)
.expect("quadrature ok");
assert!((val - 2.0 / 3.0).abs() < 1e-8, "GL ∫ x^2 = {val}");
}
#[test]
fn test_smolyak_interpolant_constant() {
let config = SmolyakConfig {
dim: 2,
level: 2,
rule: SmolyakRule::ClenshawCurtis,
};
let interp = smolyak_interpolant(|_x| 5.0, &config).expect("interpolant ok");
let val = interp(&[0.3, -0.2]);
assert!((val - 5.0).abs() < 0.1, "constant interpolant: {val}");
}
#[test]
fn test_smolyak_grid_weight_sum() {
for d in 1..=3 {
for q in 1..=3 {
let config = SmolyakConfig {
dim: d,
level: q,
rule: SmolyakRule::ClenshawCurtis,
};
let grid = smolyak_grid(&config).expect("grid ok");
let wsum: f64 = grid.weights.iter().sum();
let expected = 2.0f64.powi(d as i32);
assert!(
(wsum - expected).abs() < 1e-6,
"d={d} q={q}: weight sum = {wsum}, expected {expected}"
);
}
}
}
#[test]
fn test_binom() {
assert_eq!(binom(0, 0), 1);
assert_eq!(binom(5, 2), 10);
assert_eq!(binom(10, 3), 120);
assert_eq!(binom(3, 4), 0);
}
#[test]
fn test_gauss_legendre_2pt() {
let (pts, wts) = gauss_legendre(2).expect("GL ok");
assert_eq!(pts.len(), 2);
let s = 1.0f64 / 3.0f64.sqrt();
assert!((pts[0] + s).abs() < 1e-12);
assert!((pts[1] - s).abs() < 1e-12);
assert!((wts[0] - 1.0).abs() < 1e-12);
}
#[test]
fn test_cc_level1() {
let (pts, wts) = clenshaw_curtis(1).expect("CC ok");
assert_eq!(pts.len(), 1);
assert_eq!(pts[0], 0.0);
assert_eq!(wts[0], 2.0);
}
#[test]
fn test_cc_level2() {
let (pts, wts) = clenshaw_curtis(2).expect("CC ok");
assert_eq!(pts.len(), 3);
assert!((pts[0] + 1.0).abs() < 1e-12);
assert!((pts[1]).abs() < 1e-12);
assert!((pts[2] - 1.0).abs() < 1e-12);
let wsum: f64 = wts.iter().sum();
assert!((wsum - 2.0).abs() < 1e-10, "CC level-2 weight sum: {wsum}");
}
}