use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use super::natural::NaturalNeighborInterpolator;
use crate::error::{InterpolateError, InterpolateResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExtrapolationMethod {
NearestNeighbor,
InverseDistanceWeighting,
LinearGradient,
ConstantValue,
}
#[derive(Debug, Clone)]
pub struct ExtrapolationParams<F: Float + FromPrimitive + Debug + ordered_float::FloatCore> {
pub method: ExtrapolationMethod,
pub n_neighbors: usize,
pub idw_power: F,
pub constant_value: F,
}
impl<F: Float + FromPrimitive + Debug + ordered_float::FloatCore> Default
for ExtrapolationParams<F>
{
fn default() -> Self {
ExtrapolationParams {
method: ExtrapolationMethod::NearestNeighbor,
n_neighbors: 3,
idw_power: F::from(2.0).expect("Failed to convert constant to float"), constant_value: F::zero(),
}
}
}
pub trait Extrapolation<F: Float + FromPrimitive + Debug + ordered_float::FloatCore> {
fn extrapolate(
&self,
query: &ArrayView1<F>,
params: &ExtrapolationParams<F>,
) -> InterpolateResult<F>;
fn extrapolate_multi(
&self,
queries: &ArrayView2<F>,
params: &ExtrapolationParams<F>,
) -> InterpolateResult<Array1<F>>;
fn interpolate_or_extrapolate(
&self,
query: &ArrayView1<F>,
params: &ExtrapolationParams<F>,
) -> InterpolateResult<F>;
fn interpolate_or_extrapolate_multi(
&self,
queries: &ArrayView2<F>,
params: &ExtrapolationParams<F>,
) -> InterpolateResult<Array1<F>>;
}
impl<
F: Float
+ FromPrimitive
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ for<'a> std::iter::Sum<&'a F>
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
> Extrapolation<F> for NaturalNeighborInterpolator<F>
{
fn extrapolate(
&self,
query: &ArrayView1<F>,
params: &ExtrapolationParams<F>,
) -> InterpolateResult<F> {
let dim = query.len();
if dim != self.points.ncols() {
return Err(InterpolateError::DimensionMismatch(format!(
"Query point dimension ({}) does not match data dimension ({})",
dim,
self.points.ncols()
)));
}
match params.method {
ExtrapolationMethod::NearestNeighbor => {
let indices = self.kdtree.query_nearest(query, 1)?;
if indices.is_empty() {
return Err(InterpolateError::InterpolationFailed(
"Nearest neighbor search failed".to_string(),
));
}
Ok(self.values[indices[0]])
}
ExtrapolationMethod::InverseDistanceWeighting => {
let k = params.n_neighbors.min(self.points.nrows());
let indices = self.kdtree.query_nearest(query, k)?;
if indices.is_empty() {
return Err(InterpolateError::InterpolationFailed(
"Nearest neighbor search failed".to_string(),
));
}
let mut weighted_sum = F::zero();
let mut weight_sum = F::zero();
for &idx in &indices {
let point = self.points.row(idx);
let mut dist_sq = F::zero();
for j in 0..dim {
dist_sq =
dist_sq + scirs2_core::numeric::Float::powi(point[j] - query[j], 2);
}
if dist_sq < <F as scirs2_core::numeric::Float>::epsilon() {
return Ok(self.values[idx]);
}
let weight = F::one()
/ scirs2_core::numeric::Float::powf(
dist_sq,
params.idw_power
/ F::from(2.0).expect("Failed to convert constant to float"),
);
weighted_sum = weighted_sum + weight * self.values[idx];
weight_sum = weight_sum + weight;
}
if weight_sum > F::zero() {
Ok(weighted_sum / weight_sum)
} else {
Err(InterpolateError::InterpolationFailed(
"All weights are zero in inverse distance weighting".to_string(),
))
}
}
ExtrapolationMethod::LinearGradient => {
let indices = self.kdtree.query_nearest(query, 1)?;
if indices.is_empty() {
return Err(InterpolateError::InterpolationFailed(
"Nearest neighbor search failed".to_string(),
));
}
let nearest_idx = indices[0];
let nearest_point = self.points.row(nearest_idx);
let nearest_value = self.values[nearest_idx];
let nearest_query = nearest_point.to_owned();
let gradient = match super::gradient::GradientEstimation::gradient(
self,
&nearest_query.view(),
) {
Ok(grad) => grad,
Err(_) => {
return Ok(nearest_value);
}
};
let mut extrapolated_value = nearest_value;
for j in 0..dim {
extrapolated_value =
extrapolated_value + gradient[j] * (query[j] - nearest_point[j]);
}
Ok(extrapolated_value)
}
ExtrapolationMethod::ConstantValue => {
Ok(params.constant_value)
}
}
}
fn extrapolate_multi(
&self,
queries: &ArrayView2<F>,
params: &ExtrapolationParams<F>,
) -> InterpolateResult<Array1<F>> {
let n_queries = queries.nrows();
let dim = queries.ncols();
if dim != self.points.ncols() {
return Err(InterpolateError::DimensionMismatch(format!(
"Query points dimension ({}) does not match data dimension ({})",
dim,
self.points.ncols()
)));
}
let mut results = Array1::zeros(n_queries);
for i in 0..n_queries {
let query = queries.row(i);
results[i] = self.extrapolate(&query, params)?;
}
Ok(results)
}
fn interpolate_or_extrapolate(
&self,
query: &ArrayView1<F>,
params: &ExtrapolationParams<F>,
) -> InterpolateResult<F> {
match self.interpolate(query) {
Ok(value) => Ok(value),
Err(_) => {
self.extrapolate(query, params)
}
}
}
fn interpolate_or_extrapolate_multi(
&self,
queries: &ArrayView2<F>,
params: &ExtrapolationParams<F>,
) -> InterpolateResult<Array1<F>> {
let n_queries = queries.nrows();
let dim = queries.ncols();
if dim != self.points.ncols() {
return Err(InterpolateError::DimensionMismatch(format!(
"Query points dimension ({}) does not match data dimension ({})",
dim,
self.points.ncols()
)));
}
let mut results = Array1::zeros(n_queries);
for i in 0..n_queries {
let query = queries.row(i);
results[i] = self.interpolate_or_extrapolate(&query, params)?;
}
Ok(results)
}
}
#[allow(dead_code)]
pub fn nearest_neighbor_extrapolation<
F: crate::traits::InterpolationFloat + ordered_float::FloatCore,
>() -> ExtrapolationParams<F> {
ExtrapolationParams {
method: ExtrapolationMethod::NearestNeighbor,
..Default::default()
}
}
#[allow(dead_code)]
pub fn inverse_distance_extrapolation<
F: crate::traits::InterpolationFloat + ordered_float::FloatCore,
>(
n_neighbors: usize,
power: F,
) -> ExtrapolationParams<F> {
ExtrapolationParams {
method: ExtrapolationMethod::InverseDistanceWeighting,
n_neighbors,
idw_power: power,
..Default::default()
}
}
#[allow(dead_code)]
pub fn linear_gradient_extrapolation<
F: crate::traits::InterpolationFloat + ordered_float::FloatCore,
>() -> ExtrapolationParams<F> {
ExtrapolationParams {
method: ExtrapolationMethod::LinearGradient,
..Default::default()
}
}
#[allow(dead_code)]
pub fn constant_value_extrapolation<
F: crate::traits::InterpolationFloat + ordered_float::FloatCore,
>(
value: F,
) -> ExtrapolationParams<F> {
ExtrapolationParams {
method: ExtrapolationMethod::ConstantValue,
constant_value: value,
..Default::default()
}
}