use crate::interpolate::error::{InterpolateError, InterpolateResult};
use crate::interpolate::impl_generic::bspline::build_knot_vector_tensor;
use crate::interpolate::impl_generic::rect_bivariate_spline::rect_bivariate_spline_evaluate_impl;
use crate::interpolate::traits::bspline::BSplineBoundary;
use crate::interpolate::traits::rect_bivariate_spline::BivariateSpline;
use numr::algorithm::linalg::LinearAlgebraAlgorithms;
use numr::ops::{CompareOps, ScalarOps, UtilityOps};
use numr::prelude::DType;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use super::bspline::compute_basis_matrix;
fn on_device_linspace<R, C>(
client: &C,
min_t: &Tensor<R>,
max_t: &Tensor<R>,
n: usize,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + UtilityOps<R> + RuntimeClient<R>,
{
let device = client.device();
if n == 1 {
return Ok(min_t.reshape(&[1])?.contiguous()?);
}
let steps = client.arange(0.0, n as f64, 1.0, DType::F64)?; let denom = Tensor::full_scalar(&[1], DType::F64, (n - 1) as f64, device);
let t = client.div(&steps, &denom)?; let range = client.sub(max_t, min_t)?; let range_broad = range.broadcast_to(&[n])?.contiguous()?;
let min_broad = min_t.broadcast_to(&[n])?.contiguous()?;
Ok(client.add(&min_broad, &client.mul(&range_broad, &t)?)?)
}
#[allow(clippy::too_many_arguments)]
pub fn smooth_bivariate_spline_fit_impl<R, C>(
client: &C,
x: &Tensor<R>,
y: &Tensor<R>,
z: &Tensor<R>,
weights: Option<&Tensor<R>>,
smoothing: f64,
kx: usize,
ky: usize,
) -> InterpolateResult<BivariateSpline<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + UtilityOps<R> + LinearAlgebraAlgorithms<R> + RuntimeClient<R>,
{
let device = client.device();
let m = x.shape()[0];
if y.shape()[0] != m || z.shape()[0] != m {
return Err(InterpolateError::ShapeMismatch {
expected: m,
actual: y.shape()[0].min(z.shape()[0]),
context: "smooth_bivariate_spline_fit: x, y, z must have same length".to_string(),
});
}
if m < (kx + 1) * (ky + 1) {
return Err(InterpolateError::InsufficientData {
required: (kx + 1) * (ky + 1),
actual: m,
context: "smooth_bivariate_spline_fit".to_string(),
});
}
let max_per_axis = (m as f64).sqrt().floor() as usize;
let nx_grid = max_per_axis.max(kx + 1);
let ny_grid = max_per_axis.max(ky + 1);
let x_sorted = client.sort(x, 0, false)?;
let y_sorted = client.sort(y, 0, false)?;
let x_min = x_sorted.narrow(0, 0, 1)?;
let x_max = x_sorted.narrow(0, m - 1, 1)?;
let y_min = y_sorted.narrow(0, 0, 1)?;
let y_max = y_sorted.narrow(0, m - 1, 1)?;
let x_grid = on_device_linspace(client, &x_min, &x_max, nx_grid)?;
let y_grid = on_device_linspace(client, &y_min, &y_max, ny_grid)?;
let knots_x =
build_knot_vector_tensor(client, &x_grid, kx, &BSplineBoundary::NotAKnot, nx_grid)?;
let knots_y =
build_knot_vector_tensor(client, &y_grid, ky, &BSplineBoundary::NotAKnot, ny_grid)?;
let ncx = knots_x.shape()[0] - kx - 1;
let ncy = knots_y.shape()[0] - ky - 1;
let n_coeffs = ncx * ncy;
let bx = compute_basis_matrix(client, x, &knots_x, kx, ncx)?; let by = compute_basis_matrix(client, y, &knots_y, ky, ncy)?;
let bx_expanded = bx
.unsqueeze(1)?
.broadcast_to(&[m, ncy, ncx])?
.contiguous()?; let by_expanded = by
.unsqueeze(2)?
.broadcast_to(&[m, ncy, ncx])?
.contiguous()?; let a_3d = client.mul(&bx_expanded, &by_expanded)?; let a = a_3d.reshape(&[m, n_coeffs])?;
let z_col = z.reshape(&[m, 1])?;
let (a_weighted, z_weighted) = if let Some(w) = weights {
let w_col = w.reshape(&[m, 1])?;
let w_broad = w_col.broadcast_to(&[m, n_coeffs])?.contiguous()?;
(client.mul(&a, &w_broad)?, client.mul(&z_col, &w_col)?)
} else {
(a.clone(), z_col.clone())
};
let coeffs_flat = if smoothing <= 0.0 {
LinearAlgebraAlgorithms::lstsq(client, &a_weighted, &z_weighted).map_err(|e| {
InterpolateError::NumericalError {
message: format!("lstsq failed: {}", e),
}
})?
} else {
let sqrt_lambda = smoothing.sqrt();
let eye_vals = Tensor::full_scalar(&[n_coeffs], DType::F64, sqrt_lambda, device);
let penalty = LinearAlgebraAlgorithms::diagflat(client, &eye_vals)?;
let zero_rhs = Tensor::zeros(&[n_coeffs, 1], DType::F64, device);
let a_aug = client.cat(&[&a_weighted, &penalty], 0)?; let z_aug = client.cat(&[&z_weighted, &zero_rhs], 0)?;
LinearAlgebraAlgorithms::lstsq(client, &a_aug, &z_aug).map_err(|e| {
InterpolateError::NumericalError {
message: format!("lstsq failed: {}", e),
}
})?
};
let coefficients = coeffs_flat
.reshape(&[ncy, ncx])?
.transpose(0, 1)?
.contiguous()?;
Ok(BivariateSpline {
knots_x,
knots_y,
coefficients,
degree_x: kx,
degree_y: ky,
})
}
pub fn smooth_bivariate_spline_evaluate_impl<R, C>(
client: &C,
spline: &BivariateSpline<R>,
xi: &Tensor<R>,
yi: &Tensor<R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
{
rect_bivariate_spline_evaluate_impl(client, spline, xi, yi)
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(device, client)
}
#[test]
fn test_smooth_interpolation_exact() {
let (device, client) = setup();
let n = 25;
let mut xv = Vec::with_capacity(n);
let mut yv = Vec::with_capacity(n);
let mut zv = Vec::with_capacity(n);
for i in 0..5 {
for j in 0..5 {
xv.push(i as f64);
yv.push(j as f64);
zv.push(i as f64 + j as f64);
}
}
let x = Tensor::<CpuRuntime>::from_slice(&xv, &[n], &device);
let y = Tensor::<CpuRuntime>::from_slice(&yv, &[n], &device);
let z = Tensor::<CpuRuntime>::from_slice(&zv, &[n], &device);
let spline =
smooth_bivariate_spline_fit_impl(&client, &x, &y, &z, None, 0.0, 3, 3).unwrap();
let xi = Tensor::<CpuRuntime>::from_slice(&[1.0, 2.0, 3.0], &[3], &device);
let yi = Tensor::<CpuRuntime>::from_slice(&[1.0, 2.0, 3.0], &[3], &device);
let result = smooth_bivariate_spline_evaluate_impl(&client, &spline, &xi, &yi).unwrap();
let vals: Vec<f64> = result.to_vec();
for (i, &v) in vals.iter().enumerate() {
let expected = (i + 1) as f64 * 2.0;
assert!(
(v - expected).abs() < 0.5,
"point {}: got {} expected {}",
i,
v,
expected
);
}
}
#[test]
fn test_smoothing_reduces_noise() {
let (device, client) = setup();
let n = 36;
let mut xv = Vec::with_capacity(n);
let mut yv = Vec::with_capacity(n);
let mut zv_noisy = Vec::with_capacity(n);
let mut zv_true = Vec::with_capacity(n);
let noise = [
0.3, -0.2, 0.5, -0.4, 0.1, -0.3, 0.2, -0.1, 0.4, -0.5, 0.3, -0.2, 0.1, -0.3, 0.4, -0.1,
0.2, -0.4, 0.5, -0.2, 0.3, -0.1, 0.4, -0.3, 0.2, -0.5, 0.1, -0.2, 0.3, -0.4, 0.5, -0.1,
0.2, -0.3, 0.4, -0.2,
];
for i in 0..6 {
for j in 0..6 {
let idx = i * 6 + j;
xv.push(i as f64);
yv.push(j as f64);
let true_val = i as f64 + j as f64;
zv_true.push(true_val);
zv_noisy.push(true_val + noise[idx]);
}
}
let x = Tensor::<CpuRuntime>::from_slice(&xv, &[n], &device);
let y = Tensor::<CpuRuntime>::from_slice(&yv, &[n], &device);
let z_noisy = Tensor::<CpuRuntime>::from_slice(&zv_noisy, &[n], &device);
let spline_smooth =
smooth_bivariate_spline_fit_impl(&client, &x, &y, &z_noisy, None, 0.01, 3, 3).unwrap();
let xi = Tensor::<CpuRuntime>::from_slice(&[2.0, 3.0], &[2], &device);
let yi = Tensor::<CpuRuntime>::from_slice(&[2.0, 3.0], &[2], &device);
let result =
smooth_bivariate_spline_evaluate_impl(&client, &spline_smooth, &xi, &yi).unwrap();
let vals: Vec<f64> = result.to_vec();
assert!(
(vals[0] - 4.0).abs() < 1.0,
"smoothed at (2,2): {} vs 4.0",
vals[0]
);
assert!(
(vals[1] - 6.0).abs() < 1.0,
"smoothed at (3,3): {} vs 6.0",
vals[1]
);
}
}