use crate::distance::EuclideanDistance;
use crate::error::{SpatialError, SpatialResult};
use crate::kdtree::KDTree;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
#[derive(Debug, Clone)]
pub struct IDWInterpolator {
points: Array2<f64>,
values: Array1<f64>,
dim: usize,
n_points: usize,
power: f64,
n_neighbors: Option<usize>,
kdtree: KDTree<f64, EuclideanDistance<f64>>,
}
impl IDWInterpolator {
pub fn new(
points: &ArrayView2<'_, f64>,
values: &ArrayView1<f64>,
power: f64,
n_neighbors: Option<usize>,
) -> SpatialResult<Self> {
let n_points = points.nrows();
let dim = points.ncols();
if n_points != values.len() {
return Err(SpatialError::DimensionError(format!(
"Number of points ({}) must match number of values ({})",
n_points,
values.len()
)));
}
if power < 0.0 {
return Err(SpatialError::ValueError(format!(
"Power parameter must be non-negative, got {power}"
)));
}
if let Some(k) = n_neighbors {
if k == 0 {
return Err(SpatialError::ValueError(
"Number of _neighbors must be positive".to_string(),
));
}
if k > n_points {
return Err(SpatialError::ValueError(format!(
"Number of _neighbors ({k}) cannot exceed number of points ({n_points})"
)));
}
}
let kdtree = KDTree::new(&points.to_owned())?;
Ok(Self {
points: points.to_owned(),
values: values.to_owned(),
dim,
n_points,
power,
n_neighbors,
kdtree,
})
}
pub fn interpolate(&self, point: &ArrayView1<f64>) -> SpatialResult<f64> {
if point.len() != self.dim {
return Err(SpatialError::DimensionError(format!(
"Query point has dimension {}, expected {}",
point.len(),
self.dim
)));
}
for i in 0..self.n_points {
let data_point = self.points.row(i);
if Self::is_same_point(&data_point, point) {
return Ok(self.values[i]);
}
}
let (indices, distances) = match self.n_neighbors {
Some(k) => {
self.kdtree
.query(point.as_slice().expect("Operation failed"), k)?
}
None => {
let mut indices = Vec::with_capacity(self.n_points);
let mut distances = Vec::with_capacity(self.n_points);
for i in 0..self.n_points {
let data_point = self.points.row(i);
let dist_sq = Self::squared_distance(&data_point, point);
indices.push(i);
distances.push(dist_sq);
}
(indices, distances)
}
};
let mut weighted_sum = 0.0;
let mut weight_sum = 0.0;
for i in 0..indices.len() {
let dist_sq = distances[i];
if dist_sq < 1e-10 {
return Ok(self.values[indices[i]]);
}
let weight = 1.0 / dist_sq.powf(self.power / 2.0);
weighted_sum += weight * self.values[indices[i]];
weight_sum += weight;
}
if weight_sum > 0.0 {
Ok(weighted_sum / weight_sum)
} else {
Err(SpatialError::ComputationError(
"Zero weight sum in IDW interpolation".to_string(),
))
}
}
pub fn interpolate_many(&self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array1<f64>> {
if points.ncols() != self.dim {
return Err(SpatialError::DimensionError(format!(
"Query _points have dimension {}, expected {}",
points.ncols(),
self.dim
)));
}
let n_queries = points.nrows();
let mut results = Array1::zeros(n_queries);
for i in 0..n_queries {
let point = points.row(i);
results[i] = self.interpolate(&point)?;
}
Ok(results)
}
pub fn power(&self) -> f64 {
self.power
}
pub fn n_neighbors(&self) -> Option<usize> {
self.n_neighbors
}
pub fn set_power(&mut self, power: f64) -> SpatialResult<()> {
if power < 0.0 {
return Err(SpatialError::ValueError(format!(
"Power parameter must be non-negative, got {power}"
)));
}
self.power = power;
Ok(())
}
pub fn set_n_neighbors(&mut self, _nneighbors: Option<usize>) -> SpatialResult<()> {
if let Some(k) = _nneighbors {
if k == 0 {
return Err(SpatialError::ValueError(
"Number of _neighbors must be positive".to_string(),
));
}
if k > self.n_points {
return Err(SpatialError::ValueError(format!(
"Number of _neighbors ({}) cannot exceed number of points ({})",
k, self.n_points
)));
}
}
self.n_neighbors = _nneighbors;
Ok(())
}
fn is_same_point(p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> bool {
Self::squared_distance(p1, p2) < 1e-10
}
fn squared_distance(p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> f64 {
let mut sum_sq = 0.0;
for i in 0..p1.len().min(p2.len()) {
let diff = p1[i] - p2[i];
sum_sq += diff * diff;
}
sum_sq
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_idw_interpolation_basic() {
let points = array![
[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], ];
let values = array![0.0, 1.0, 1.0, 2.0];
for power in &[1.0, 2.0, 3.0] {
let interp = IDWInterpolator::new(&points.view(), &values.view(), *power, None)
.expect("Operation failed");
let val_00 = interp
.interpolate(&array![0.0, 0.0].view())
.expect("Operation failed");
let val_10 = interp
.interpolate(&array![1.0, 0.0].view())
.expect("Operation failed");
let val_01 = interp
.interpolate(&array![0.0, 1.0].view())
.expect("Operation failed");
let val_11 = interp
.interpolate(&array![1.0, 1.0].view())
.expect("Operation failed");
assert_relative_eq!(val_00, 0.0, epsilon = 1e-10);
assert_relative_eq!(val_10, 1.0, epsilon = 1e-10);
assert_relative_eq!(val_01, 1.0, epsilon = 1e-10);
assert_relative_eq!(val_11, 2.0, epsilon = 1e-10);
let val_center = interp
.interpolate(&array![0.5, 0.5].view())
.expect("Operation failed");
assert_relative_eq!(val_center, 1.0, epsilon = 0.1);
}
}
#[test]
fn test_idw_with_neighbors() {
let points = array![
[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5], [0.2, 0.8], [0.8, 0.2], [0.3, 0.3], [0.7, 0.7], ];
let values = Array1::from_vec(
points
.rows()
.into_iter()
.map(|row| row[0] + row[1])
.collect(),
);
let interp_all = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
.expect("Operation failed");
let interp_3 = IDWInterpolator::new(&points.view(), &values.view(), 2.0, Some(3))
.expect("Operation failed");
let test_point = array![0.6, 0.4];
let val_all = interp_all
.interpolate(&test_point.view())
.expect("Operation failed");
let val_3 = interp_3
.interpolate(&test_point.view())
.expect("Operation failed");
assert_relative_eq!(val_all, 1.0, epsilon = 0.1);
assert_relative_eq!(val_3, 1.0, epsilon = 0.1);
}
#[test]
fn test_interpolate_many() {
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
let values = array![0.0, 1.0, 1.0, 2.0];
let interp = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
.expect("Operation failed");
let query_points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5],];
let results = interp
.interpolate_many(&query_points.view())
.expect("Operation failed");
assert_eq!(results.len(), 5);
assert_relative_eq!(results[0], 0.0, epsilon = 1e-10);
assert_relative_eq!(results[1], 1.0, epsilon = 1e-10);
assert_relative_eq!(results[2], 1.0, epsilon = 1e-10);
assert_relative_eq!(results[3], 2.0, epsilon = 1e-10);
assert_relative_eq!(results[4], 1.0, epsilon = 0.1);
}
#[test]
fn test_setter_methods() {
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
let values = array![0.0, 1.0, 1.0, 2.0];
let mut interp = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
.expect("Operation failed");
assert_eq!(interp.power(), 2.0);
assert_eq!(interp.n_neighbors(), None);
interp.set_power(3.0).expect("Operation failed");
assert_eq!(interp.power(), 3.0);
interp.set_n_neighbors(Some(2)).expect("Operation failed");
assert_eq!(interp.n_neighbors(), Some(2));
let result = interp.set_power(-1.0);
assert!(result.is_err());
let result = interp.set_n_neighbors(Some(0));
assert!(result.is_err());
let result = interp.set_n_neighbors(Some(10));
assert!(result.is_err());
}
#[test]
fn test_error_handling() {
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let values = array![0.0, 1.0, 1.0];
let interp = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
.expect("Operation failed");
let result = interp.interpolate(&array![0.0].view());
assert!(result.is_err());
let result = IDWInterpolator::new(&points.view(), &values.view(), -1.0, None);
assert!(result.is_err());
let result = IDWInterpolator::new(&points.view(), &values.view(), 2.0, Some(0));
assert!(result.is_err());
let result = IDWInterpolator::new(&points.view(), &values.view(), 2.0, Some(10));
assert!(result.is_err());
}
}