use crate::error::{InterpolateError, InterpolateResult};
use crate::traits::InterpolationFloat;
use scirs2_core::ndarray::{Array1, Array2};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RbfKernel {
Multiquadric(f64),
InverseMultiquadric(f64),
Gaussian(f64),
ThinPlateSpline,
Linear,
Cubic,
Quintic,
}
pub fn rbf_kernel(r: f64, kernel: RbfKernel) -> InterpolateResult<f64> {
match kernel {
RbfKernel::Multiquadric(c) => Ok((r * r + c * c).sqrt()),
RbfKernel::InverseMultiquadric(c) => {
let denom = (r * r + c * c).sqrt();
if denom < f64::EPSILON {
return Err(InterpolateError::NumericalError(
"InverseMultiquadric denominator is effectively zero".to_string(),
));
}
Ok(1.0 / denom)
}
RbfKernel::Gaussian(eps) => {
if eps.abs() < f64::EPSILON {
return Err(InterpolateError::invalid_input(
"Gaussian kernel requires eps > 0".to_string(),
));
}
Ok((-r * r / (eps * eps)).exp())
}
RbfKernel::ThinPlateSpline => {
if r < f64::EPSILON {
Ok(0.0)
} else {
Ok(r * r * r.ln())
}
}
RbfKernel::Linear => Ok(r),
RbfKernel::Cubic => Ok(r * r * r),
RbfKernel::Quintic => {
let r2 = r * r;
Ok(r2 * r2 * r)
}
}
}
fn eval_kernel_generic<F: InterpolationFloat>(r: F, kernel: RbfKernel) -> InterpolateResult<F> {
let r_f64 = r.to_f64().ok_or_else(|| {
InterpolateError::ComputationError("float conversion to f64 failed".to_string())
})?;
let v_f64 = rbf_kernel(r_f64, kernel)?;
F::from_f64(v_f64).ok_or_else(|| {
InterpolateError::ComputationError("float conversion from f64 failed".to_string())
})
}
fn poly_degree(kernel: RbfKernel) -> usize {
match kernel {
RbfKernel::InverseMultiquadric(_) | RbfKernel::Gaussian(_) => 0,
RbfKernel::Multiquadric(_) | RbfKernel::Linear => 1,
RbfKernel::ThinPlateSpline | RbfKernel::Cubic => 2,
RbfKernel::Quintic => 3,
}
}
fn poly_terms(degree: usize, dim: usize) -> usize {
match degree {
0 => 1,
1 => 1 + dim,
2 => 1 + dim + dim * (dim + 1) / 2,
_ => 1 + dim, }
}
fn euclidean_dist_rows<F: InterpolationFloat>(
a: &Array2<F>,
i: usize,
b: &Array2<F>,
j: usize,
) -> InterpolateResult<F> {
let d = a.ncols();
if b.ncols() != d {
return Err(InterpolateError::DimensionMismatch(
"dimension mismatch in distance computation".to_string(),
));
}
let mut sq = F::zero();
for k in 0..d {
let diff = a[[i, k]] - b[[j, k]];
sq = sq + diff * diff;
}
Ok(sq.sqrt())
}
fn euclidean_dist_point<F: InterpolationFloat>(a: &Array2<F>, i: usize, pt: &[F]) -> F {
let d = a.ncols().min(pt.len());
let mut sq = F::zero();
for k in 0..d {
let diff = a[[i, k]] - pt[k];
sq = sq + diff * diff;
}
sq.sqrt()
}
fn poly_basis_at<F: InterpolationFloat>(point: &[F], degree: usize) -> Vec<F> {
let d = point.len();
let mut basis = vec![F::one()]; if degree >= 1 {
for &xi in point.iter() {
basis.push(xi);
}
}
if degree >= 2 {
for i in 0..d {
for j in i..d {
basis.push(point[i] * point[j]);
}
}
}
basis
}
fn fill_poly_row<F: InterpolationFloat>(
p_mat: &mut Array2<F>,
row: usize,
points: &Array2<F>,
point_idx: usize,
degree: usize,
) {
let d = points.ncols();
let pt: Vec<F> = (0..d).map(|k| points[[point_idx, k]]).collect();
let basis = poly_basis_at(&pt, degree);
for (col, &b) in basis.iter().enumerate() {
p_mat[[row, col]] = b;
}
}
fn solve_linear_system<F: InterpolationFloat>(
a: &Array2<F>,
b: &Array1<F>,
) -> InterpolateResult<Array1<F>> {
let n = a.nrows();
if a.ncols() != n || b.len() != n {
return Err(InterpolateError::invalid_input(
"Matrix must be square and match RHS".to_string(),
));
}
let tiny = F::from_f64(1e-14).unwrap_or(F::epsilon());
let tiny_pivot = F::from_f64(1e-30).unwrap_or(F::epsilon());
let mut aug = Array2::<F>::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a[[i, j]];
}
aug[[i, n]] = b[i];
}
for col in 0..n {
let mut max_val = aug[[col, col]].abs();
let mut max_row = col;
for row in (col + 1)..n {
let v = aug[[row, col]].abs();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_val < tiny {
let reg = F::from_f64(1e-10).unwrap_or(F::epsilon());
aug[[col, col]] = aug[[col, col]] + reg;
}
if max_row != col {
for k in 0..=n {
let tmp = aug[[col, k]];
aug[[col, k]] = aug[[max_row, k]];
aug[[max_row, k]] = tmp;
}
}
let pivot = aug[[col, col]];
if pivot.abs() < tiny_pivot {
return Err(InterpolateError::LinalgError(
"Singular or near-singular RBF system matrix".to_string(),
));
}
for row in (col + 1)..n {
let factor = aug[[row, col]] / pivot;
for k in col..=n {
aug[[row, k]] = aug[[row, k]] - factor * aug[[col, k]];
}
}
}
let mut x = Array1::<F>::zeros(n);
for i in (0..n).rev() {
let mut s = aug[[i, n]];
for j in (i + 1)..n {
s = s - aug[[i, j]] * x[j];
}
let diag = aug[[i, i]];
if diag.abs() < tiny_pivot {
return Err(InterpolateError::LinalgError(
"Zero pivot in back substitution".to_string(),
));
}
x[i] = s / diag;
}
Ok(x)
}
pub fn auto_epsilon<F: InterpolationFloat>(points: &Array2<F>) -> InterpolateResult<F> {
let n = points.nrows();
if n <= 1 {
return Ok(F::one());
}
let mut sum_nn = F::zero();
for i in 0..n {
let mut min_d = F::infinity();
for j in 0..n {
if i == j {
continue;
}
let d = euclidean_dist_rows(points, i, points, j)?;
if d < min_d {
min_d = d;
}
}
sum_nn = sum_nn + min_d;
}
let n_f = F::from_usize(n).ok_or_else(|| {
InterpolateError::ComputationError("usize to float conversion failed".to_string())
})?;
let avg = sum_nn / n_f;
if avg < F::epsilon() {
Ok(F::one())
} else {
Ok(avg)
}
}
#[derive(Debug, Clone)]
pub struct RbfInterpolator<F: InterpolationFloat> {
centers: Array2<F>,
weights: Array1<F>,
kernel: RbfKernel,
degree: usize,
n_poly: usize,
dim: usize,
}
impl<F: InterpolationFloat> RbfInterpolator<F> {
pub fn new(
points: Array2<F>,
values: Array1<F>,
kernel: RbfKernel,
) -> InterpolateResult<Self> {
let n = points.nrows();
let d = points.ncols();
if values.len() != n {
return Err(InterpolateError::invalid_input(format!(
"points has {} rows but values has {} elements",
n,
values.len()
)));
}
if n == 0 {
return Err(InterpolateError::empty_data("RbfInterpolator"));
}
let degree = poly_degree(kernel);
let n_poly = poly_terms(degree, d);
let total = n + n_poly;
let mut mat = Array2::<F>::zeros((total, total));
for i in 0..n {
for j in 0..n {
let r = euclidean_dist_rows(&points, i, &points, j)?;
mat[[i, j]] = eval_kernel_generic(r, kernel)?;
}
}
if n_poly > 0 {
let mut p_mat = Array2::<F>::zeros((n, n_poly));
for i in 0..n {
fill_poly_row(&mut p_mat, i, &points, i, degree);
}
for i in 0..n {
for j in 0..n_poly {
mat[[i, n + j]] = p_mat[[i, j]];
mat[[n + j, i]] = p_mat[[i, j]];
}
}
}
let mut rhs = Array1::<F>::zeros(total);
for i in 0..n {
rhs[i] = values[i];
}
let weights = solve_linear_system(&mat, &rhs)?;
Ok(Self {
centers: points,
weights,
kernel,
degree,
n_poly,
dim: d,
})
}
pub fn evaluate_point(&self, point: &[F]) -> InterpolateResult<F> {
if point.len() != self.dim {
return Err(InterpolateError::DimensionMismatch(format!(
"expected {} dimensions, got {}",
self.dim,
point.len()
)));
}
let n = self.centers.nrows();
let mut val = F::zero();
for i in 0..n {
let r = euclidean_dist_point(&self.centers, i, point);
let phi = eval_kernel_generic(r, self.kernel)?;
val = val + self.weights[i] * phi;
}
if self.n_poly > 0 {
let basis = poly_basis_at(point, self.degree);
for (j, &b) in basis.iter().enumerate() {
val = val + self.weights[n + j] * b;
}
}
Ok(val)
}
pub fn interpolate(&self, query_points: &Array2<F>) -> InterpolateResult<Array1<F>> {
if query_points.ncols() != self.dim {
return Err(InterpolateError::DimensionMismatch(format!(
"expected {} dimensions, got {}",
self.dim,
query_points.ncols()
)));
}
let m = query_points.nrows();
let mut result = Array1::<F>::zeros(m);
for i in 0..m {
let pt: Vec<F> = (0..self.dim).map(|k| query_points[[i, k]]).collect();
result[i] = self.evaluate_point(&pt)?;
}
Ok(result)
}
pub fn kernel(&self) -> RbfKernel {
self.kernel
}
pub fn centers(&self) -> &Array2<F> {
&self.centers
}
pub fn weights(&self) -> &Array1<F> {
&self.weights
}
pub fn dim(&self) -> usize {
self.dim
}
}
#[derive(Debug, Clone)]
pub struct RbfSmoothing<F: InterpolationFloat> {
inner: RbfInterpolator<F>,
lambda: F,
}
impl<F: InterpolationFloat> RbfSmoothing<F> {
pub fn new(
points: Array2<F>,
values: Array1<F>,
kernel: RbfKernel,
lambda: F,
) -> InterpolateResult<Self> {
if lambda < F::zero() {
return Err(InterpolateError::invalid_input(
"regularization parameter lambda must be non-negative".to_string(),
));
}
let n = points.nrows();
let d = points.ncols();
if values.len() != n {
return Err(InterpolateError::invalid_input(format!(
"points has {} rows but values has {} elements",
n,
values.len()
)));
}
if n == 0 {
return Err(InterpolateError::empty_data("RbfSmoothing"));
}
let degree = poly_degree(kernel);
let n_poly = poly_terms(degree, d);
let total = n + n_poly;
let mut mat = Array2::<F>::zeros((total, total));
for i in 0..n {
for j in 0..n {
let r = euclidean_dist_rows(&points, i, &points, j)?;
mat[[i, j]] = eval_kernel_generic(r, kernel)?;
}
mat[[i, i]] = mat[[i, i]] + lambda;
}
if n_poly > 0 {
let mut p_mat = Array2::<F>::zeros((n, n_poly));
for i in 0..n {
fill_poly_row(&mut p_mat, i, &points, i, degree);
}
for i in 0..n {
for j in 0..n_poly {
mat[[i, n + j]] = p_mat[[i, j]];
mat[[n + j, i]] = p_mat[[i, j]];
}
}
}
let mut rhs = Array1::<F>::zeros(total);
for i in 0..n {
rhs[i] = values[i];
}
let weights = solve_linear_system(&mat, &rhs)?;
let inner = RbfInterpolator {
centers: points,
weights,
kernel,
degree,
n_poly,
dim: d,
};
Ok(Self { inner, lambda })
}
pub fn interpolate(&self, query_points: &Array2<F>) -> InterpolateResult<Array1<F>> {
self.inner.interpolate(query_points)
}
pub fn evaluate_point(&self, point: &[F]) -> InterpolateResult<F> {
self.inner.evaluate_point(point)
}
pub fn lambda(&self) -> F {
self.lambda
}
pub fn inner(&self) -> &RbfInterpolator<F> {
&self.inner
}
}
pub fn rbf_1d(
x_train: &Array1<f64>,
y_train: &Array1<f64>,
kernel: RbfKernel,
x_query: &Array1<f64>,
) -> InterpolateResult<Array1<f64>> {
let n = x_train.len();
if y_train.len() != n {
return Err(InterpolateError::invalid_input(format!(
"x_train has {} elements but y_train has {}",
n,
y_train.len()
)));
}
if n == 0 {
return Err(InterpolateError::empty_data("rbf_1d"));
}
let points = Array2::from_shape_vec(
(n, 1),
x_train.iter().copied().collect::<Vec<f64>>(),
)
.map_err(|e| InterpolateError::ShapeError(e.to_string()))?;
let interp = RbfInterpolator::new(points, y_train.clone(), kernel)?;
let m = x_query.len();
let query = Array2::from_shape_vec(
(m, 1),
x_query.iter().copied().collect::<Vec<f64>>(),
)
.map_err(|e| InterpolateError::ShapeError(e.to_string()))?;
interp.interpolate(&query)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::{Array1, Array2};
fn make_2d_data() -> (Array2<f64>, Array1<f64>) {
let points = Array2::from_shape_vec(
(5, 2),
vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.5, 0.5],
)
.expect("shape");
let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0, 1.0]);
(points, values)
}
fn make_1d_data() -> (Array2<f64>, Array1<f64>) {
let points =
Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0]).expect("shape");
let values = Array1::from_vec(vec![0.0, 1.0, 4.0, 9.0, 16.0]);
(points, values)
}
#[test]
fn test_rbf_kernel_tps_zero_at_origin() {
let v = rbf_kernel(0.0, RbfKernel::ThinPlateSpline).expect("eval");
assert_eq!(v, 0.0);
}
#[test]
fn test_rbf_kernel_tps_positive() {
let v = rbf_kernel(1.0, RbfKernel::ThinPlateSpline).expect("eval");
assert_abs_diff_eq!(v, 0.0, epsilon = 1e-10); }
#[test]
fn test_rbf_kernel_tps_at_two() {
let r = 2.0_f64;
let v = rbf_kernel(r, RbfKernel::ThinPlateSpline).expect("eval");
assert_abs_diff_eq!(v, r * r * r.ln(), epsilon = 1e-12);
}
#[test]
fn test_rbf_kernel_gaussian_at_origin() {
let v = rbf_kernel(0.0, RbfKernel::Gaussian(1.0)).expect("eval");
assert_abs_diff_eq!(v, 1.0, epsilon = 1e-12);
}
#[test]
fn test_rbf_kernel_gaussian_decay() {
let v1 = rbf_kernel(1.0, RbfKernel::Gaussian(1.0)).expect("eval");
let v2 = rbf_kernel(2.0, RbfKernel::Gaussian(1.0)).expect("eval");
assert!(v1 > v2, "Gaussian should decay with distance");
assert!(v1 > 0.0);
}
#[test]
fn test_rbf_kernel_gaussian_zero_eps_error() {
assert!(rbf_kernel(1.0, RbfKernel::Gaussian(0.0)).is_err());
}
#[test]
fn test_rbf_kernel_multiquadric_at_origin() {
let c = 1.5;
let v = rbf_kernel(0.0, RbfKernel::Multiquadric(c)).expect("eval");
assert_abs_diff_eq!(v, c, epsilon = 1e-12);
}
#[test]
fn test_rbf_kernel_inv_multiquadric_at_origin() {
let c = 2.0;
let v = rbf_kernel(0.0, RbfKernel::InverseMultiquadric(c)).expect("eval");
assert_abs_diff_eq!(v, 1.0 / c, epsilon = 1e-12);
}
#[test]
fn test_rbf_kernel_linear() {
let v = rbf_kernel(3.5, RbfKernel::Linear).expect("eval");
assert_abs_diff_eq!(v, 3.5, epsilon = 1e-12);
}
#[test]
fn test_rbf_kernel_cubic() {
let v = rbf_kernel(2.0, RbfKernel::Cubic).expect("eval");
assert_abs_diff_eq!(v, 8.0, epsilon = 1e-12);
}
#[test]
fn test_rbf_kernel_quintic() {
let v = rbf_kernel(2.0, RbfKernel::Quintic).expect("eval");
assert_abs_diff_eq!(v, 32.0, epsilon = 1e-12);
}
#[test]
fn test_rbf_tps_interpolates_at_centers() {
let (pts, vals) = make_2d_data();
let interp = RbfInterpolator::new(pts.clone(), vals.clone(), RbfKernel::ThinPlateSpline)
.expect("construction");
for i in 0..pts.nrows() {
let pt = vec![pts[[i, 0]], pts[[i, 1]]];
let v = interp.evaluate_point(&pt).expect("eval");
assert_abs_diff_eq!(v, vals[i], epsilon = 1e-5);
}
}
#[test]
fn test_rbf_gaussian_interpolates_at_centers() {
let (pts, vals) = make_2d_data();
let interp = RbfInterpolator::new(pts.clone(), vals.clone(), RbfKernel::Gaussian(1.0))
.expect("construction");
for i in 0..pts.nrows() {
let pt = vec![pts[[i, 0]], pts[[i, 1]]];
let v = interp.evaluate_point(&pt).expect("eval");
assert_abs_diff_eq!(v, vals[i], epsilon = 1e-4);
}
}
#[test]
fn test_rbf_multiquadric_interpolates_at_centers() {
let (pts, vals) = make_2d_data();
let interp = RbfInterpolator::new(pts.clone(), vals.clone(), RbfKernel::Multiquadric(1.0))
.expect("construction");
for i in 0..pts.nrows() {
let pt = vec![pts[[i, 0]], pts[[i, 1]]];
let v = interp.evaluate_point(&pt).expect("eval");
assert_abs_diff_eq!(v, vals[i], epsilon = 1e-4);
}
}
#[test]
fn test_rbf_inv_multiquadric_interpolates_at_centers() {
let (pts, vals) = make_2d_data();
let interp =
RbfInterpolator::new(pts.clone(), vals.clone(), RbfKernel::InverseMultiquadric(1.0))
.expect("construction");
for i in 0..pts.nrows() {
let pt = vec![pts[[i, 0]], pts[[i, 1]]];
let v = interp.evaluate_point(&pt).expect("eval");
assert_abs_diff_eq!(v, vals[i], epsilon = 1e-4);
}
}
#[test]
fn test_rbf_cubic_interpolates_at_centers() {
let (pts, vals) = make_2d_data();
let interp = RbfInterpolator::new(pts.clone(), vals.clone(), RbfKernel::Cubic)
.expect("construction");
for i in 0..pts.nrows() {
let pt = vec![pts[[i, 0]], pts[[i, 1]]];
let v = interp.evaluate_point(&pt).expect("eval");
assert_abs_diff_eq!(v, vals[i], epsilon = 1e-5);
}
}
#[test]
fn test_rbf_quintic_interpolates_at_centers() {
let (pts, vals) = make_2d_data();
let interp = RbfInterpolator::new(pts.clone(), vals.clone(), RbfKernel::Quintic)
.expect("construction");
for i in 0..pts.nrows() {
let pt = vec![pts[[i, 0]], pts[[i, 1]]];
let v = interp.evaluate_point(&pt).expect("eval");
assert_abs_diff_eq!(v, vals[i], epsilon = 1e-4);
}
}
#[test]
fn test_rbf_linear_interpolates_at_centers() {
let (pts, vals) = make_2d_data();
let interp = RbfInterpolator::new(pts.clone(), vals.clone(), RbfKernel::Linear)
.expect("construction");
for i in 0..pts.nrows() {
let pt = vec![pts[[i, 0]], pts[[i, 1]]];
let v = interp.evaluate_point(&pt).expect("eval");
assert_abs_diff_eq!(v, vals[i], epsilon = 1e-5);
}
}
#[test]
fn test_rbf_batch_interpolate() {
let (pts, vals) = make_2d_data();
let interp = RbfInterpolator::new(pts.clone(), vals.clone(), RbfKernel::ThinPlateSpline)
.expect("construction");
let result = interp.interpolate(&pts).expect("batch eval");
for i in 0..vals.len() {
assert_abs_diff_eq!(result[i], vals[i], epsilon = 1e-5);
}
}
#[test]
fn test_rbf_dimension_mismatch() {
let (pts, vals) = make_2d_data();
let interp = RbfInterpolator::new(pts, vals, RbfKernel::ThinPlateSpline)
.expect("construction");
let bad_query = Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).expect("shape");
assert!(interp.interpolate(&bad_query).is_err());
}
#[test]
fn test_rbf_length_mismatch_error() {
let pts = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0])
.expect("shape");
let vals = Array1::from_vec(vec![0.0, 1.0]); assert!(RbfInterpolator::new(pts, vals, RbfKernel::Gaussian(1.0)).is_err());
}
#[test]
fn test_rbf_empty_data_error() {
let pts = Array2::<f64>::zeros((0, 2));
let vals = Array1::<f64>::zeros(0);
assert!(RbfInterpolator::new(pts, vals, RbfKernel::ThinPlateSpline).is_err());
}
#[test]
fn test_rbf_1d_accessor() {
let (pts, vals) = make_1d_data();
let interp = RbfInterpolator::new(pts.clone(), vals.clone(), RbfKernel::ThinPlateSpline)
.expect("construction");
assert_eq!(interp.dim(), 1);
assert_eq!(interp.kernel(), RbfKernel::ThinPlateSpline);
assert_eq!(interp.centers().nrows(), 5);
}
#[test]
fn test_rbf_3d_data() {
let points = Array2::from_shape_vec(
(4, 3),
vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
)
.expect("shape");
let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 1.0]);
let interp = RbfInterpolator::new(points.clone(), values.clone(), RbfKernel::Gaussian(1.0))
.expect("construction");
for i in 0..points.nrows() {
let pt = vec![points[[i, 0]], points[[i, 1]], points[[i, 2]]];
let v = interp.evaluate_point(&pt).expect("eval");
assert_abs_diff_eq!(v, values[i], epsilon = 1e-4);
}
}
#[test]
fn test_smoothing_reduces_exact_fit() {
let (pts, vals) = make_2d_data();
let smoother =
RbfSmoothing::new(pts.clone(), vals.clone(), RbfKernel::ThinPlateSpline, 1e-3)
.expect("construction");
let result = smoother.interpolate(&pts).expect("eval");
for i in 0..vals.len() {
assert!((result[i] - vals[i]).abs() < 0.5, "index {i}");
}
}
#[test]
fn test_smoothing_lambda_zero_exact() {
let (pts, vals) = make_2d_data();
let smoother =
RbfSmoothing::new(pts.clone(), vals.clone(), RbfKernel::ThinPlateSpline, 0.0)
.expect("construction");
let result = smoother.interpolate(&pts).expect("eval");
for i in 0..vals.len() {
assert_abs_diff_eq!(result[i], vals[i], epsilon = 1e-5);
}
}
#[test]
fn test_smoothing_negative_lambda_error() {
let (pts, vals) = make_2d_data();
assert!(RbfSmoothing::new(pts, vals, RbfKernel::Gaussian(1.0), -1.0).is_err());
}
#[test]
fn test_smoothing_lambda_accessor() {
let (pts, vals) = make_2d_data();
let s = RbfSmoothing::new(pts, vals, RbfKernel::ThinPlateSpline, 0.05).expect("ok");
assert_abs_diff_eq!(s.lambda(), 0.05, epsilon = 1e-12);
}
#[test]
fn test_rbf_1d_interpolates_at_training() {
let x_train = Array1::from_vec(vec![0.0_f64, 1.0, 2.0, 3.0, 4.0]);
let y_train = Array1::from_vec(vec![0.0_f64, 1.0, 4.0, 9.0, 16.0]);
let result = rbf_1d(&x_train, &y_train, RbfKernel::ThinPlateSpline, &x_train)
.expect("rbf_1d");
for i in 0..x_train.len() {
assert_abs_diff_eq!(result[i], y_train[i], epsilon = 1e-4);
}
}
#[test]
fn test_rbf_1d_between_points_finite() {
let x_train = Array1::from_vec(vec![0.0_f64, 1.0, 2.0, 3.0, 4.0]);
let y_train = Array1::from_vec(vec![0.0_f64, 1.0, 4.0, 9.0, 16.0]);
let x_query = Array1::from_vec(vec![0.5_f64, 1.5, 2.5, 3.5]);
let result = rbf_1d(&x_train, &y_train, RbfKernel::Gaussian(1.0), &x_query)
.expect("rbf_1d");
assert!(result.iter().all(|v| v.is_finite()), "all finite");
}
#[test]
fn test_rbf_1d_length_mismatch_error() {
let x_train = Array1::from_vec(vec![0.0_f64, 1.0, 2.0]);
let y_train = Array1::from_vec(vec![0.0_f64, 1.0]); let x_query = Array1::from_vec(vec![0.5_f64]);
assert!(rbf_1d(&x_train, &y_train, RbfKernel::Gaussian(1.0), &x_query).is_err());
}
#[test]
fn test_rbf_1d_empty_error() {
let x_train = Array1::<f64>::zeros(0);
let y_train = Array1::<f64>::zeros(0);
let x_query = Array1::from_vec(vec![0.5_f64]);
assert!(rbf_1d(&x_train, &y_train, RbfKernel::Gaussian(1.0), &x_query).is_err());
}
#[test]
fn test_auto_epsilon_positive() {
let pts = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.expect("shape");
let eps = auto_epsilon::<f64>(&pts).expect("auto_eps");
assert!(eps > 0.0);
assert!(eps.is_finite());
}
#[test]
fn test_auto_epsilon_single_point() {
let pts = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).expect("shape");
let eps = auto_epsilon::<f64>(&pts).expect("auto_eps");
assert_abs_diff_eq!(eps, 1.0, epsilon = 1e-12);
}
}