use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::HashMap;
use std::fmt::Debug;
use super::voronoi_cell::VoronoiDiagram;
use crate::error::{InterpolateError, InterpolateResult};
use crate::spatial::kdtree::KdTree;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InterpolationMethod {
Sibson,
Laplace,
}
#[derive(Debug, Clone)]
pub struct NaturalNeighborInterpolator<
F: Float
+ FromPrimitive
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
> {
voronoi_diagram: VoronoiDiagram<F>,
pub points: Array2<F>,
pub values: Array1<F>,
method: InterpolationMethod,
pub kdtree: KdTree<F>,
}
impl<
F: Float
+ FromPrimitive
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
> NaturalNeighborInterpolator<F>
{
pub fn new(
points: Array2<F>,
values: Array1<F>,
method: InterpolationMethod,
) -> InterpolateResult<Self> {
let n_points = points.nrows();
let dim = points.ncols();
if n_points != values.len() {
return Err(InterpolateError::DimensionMismatch(format!(
"Number of points ({}) does not match number of values ({})",
n_points,
values.len()
)));
}
if n_points < 3 {
return Err(InterpolateError::InsufficientData(
"At least 3 data points are required for Natural Neighbor interpolation"
.to_string(),
));
}
if dim != 2 && dim != 3 {
return Err(InterpolateError::UnsupportedOperation(format!(
"Natural Neighbor interpolation for {dim}-dimensional data not yet implemented"
)));
}
let voronoi_diagram = VoronoiDiagram::new(points.view(), values.view(), None)?;
let kdtree = KdTree::new(points.view())?;
Ok(NaturalNeighborInterpolator {
voronoi_diagram,
points,
values,
method,
kdtree,
})
}
pub fn interpolate(&self, query: &ArrayView1<F>) -> InterpolateResult<F> {
let dim = query.len();
if dim != self.points.ncols() {
return Err(InterpolateError::DimensionMismatch(format!(
"Query point dimension ({}) does not match data dimension ({})",
dim,
self.points.ncols()
)));
}
for i in 0..self.points.nrows() {
let point = self.points.row(i);
let mut is_same = true;
for j in 0..dim {
if scirs2_core::numeric::Float::abs(point[j] - query[j])
> <F as scirs2_core::numeric::Float>::epsilon()
{
is_same = false;
break;
}
}
if is_same {
return Ok(self.values[i]);
}
}
let neighbor_weights = match self.voronoi_diagram.natural_neighbors(query) {
Ok(weights) => weights,
Err(_) => {
let (idx, _) = self.kdtree.nearest_neighbor(&query.to_vec())?;
let mut weights = HashMap::new();
weights.insert(idx, F::one());
weights
}
};
if neighbor_weights.is_empty() {
let (idx, _) = self.kdtree.nearest_neighbor(&query.to_vec())?;
return Ok(self.values[idx]);
}
match self.method {
InterpolationMethod::Sibson => {
let mut interpolated_value = F::zero();
let mut total_weight = F::zero();
for (idx, weight) in neighbor_weights.iter() {
interpolated_value = interpolated_value + self.values[*idx] * *weight;
total_weight = total_weight + *weight;
}
if total_weight > F::zero() {
interpolated_value = interpolated_value / total_weight;
} else {
return Err(InterpolateError::InterpolationFailed(
"Total weight is zero in Sibson interpolation".to_string(),
));
}
Ok(interpolated_value)
}
InterpolationMethod::Laplace => {
let mut interpolated_value = F::zero();
let mut total_weight = F::zero();
for (idx_, _) in neighbor_weights.iter() {
let site = &self.voronoi_diagram.cells[*idx_].site;
let mut distance = F::zero();
for j in 0..dim {
distance =
distance + scirs2_core::numeric::Float::powi(site[j] - query[j], 2);
}
distance = distance.sqrt();
if distance < <F as scirs2_core::numeric::Float>::epsilon() {
return Ok(self.values[*idx_]);
}
let weight = F::one() / distance;
interpolated_value = interpolated_value + self.values[*idx_] * weight;
total_weight = total_weight + weight;
}
if total_weight > F::zero() {
interpolated_value = interpolated_value / total_weight;
} else {
return Err(InterpolateError::InterpolationFailed(
"Total weight is zero in Laplace interpolation".to_string(),
));
}
Ok(interpolated_value)
}
}
}
pub fn interpolate_multi(&self, queries: &ArrayView2<F>) -> InterpolateResult<Array1<F>> {
let n_queries = queries.nrows();
let dim = queries.ncols();
if dim != self.points.ncols() {
return Err(InterpolateError::DimensionMismatch(format!(
"Query point dimension ({}) does not match data dimension ({})",
dim,
self.points.ncols()
)));
}
let mut results = Array1::zeros(n_queries);
for i in 0..n_queries {
let query = queries.row(i);
results[i] = self.interpolate(&query)?;
}
Ok(results)
}
pub fn voronoi_diagram(&self) -> &VoronoiDiagram<F> {
&self.voronoi_diagram
}
pub fn set_method(&mut self, method: InterpolationMethod) {
self.method = method;
}
pub fn method(&self) -> InterpolationMethod {
self.method
}
}
#[allow(dead_code)]
pub fn make_natural_neighbor_interpolator<
F: Float
+ FromPrimitive
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
>(
points: Array2<F>,
values: Array1<F>,
method: InterpolationMethod,
) -> InterpolateResult<NaturalNeighborInterpolator<F>> {
NaturalNeighborInterpolator::new(points, values, method)
}
#[allow(dead_code)]
pub fn make_sibson_interpolator<
F: Float
+ FromPrimitive
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
>(
points: Array2<F>,
values: Array1<F>,
) -> InterpolateResult<NaturalNeighborInterpolator<F>> {
NaturalNeighborInterpolator::new(points, values, InterpolationMethod::Sibson)
}
#[allow(dead_code)]
pub fn make_laplace_interpolator<
F: Float
+ FromPrimitive
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
>(
points: Array2<F>,
values: Array1<F>,
) -> InterpolateResult<NaturalNeighborInterpolator<F>> {
NaturalNeighborInterpolator::new(points, values, InterpolationMethod::Laplace)
}