use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_linalg::solve;
#[inline]
fn euclidean_dist(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(ai, bi)| (ai - bi).powi(2))
.sum::<f64>()
.sqrt()
}
fn sorted_distances(query: &[f64], points: &Array2<f64>) -> Vec<(f64, usize)> {
let n = points.nrows();
let d = points.ncols();
let mut dists: Vec<(f64, usize)> = (0..n)
.map(|i| {
let row: Vec<f64> = (0..d).map(|k| points[[i, k]]).collect();
(euclidean_dist(query, &row), i)
})
.collect();
dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
dists
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ShepardMode {
Global {
power: f64,
},
Modified {
radius: f64,
k_auto: usize,
},
FrankeLittle {
radius: f64,
power: f64,
},
}
pub struct ShepardInterpolant {
points: Array2<f64>,
values: Array1<f64>,
mode: ShepardMode,
radii: Option<Array1<f64>>,
}
impl ShepardInterpolant {
pub fn new(
points: &ArrayView2<f64>,
values: &ArrayView1<f64>,
mode: ShepardMode,
) -> InterpolateResult<Self> {
let n = points.nrows();
if values.len() != n {
return Err(InterpolateError::DimensionMismatch(format!(
"points has {n} rows, values has {} entries",
values.len()
)));
}
if n == 0 {
return Err(InterpolateError::InsufficientData(
"Shepard interpolation requires at least one data point".to_string(),
));
}
match mode {
ShepardMode::Global { power } if power <= 0.0 => {
return Err(InterpolateError::InvalidInput {
message: format!("Shepard power must be > 0, got {power}"),
});
}
ShepardMode::FrankeLittle { radius, power } => {
if radius <= 0.0 {
return Err(InterpolateError::InvalidInput {
message: format!("Franke-Little radius must be > 0, got {radius}"),
});
}
if power <= 0.0 {
return Err(InterpolateError::InvalidInput {
message: format!("Franke-Little power must be > 0, got {power}"),
});
}
}
_ => {}
}
let pts_owned = points.to_owned();
let radii = match mode {
ShepardMode::Modified { radius, k_auto } if radius == 0.0 => {
let k = k_auto.max(1).min(n - 1);
let mut r_vec = Array1::<f64>::zeros(n);
for i in 0..n {
let qi: Vec<f64> = (0..pts_owned.ncols()).map(|k2| pts_owned[[i, k2]]).collect();
let dists = sorted_distances(&qi, &pts_owned);
let kth = dists.get(k).map(|(d, _)| *d).unwrap_or(1.0);
r_vec[i] = kth * 2.0; }
Some(r_vec)
}
_ => None,
};
Ok(Self {
points: pts_owned,
values: values.to_owned(),
mode,
radii,
})
}
pub fn evaluate(&self, query: &[f64]) -> InterpolateResult<f64> {
let d = self.points.ncols();
if query.len() != d {
return Err(InterpolateError::DimensionMismatch(format!(
"Query has {} dims, points have {d}",
query.len()
)));
}
match self.mode {
ShepardMode::Global { power } => self.eval_global(query, power),
ShepardMode::Modified { radius, .. } => self.eval_modified(query, radius),
ShepardMode::FrankeLittle { radius, power } => {
self.eval_franke_little(query, radius, power)
}
}
}
pub fn evaluate_batch(&self, queries: &ArrayView2<f64>) -> InterpolateResult<Array1<f64>> {
let nq = queries.nrows();
let mut out = Array1::<f64>::zeros(nq);
for i in 0..nq {
let q: Vec<f64> = (0..queries.ncols()).map(|j| queries[[i, j]]).collect();
out[i] = self.evaluate(&q)?;
}
Ok(out)
}
fn eval_global(&self, query: &[f64], power: f64) -> InterpolateResult<f64> {
let n = self.points.nrows();
let d = self.points.ncols();
let mut wsum = 0.0_f64;
let mut fsum = 0.0_f64;
for i in 0..n {
let row: Vec<f64> = (0..d).map(|k| self.points[[i, k]]).collect();
let r = euclidean_dist(query, &row);
if r <= 0.0 {
return Ok(self.values[i]);
}
let w = r.powf(-power);
wsum += w;
fsum += w * self.values[i];
}
if wsum == 0.0 {
return Err(InterpolateError::NumericalError(
"All weights are zero in global Shepard".to_string(),
));
}
Ok(fsum / wsum)
}
fn eval_modified(&self, query: &[f64], explicit_radius: f64) -> InterpolateResult<f64> {
let n = self.points.nrows();
let d = self.points.ncols();
let mut wsum = 0.0_f64;
let mut fsum = 0.0_f64;
for i in 0..n {
let xi: Vec<f64> = (0..d).map(|k| self.points[[i, k]]).collect();
let r = euclidean_dist(query, &xi);
let ri = if explicit_radius > 0.0 {
explicit_radius
} else {
self.radii.as_ref().map(|rv| rv[i]).unwrap_or(explicit_radius)
};
if r >= ri {
continue; }
let w = if r <= 0.0 {
return Ok(self.values[i]);
} else {
let ratio = (ri - r) / (ri * r);
ratio * ratio
};
let q_val = self.local_polynomial_at(i, query, ri)?;
wsum += w;
fsum += w * q_val;
}
if wsum <= 0.0 {
return self.eval_global(query, 2.0);
}
Ok(fsum / wsum)
}
fn local_polynomial_at(
&self,
center_idx: usize,
query: &[f64],
radius: f64,
) -> InterpolateResult<f64> {
let n = self.points.nrows();
let d = self.points.ncols();
let xi: Vec<f64> = (0..d).map(|k| self.points[[center_idx, k]]).collect();
let mut nbr_pts: Vec<Vec<f64>> = Vec::new();
let mut nbr_vals: Vec<f64> = Vec::new();
let mut nbr_dists: Vec<f64> = Vec::new();
for j in 0..n {
let xj: Vec<f64> = (0..d).map(|k| self.points[[j, k]]).collect();
let r_ij = euclidean_dist(&xi, &xj);
if r_ij < radius {
nbr_pts.push(xj);
nbr_vals.push(self.values[j]);
nbr_dists.push(r_ij.max(1e-14));
}
}
if nbr_pts.is_empty() {
return Ok(self.values[center_idx]);
}
let lin_params = 1 + d;
let quad_params = (d + 1) * (d + 2) / 2;
let use_quad = nbr_pts.len() >= quad_params + 1;
let use_lin = nbr_pts.len() >= lin_params;
if !use_lin {
let mut wsum = 0.0_f64;
let mut fsum = 0.0_f64;
for (j, &fj) in nbr_vals.iter().enumerate() {
let w = 1.0 / nbr_dists[j].powi(2);
wsum += w;
fsum += w * fj;
}
return Ok(if wsum > 0.0 {
fsum / wsum
} else {
self.values[center_idx]
});
}
let num_params = if use_quad { quad_params } else { lin_params };
let k = nbr_pts.len();
let mut b = Array2::<f64>::zeros((k, num_params));
let mut rhs_vec = Array1::<f64>::zeros(k);
let mut weights = Array1::<f64>::zeros(k);
for (j, (xj, &fj)) in nbr_pts.iter().zip(nbr_vals.iter()).enumerate() {
let wj = 1.0 / nbr_dists[j].powi(2);
weights[j] = wj;
rhs_vec[j] = fj;
let mut col = 0usize;
b[[j, col]] = 1.0;
col += 1;
for k2 in 0..d {
b[[j, col]] = xj[k2] - xi[k2];
col += 1;
}
if use_quad {
for k2 in 0..d {
for l in k2..d {
b[[j, col]] = (xj[k2] - xi[k2]) * (xj[l] - xi[l]);
col += 1;
}
}
}
}
let mut btb = Array2::<f64>::zeros((num_params, num_params));
let mut btf = Array1::<f64>::zeros(num_params);
for j in 0..k {
let wj = weights[j];
let fj = rhs_vec[j];
for p in 0..num_params {
btf[p] += wj * b[[j, p]] * fj;
for q in 0..num_params {
btb[[p, q]] += wj * b[[j, p]] * b[[j, q]];
}
}
}
let reg = 1e-12 * (0..num_params).map(|p| btb[[p, p]]).sum::<f64>() / num_params as f64;
for p in 0..num_params {
btb[[p, p]] += reg.max(1e-14);
}
let btb_view = btb.view();
let btf_view = btf.view();
let coeffs = solve(&btb_view, &btf_view, None).map_err(|e| {
InterpolateError::LinalgError(format!("Modified Shepard local fit failed: {e}"))
})?;
let mut val = coeffs[0]; let mut col = 1usize;
for k2 in 0..d {
val += coeffs[col] * (query[k2] - xi[k2]);
col += 1;
}
if use_quad {
for k2 in 0..d {
for l in k2..d {
val += coeffs[col] * (query[k2] - xi[k2]) * (query[l] - xi[l]);
col += 1;
}
}
}
Ok(val)
}
fn eval_franke_little(&self, query: &[f64], radius: f64, power: f64) -> InterpolateResult<f64> {
let n = self.points.nrows();
let d = self.points.ncols();
let mut wsum = 0.0_f64;
let mut fsum = 0.0_f64;
for i in 0..n {
let row: Vec<f64> = (0..d).map(|k| self.points[[i, k]]).collect();
let r = euclidean_dist(query, &row);
if r <= 0.0 {
return Ok(self.values[i]);
}
if r >= radius {
continue;
}
let w = ((radius - r) / (radius * r)).powf(power);
wsum += w;
fsum += w * self.values[i];
}
if wsum <= 0.0 {
return self.eval_global(query, power);
}
Ok(fsum / wsum)
}
}
pub fn basic_shepard(
query: &[f64],
points: &ArrayView2<f64>,
values: &ArrayView1<f64>,
power: f64,
) -> InterpolateResult<f64> {
let s = ShepardInterpolant::new(points, values, ShepardMode::Global { power })?;
s.evaluate(query)
}
pub fn modified_shepard(
query: &[f64],
points: &ArrayView2<f64>,
values: &ArrayView1<f64>,
radius: f64,
k_auto: usize,
) -> InterpolateResult<f64> {
let s = ShepardInterpolant::new(points, values, ShepardMode::Modified { radius, k_auto })?;
s.evaluate(query)
}
pub fn franke_little_shepard(
query: &[f64],
points: &ArrayView2<f64>,
values: &ArrayView1<f64>,
radius: f64,
power: f64,
) -> InterpolateResult<f64> {
let s = ShepardInterpolant::new(
points,
values,
ShepardMode::FrankeLittle { radius, power },
)?;
s.evaluate(query)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::{array, Array2};
fn pts_1d(xs: &[f64]) -> Array2<f64> {
Array2::from_shape_vec((xs.len(), 1), xs.to_vec()).expect("test: should succeed")
}
#[test]
fn test_global_shepard_exact_at_nodes() {
let pts = pts_1d(&[0.0, 1.0, 2.0, 3.0]);
let vals = array![0.0_f64, 1.0, 4.0, 9.0];
let s = ShepardInterpolant::new(
&pts.view(),
&vals.view(),
ShepardMode::Global { power: 2.0 },
)
.expect("test: should succeed");
for i in 0..4 {
let v = s.evaluate(&[i as f64]).expect("test: should succeed");
assert_abs_diff_eq!(v, vals[i], epsilon = 1e-10);
}
}
#[test]
fn test_global_shepard_symmetry() {
let pts = Array2::from_shape_vec(
(4, 2),
vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
)
.expect("test: should succeed");
let vals = array![0.0_f64, 1.0, 1.0, 2.0];
let s = ShepardInterpolant::new(
&pts.view(),
&vals.view(),
ShepardMode::Global { power: 2.0 },
)
.expect("test: should succeed");
let v = s.evaluate(&[0.5, 0.5]).expect("test: should succeed");
assert_abs_diff_eq!(v, 1.0, epsilon = 1e-10);
}
#[test]
fn test_modified_shepard_linear() {
let pts = pts_1d(&[0.0, 0.5, 1.0, 1.5, 2.0]);
let vals: Array1<f64> = (0..5).map(|i| 2.0 * i as f64 * 0.5).collect();
let s = ShepardInterpolant::new(
&pts.view(),
&vals.view(),
ShepardMode::Modified {
radius: 1.2,
k_auto: 4,
},
)
.expect("test: should succeed");
let v = s.evaluate(&[0.75]).expect("test: should succeed");
assert_abs_diff_eq!(v, 1.5, epsilon = 1e-6);
}
#[test]
fn test_franke_little_finite() {
let pts = pts_1d(&[0.0, 1.0, 2.0, 3.0]);
let vals = array![0.0_f64, 1.0, 4.0, 9.0];
let v = franke_little_shepard(&[1.5], &pts.view(), &vals.view(), 2.5, 2.0).expect("test: should succeed");
assert!(v.is_finite());
assert!(v > 0.0);
}
#[test]
fn test_basic_shepard_free_fn() {
let pts = pts_1d(&[0.0, 1.0, 2.0]);
let vals = array![0.0_f64, 1.0, 2.0];
let v = basic_shepard(&[0.5], &pts.view(), &vals.view(), 2.0).expect("test: should succeed");
assert!(v.is_finite());
assert!(v > 0.0 && v < 1.0);
}
#[test]
fn test_modified_shepard_auto_radius() {
let pts = pts_1d(&[0.0, 1.0, 2.0, 3.0, 4.0]);
let vals: Array1<f64> = (0..5).map(|i| i as f64).collect();
let s = ShepardInterpolant::new(
&pts.view(),
&vals.view(),
ShepardMode::Modified {
radius: 0.0,
k_auto: 3,
},
)
.expect("test: should succeed");
let v = s.evaluate(&[2.5]).expect("test: should succeed");
assert!(v.is_finite());
assert!((v - 2.5).abs() < 0.5);
}
#[test]
fn test_batch_equals_individual() {
let pts = Array2::from_shape_vec(
(4, 2),
vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
)
.expect("test: should succeed");
let vals = array![0.0_f64, 1.0, 1.0, 2.0];
let s = ShepardInterpolant::new(
&pts.view(),
&vals.view(),
ShepardMode::Global { power: 2.0 },
)
.expect("test: should succeed");
let queries = Array2::from_shape_vec(
(3, 2),
vec![0.2, 0.3, 0.7, 0.8, 0.5, 0.5],
)
.expect("test: should succeed");
let batch = s.evaluate_batch(&queries.view()).expect("test: should succeed");
for i in 0..3 {
let q = vec![queries[[i, 0]], queries[[i, 1]]];
let single = s.evaluate(&q).expect("test: should succeed");
assert_abs_diff_eq!(batch[i], single, epsilon = 1e-12);
}
}
#[test]
fn test_invalid_power_rejected() {
let pts = pts_1d(&[0.0, 1.0]);
let vals = array![0.0_f64, 1.0];
let result = ShepardInterpolant::new(
&pts.view(),
&vals.view(),
ShepardMode::Global { power: -1.0 },
);
assert!(result.is_err());
}
#[test]
fn test_empty_points_rejected() {
let pts = Array2::<f64>::zeros((0, 2));
let vals = Array1::<f64>::zeros(0);
let result = ShepardInterpolant::new(
&pts.view(),
&vals.view(),
ShepardMode::Global { power: 2.0 },
);
assert!(result.is_err());
}
}