use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct ThinPlateSpline<T>
where
T: Float + FromPrimitive + Debug,
{
centers: Array2<T>,
coeffs: Array1<T>,
poly_coeffs: Array1<T>,
smoothing: T,
basis_values: Option<Array2<T>>,
}
impl<T> ThinPlateSpline<T>
where
T: Float + FromPrimitive + Debug,
{
pub fn basis_values(&self) -> Option<&Array2<T>> {
self.basis_values.as_ref()
}
}
impl<T> ThinPlateSpline<T>
where
T: Float + FromPrimitive + Debug,
{
pub fn new(x: &ArrayView2<T>, y: &ArrayView1<T>, smoothing: T) -> InterpolateResult<Self> {
if x.nrows() != y.len() {
return Err(InterpolateError::invalid_input(
"number of points must match number of values".to_string(),
));
}
if smoothing < T::zero() {
return Err(InterpolateError::invalid_input(
"smoothing parameter must be non-negative".to_string(),
));
}
let n_points = x.nrows();
let n_dims = x.ncols();
if n_points < n_dims + 1 {
return Err(InterpolateError::invalid_input(format!(
"need at least {} points for {} dimensions",
n_dims + 1,
n_dims
)));
}
let poly_terms = 1 + n_dims;
let mut k = Array2::zeros((n_points, n_points));
for i in 0..n_points {
for j in 0..n_points {
if i == j {
k[(i, j)] = smoothing; } else {
let mut dist_sq = T::zero();
for d in 0..n_dims {
let diff = x[(i, d)] - x[(j, d)];
dist_sq = dist_sq + diff * diff;
}
k[(i, j)] = tps_kernel(dist_sq.sqrt());
}
}
}
let mut p = Array2::zeros((n_points, poly_terms));
for i in 0..n_points {
p[(i, 0)] = T::one();
for d in 0..n_dims {
p[(i, d + 1)] = x[(i, d)];
}
}
let mut a = Array2::zeros((n_points + poly_terms, n_points + poly_terms));
for i in 0..n_points {
for j in 0..n_points {
a[(i, j)] = k[(i, j)];
}
}
for i in 0..n_points {
for j in 0..poly_terms {
a[(i, n_points + j)] = p[(i, j)];
a[(n_points + j, i)] = p[(i, j)];
}
}
let mut b = Array1::zeros(n_points + poly_terms);
for i in 0..n_points {
b[i] = y[i];
}
let coeffs_full = {
use scirs2_linalg::solve;
let a_f64 = a.mapv(|x| x.to_f64().expect("Operation failed"));
let b_f64 = b.mapv(|x| x.to_f64().expect("Operation failed"));
match solve(&a_f64.view(), &b_f64.view(), None) {
Ok(solution) => solution.mapv(|x| T::from_f64(x).expect("Operation failed")),
Err(_) => {
use scirs2_linalg::lstsq;
match lstsq(&a_f64.view(), &b_f64.view(), None) {
Ok(result) => result.x.mapv(|x| T::from_f64(x).expect("Operation failed")),
Err(_) => {
return Err(InterpolateError::LinalgError(
"failed to solve linear system".to_string(),
));
}
}
}
}
};
let coeffs = coeffs_full.slice(s![0..n_points]).to_owned();
let poly_coeffs = coeffs_full.slice(s![n_points..]).to_owned();
Ok(ThinPlateSpline {
centers: x.to_owned(),
coeffs,
poly_coeffs,
smoothing,
basis_values: None,
})
}
pub fn evaluate(&self, x: &ArrayView2<T>) -> InterpolateResult<Array1<T>> {
if x.ncols() != self.centers.ncols() {
return Err(InterpolateError::DimensionMismatch(format!(
"expected {} dimensions, got {}",
self.centers.ncols(),
x.ncols()
)));
}
let n_eval = x.nrows();
let n_centers = self.centers.nrows();
let n_dims = self.centers.ncols();
let mut result = Array1::zeros(n_eval);
for i in 0..n_eval {
for j in 0..n_centers {
let mut dist_sq = T::zero();
for d in 0..n_dims {
let diff = x[(i, d)] - self.centers[(j, d)];
dist_sq = dist_sq + diff * diff;
}
let kernel_value = tps_kernel(dist_sq.sqrt());
result[i] = result[i] + self.coeffs[j] * kernel_value;
}
result[i] = result[i] + self.poly_coeffs[0];
for d in 0..n_dims {
result[i] = result[i] + self.poly_coeffs[d + 1] * x[(i, d)];
}
}
Ok(result)
}
pub fn with_smoothing(&self, smoothing: T) -> InterpolateResult<Self> {
if smoothing == self.smoothing {
return Ok(self.clone());
}
let y = self.get_values()?;
ThinPlateSpline::new(&self.centers.view(), &y.view(), smoothing)
}
fn get_values(&self) -> InterpolateResult<Array1<T>> {
let y = self.evaluate(&self.centers.view())?;
Ok(y)
}
}
#[allow(dead_code)]
fn tps_kernel<T: Float + FromPrimitive>(r: T) -> T {
if r == T::zero() {
return T::zero();
}
let r_sq = r * r;
if r_sq.is_zero() {
T::zero()
} else {
r_sq * r_sq.ln()
}
}
#[allow(dead_code)]
pub fn make_thinplate_interpolator<T>(
points: &ArrayView2<T>,
values: &ArrayView1<T>,
smoothing: T,
) -> InterpolateResult<ThinPlateSpline<T>>
where
T: Float + FromPrimitive + Debug,
{
ThinPlateSpline::new(points, values, smoothing)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_thinplate_exact_fit() {
let points = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
.expect("Operation failed");
let values = array![0.0, 1.0, 1.0, 2.0];
let tps = ThinPlateSpline::new(&points.view(), &values.view(), 0.0);
assert!(tps.is_ok());
let tps = tps.expect("Operation failed");
let result = tps.evaluate(&points.view());
assert!(result.is_ok());
let interpolated = result.expect("Operation failed");
for i in 0..values.len() {
assert!(
(interpolated[i] - values[i]).abs() < 1e-10,
"ThinPlateSpline should fit exactly at point {}: {} vs {}",
i,
interpolated[i],
values[i]
);
}
}
#[test]
fn test_thinplate_smoothing() {
let points = Array2::from_shape_vec(
(5, 2),
vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.5, 0.5],
)
.expect("Operation failed");
let values = array![0.0, 1.0, 1.0, 2.0, 0.6];
let tps_exact = ThinPlateSpline::new(&points.view(), &values.view(), 0.0);
let tps_smooth = ThinPlateSpline::new(&points.view(), &values.view(), 0.1);
assert!(tps_exact.is_ok());
assert!(tps_smooth.is_ok());
let tps_exact = tps_exact.expect("Operation failed");
let tps_smooth = tps_smooth.expect("Operation failed");
let _result_exact = tps_exact
.evaluate(&points.view())
.expect("Operation failed");
let result_smooth = tps_smooth
.evaluate(&points.view())
.expect("Operation failed");
for i in 0..values.len() {
assert!(
(result_smooth[i] - values[i]).abs() < 0.5,
"Smoothed TPS value at point {} should be close to original: {} vs {}",
i,
result_smooth[i],
values[i]
);
}
}
}