use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use super::natural::{InterpolationMethod, NaturalNeighborInterpolator};
use crate::error::{InterpolateError, InterpolateResult};
pub trait Interpolator<F: Float + FromPrimitive + Debug + ordered_float::FloatCore> {
fn interpolate(&self, query: &ArrayView1<F>) -> InterpolateResult<F>;
}
impl<
F: Float
+ FromPrimitive
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
> Interpolator<F> for NaturalNeighborInterpolator<F>
{
fn interpolate(&self, query: &ArrayView1<F>) -> InterpolateResult<F> {
NaturalNeighborInterpolator::interpolate(self, query)
}
}
pub trait GradientEstimation<F: Float + FromPrimitive + Debug + ordered_float::FloatCore> {
fn gradient(&self, query: &ArrayView1<F>) -> InterpolateResult<Array1<F>>;
fn gradient_multi(&self, queries: &ArrayView2<F>) -> InterpolateResult<Array2<F>>;
}
impl<
F: Float
+ FromPrimitive
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ for<'a> std::iter::Sum<&'a F>
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
> GradientEstimation<F> for NaturalNeighborInterpolator<F>
{
fn gradient(&self, query: &ArrayView1<F>) -> InterpolateResult<Array1<F>> {
let dim = query.len();
if dim != self.points.ncols() {
return Err(InterpolateError::DimensionMismatch(format!(
"Query point dimension ({}) does not match data dimension ({})",
dim,
self.points.ncols()
)));
}
let neighbor_weights = self.voronoi_diagram().natural_neighbors(query)?;
if neighbor_weights.is_empty() {
return finite_difference_gradient(self, query);
}
match self.method() {
InterpolationMethod::Sibson => {
let mut gradient = Array1::zeros(dim);
for (idx, weight) in neighbor_weights.iter() {
let neighbor_point = self.points.row(*idx);
let neighbor_value = self.values[*idx];
for d in 0..dim {
let coordinate_diff = neighbor_point[d] - query[d];
gradient[d] = gradient[d] + *weight * neighbor_value * coordinate_diff;
}
}
let weight_sum: F = neighbor_weights.values().sum();
if weight_sum > F::zero() {
gradient = gradient / weight_sum;
}
Ok(gradient)
}
InterpolationMethod::Laplace => {
let mut gradient = Array1::zeros(dim);
let center_value = self.interpolate(query)?;
let mut total_weight = F::zero();
for (idx, weight) in neighbor_weights.iter() {
let neighbor_point = self.points.row(*idx);
let neighbor_value = self.values[*idx];
let mut distance = F::zero();
for d in 0..dim {
distance = distance
+ scirs2_core::numeric::Float::powi(neighbor_point[d] - query[d], 2);
}
distance = distance.sqrt();
if distance < <F as scirs2_core::numeric::Float>::epsilon() {
continue;
}
let value_diff = neighbor_value - center_value;
for d in 0..dim {
let coordinate_diff = neighbor_point[d] - query[d];
let dir_deriv = value_diff / distance;
gradient[d] =
gradient[d] + *weight * dir_deriv * coordinate_diff / distance;
}
total_weight = total_weight + *weight;
}
if total_weight > F::zero() {
gradient = gradient / total_weight;
}
Ok(gradient)
}
}
}
fn gradient_multi(&self, queries: &ArrayView2<F>) -> InterpolateResult<Array2<F>> {
let n_queries = queries.nrows();
let dim = queries.ncols();
if dim != self.points.ncols() {
return Err(InterpolateError::DimensionMismatch(format!(
"Query points dimension ({}) does not match data dimension ({})",
dim,
self.points.ncols()
)));
}
let mut gradients = Array2::zeros((n_queries, dim));
for i in 0..n_queries {
let query = queries.row(i);
let gradient = self.gradient(&query)?;
gradients.row_mut(i).assign(&gradient);
}
Ok(gradients)
}
}
#[allow(dead_code)]
fn finite_difference_gradient<F, T>(
interpolator: &T,
query: &ArrayView1<F>,
) -> InterpolateResult<Array1<F>>
where
F: Float + FromPrimitive + Debug + ordered_float::FloatCore,
T: GradientEstimation<F> + Interpolator<F>,
{
let dim = query.len();
let mut gradient = Array1::zeros(dim);
let h = F::from(1e-6).expect("Failed to convert constant to float");
let center_value = match interpolator.interpolate(query) {
Ok(v) => v,
Err(_) => {
for d in 0..dim {
let mut forward_query = query.to_owned();
forward_query[d] = forward_query[d] + h;
if let Ok(forward_value) = interpolator.interpolate(&forward_query.view()) {
gradient[d] = forward_value / h; }
}
return Ok(gradient);
}
};
for d in 0..dim {
let mut forward_query = query.to_owned();
forward_query[d] = forward_query[d] + h;
let mut backward_query = query.to_owned();
backward_query[d] = backward_query[d] - h;
let forward_result = interpolator.interpolate(&forward_query.view());
let backward_result = interpolator.interpolate(&backward_query.view());
match (forward_result, backward_result) {
(Ok(forward_value), Ok(backward_value)) => {
gradient[d] = (forward_value - backward_value) / (h + h);
}
(Ok(forward_value), Err(_)) => {
gradient[d] = (forward_value - center_value) / h;
}
(Err(_), Ok(backward_value)) => {
gradient[d] = (center_value - backward_value) / h;
}
(Err(_), Err(_)) => {
gradient[d] = F::zero();
}
}
}
Ok(gradient)
}
pub struct InterpolateWithGradientResult<
F: Float + FromPrimitive + Debug + ordered_float::FloatCore,
> {
pub value: F,
pub gradient: Array1<F>,
}
pub trait InterpolateWithGradient<F: Float + FromPrimitive + Debug + ordered_float::FloatCore> {
fn interpolate_with_gradient(
&self,
query: &ArrayView1<F>,
) -> InterpolateResult<InterpolateWithGradientResult<F>>;
fn interpolate_with_gradient_multi(
&self,
queries: &ArrayView2<F>,
) -> InterpolateResult<Vec<InterpolateWithGradientResult<F>>>;
}
impl<
F: Float
+ FromPrimitive
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ for<'a> std::iter::Sum<&'a F>
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
> InterpolateWithGradient<F> for NaturalNeighborInterpolator<F>
{
fn interpolate_with_gradient(
&self,
query: &ArrayView1<F>,
) -> InterpolateResult<InterpolateWithGradientResult<F>> {
let value = self.interpolate(query)?;
let gradient = self.gradient(query)?;
Ok(InterpolateWithGradientResult { value, gradient })
}
fn interpolate_with_gradient_multi(
&self,
queries: &ArrayView2<F>,
) -> InterpolateResult<Vec<InterpolateWithGradientResult<F>>> {
let n_queries = queries.nrows();
let mut results = Vec::with_capacity(n_queries);
let values = self.interpolate_multi(queries)?;
let gradients = self.gradient_multi(queries)?;
for i in 0..n_queries {
results.push(InterpolateWithGradientResult {
value: values[i],
gradient: gradients.row(i).to_owned(),
});
}
Ok(results)
}
}