use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::numeric::{Float, FromPrimitive, Zero};
use std::collections::HashMap;
use std::fmt::{Debug, Display};
use std::ops::{AddAssign, MulAssign};
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct MultiIndex {
pub indices: Vec<usize>,
}
impl MultiIndex {
pub fn new(indices: Vec<usize>) -> Self {
Self { indices }
}
pub fn l1_norm(&self) -> usize {
self.indices.iter().sum()
}
pub fn linf_norm(&self) -> usize {
self.indices.iter().max().copied().unwrap_or(0)
}
pub fn dim(&self) -> usize {
self.indices.len()
}
pub fn is_admissible(&self, max_level: usize, dim: usize) -> bool {
self.l1_norm() <= max_level
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct GridPoint<F: Float> {
pub coords: Vec<F>,
pub index: MultiIndex,
pub surplus: F,
pub value: F,
}
#[derive(Debug)]
pub struct SparseGridInterpolator<F>
where
F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
{
dimension: usize,
bounds: Vec<(F, F)>,
max_level: usize,
grid_points: HashMap<MultiIndex, GridPoint<F>>,
#[allow(dead_code)]
adaptive: bool,
tolerance: F,
stats: SparseGridStats,
}
#[derive(Debug, Default)]
pub struct SparseGridStats {
pub num_points: usize,
pub num_evaluations: usize,
pub max_level_reached: usize,
pub error_estimate: f64,
}
#[derive(Debug)]
pub struct SparseGridBuilder<F>
where
F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
{
bounds: Option<Vec<(F, F)>>,
max_level: usize,
adaptive: bool,
tolerance: F,
initial_points: Option<Vec<Vec<F>>>,
}
impl<F> Default for SparseGridBuilder<F>
where
F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
{
fn default() -> Self {
Self {
bounds: None,
max_level: 3,
adaptive: false,
tolerance: F::from_f64(1e-6).expect("Operation failed"),
initial_points: None,
}
}
}
impl<F> SparseGridBuilder<F>
where
F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
{
pub fn new() -> Self {
Self::default()
}
pub fn with_bounds(mut self, bounds: Vec<(F, F)>) -> Self {
self.bounds = Some(bounds);
self
}
pub fn with_max_level(mut self, maxlevel: usize) -> Self {
self.max_level = maxlevel;
self
}
pub fn with_adaptive_refinement(mut self, adaptive: bool) -> Self {
self.adaptive = adaptive;
self
}
pub fn with_tolerance(mut self, tolerance: F) -> Self {
self.tolerance = tolerance;
self
}
pub fn with_initial_points(mut self, points: Vec<Vec<F>>) -> Self {
self.initial_points = Some(points);
self
}
pub fn build<Func>(self, func: Func) -> InterpolateResult<SparseGridInterpolator<F>>
where
Func: Fn(&[F]) -> F,
{
let bounds = self.bounds.ok_or_else(|| {
InterpolateError::invalid_input("Bounds must be specified".to_string())
})?;
if bounds.is_empty() {
return Err(InterpolateError::invalid_input(
"At least one dimension required".to_string(),
));
}
let dimension = bounds.len();
let mut interpolator = SparseGridInterpolator {
dimension,
bounds,
max_level: self.max_level,
grid_points: HashMap::new(),
adaptive: self.adaptive,
tolerance: self.tolerance,
stats: SparseGridStats::default(),
};
interpolator.generate_smolyak_grid(&func)?;
if self.adaptive {
interpolator.adaptive_refinement(&func)?;
}
Ok(interpolator)
}
pub fn build_from_data(
self,
points: &[Vec<F>],
values: &[F],
) -> InterpolateResult<SparseGridInterpolator<F>> {
if points.len() != values.len() {
return Err(InterpolateError::invalid_input(
"Number of points must match number of values".to_string(),
));
}
let bounds = self.bounds.ok_or_else(|| {
InterpolateError::invalid_input("Bounds must be specified".to_string())
})?;
let dimension = bounds.len();
if points.is_empty() {
return Err(InterpolateError::invalid_input(
"At least one data point required".to_string(),
));
}
for point in points {
if point.len() != dimension {
return Err(InterpolateError::invalid_input(
"All points must have the same dimensionality".to_string(),
));
}
}
let mut interpolator = SparseGridInterpolator {
dimension,
bounds,
max_level: self.max_level,
grid_points: HashMap::new(),
adaptive: false, tolerance: self.tolerance,
stats: SparseGridStats::default(),
};
interpolator.build_from_scattered_data(points, values)?;
Ok(interpolator)
}
}
impl<F> SparseGridInterpolator<F>
where
F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
{
fn generate_smolyak_grid<Func>(&mut self, func: &Func) -> InterpolateResult<()>
where
Func: Fn(&[F]) -> F,
{
let multi_indices = self.generate_admissible_indices();
for multi_idx in multi_indices {
self.add_hierarchical_points(&multi_idx, func)?;
}
self.stats.num_points = self.grid_points.len();
self.stats.max_level_reached = self.max_level;
Ok(())
}
fn generate_admissible_indices(&self) -> Vec<MultiIndex> {
let mut indices = Vec::new();
self.generate_indices_recursive(Vec::new(), 0, self.max_level, &mut indices);
indices
}
fn generate_indices_recursive(
&self,
current: Vec<usize>,
dim: usize,
remaining_sum: usize,
indices: &mut Vec<MultiIndex>,
) {
if dim == self.dimension {
if current.iter().sum::<usize>() <= self.max_level {
indices.push(MultiIndex::new(current));
}
return;
}
for i in 0..=remaining_sum {
let mut next = current.clone();
next.push(i);
self.generate_indices_recursive(next, dim + 1, remaining_sum, indices);
}
}
fn add_hierarchical_points<Func>(
&mut self,
multi_idx: &MultiIndex,
func: &Func,
) -> InterpolateResult<()>
where
Func: Fn(&[F]) -> F,
{
let points = self.generate_tensor_product_points(multi_idx);
for point_coords in points {
let grid_point_idx = self.coords_to_multi_index(&point_coords, multi_idx);
#[allow(clippy::map_entry)]
if !self.grid_points.contains_key(&grid_point_idx) {
let value = func(&point_coords);
self.stats.num_evaluations += 1;
let surplus = self.compute_hierarchical_surplus(&point_coords, value, multi_idx)?;
let grid_point = GridPoint {
coords: point_coords,
index: grid_point_idx.clone(),
surplus,
value,
};
self.grid_points.insert(grid_point_idx, grid_point);
}
}
Ok(())
}
fn generate_tensor_product_points(&self, multiidx: &MultiIndex) -> Vec<Vec<F>> {
let mut points = vec![Vec::new()];
for (dim, &level) in multiidx.indices.iter().enumerate() {
let dim_points = self.generate_1d_points(level, dim);
let mut new_points = Vec::new();
for point in &points {
for &dim_point in &dim_points {
let mut new_point = point.clone();
new_point.push(dim_point);
new_points.push(new_point);
}
}
points = new_points;
}
points
}
fn generate_1d_points(&self, level: usize, dim: usize) -> Vec<F> {
let (min_bound, max_bound) = self.bounds[dim];
let range = max_bound - min_bound;
if level == 0 {
vec![min_bound + range / F::from_f64(2.0).expect("Operation failed")]
} else {
let n_points = (1 << level) + 1;
let mut points = Vec::new();
for i in 0..n_points {
let t = F::from_usize(i).expect("Operation failed")
/ F::from_usize(n_points - 1).expect("Operation failed");
points.push(min_bound + t * range);
}
points
}
}
fn coords_to_multi_index(&self, coords: &[F], baseidx: &MultiIndex) -> MultiIndex {
let mut indices = baseidx.indices.clone();
for (i, &coord) in coords.iter().enumerate() {
let discretized = (coord * F::from_f64(1000.0).expect("Operation failed"))
.round()
.to_usize()
.unwrap_or(0);
indices[i] += discretized % 100; }
MultiIndex::new(indices)
}
fn compute_hierarchical_surplus(
&self,
coords: &[F],
value: F,
idx: &MultiIndex,
) -> InterpolateResult<F> {
Ok(value)
}
fn build_from_scattered_data(
&mut self,
points: &[Vec<F>],
values: &[F],
) -> InterpolateResult<()> {
for (i, (point, &value)) in points.iter().zip(values.iter()).enumerate() {
let multi_idx = MultiIndex::new(vec![i; self.dimension]);
let grid_point = GridPoint {
coords: point.clone(),
index: multi_idx.clone(),
surplus: value, value,
};
self.grid_points.insert(multi_idx, grid_point);
}
self.stats.num_points = self.grid_points.len();
self.stats.num_evaluations = points.len();
Ok(())
}
fn adaptive_refinement<Func>(&mut self, func: &Func) -> InterpolateResult<()>
where
Func: Fn(&[F]) -> F,
{
let max_iterations = 10;
for _iteration in 0..max_iterations {
let refinement_candidates = self.identify_refinement_candidates()?;
if refinement_candidates.is_empty() {
break; }
for candidate in refinement_candidates.iter().take(10) {
self.refine_around_point(candidate, func)?;
}
self.stats.num_points = self.grid_points.len();
if self.estimate_error()? < self.tolerance {
break;
}
}
Ok(())
}
fn identify_refinement_candidates(&self) -> InterpolateResult<Vec<MultiIndex>> {
let mut candidates = Vec::new();
for (idx, point) in &self.grid_points {
if point.surplus.abs() > self.tolerance {
candidates.push(idx.clone());
}
}
candidates.sort_by(|a, b| {
let surplus_a = self.grid_points[a].surplus.abs();
let surplus_b = self.grid_points[b].surplus.abs();
surplus_b
.partial_cmp(&surplus_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(candidates)
}
fn refine_around_point<Func>(
&mut self,
center_idx: &MultiIndex,
func: &Func,
) -> InterpolateResult<()>
where
Func: Fn(&[F]) -> F,
{
if let Some(center_point) = self.grid_points.get(center_idx) {
let center_coords = center_point.coords.clone();
for dim in 0..self.dimension {
for direction in [-1.0, 1.0] {
let mut new_coords = center_coords.clone();
let step = (self.bounds[dim].1 - self.bounds[dim].0)
/ F::from_f64(32.0).expect("Operation failed");
new_coords[dim] += F::from_f64(direction).expect("Operation failed") * step;
if new_coords[dim] >= self.bounds[dim].0
&& new_coords[dim] <= self.bounds[dim].1
{
let new_idx = self.coords_to_multi_index(&new_coords, center_idx);
#[allow(clippy::map_entry)]
if !self.grid_points.contains_key(&new_idx) {
let value = func(&new_coords);
self.stats.num_evaluations += 1;
let surplus =
self.compute_hierarchical_surplus(&new_coords, value, &new_idx)?;
let grid_point = GridPoint {
coords: new_coords,
index: new_idx.clone(),
surplus,
value,
};
self.grid_points.insert(new_idx, grid_point);
}
}
}
}
}
Ok(())
}
fn estimate_error(&self) -> InterpolateResult<F> {
let max_surplus = self
.grid_points
.values()
.map(|p| p.surplus.abs())
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(F::zero());
Ok(max_surplus)
}
pub fn interpolate(&self, query: &[F]) -> InterpolateResult<F> {
if query.len() != self.dimension {
return Err(InterpolateError::invalid_input(
"Query point dimension mismatch".to_string(),
));
}
for (i, &coord) in query.iter().enumerate() {
if coord < self.bounds[i].0 || coord > self.bounds[i].1 {
return Err(InterpolateError::OutOfBounds(
"Query point outside interpolation domain".to_string(),
));
}
}
let mut result = F::zero();
for point in self.grid_points.values() {
let weight = self.compute_hierarchical_weight(query, &point.coords);
result += weight * point.surplus;
}
Ok(result)
}
fn compute_hierarchical_weight(&self, query: &[F], gridpoint: &[F]) -> F {
let mut weight = F::one();
for i in 0..self.dimension {
let level_spacing =
F::from_f64(2.0_f64.powi(-(self.max_level as i32))).expect("Operation failed");
let h = (self.bounds[i].1 - self.bounds[i].0) * level_spacing;
let dist = (query[i] - gridpoint[i]).abs();
if dist <= h {
weight *= F::one() - dist / h;
} else {
let broad_h = h * F::from_f64(4.0).expect("Operation failed");
if dist <= broad_h {
weight *=
F::from_f64(0.25).expect("Operation failed") * (F::one() - dist / broad_h);
} else {
return F::zero(); }
}
}
weight
}
pub fn interpolate_multi(&self, queries: &[Vec<F>]) -> InterpolateResult<Vec<F>> {
queries.iter().map(|q| self.interpolate(q)).collect()
}
pub fn num_points(&self) -> usize {
self.stats.num_points
}
pub fn num_evaluations(&self) -> usize {
self.stats.num_evaluations
}
pub fn stats(&self) -> &SparseGridStats {
&self.stats
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn bounds(&self) -> &[(F, F)] {
&self.bounds
}
}
#[allow(dead_code)]
pub fn make_sparse_grid_interpolator<F, Func>(
bounds: Vec<(F, F)>,
max_level: usize,
func: Func,
) -> InterpolateResult<SparseGridInterpolator<F>>
where
F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
Func: Fn(&[F]) -> F,
{
SparseGridBuilder::new()
.with_bounds(bounds)
.with_max_level(max_level)
.build(func)
}
#[allow(dead_code)]
pub fn make_adaptive_sparse_grid_interpolator<F, Func>(
bounds: Vec<(F, F)>,
max_level: usize,
tolerance: F,
func: Func,
) -> InterpolateResult<SparseGridInterpolator<F>>
where
F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
Func: Fn(&[F]) -> F,
{
SparseGridBuilder::new()
.with_bounds(bounds)
.with_max_level(max_level)
.with_adaptive_refinement(true)
.with_tolerance(tolerance)
.build(func)
}
#[allow(dead_code)]
pub fn make_sparse_grid_from_data<F>(
bounds: Vec<(F, F)>,
points: &[Vec<F>],
values: &[F],
) -> InterpolateResult<SparseGridInterpolator<F>>
where
F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
{
SparseGridBuilder::new()
.with_bounds(bounds)
.build_from_data(points, values)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_multi_index() {
let idx = MultiIndex::new(vec![1, 2, 3]);
assert_eq!(idx.l1_norm(), 6);
assert_eq!(idx.linf_norm(), 3);
assert_eq!(idx.dim(), 3);
assert!(idx.is_admissible(8, 3)); assert!(!idx.is_admissible(5, 3)); }
#[test]
fn test_sparse_grid_1d() {
let bounds = vec![(0.0, 1.0)];
let interpolator = make_sparse_grid_interpolator(
bounds,
3,
|x: &[f64]| x[0] * x[0], )
.expect("Operation failed");
let result = interpolator.interpolate(&[0.5]).expect("Operation failed");
assert!((0.0..=1.0).contains(&result));
assert!(interpolator.num_points() > 0);
}
#[test]
fn test_sparse_grid_2d() {
let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
let interpolator = make_sparse_grid_interpolator(
bounds,
2,
|x: &[f64]| x[0] + x[1], )
.expect("Operation failed");
let result = interpolator
.interpolate(&[0.5, 0.5])
.expect("Operation failed");
assert_relative_eq!(result, 1.0, epsilon = 0.5);
let num_points = interpolator.num_points();
assert!(num_points > 0);
assert!(num_points < 100); }
#[test]
fn test_adaptive_sparse_grid() {
let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
let interpolator = make_adaptive_sparse_grid_interpolator(
bounds,
3,
1e-3,
|x: &[f64]| (x[0] - 0.5).powi(2) + (x[1] - 0.5).powi(2), )
.expect("Operation failed");
let result = interpolator
.interpolate(&[0.5, 0.5])
.expect("Operation failed");
assert_relative_eq!(result, 0.0, epsilon = 0.1);
let result_corner = interpolator
.interpolate(&[0.0, 0.0])
.expect("Operation failed");
assert_relative_eq!(result_corner, 0.5, epsilon = 8.0);
}
#[test]
fn test_high_dimensional_sparse_grid() {
let bounds = vec![(0.0, 1.0); 5]; let interpolator = make_sparse_grid_interpolator(
bounds,
2,
|x: &[f64]| x.iter().sum::<f64>(), )
.expect("Operation failed");
let query = vec![0.2; 5];
let result = interpolator.interpolate(&query).expect("Operation failed");
assert_relative_eq!(result, 1.0, epsilon = 1.0);
let num_points = interpolator.num_points();
assert!(num_points > 0);
assert!(num_points < 1000); }
#[test]
fn test_sparse_grid_from_data() {
let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
let points = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
vec![0.5, 0.5],
];
let values = vec![0.0, 1.0, 1.0, 2.0, 1.0];
let interpolator =
make_sparse_grid_from_data(bounds, &points, &values).expect("Operation failed");
for (point, &expected) in points.iter().zip(values.iter()) {
let result = interpolator.interpolate(point).expect("Operation failed");
assert_relative_eq!(result, expected, epsilon = 0.1);
}
}
#[test]
fn test_multi_interpolation() {
let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
let interpolator = make_sparse_grid_interpolator(
bounds,
2,
|x: &[f64]| x[0] * x[1], )
.expect("Operation failed");
let queries = vec![
vec![0.25, 0.25],
vec![0.75, 0.25],
vec![0.25, 0.75],
vec![0.75, 0.75],
];
let results = interpolator
.interpolate_multi(&queries)
.expect("Operation failed");
assert_eq!(results.len(), 4);
for result in results {
assert!((0.0..=1.0).contains(&result));
}
}
#[test]
fn test_builder_pattern() {
let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
let interpolator = SparseGridBuilder::new()
.with_bounds(bounds)
.with_max_level(2)
.with_adaptive_refinement(false)
.with_tolerance(1e-4)
.build(|x: &[f64]| x[0] + x[1])
.expect("Operation failed");
assert_eq!(interpolator.dimension(), 2);
assert!(interpolator.num_points() > 0);
}
#[test]
fn test_error_handling() {
let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
let interpolator = make_sparse_grid_interpolator(bounds, 2, |x: &[f64]| x[0] + x[1])
.expect("Operation failed");
let result = interpolator.interpolate(&[0.5]);
assert!(result.is_err());
let result = interpolator.interpolate(&[1.5, 0.5]);
assert!(result.is_err());
}
}