use crate::error::{SpatialError, SpatialResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::random::{thread_rng, Rng};
#[allow(dead_code)]
fn solve_linear_system(a: Array2<f64>, b: Array1<f64>) -> SpatialResult<Array1<f64>> {
let n = a.nrows();
if n != a.ncols() {
return Err(SpatialError::DimensionError(
"Matrix A must be square".to_string(),
));
}
if n != b.len() {
return Err(SpatialError::DimensionError(
"Matrix A and vector b dimensions must match".to_string(),
));
}
let mut x = Array1::zeros(n);
let mut a_reg = a.clone();
for i in 0..n {
a_reg[[i, i]] += 1e-10;
}
let mut aug = Array2::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a_reg[[i, j]];
}
aug[[i, n]] = b[i];
}
for i in 0..n {
let mut max_row = i;
let mut max_val = aug[[i, i]].abs();
for j in i + 1..n {
if aug[[j, i]].abs() > max_val {
max_row = j;
max_val = aug[[j, i]].abs();
}
}
if max_val < 1e-10 {
return Err(SpatialError::ComputationError(
"Matrix is singular or nearly singular".to_string(),
));
}
if max_row != i {
for j in 0..=n {
let temp = aug[[i, j]];
aug[[i, j]] = aug[[max_row, j]];
aug[[max_row, j]] = temp;
}
}
for j in i + 1..n {
let factor = aug[[j, i]] / aug[[i, i]];
aug[[j, i]] = 0.0;
for k in i + 1..=n {
aug[[j, k]] -= factor * aug[[i, k]];
}
}
}
for i in (0..n).rev() {
x[i] = aug[[i, n]];
for j in i + 1..n {
x[i] -= aug[[i, j]] * x[j];
}
x[i] /= aug[[i, i]];
}
Ok(x)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RBFKernel {
Gaussian,
Multiquadric,
InverseMultiquadric,
ThinPlateSpline,
Linear,
Cubic,
}
impl RBFKernel {
fn apply(&self, r: f64, epsilon: f64) -> f64 {
match self {
RBFKernel::Gaussian => (-epsilon * epsilon * r * r).exp(),
RBFKernel::Multiquadric => (1.0 + (epsilon * r).powi(2)).sqrt(),
RBFKernel::InverseMultiquadric => 1.0 / (1.0 + (epsilon * r).powi(2)).sqrt(),
RBFKernel::ThinPlateSpline => {
if r < 1e-10 {
0.0
} else {
r * r * r.ln()
}
}
RBFKernel::Linear => r,
RBFKernel::Cubic => r.powi(3),
}
}
}
#[derive(Debug, Clone)]
pub struct RBFInterpolator {
points: Array2<f64>,
_values: Array1<f64>,
dim: usize,
n_points: usize,
kernel: RBFKernel,
epsilon: f64,
polynomial: bool,
weights: Array1<f64>,
poly_coefs: Option<Array1<f64>>,
}
impl RBFInterpolator {
pub fn new(
points: &ArrayView2<'_, f64>,
values: &ArrayView1<f64>,
kernel: RBFKernel,
epsilon: Option<f64>,
polynomial: Option<bool>,
) -> SpatialResult<Self> {
let n_points = points.nrows();
let dim = points.ncols();
if n_points != values.len() {
return Err(SpatialError::DimensionError(format!(
"Number of points ({}) must match number of values ({})",
n_points,
values.len()
)));
}
if n_points < dim + 1 {
return Err(SpatialError::ValueError(format!(
"At least {} points required for {}D interpolation",
dim + 1,
dim
)));
}
let epsilon = epsilon.unwrap_or_else(|| Self::default_epsilon(kernel, points));
let polynomial = polynomial.unwrap_or(false);
let (weights, poly_coefs) =
Self::solve_rbf_system(points, values, kernel, epsilon, polynomial)?;
Ok(Self {
points: points.to_owned(),
_values: values.to_owned(),
dim,
n_points,
kernel,
epsilon,
polynomial,
weights,
poly_coefs,
})
}
fn default_epsilon(kernel: RBFKernel, points: &ArrayView2<'_, f64>) -> f64 {
match kernel {
RBFKernel::Gaussian => {
let avg_dist = Self::average_distance(points);
if avg_dist > 0.0 {
1.0 / (2.0 * avg_dist * avg_dist)
} else {
1.0
}
}
RBFKernel::Multiquadric | RBFKernel::InverseMultiquadric => {
let avg_dist = Self::average_distance(points);
if avg_dist > 0.0 {
1.0 / avg_dist
} else {
1.0
}
}
_ => 1.0,
}
}
fn average_distance(points: &ArrayView2<'_, f64>) -> f64 {
let n_points = points.nrows();
if n_points <= 1 {
return 0.0;
}
let max_pairs = 1000;
let mut total_dist = 0.0;
let mut n_pairs = 0;
if n_points * (n_points - 1) / 2 <= max_pairs {
for i in 0..n_points {
for j in (i + 1)..n_points {
let pi = points.row(i);
let pj = points.row(j);
total_dist += Self::euclidean_distance(&pi, &pj);
n_pairs += 1;
}
}
} else {
let mut rng = thread_rng();
let mut seen_pairs = std::collections::HashSet::new();
for _ in 0..max_pairs {
let i = rng.random_range(0..n_points);
let j = rng.random_range(0..n_points);
if i != j {
let pair = if i < j { (i, j) } else { (j, i) };
if !seen_pairs.contains(&pair) {
seen_pairs.insert(pair);
let pi = points.row(i);
let pj = points.row(j);
total_dist += Self::euclidean_distance(&pi, &pj);
n_pairs += 1;
}
}
}
}
if n_pairs > 0 {
total_dist / (n_pairs as f64)
} else {
1.0
}
}
fn solve_rbf_system(
points: &ArrayView2<'_, f64>,
values: &ArrayView1<f64>,
kernel: RBFKernel,
epsilon: f64,
polynomial: bool,
) -> SpatialResult<(Array1<f64>, Option<Array1<f64>>)> {
let n_points = points.nrows();
let dim = points.ncols();
if !polynomial {
let mut a = Array2::zeros((n_points, n_points));
for i in 0..n_points {
let pi = points.row(i);
for j in 0..n_points {
let pj = points.row(j);
let dist = Self::euclidean_distance(&pi, &pj);
a[[i, j]] = kernel.apply(dist, epsilon);
}
}
let trans_a = a.t();
let ata = trans_a.dot(&a);
let atb = trans_a.dot(&values.to_owned());
let weights = solve_linear_system(ata, atb);
match weights {
Ok(weights) => Ok((weights, None)),
Err(e) => Err(SpatialError::ComputationError(format!(
"Failed to solve RBF system: {e}"
))),
}
} else {
let poly_terms = dim + 1;
let mut aug_matrix = Array2::zeros((n_points + poly_terms, n_points + poly_terms));
let mut aug_values = Array1::zeros(n_points + poly_terms);
for i in 0..n_points {
let pi = points.row(i);
for j in 0..n_points {
let pj = points.row(j);
let dist = Self::euclidean_distance(&pi, &pj);
aug_matrix[[i, j]] = kernel.apply(dist, epsilon);
}
}
for i in 0..n_points {
aug_matrix[[i, n_points]] = 1.0;
aug_matrix[[n_points, i]] = 1.0;
for j in 0..dim {
aug_matrix[[i, n_points + 1 + j]] = points[[i, j]];
aug_matrix[[n_points + 1 + j, i]] = points[[i, j]];
}
}
for i in 0..n_points {
aug_values[i] = values[i];
}
let trans_a = aug_matrix.t();
let ata = trans_a.dot(&aug_matrix);
let atb = trans_a.dot(&aug_values);
let solution = solve_linear_system(ata, atb);
match solution {
Ok(solution) => {
let weights = solution
.slice(scirs2_core::ndarray::s![0..n_points])
.to_owned();
let poly_coefs = solution
.slice(scirs2_core::ndarray::s![n_points..])
.to_owned();
Ok((weights, Some(poly_coefs)))
}
Err(e) => Err(SpatialError::ComputationError(format!(
"Failed to solve augmented RBF system: {e}"
))),
}
}
}
pub fn interpolate(&self, point: &ArrayView1<f64>) -> SpatialResult<f64> {
if point.len() != self.dim {
return Err(SpatialError::DimensionError(format!(
"Query _point has dimension {}, expected {}",
point.len(),
self.dim
)));
}
let mut result = 0.0;
for i in 0..self.n_points {
let pi = self.points.row(i);
let dist = Self::euclidean_distance(&pi, point);
result += self.weights[i] * self.kernel.apply(dist, self.epsilon);
}
if let Some(ref poly_coefs) = self.poly_coefs {
result += poly_coefs[0];
for j in 0..self.dim {
result += poly_coefs[j + 1] * point[j];
}
}
Ok(result)
}
pub fn interpolate_many(&self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array1<f64>> {
if points.ncols() != self.dim {
return Err(SpatialError::DimensionError(format!(
"Query _points have dimension {}, expected {}",
points.ncols(),
self.dim
)));
}
let n_queries = points.nrows();
let mut results = Array1::zeros(n_queries);
for i in 0..n_queries {
let point = points.row(i);
results[i] = self.interpolate(&point)?;
}
Ok(results)
}
pub fn kernel(&self) -> RBFKernel {
self.kernel
}
pub fn epsilon(&self) -> f64 {
self.epsilon
}
pub fn has_polynomial(&self) -> bool {
self.polynomial
}
fn euclidean_distance(p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> f64 {
let mut sum_sq = 0.0;
for i in 0..p1.len().min(p2.len()) {
let diff = p1[i] - p2[i];
sum_sq += diff * diff;
}
sum_sq.sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_rbf_interpolation_basic() {
let points = array![
[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], ];
let values = array![0.0, 1.0, 1.0, 2.0];
let kernels = [
RBFKernel::Gaussian,
RBFKernel::Multiquadric,
RBFKernel::InverseMultiquadric,
RBFKernel::ThinPlateSpline,
RBFKernel::Linear,
RBFKernel::Cubic,
];
for kernel in &kernels {
let interp = RBFInterpolator::new(&points.view(), &values.view(), *kernel, None, None)
.expect("Operation failed");
let val_00 = interp
.interpolate(&array![0.0, 0.0].view())
.expect("Operation failed");
let val_10 = interp
.interpolate(&array![1.0, 0.0].view())
.expect("Operation failed");
let val_01 = interp
.interpolate(&array![0.0, 1.0].view())
.expect("Operation failed");
let val_11 = interp
.interpolate(&array![1.0, 1.0].view())
.expect("Operation failed");
assert_relative_eq!(val_00, 0.0, epsilon = 1e-6);
assert_relative_eq!(val_10, 1.0, epsilon = 1e-6);
assert_relative_eq!(val_01, 1.0, epsilon = 1e-6);
assert_relative_eq!(val_11, 2.0, epsilon = 1e-6);
let val_center = interp
.interpolate(&array![0.5, 0.5].view())
.expect("Operation failed");
assert!(val_center.is_finite());
}
}
#[test]
fn test_rbf_with_polynomial() {
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
let values = array![1.0, 3.0, 4.0, 6.0];
let interp = RBFInterpolator::new(
&points.view(),
&values.view(),
RBFKernel::Gaussian,
Some(1.0),
Some(true),
)
.expect("Operation failed");
assert!(interp.has_polynomial());
let val_00 = interp
.interpolate(&array![0.0, 0.0].view())
.expect("Operation failed");
let val_10 = interp
.interpolate(&array![1.0, 0.0].view())
.expect("Operation failed");
let val_01 = interp
.interpolate(&array![0.0, 1.0].view())
.expect("Operation failed");
let val_11 = interp
.interpolate(&array![1.0, 1.0].view())
.expect("Operation failed");
assert_relative_eq!(val_00, 1.0, epsilon = 1e-6);
assert_relative_eq!(val_10, 3.0, epsilon = 1e-6);
assert_relative_eq!(val_01, 4.0, epsilon = 1e-6);
assert_relative_eq!(val_11, 6.0, epsilon = 1e-6);
let val_new = interp
.interpolate(&array![2.0, 2.0].view())
.expect("Operation failed");
assert_relative_eq!(val_new, 11.0, epsilon = 0.1);
}
#[test]
fn test_interpolate_many() {
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
let values = array![0.0, 1.0, 1.0, 2.0];
let interp = RBFInterpolator::new(
&points.view(),
&values.view(),
RBFKernel::Gaussian,
None,
None,
)
.expect("Operation failed");
let query_points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5],];
let results = interp
.interpolate_many(&query_points.view())
.expect("Operation failed");
assert_eq!(results.len(), 5);
assert_relative_eq!(results[0], 0.0, epsilon = 1e-6);
assert_relative_eq!(results[1], 1.0, epsilon = 1e-6);
assert_relative_eq!(results[2], 1.0, epsilon = 1e-6);
assert_relative_eq!(results[3], 2.0, epsilon = 1e-6);
assert_relative_eq!(results[4], 1.0, epsilon = 0.1);
}
#[test]
fn test_error_handling() {
let points = array![[0.0, 0.0]];
let values = array![0.0];
let result = RBFInterpolator::new(
&points.view(),
&values.view(),
RBFKernel::Gaussian,
None,
None,
);
assert!(result.is_err());
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let values = array![0.0, 1.0];
let result = RBFInterpolator::new(
&points.view(),
&values.view(),
RBFKernel::Gaussian,
None,
None,
);
assert!(result.is_err());
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let values = array![0.0, 1.0, 2.0];
let interp = RBFInterpolator::new(
&points.view(),
&values.view(),
RBFKernel::Gaussian,
None,
None,
)
.expect("Operation failed");
let result = interp.interpolate(&array![0.0, 0.0, 0.0].view());
assert!(result.is_err());
}
}