use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use std::marker::PhantomData;
use crate::error::{InterpolateError, InterpolateResult};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum WeightFunction {
Gaussian,
WendlandC2,
InverseDistance,
CubicSpline,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PolynomialBasis {
Constant,
Linear,
Quadratic,
}
#[derive(Debug, Clone)]
pub struct MovingLeastSquares<F>
where
F: Float + FromPrimitive + Debug + 'static + std::cmp::PartialOrd,
{
points: Array2<F>,
values: Array1<F>,
weight_fn: WeightFunction,
basis: PolynomialBasis,
bandwidth: F,
epsilon: F,
max_points: Option<usize>,
_phantom: PhantomData<F>,
}
impl<F> MovingLeastSquares<F>
where
F: Float + FromPrimitive + Debug + 'static + std::cmp::PartialOrd,
{
pub fn new(
points: Array2<F>,
values: Array1<F>,
weight_fn: WeightFunction,
basis: PolynomialBasis,
bandwidth: F,
) -> InterpolateResult<Self> {
if points.shape()[0] != values.len() {
return Err(InterpolateError::DimensionMismatch(
"Number of points must match number of values".to_string(),
));
}
if points.shape()[0] < 2 {
return Err(InterpolateError::InvalidValue(
"At least 2 points are required for MLS interpolation".to_string(),
));
}
if bandwidth <= F::zero() {
return Err(InterpolateError::InvalidValue(
"Bandwidth parameter must be positive".to_string(),
));
}
Ok(Self {
points,
values,
weight_fn,
basis,
bandwidth,
epsilon: F::from_f64(1e-10).expect("Operation failed"),
max_points: None,
_phantom: PhantomData,
})
}
pub fn with_max_points(mut self, maxpoints: usize) -> Self {
self.max_points = Some(maxpoints);
self
}
pub fn with_epsilon(mut self, epsilon: F) -> Self {
self.epsilon = epsilon;
self
}
pub fn evaluate(&self, x: &ArrayView1<F>) -> InterpolateResult<F> {
if x.len() != self.points.shape()[1] {
return Err(InterpolateError::DimensionMismatch(
"Query point dimension must match training points".to_string(),
));
}
let (indices, distances) = self.find_relevant_points(x)?;
if indices.is_empty() {
return Err(InterpolateError::invalid_input(
"No points found within effective range".to_string(),
));
}
let weights = self.compute_weights(&distances)?;
let basis_functions = self.create_basis_functions(&indices, x)?;
let result = self.solve_weighted_least_squares(&indices, &weights, &basis_functions, x)?;
Ok(result)
}
pub fn evaluate_multi(&self, points: &ArrayView2<F>) -> InterpolateResult<Array1<F>> {
if points.shape()[1] != self.points.shape()[1] {
return Err(InterpolateError::DimensionMismatch(
"Query points dimension must match training points".to_string(),
));
}
let n_points = points.shape()[0];
let mut results = Array1::zeros(n_points);
for i in 0..n_points {
let point = points.slice(scirs2_core::ndarray::s![i, ..]);
results[i] = self.evaluate(&point)?;
}
Ok(results)
}
fn find_relevant_points(&self, x: &ArrayView1<F>) -> InterpolateResult<(Vec<usize>, Vec<F>)> {
let n_points = self.points.shape()[0];
let n_dims = self.points.shape()[1];
let mut distances = Vec::with_capacity(n_points);
for i in 0..n_points {
let mut d_squared = F::zero();
for j in 0..n_dims {
let diff = x[j] - self.points[[i, j]];
d_squared = d_squared + diff * diff;
}
let dist = d_squared.sqrt();
distances.push((i, dist));
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let limit = match self.max_points {
Some(limit) => std::cmp::min(limit, n_points),
None => n_points,
};
let effective_radius = match self.weight_fn {
WeightFunction::WendlandC2 | WeightFunction::CubicSpline => self.bandwidth,
_ => F::infinity(),
};
let mut indices = Vec::new();
let mut dist_values = Vec::new();
for &(idx, dist) in distances.iter().take(limit) {
if dist <= effective_radius {
indices.push(idx);
dist_values.push(dist);
}
}
let min_points = match self.basis {
PolynomialBasis::Constant => 1,
PolynomialBasis::Linear => n_dims + 1,
PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
};
if indices.len() < min_points {
indices = distances
.iter()
.take(min_points)
.map(|&(idx, _)| idx)
.collect();
dist_values = distances
.iter()
.take(min_points)
.map(|&(_, dist)| dist)
.collect();
}
Ok((indices, dist_values))
}
fn compute_weights(&self, distances: &[F]) -> InterpolateResult<Array1<F>> {
let n = distances.len();
let mut weights = Array1::zeros(n);
for (i, &d) in distances.iter().enumerate() {
let r = d / self.bandwidth;
let weight = match self.weight_fn {
WeightFunction::Gaussian => (-r * r).exp(),
WeightFunction::WendlandC2 => {
if r < F::one() {
let t = F::one() - r;
let factor = F::from_f64(4.0).expect("Operation failed") * r + F::one();
t.powi(4) * factor
} else {
F::zero()
}
}
WeightFunction::InverseDistance => F::one() / (self.epsilon + r * r),
WeightFunction::CubicSpline => {
if r < F::from_f64(1.0 / 3.0).expect("Operation failed") {
let r2 = r * r;
let r3 = r2 * r;
F::from_f64(2.0 / 3.0).expect("Operation failed")
- F::from_f64(9.0).expect("Operation failed") * r2
+ F::from_f64(19.0).expect("Operation failed") * r3
} else if r < F::one() {
let t = F::from_f64(2.0).expect("Operation failed")
- F::from_f64(3.0).expect("Operation failed") * r;
F::from_f64(1.0 / 3.0).expect("Operation failed") * t.powi(3)
} else {
F::zero()
}
}
};
weights[i] = weight;
}
let sum = weights.sum();
if sum > F::zero() {
weights.mapv_inplace(|w| w / sum);
} else {
weights.fill(F::from_f64(1.0 / n as f64).expect("Operation failed"));
}
Ok(weights)
}
fn create_basis_functions(
&self,
indices: &[usize],
x: &ArrayView1<F>,
) -> InterpolateResult<Array2<F>> {
let n_points = indices.len();
let n_dims = x.len();
let n_basis = match self.basis {
PolynomialBasis::Constant => 1,
PolynomialBasis::Linear => n_dims + 1,
PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
};
let mut basis = Array2::zeros((n_points, n_basis));
for (i, &idx) in indices.iter().enumerate() {
let point = self.points.row(idx);
let mut col = 0;
basis[[i, col]] = F::one();
col += 1;
if self.basis == PolynomialBasis::Linear || self.basis == PolynomialBasis::Quadratic {
for j in 0..n_dims {
basis[[i, col]] = point[j];
col += 1;
}
}
if self.basis == PolynomialBasis::Quadratic {
for j in 0..n_dims {
for k in j..n_dims {
basis[[i, col]] = point[j] * point[k];
col += 1;
}
}
}
}
Ok(basis)
}
fn create_query_basis(&self, x: &ArrayView1<F>) -> InterpolateResult<Array1<F>> {
let n_dims = x.len();
let n_basis = match self.basis {
PolynomialBasis::Constant => 1,
PolynomialBasis::Linear => n_dims + 1,
PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
};
let mut basis = Array1::zeros(n_basis);
let mut col = 0;
basis[col] = F::one();
col += 1;
if self.basis == PolynomialBasis::Linear || self.basis == PolynomialBasis::Quadratic {
for j in 0..n_dims {
basis[col] = x[j];
col += 1;
}
}
if self.basis == PolynomialBasis::Quadratic {
for j in 0..n_dims {
for k in j..n_dims {
basis[col] = x[j] * x[k];
col += 1;
}
}
}
Ok(basis)
}
fn solve_weighted_least_squares(
&self,
indices: &[usize],
weights: &Array1<F>,
basis: &Array2<F>,
x: &ArrayView1<F>,
) -> InterpolateResult<F> {
let n_points = indices.len();
let n_basis = basis.shape()[1];
let mut w_basis = Array2::zeros((n_points, n_basis));
let mut w_values = Array1::zeros(n_points);
for i in 0..n_points {
let sqrt_w = weights[i].sqrt();
for j in 0..n_basis {
w_basis[[i, j]] = basis[[i, j]] * sqrt_w;
}
w_values[i] = self.values[indices[i]] * sqrt_w;
}
#[cfg(feature = "linalg")]
let btb = w_basis.t().dot(&w_basis);
#[cfg(not(feature = "linalg"))]
let _btb = w_basis.t().dot(&w_basis);
#[allow(unused_variables)]
let bty = w_basis.t().dot(&w_values);
#[cfg(feature = "linalg")]
let coeffs = {
use scirs2_linalg::solve;
let btb_f64 = btb.mapv(|x| x.to_f64().expect("Operation failed"));
let bty_f64 = bty.mapv(|x| x.to_f64().expect("Operation failed"));
match solve(&btb_f64.view(), &bty_f64.view(), None) {
Ok(c) => c.mapv(|x| F::from_f64(x).expect("Operation failed")),
Err(_) => {
let mut mean = F::zero();
let mut sum_weights = F::zero();
for (i, &idx) in indices.iter().enumerate() {
mean = mean + weights[i] * self.values[idx];
sum_weights = sum_weights + weights[i];
}
if sum_weights > F::zero() {
let mut fallback_coeffs = Array1::zeros(bty.len());
fallback_coeffs[0] = mean / sum_weights;
fallback_coeffs
} else {
return Err(InterpolateError::ComputationError(
"Failed to solve weighted least squares system".to_string(),
));
}
}
}
};
#[cfg(not(feature = "linalg"))]
let coeffs = {
let mut result = Array1::zeros(bty.len());
let mut mean = F::zero();
let mut sum_weights = F::zero();
for (i, &idx) in indices.iter().enumerate() {
mean = mean + weights[i] * self.values[idx];
sum_weights = sum_weights + weights[i];
}
if sum_weights > F::zero() {
result[0] = mean / sum_weights;
}
result
};
let query_basis = self.create_query_basis(x)?;
let result = query_basis.dot(&coeffs);
Ok(result)
}
pub fn weight_fn(&self) -> WeightFunction {
self.weight_fn
}
pub fn bandwidth(&self) -> F {
self.bandwidth
}
pub fn points(&self) -> &Array2<F> {
&self.points
}
pub fn values(&self) -> &Array1<F> {
&self.values
}
pub fn basis(&self) -> PolynomialBasis {
self.basis
}
pub fn max_points(&self) -> Option<usize> {
self.max_points
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_mls_constant_basis() {
let points = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.expect("Operation failed");
let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
let mls = MovingLeastSquares::new(
points,
values,
WeightFunction::Gaussian,
PolynomialBasis::Constant,
0.5,
)
.expect("Operation failed");
let center = array![0.5, 0.5];
let val = mls.evaluate(¢er.view()).expect("Operation failed");
assert_abs_diff_eq!(val, 1.0, epsilon = 0.1);
}
#[test]
fn test_mls_linear_basis() {
let points = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.expect("Operation failed");
let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
let mls = MovingLeastSquares::new(
points,
values,
WeightFunction::Gaussian,
PolynomialBasis::Linear,
1.0,
)
.expect("Operation failed");
let test_points = Array2::from_shape_vec(
(5, 2),
vec![
0.5, 0.5, 0.25, 0.25, 0.75, 0.25, 0.25, 0.75, 0.75, 0.75, ],
)
.expect("Operation failed");
let expected = Array1::from_vec(vec![1.0, 0.5, 1.0, 1.0, 1.5]);
let results = mls
.evaluate_multi(&test_points.view())
.expect("Operation failed");
for (result, expect) in results.iter().zip(expected.iter()) {
assert_abs_diff_eq!(result, expect, epsilon = 0.5);
}
}
#[test]
fn test_different_weight_functions() {
let points = Array2::from_shape_vec(
(6, 2),
vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.3, 0.3, 0.7, 0.7],
)
.expect("Operation failed");
let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0, 0.6, 1.4]);
let weight_fns = [WeightFunction::Gaussian, WeightFunction::InverseDistance];
let query = array![0.5, 0.5];
let expected = 0.5 + 0.5;
for &weight_fn in &weight_fns {
let mls = MovingLeastSquares::new(
points.clone(),
values.clone(),
weight_fn,
PolynomialBasis::Linear, 2.0, )
.expect("Operation failed");
let result = mls.evaluate(&query.view());
match result {
Ok(val) => {
if val.is_finite() {
assert!((val - expected).abs() < 0.5,
"Weight function {:?}: result {:.6} differs too much from expected {:.6}",
weight_fn, val, expected);
} else {
panic!(
"Weight function {:?} produced non-finite result: {}",
weight_fn, val
);
}
}
Err(e) => {
panic!("Weight function {:?} failed with error: {}", weight_fn, e);
}
}
}
}
}