use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
#[derive(Debug, Clone, PartialEq)]
pub enum WeightFunction {
Gaussian,
Wendland,
InverseDistance(f64),
}
impl WeightFunction {
#[inline]
pub fn eval(&self, d: f64, h: f64) -> f64 {
match self {
WeightFunction::Gaussian => {
let r = d / h;
(-r * r).exp()
}
WeightFunction::Wendland => {
let t = d / h;
if t >= 1.0 {
0.0
} else {
let s = 1.0 - t;
s * s * s * s * (4.0 * t + 1.0)
}
}
WeightFunction::InverseDistance(p) => {
if d < f64::EPSILON {
f64::INFINITY
} else {
d.powf(-p)
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct MovingLeastSquares {
x: Array2<f64>, y: Array2<f64>, degree: usize,
weight_fn: WeightFunction,
bandwidth: f64,
}
impl MovingLeastSquares {
pub fn new(
x: Array2<f64>,
y: Array2<f64>,
degree: usize,
bandwidth: f64,
) -> InterpolateResult<Self> {
Self::with_weight(x, y, degree, WeightFunction::Gaussian, bandwidth)
}
pub fn with_weight(
x: Array2<f64>,
y: Array2<f64>,
degree: usize,
weight_fn: WeightFunction,
bandwidth: f64,
) -> InterpolateResult<Self> {
if x.nrows() != y.nrows() {
return Err(InterpolateError::InvalidInput {
message: format!(
"MovingLeastSquares: x.nrows()={} != y.nrows()={}",
x.nrows(),
y.nrows()
),
});
}
if x.nrows() == 0 {
return Err(InterpolateError::InvalidInput {
message: "MovingLeastSquares: no data points provided".into(),
});
}
if degree > 2 {
return Err(InterpolateError::InvalidInput {
message: format!(
"MovingLeastSquares: degree must be 0, 1, or 2; got {}",
degree
),
});
}
if bandwidth <= 0.0 || !bandwidth.is_finite() {
return Err(InterpolateError::InvalidInput {
message: format!(
"MovingLeastSquares: bandwidth must be positive and finite; got {}",
bandwidth
),
});
}
let dim = x.ncols();
let n_basis = basis_size(dim, degree);
if x.nrows() < n_basis {
return Err(InterpolateError::InvalidInput {
message: format!(
"MovingLeastSquares: need at least {} data points for degree={} in dim={}; got {}",
n_basis, degree, dim, x.nrows()
),
});
}
Ok(Self {
x,
y,
degree,
weight_fn,
bandwidth,
})
}
pub fn eval(&self, xi: &[f64]) -> InterpolateResult<Vec<f64>> {
let dim = self.x.ncols();
if xi.len() != dim {
return Err(InterpolateError::InvalidInput {
message: format!(
"MovingLeastSquares::eval: xi.len()={} != dim={}",
xi.len(),
dim
),
});
}
let n_out = self.y.ncols();
let n_pts = self.x.nrows();
let tol = f64::EPSILON * (1.0 + xi.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max));
let mut weights = Vec::with_capacity(n_pts);
for i in 0..n_pts {
let row = self.x.row(i);
let d = euclidean_distance(row, xi);
if d < tol {
return Ok(self.y.row(i).to_vec());
}
weights.push(self.weight_fn.eval(d, self.bandwidth));
}
if let Some(idx) = weights.iter().position(|&w| w.is_infinite()) {
return Ok(self.y.row(idx).to_vec());
}
let n_basis = basis_size(dim, self.degree);
let mut p = Array2::<f64>::zeros((n_pts, n_basis));
let mut wy = Array2::<f64>::zeros((n_pts, n_out));
for i in 0..n_pts {
let row = self.x.row(i);
let basis = polynomial_basis(row.as_slice().ok_or_else(|| {
InterpolateError::ComputationError("non-contiguous array slice".into())
})?, xi, self.degree);
let sqrt_w = weights[i].sqrt();
for (j, &b) in basis.iter().enumerate() {
p[[i, j]] = sqrt_w * b;
}
for k in 0..n_out {
wy[[i, k]] = sqrt_w * self.y[[i, k]];
}
}
let ptwy = p.t().dot(&wy); let ptp = p.t().dot(&p);
let reg = 1e-12 * ptp.diag().iter().cloned().fold(0.0_f64, f64::max);
let mut ptp_reg = ptp;
for j in 0..n_basis {
ptp_reg[[j, j]] += reg;
}
let c = solve_small_system(&ptp_reg, &ptwy)?;
let basis_xi = polynomial_basis_at(xi, self.degree);
let mut result = vec![0.0_f64; n_out];
for k in 0..n_out {
for (j, &bj) in basis_xi.iter().enumerate() {
result[k] += bj * c[[j, k]];
}
}
Ok(result)
}
pub fn eval_batch(&self, xi: &Array2<f64>) -> InterpolateResult<Array2<f64>> {
let n_query = xi.nrows();
let n_out = self.y.ncols();
let mut out = Array2::<f64>::zeros((n_query, n_out));
for q in 0..n_query {
let row = xi.row(q);
let point: Vec<f64> = row.to_vec();
let vals = self.eval(&point)?;
for (k, v) in vals.iter().enumerate() {
out[[q, k]] = *v;
}
}
Ok(out)
}
pub fn deform(
src: &Array2<f64>,
dst: &Array2<f64>,
query: &Array2<f64>,
) -> InterpolateResult<Array2<f64>> {
let k = src.nrows();
if dst.nrows() != k {
return Err(InterpolateError::InvalidInput {
message: format!(
"MovingLeastSquares::deform: src.nrows()={} != dst.nrows()={}",
k,
dst.nrows()
),
});
}
if k == 0 {
return Err(InterpolateError::InvalidInput {
message: "MovingLeastSquares::deform: no control points provided".into(),
});
}
let dim = src.ncols();
if dst.ncols() != dim || query.ncols() != dim {
return Err(InterpolateError::InvalidInput {
message: "MovingLeastSquares::deform: dimension mismatch".into(),
});
}
let n_query = query.nrows();
let mut result = Array2::<f64>::zeros((n_query, dim));
for q in 0..n_query {
let v: Vec<f64> = query.row(q).to_vec();
let mut weights = Vec::with_capacity(k);
let mut w_sum = 0.0_f64;
let mut exact_match: Option<usize> = None;
for i in 0..k {
let si: Vec<f64> = src.row(i).to_vec();
let d2: f64 = si.iter().zip(v.iter()).map(|(&a, &b)| (a - b) * (a - b)).sum();
if d2 < f64::EPSILON * f64::EPSILON {
exact_match = Some(i);
break;
}
let w = 1.0 / (d2 * d2);
weights.push(w);
w_sum += w;
}
if let Some(idx) = exact_match {
for d in 0..dim {
result[[q, d]] = dst[[idx, d]];
}
continue;
}
let mut p_star = vec![0.0_f64; dim];
let mut q_star = vec![0.0_f64; dim];
for i in 0..k {
for d in 0..dim {
p_star[d] += weights[i] * src[[i, d]];
q_star[d] += weights[i] * dst[[i, d]];
}
}
for d in 0..dim {
p_star[d] /= w_sum;
q_star[d] /= w_sum;
}
let p_hat: Vec<Vec<f64>> = (0..k)
.map(|i| (0..dim).map(|d| src[[i, d]] - p_star[d]).collect())
.collect();
let q_hat: Vec<Vec<f64>> = (0..k)
.map(|i| (0..dim).map(|d| dst[[i, d]] - q_star[d]).collect())
.collect();
let v_hat: Vec<f64> = (0..dim).map(|d| v[d] - p_star[d]).collect();
let mut m = vec![vec![0.0_f64; dim]; dim];
let mut n_mat = vec![vec![0.0_f64; dim]; dim];
for i in 0..k {
for r in 0..dim {
for c in 0..dim {
m[r][c] += weights[i] * p_hat[i][r] * p_hat[i][c];
n_mat[r][c] += weights[i] * q_hat[i][r] * p_hat[i][c];
}
}
}
let m_arr = vec_to_array2(&m);
let n_arr = vec_to_array2(&n_mat);
let m_inv = invert_small(&m_arr).unwrap_or_else(|_| Array2::eye(dim));
let transform = n_arr.dot(&m_inv);
for d in 0..dim {
let mut val = q_star[d];
for c in 0..dim {
val += transform[[d, c]] * v_hat[c];
}
result[[q, d]] = val;
}
}
Ok(result)
}
}
fn basis_size(dim: usize, degree: usize) -> usize {
match (dim, degree) {
(_, 0) => 1,
(1, 1) => 2,
(2, 1) => 3,
(3, 1) => 4,
(d, 1) => 1 + d,
(1, 2) => 3,
(2, 2) => 6,
(3, 2) => 10,
(d, 2) => 1 + d + d * (d + 1) / 2,
_ => {
let mut num = 1usize;
let mut den = 1usize;
for i in 0..degree {
num *= dim + degree - i;
den *= i + 1;
}
num / den
}
}
}
fn polynomial_basis(x: &[f64], xi: &[f64], degree: usize) -> Vec<f64> {
let dim = x.len();
let dx: Vec<f64> = x.iter().zip(xi.iter()).map(|(&a, &b)| a - b).collect();
polynomial_basis_vec(&dx, degree, dim)
}
fn polynomial_basis_at(xi: &[f64], degree: usize) -> Vec<f64> {
let dim = xi.len();
let mut b = vec![0.0_f64; basis_size(dim, degree)];
b[0] = 1.0;
b
}
fn polynomial_basis_vec(dx: &[f64], degree: usize, _dim: usize) -> Vec<f64> {
let mut b = Vec::new();
b.push(1.0_f64);
if degree >= 1 {
for &d in dx.iter() {
b.push(d);
}
}
if degree >= 2 {
for i in 0..dx.len() {
for j in i..dx.len() {
b.push(dx[i] * dx[j]);
}
}
}
b
}
#[inline]
fn euclidean_distance(row: ArrayView1<f64>, xi: &[f64]) -> f64 {
row.iter()
.zip(xi.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum::<f64>()
.sqrt()
}
fn solve_small_system(a: &Array2<f64>, b: &Array2<f64>) -> InterpolateResult<Array2<f64>> {
let n = a.nrows();
let m = b.ncols();
assert_eq!(a.ncols(), n);
assert_eq!(b.nrows(), n);
let mut aug = Array2::<f64>::zeros((n, n + m));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a[[i, j]];
}
for k in 0..m {
aug[[i, n + k]] = b[[i, k]];
}
}
for col in 0..n {
let mut max_val = aug[[col, col]].abs();
let mut max_row = col;
for row in col + 1..n {
if aug[[row, col]].abs() > max_val {
max_val = aug[[row, col]].abs();
max_row = row;
}
}
if max_val < 1e-15 {
return Err(InterpolateError::LinalgError(
"MovingLeastSquares: singular or near-singular local system".into(),
));
}
if max_row != col {
for j in 0..n + m {
let tmp = aug[[col, j]];
aug[[col, j]] = aug[[max_row, j]];
aug[[max_row, j]] = tmp;
}
}
let pivot = aug[[col, col]];
for row in col + 1..n {
let factor = aug[[row, col]] / pivot;
for j in col..n + m {
let delta = factor * aug[[col, j]];
aug[[row, j]] -= delta;
}
}
}
let mut x = Array2::<f64>::zeros((n, m));
for col in (0..n).rev() {
for k in 0..m {
let mut val = aug[[col, n + k]];
for j in col + 1..n {
val -= aug[[col, j]] * x[[j, k]];
}
x[[col, k]] = val / aug[[col, col]];
}
}
Ok(x)
}
fn vec_to_array2(v: &[Vec<f64>]) -> Array2<f64> {
let rows = v.len();
let cols = if rows > 0 { v[0].len() } else { 0 };
let mut a = Array2::<f64>::zeros((rows, cols));
for (i, row) in v.iter().enumerate() {
for (j, &val) in row.iter().enumerate() {
a[[i, j]] = val;
}
}
a
}
fn invert_small(a: &Array2<f64>) -> InterpolateResult<Array2<f64>> {
let n = a.nrows();
let eye = Array2::<f64>::eye(n);
solve_small_system(a, &eye)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_mls_constant_function() {
let src = array![[0.0_f64, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]];
let vals = array![[5.0_f64], [5.0], [5.0], [5.0], [5.0]];
let mls = MovingLeastSquares::new(src, vals, 1, 2.0).expect("test: should succeed");
let result = mls.eval(&[0.3, 0.4]).expect("test: should succeed");
assert_abs_diff_eq!(result[0], 5.0, epsilon = 1e-6);
}
#[test]
fn test_mls_linear_function_1d() {
let xs: Vec<f64> = (0..8).map(|i| i as f64).collect();
let ys: Vec<f64> = xs.iter().map(|&x| 3.0 * x + 1.0).collect();
let src = Array2::from_shape_fn((xs.len(), 1), |(i, _)| xs[i]);
let vals = Array2::from_shape_fn((ys.len(), 1), |(i, _)| ys[i]);
let mls = MovingLeastSquares::new(src, vals, 1, 3.0).expect("test: should succeed");
let result = mls.eval(&[3.5]).expect("test: should succeed");
assert_abs_diff_eq!(result[0], 3.0 * 3.5 + 1.0, epsilon = 1e-4);
}
#[test]
fn test_mls_linear_function_2d() {
let pts = array![
[0.0_f64, 0.0],
[1.0, 0.0],
[0.0, 1.0],
[1.0, 1.0],
[0.5, 0.5],
[0.25, 0.75]
];
let vals = Array2::from_shape_fn((pts.nrows(), 1), |(i, _)| {
pts[[i, 0]] + 2.0 * pts[[i, 1]]
});
let mls = MovingLeastSquares::new(pts, vals, 1, 2.0).expect("test: should succeed");
let query = [0.4, 0.6];
let expected = 0.4 + 2.0 * 0.6;
let result = mls.eval(&query).expect("test: should succeed");
assert_abs_diff_eq!(result[0], expected, epsilon = 1e-4);
}
#[test]
fn test_mls_eval_batch() {
let src = array![[0.0_f64, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]];
let vals = Array2::from_shape_fn((src.nrows(), 1), |(i, _)| {
src[[i, 0]] + 2.0 * src[[i, 1]]
});
let mls = MovingLeastSquares::new(src, vals, 1, 2.0).expect("test: should succeed");
let queries = array![[0.2_f64, 0.3], [0.7, 0.8]];
let result = mls.eval_batch(&queries).expect("test: should succeed");
assert_eq!(result.shape(), &[2, 1]);
assert_abs_diff_eq!(result[[0, 0]], 0.2 + 2.0 * 0.3, epsilon = 1e-4);
assert_abs_diff_eq!(result[[1, 0]], 0.7 + 2.0 * 0.8, epsilon = 1e-4);
}
#[test]
fn test_mls_weight_function_wendland() {
let src = array![[0.0_f64, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]];
let vals = Array2::from_shape_fn((src.nrows(), 1), |(i, _)| {
src[[i, 0]] + src[[i, 1]]
});
let mls = MovingLeastSquares::with_weight(
src,
vals,
1,
WeightFunction::Wendland,
2.0,
).expect("test: should succeed");
let result = mls.eval(&[0.5, 0.5]).expect("test: should succeed");
assert!(result[0].is_finite());
assert_abs_diff_eq!(result[0], 1.0, epsilon = 0.1);
}
#[test]
fn test_mls_weight_function_inverse_distance() {
let src = array![[0.0_f64, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]];
let vals = Array2::from_shape_fn((src.nrows(), 1), |(i, _)| src[[i, 0]] + src[[i, 1]]);
let mls = MovingLeastSquares::with_weight(
src,
vals,
1,
WeightFunction::InverseDistance(2.0),
1.0,
).expect("test: should succeed");
let result = mls.eval(&[0.3, 0.4]).expect("test: should succeed");
assert!(result[0].is_finite());
}
#[test]
fn test_mls_exact_at_node() {
let src = array![[0.0_f64], [1.0], [2.0], [3.0], [4.0]];
let vals = Array2::from_shape_fn((src.nrows(), 1), |(i, _)| (src[[i, 0]]).powi(2));
let mls = MovingLeastSquares::new(src, vals, 1, 2.0).expect("test: should succeed");
let result = mls.eval(&[2.0]).expect("test: should succeed");
assert_abs_diff_eq!(result[0], 4.0, epsilon = 1e-6);
}
#[test]
fn test_mls_invalid_degree() {
let src = array![[0.0_f64, 0.0], [1.0, 0.0]];
let vals = array![[0.0_f64], [1.0]];
let result = MovingLeastSquares::new(src, vals, 3, 1.0);
assert!(result.is_err());
}
#[test]
fn test_mls_invalid_bandwidth() {
let src = array![[0.0_f64], [1.0], [2.0]];
let vals = array![[0.0_f64], [1.0], [4.0]];
let result = MovingLeastSquares::new(src, vals, 1, -1.0);
assert!(result.is_err());
}
#[test]
fn test_mls_deform_identity() {
let ctrl = array![[0.0_f64, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let query = array![[0.5_f64, 0.5], [0.25, 0.75]];
let result = MovingLeastSquares::deform(&ctrl, &ctrl, &query).expect("test: should succeed");
for q in 0..query.nrows() {
for d in 0..2 {
assert_abs_diff_eq!(result[[q, d]], query[[q, d]], epsilon = 1e-6);
}
}
}
#[test]
fn test_mls_deform_translation() {
let src = array![[0.0_f64, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let dst = array![[1.0_f64, 2.0], [2.0, 2.0], [1.0, 3.0], [2.0, 3.0]];
let query = array![[0.5_f64, 0.5]];
let result = MovingLeastSquares::deform(&src, &dst, &query).expect("test: should succeed");
assert_abs_diff_eq!(result[[0, 0]], 1.5, epsilon = 0.2);
assert_abs_diff_eq!(result[[0, 1]], 2.5, epsilon = 0.2);
}
#[test]
fn test_weight_function_gaussian() {
let w = WeightFunction::Gaussian;
assert_abs_diff_eq!(w.eval(0.0, 1.0), 1.0, epsilon = 1e-12);
assert!(w.eval(1.0, 1.0) < 1.0);
assert!(w.eval(2.0, 1.0) < w.eval(1.0, 1.0));
}
#[test]
fn test_weight_function_wendland() {
let w = WeightFunction::Wendland;
assert_abs_diff_eq!(w.eval(0.0, 1.0), 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(w.eval(1.0, 1.0), 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(w.eval(2.0, 1.0), 0.0, epsilon = 1e-12);
}
#[test]
fn test_weight_function_inverse_distance() {
let w = WeightFunction::InverseDistance(2.0);
assert_abs_diff_eq!(w.eval(1.0, 1.0), 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(w.eval(2.0, 1.0), 0.25, epsilon = 1e-12);
assert!(w.eval(0.0, 1.0).is_infinite());
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BasisType {
Constant,
Linear,
Quadratic,
}
impl BasisType {
#[inline]
pub fn basis_size_2d(self) -> usize {
match self {
BasisType::Constant => 1,
BasisType::Linear => 3,
BasisType::Quadratic => 6,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum WeightFn {
Gaussian(f64),
Wendland(f64),
InverseDistance(f64),
}
impl WeightFn {
#[inline]
pub fn eval(self, d: f64) -> f64 {
match self {
WeightFn::Gaussian(h) => {
let r = d / h;
(-r * r).exp()
}
WeightFn::Wendland(h) => {
let t = d / h;
if t >= 1.0 {
0.0
} else {
let s = 1.0 - t;
s * s * s * s * (4.0 * t + 1.0)
}
}
WeightFn::InverseDistance(p) => {
if d < f64::EPSILON {
f64::INFINITY
} else {
d.powf(-p)
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct MlsInterpolator {
points: Vec<(f64, f64)>,
values: Vec<f64>,
basis: BasisType,
weight: WeightFn,
}
pub type MlsResult<T> = Result<T, crate::error::InterpolateError>;
impl MlsInterpolator {
pub fn new(
points: &[(f64, f64)],
values: &[f64],
basis: BasisType,
weight: WeightFn,
) -> MlsResult<Self> {
if points.is_empty() {
return Err(crate::error::InterpolateError::InvalidInput {
message: "MlsInterpolator: no data points provided".into(),
});
}
if values.len() != points.len() {
return Err(crate::error::InterpolateError::InvalidInput {
message: format!(
"MlsInterpolator: points.len()={} != values.len()={}",
points.len(),
values.len()
),
});
}
let required = basis.basis_size_2d();
if points.len() < required {
return Err(crate::error::InterpolateError::InvalidInput {
message: format!(
"MlsInterpolator: basis {:?} requires >= {} points; got {}",
basis,
required,
points.len()
),
});
}
Ok(Self {
points: points.to_vec(),
values: values.to_vec(),
basis,
weight,
})
}
pub fn evaluate(&self, x: f64, y: f64) -> MlsResult<f64> {
let n = self.points.len();
let n_b = self.basis.basis_size_2d();
let tol = f64::EPSILON * 1e6;
let mut weights = Vec::with_capacity(n);
for (i, &(px, py)) in self.points.iter().enumerate() {
let d = ((x - px) * (x - px) + (y - py) * (y - py)).sqrt();
if d < tol {
return Ok(self.values[i]);
}
weights.push(self.weight.eval(d));
}
if let Some(idx) = weights.iter().position(|&w| w.is_infinite()) {
return Ok(self.values[idx]);
}
let mut p = vec![0.0_f64; n * n_b]; let mut rhs = vec![0.0_f64; n];
for (i, &(px, py)) in self.points.iter().enumerate() {
let sqrt_w = weights[i].sqrt();
let row = mls_2d_basis(px - x, py - y, self.basis);
for (j, &bj) in row.iter().enumerate() {
p[i * n_b + j] = sqrt_w * bj;
}
rhs[i] = sqrt_w * self.values[i];
}
let c = solve_normal_equations_2d(&p, &rhs, n, n_b)?;
Ok(c[0])
}
pub fn gradient(&self, x: f64, y: f64) -> MlsResult<(f64, f64)> {
let h = 1e-5_f64.max(1e-5 * x.abs().max(y.abs()));
let fx_p = self.evaluate(x + h, y)?;
let fx_m = self.evaluate(x - h, y)?;
let fy_p = self.evaluate(x, y + h)?;
let fy_m = self.evaluate(x, y - h)?;
let dfdx = (fx_p - fx_m) / (2.0 * h);
let dfdy = (fy_p - fy_m) / (2.0 * h);
Ok((dfdx, dfdy))
}
}
fn mls_2d_basis(dx: f64, dy: f64, basis: BasisType) -> Vec<f64> {
match basis {
BasisType::Constant => vec![1.0],
BasisType::Linear => vec![1.0, dx, dy],
BasisType::Quadratic => vec![1.0, dx, dy, dx * dx, dx * dy, dy * dy],
}
}
fn solve_normal_equations_2d(
p: &[f64],
rhs: &[f64],
n: usize,
n_b: usize,
) -> MlsResult<Vec<f64>> {
let mut a = vec![0.0_f64; n_b * n_b];
let mut b = vec![0.0_f64; n_b];
for i in 0..n {
for j in 0..n_b {
let pij = p[i * n_b + j];
b[j] += pij * rhs[i];
for k in 0..n_b {
a[j * n_b + k] += pij * p[i * n_b + k];
}
}
}
let max_diag = (0..n_b)
.map(|j| a[j * n_b + j].abs())
.fold(0.0_f64, f64::max);
let reg = 1e-12 * max_diag.max(1e-30);
for j in 0..n_b {
a[j * n_b + j] += reg;
}
let mut aug = vec![0.0_f64; n_b * (n_b + 1)];
for i in 0..n_b {
for j in 0..n_b {
aug[i * (n_b + 1) + j] = a[i * n_b + j];
}
aug[i * (n_b + 1) + n_b] = b[i];
}
for col in 0..n_b {
let mut max_val = aug[col * (n_b + 1) + col].abs();
let mut max_row = col;
for row in col + 1..n_b {
let v = aug[row * (n_b + 1) + col].abs();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_val < 1e-15 {
return Err(crate::error::InterpolateError::LinalgError(
"MlsInterpolator: singular or near-singular local system; \
try more data points or a larger bandwidth".into(),
));
}
if max_row != col {
for j in 0..=n_b {
aug.swap(col * (n_b + 1) + j, max_row * (n_b + 1) + j);
}
}
let pivot = aug[col * (n_b + 1) + col];
for row in col + 1..n_b {
let factor = aug[row * (n_b + 1) + col] / pivot;
for j in col..=n_b {
let delta = factor * aug[col * (n_b + 1) + j];
aug[row * (n_b + 1) + j] -= delta;
}
}
}
let mut c = vec![0.0_f64; n_b];
for col in (0..n_b).rev() {
let mut val = aug[col * (n_b + 1) + n_b];
for j in col + 1..n_b {
val -= aug[col * (n_b + 1) + j] * c[j];
}
c[col] = val / aug[col * (n_b + 1) + col];
}
Ok(c)
}
#[cfg(test)]
mod mls_interp_tests {
use super::{BasisType, MlsInterpolator, WeightFn};
use approx::assert_abs_diff_eq;
fn grid_pts(n: usize) -> (Vec<(f64, f64)>, Vec<f64>, Vec<f64>) {
let mut pts = Vec::new();
let mut xs = Vec::new();
let mut ys = Vec::new();
for i in 0..n {
for j in 0..n {
let x = i as f64 / (n - 1) as f64;
let y = j as f64 / (n - 1) as f64;
pts.push((x, y));
xs.push(x);
ys.push(y);
}
}
(pts, xs, ys)
}
#[test]
fn test_mls_new_valid() {
let pts = vec![(0.0, 0.0), (1.0, 0.0), (0.0, 1.0), (1.0, 1.0)];
let vals = vec![0.0, 1.0, 1.0, 2.0];
let result = MlsInterpolator::new(&pts, &vals, BasisType::Linear, WeightFn::Gaussian(2.0));
assert!(result.is_ok());
}
#[test]
fn test_mls_new_empty_error() {
let result = MlsInterpolator::new(&[], &[], BasisType::Constant, WeightFn::Gaussian(1.0));
assert!(result.is_err());
}
#[test]
fn test_mls_new_length_mismatch_error() {
let pts = vec![(0.0, 0.0), (1.0, 0.0)];
let vals = vec![0.0];
let result = MlsInterpolator::new(&pts, &vals, BasisType::Linear, WeightFn::Gaussian(1.0));
assert!(result.is_err());
}
#[test]
fn test_mls_new_insufficient_points_error() {
let pts = vec![(0.0, 0.0), (1.0, 0.0), (0.0, 1.0)];
let vals = vec![0.0, 1.0, 1.0];
let result = MlsInterpolator::new(&pts, &vals, BasisType::Quadratic, WeightFn::Gaussian(1.0));
assert!(result.is_err());
}
#[test]
fn test_mls_constant_field() {
let (pts, _, _) = grid_pts(4);
let vals: Vec<f64> = pts.iter().map(|_| 7.0).collect();
let mls = MlsInterpolator::new(&pts, &vals, BasisType::Linear, WeightFn::Gaussian(2.0))
.expect("test: should succeed");
let v = mls.evaluate(0.35, 0.65).expect("test: should succeed");
assert_abs_diff_eq!(v, 7.0, epsilon = 1e-5);
}
#[test]
fn test_mls_constant_basis_constant_field() {
let (pts, _, _) = grid_pts(4);
let vals: Vec<f64> = pts.iter().map(|_| 3.5).collect();
let mls = MlsInterpolator::new(&pts, &vals, BasisType::Constant, WeightFn::Gaussian(2.0))
.expect("test: should succeed");
let v = mls.evaluate(0.5, 0.5).expect("test: should succeed");
assert_abs_diff_eq!(v, 3.5, epsilon = 1e-5);
}
#[test]
fn test_mls_linear_field_exact() {
let (pts, _, _) = grid_pts(4);
let vals: Vec<f64> = pts.iter().map(|&(x, y)| x + 2.0 * y).collect();
let mls = MlsInterpolator::new(&pts, &vals, BasisType::Linear, WeightFn::Gaussian(2.0))
.expect("test: should succeed");
for &(qx, qy) in &[(0.3, 0.4), (0.7, 0.2), (0.5, 0.8)] {
let expected = qx + 2.0 * qy;
let v = mls.evaluate(qx, qy).expect("test: should succeed");
assert_abs_diff_eq!(v, expected, epsilon = 0.02);
}
}
#[test]
fn test_mls_linear_field_wendland() {
let (pts, _, _) = grid_pts(4);
let vals: Vec<f64> = pts.iter().map(|&(x, y)| 3.0 * x + y).collect();
let mls = MlsInterpolator::new(
&pts, &vals, BasisType::Linear, WeightFn::Wendland(2.0),
).expect("test: should succeed");
let v = mls.evaluate(0.5, 0.5).expect("test: should succeed");
assert_abs_diff_eq!(v, 3.0 * 0.5 + 0.5, epsilon = 0.05);
}
#[test]
fn test_mls_linear_field_inverse_distance() {
let (pts, _, _) = grid_pts(4);
let vals: Vec<f64> = pts.iter().map(|&(x, y)| x + y).collect();
let mls = MlsInterpolator::new(
&pts, &vals, BasisType::Linear, WeightFn::InverseDistance(2.0),
).expect("test: should succeed");
let v = mls.evaluate(0.4, 0.4).expect("test: should succeed");
assert_abs_diff_eq!(v, 0.8, epsilon = 0.1);
}
#[test]
fn test_mls_quadratic_field() {
let (pts, _, _) = grid_pts(5);
let vals: Vec<f64> = pts.iter().map(|&(x, y)| x * x + y * y).collect();
let mls = MlsInterpolator::new(
&pts, &vals, BasisType::Quadratic, WeightFn::Gaussian(2.0),
).expect("test: should succeed");
let v = mls.evaluate(0.5, 0.5).expect("test: should succeed");
assert_abs_diff_eq!(v, 0.5, epsilon = 0.05);
}
#[test]
fn test_mls_exact_at_node() {
let pts = vec![(0.0, 0.0), (1.0, 0.0), (0.0, 1.0), (1.0, 1.0), (0.5, 0.5)];
let vals = vec![0.0, 1.0, 2.0, 3.0, 1.5];
let mls = MlsInterpolator::new(&pts, &vals, BasisType::Linear, WeightFn::Gaussian(1.0))
.expect("test: should succeed");
let v = mls.evaluate(1.0, 0.0).expect("test: should succeed");
assert_abs_diff_eq!(v, 1.0, epsilon = 1e-6);
}
#[test]
fn test_mls_gradient_linear_field() {
let (pts, _, _) = grid_pts(5);
let (a, b) = (2.0_f64, 3.0_f64);
let vals: Vec<f64> = pts.iter().map(|&(x, y)| a * x + b * y).collect();
let mls = MlsInterpolator::new(&pts, &vals, BasisType::Linear, WeightFn::Gaussian(2.0))
.expect("test: should succeed");
let (dfdx, dfdy) = mls.gradient(0.5, 0.5).expect("test: should succeed");
assert_abs_diff_eq!(dfdx, a, epsilon = 0.1);
assert_abs_diff_eq!(dfdy, b, epsilon = 0.1);
}
#[test]
fn test_mls_gradient_constant_field() {
let (pts, _, _) = grid_pts(4);
let vals: Vec<f64> = pts.iter().map(|_| 5.0).collect();
let mls = MlsInterpolator::new(&pts, &vals, BasisType::Linear, WeightFn::Gaussian(2.0))
.expect("test: should succeed");
let (dfdx, dfdy) = mls.gradient(0.5, 0.5).expect("test: should succeed");
assert_abs_diff_eq!(dfdx, 0.0, epsilon = 1e-4);
assert_abs_diff_eq!(dfdy, 0.0, epsilon = 1e-4);
}
#[test]
fn test_mls_scattered_data_finite() {
let pts: Vec<(f64, f64)> = vec![
(0.1, 0.2), (0.5, 0.1), (0.9, 0.3),
(0.2, 0.7), (0.6, 0.8), (0.4, 0.5),
(0.8, 0.6), (0.3, 0.9), (0.7, 0.4),
];
let vals: Vec<f64> = pts.iter().map(|&(x, y)| x * x + y).collect();
let mls = MlsInterpolator::new(&pts, &vals, BasisType::Linear, WeightFn::Gaussian(1.0))
.expect("test: should succeed");
let v = mls.evaluate(0.4, 0.6).expect("test: should succeed");
assert!(v.is_finite(), "evaluate returned non-finite value");
}
#[test]
fn test_mls_scattered_gradient_finite() {
let pts: Vec<(f64, f64)> = vec![
(0.0, 0.0), (1.0, 0.0), (0.0, 1.0), (1.0, 1.0),
(0.5, 0.0), (0.0, 0.5), (1.0, 0.5), (0.5, 1.0), (0.5, 0.5),
];
let vals: Vec<f64> = pts.iter().map(|&(x, y)| (x + y).sin()).collect();
let mls = MlsInterpolator::new(&pts, &vals, BasisType::Linear, WeightFn::Gaussian(2.0))
.expect("test: should succeed");
let (gx, gy) = mls.gradient(0.3, 0.4).expect("test: should succeed");
assert!(gx.is_finite() && gy.is_finite());
}
#[test]
fn test_basis_size_2d() {
assert_eq!(BasisType::Constant.basis_size_2d(), 1);
assert_eq!(BasisType::Linear.basis_size_2d(), 3);
assert_eq!(BasisType::Quadratic.basis_size_2d(), 6);
}
#[test]
fn test_weight_fn_gaussian_zero_distance() {
assert_abs_diff_eq!(WeightFn::Gaussian(1.0).eval(0.0), 1.0, epsilon = 1e-12);
}
#[test]
fn test_weight_fn_wendland_at_bandwidth() {
assert_abs_diff_eq!(WeightFn::Wendland(1.0).eval(1.0), 0.0, epsilon = 1e-12);
}
#[test]
fn test_weight_fn_inverse_distance_zero() {
assert!(WeightFn::InverseDistance(2.0).eval(0.0).is_infinite());
}
}