use crate::interpolate::error::{InterpolateError, InterpolateResult};
use crate::interpolate::impl_generic::bspline::{build_knot_vector_tensor, compute_basis_matrix};
use crate::interpolate::traits::bspline::BSplineBoundary;
use crate::interpolate::traits::rect_bivariate_spline::BivariateSpline;
use numr::algorithm::linalg::LinearAlgebraAlgorithms;
use numr::ops::{CompareOps, ScalarOps};
use numr::prelude::DType;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[allow(clippy::too_many_arguments)]
pub fn rect_bivariate_spline_fit_impl<R, C>(
client: &C,
x: &Tensor<R>,
y: &Tensor<R>,
z: &Tensor<R>,
degree_x: usize,
degree_y: usize,
boundary: &BSplineBoundary,
) -> InterpolateResult<BivariateSpline<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + LinearAlgebraAlgorithms<R> + RuntimeClient<R>,
{
let nx = x.shape()[0];
let ny = y.shape()[0];
if z.shape().len() != 2 || z.shape()[0] != nx || z.shape()[1] != ny {
return Err(InterpolateError::ShapeMismatch {
expected: nx * ny,
actual: z.shape().iter().product(),
context: "rect_bivariate_spline_fit: z must be [nx, ny]".to_string(),
});
}
if nx <= degree_x || ny <= degree_y {
return Err(InterpolateError::InsufficientData {
required: (degree_x + 1).max(degree_y + 1),
actual: nx.min(ny),
context: "rect_bivariate_spline_fit: need at least degree+1 points per axis"
.to_string(),
});
}
let knots_x = build_knot_vector_tensor(client, x, degree_x, boundary, nx)?;
let knots_y = build_knot_vector_tensor(client, y, degree_y, boundary, ny)?;
let ncx = knots_x.shape()[0] - degree_x - 1;
let ncy = knots_y.shape()[0] - degree_y - 1;
let bx = compute_basis_matrix(client, x, &knots_x, degree_x, ncx)?; let by = compute_basis_matrix(client, y, &knots_y, degree_y, ncy)?;
let a = LinearAlgebraAlgorithms::kron(client, &by, &bx)?;
let z_t = z.transpose(0, 1)?.contiguous()?; let z_flat = z_t.reshape(&[nx * ny, 1])?;
let coeffs_flat = LinearAlgebraAlgorithms::solve(client, &a, &z_flat).map_err(|e| {
InterpolateError::NumericalError {
message: format!("Failed to solve bivariate spline system: {}", e),
}
})?;
let coefficients = coeffs_flat
.reshape(&[ncy, ncx])?
.transpose(0, 1)?
.contiguous()?;
Ok(BivariateSpline {
knots_x,
knots_y,
coefficients,
degree_x,
degree_y,
})
}
pub fn rect_bivariate_spline_evaluate_impl<R, C>(
client: &C,
spline: &BivariateSpline<R>,
xi: &Tensor<R>,
yi: &Tensor<R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
{
let m = xi.shape()[0];
if yi.shape()[0] != m {
return Err(InterpolateError::ShapeMismatch {
expected: m,
actual: yi.shape()[0],
context: "rect_bivariate_spline_evaluate: xi and yi must have same length".to_string(),
});
}
let ncx = spline.knots_x.shape()[0] - spline.degree_x - 1;
let ncy = spline.knots_y.shape()[0] - spline.degree_y - 1;
let bx = compute_basis_matrix(client, xi, &spline.knots_x, spline.degree_x, ncx)?; let by = compute_basis_matrix(client, yi, &spline.knots_y, spline.degree_y, ncy)?;
let tmp = client.matmul(&bx, &spline.coefficients)?; let product = client.mul(&tmp, &by)?; let result = client.sum(&product, &[1], false)?;
Ok(result)
}
pub fn rect_bivariate_spline_evaluate_grid_impl<R, C>(
client: &C,
spline: &BivariateSpline<R>,
xi: &Tensor<R>,
yi: &Tensor<R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
{
let ncx = spline.knots_x.shape()[0] - spline.degree_x - 1;
let ncy = spline.knots_y.shape()[0] - spline.degree_y - 1;
let bx = compute_basis_matrix(client, xi, &spline.knots_x, spline.degree_x, ncx)?; let by = compute_basis_matrix(client, yi, &spline.knots_y, spline.degree_y, ncy)?;
let tmp = client.matmul(&bx, &spline.coefficients)?; let by_t = by.transpose(0, 1)?.contiguous()?; let result = client.matmul(&tmp, &by_t)?;
Ok(result)
}
pub fn rect_bivariate_spline_partial_derivative_impl<R, C>(
client: &C,
spline: &BivariateSpline<R>,
xi: &Tensor<R>,
yi: &Tensor<R>,
dx: usize,
dy: usize,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
{
let m = xi.shape()[0];
let device = client.device();
if yi.shape()[0] != m {
return Err(InterpolateError::ShapeMismatch {
expected: m,
actual: yi.shape()[0],
context: "rect_bivariate_spline_partial_derivative".to_string(),
});
}
if dx > spline.degree_x || dy > spline.degree_y {
return Ok(Tensor::zeros(&[m], DType::F64, device));
}
let (knots_x_d, coeffs_dx, degree_x_d) = differentiate_2d_x(
client,
&spline.knots_x,
&spline.coefficients,
spline.degree_x,
dx,
)?;
let (knots_y_d, coeffs_dxy, degree_y_d) =
differentiate_2d_y(client, &spline.knots_y, &coeffs_dx, spline.degree_y, dy)?;
let ncx_d = knots_x_d.shape()[0] - degree_x_d - 1;
let ncy_d = knots_y_d.shape()[0] - degree_y_d - 1;
let bx = compute_basis_matrix(client, xi, &knots_x_d, degree_x_d, ncx_d)?;
let by = compute_basis_matrix(client, yi, &knots_y_d, degree_y_d, ncy_d)?;
let tmp = client.matmul(&bx, &coeffs_dxy)?;
let product = client.mul(&tmp, &by)?;
let result = client.sum(&product, &[1], false)?;
Ok(result)
}
pub fn rect_bivariate_spline_integrate_impl<R, C>(
client: &C,
spline: &BivariateSpline<R>,
xa: f64,
xb: f64,
ya: f64,
yb: f64,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + LinearAlgebraAlgorithms<R> + RuntimeClient<R>,
{
let ncx = spline.knots_x.shape()[0] - spline.degree_x - 1;
let ncy = spline.knots_y.shape()[0] - spline.degree_y - 1;
let ix = integrate_basis(client, &spline.knots_x, spline.degree_x, ncx, xa, xb)?; let iy = integrate_basis(client, &spline.knots_y, spline.degree_y, ncy, ya, yb)?;
let ix_row = ix.reshape(&[1, ncx])?;
let iy_col = iy.reshape(&[ncy, 1])?;
let tmp = client.matmul(&ix_row, &spline.coefficients)?; let result = client.matmul(&tmp, &iy_col)?; Ok(result.reshape(&[1])?)
}
fn differentiate_2d_x<R, C>(
client: &C,
knots_x: &Tensor<R>,
coefficients: &Tensor<R>,
degree_x: usize,
order: usize,
) -> InterpolateResult<(Tensor<R>, Tensor<R>, usize)>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
{
if order == 0 {
return Ok((knots_x.clone(), coefficients.clone(), degree_x));
}
let ncy = coefficients.shape()[1];
let mut current_knots = knots_x.clone();
let mut current_coeffs = coefficients.clone(); let mut current_degree = degree_x;
for _ in 0..order {
if current_degree == 0 {
break;
}
let n = current_coeffs.shape()[0];
let k = current_degree;
let n_knots = current_knots.shape()[0];
let c_hi = current_coeffs.narrow(0, 1, n - 1)?.contiguous()?; let c_lo = current_coeffs.narrow(0, 0, n - 1)?.contiguous()?; let dc = client.sub(&c_hi, &c_lo)?;
let t_hi = current_knots.narrow(0, k + 1, n - 1)?.contiguous()?; let t_lo = current_knots.narrow(0, 1, n - 1)?.contiguous()?; let dt = client.sub(&t_hi, &t_lo)?;
let dt_col = dt
.reshape(&[n - 1, 1])?
.broadcast_to(&[n - 1, ncy])?
.contiguous()?;
let eps = Tensor::full_scalar(&[n - 1, ncy], DType::F64, 1e-300, client.device());
let abs_dt = client.abs(&dt_col)?;
let dt_safe = client.maximum(&abs_dt, &eps)?;
let zero = Tensor::zeros(&[n - 1, ncy], DType::F64, client.device());
let mask = client.gt(&abs_dt, &zero)?;
let new_coeffs =
client.mul_scalar(&client.mul(&client.div(&dc, &dt_safe)?, &mask)?, k as f64)?;
let new_knots = current_knots.narrow(0, 1, n_knots - 2)?.contiguous()?;
current_coeffs = new_coeffs;
current_knots = new_knots;
current_degree -= 1;
}
Ok((current_knots, current_coeffs, current_degree))
}
fn differentiate_2d_y<R, C>(
client: &C,
knots_y: &Tensor<R>,
coefficients: &Tensor<R>,
degree_y: usize,
order: usize,
) -> InterpolateResult<(Tensor<R>, Tensor<R>, usize)>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
{
if order == 0 {
return Ok((knots_y.clone(), coefficients.clone(), degree_y));
}
let c_t = coefficients.transpose(0, 1)?.contiguous()?;
let (knots_d, c_d, degree_d) = differentiate_2d_x(client, knots_y, &c_t, degree_y, order)?;
let c_result = c_d.transpose(0, 1)?.contiguous()?;
Ok((knots_d, c_result, degree_d))
}
fn integrate_basis<R, C>(
client: &C,
knots: &Tensor<R>,
degree: usize,
n_coeffs: usize,
a: f64,
b: f64,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + LinearAlgebraAlgorithms<R> + RuntimeClient<R>,
{
let device = client.device();
let k = degree;
let n_knots = knots.shape()[0];
let first = knots.narrow(0, 0, 1)?.contiguous()?;
let last = knots.narrow(0, n_knots - 1, 1)?.contiguous()?;
let anti_knots = client.cat(&[&first, knots, &last], 0)?;
let ncx_anti = anti_knots.shape()[0] - (k + 1) - 1;
let t_hi = knots.narrow(0, k + 1, n_coeffs)?.contiguous()?;
let t_lo = knots.narrow(0, 0, n_coeffs)?.contiguous()?;
let dt_scaled = client.mul_scalar(&client.sub(&t_hi, &t_lo)?, 1.0 / (k + 1) as f64)?;
let terms_matrix = LinearAlgebraAlgorithms::diagflat(client, &dt_scaled)?; let cumsum_matrix = client.cumsum(&terms_matrix, 0)?; let zero_row = Tensor::zeros(&[1, n_coeffs], DType::F64, device);
let anti_coeffs_all = client.cat(&[&zero_row, &cumsum_matrix], 0)?;
let ab = Tensor::from_slice(&[b, a], &[2], device);
let basis_ab = compute_basis_matrix(client, &ab, &anti_knots, k + 1, ncx_anti)?;
let vals = client.matmul(&basis_ab, &anti_coeffs_all)?;
let vals_b = vals.narrow(0, 0, 1)?.contiguous()?.reshape(&[n_coeffs])?;
let vals_a = vals.narrow(0, 1, 1)?.contiguous()?.reshape(&[n_coeffs])?;
Ok(client.sub(&vals_b, &vals_a)?)
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(device, client)
}
#[test]
fn test_bilinear_exact() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0, 3.0], &[4], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0], &[3], &device);
let mut z_data = vec![0.0f64; 12];
for i in 0..4 {
for j in 0..3 {
z_data[i * 3 + j] = i as f64 + 2.0 * j as f64;
}
}
let z = Tensor::<CpuRuntime>::from_slice(&z_data, &[4, 3], &device);
let spline =
rect_bivariate_spline_fit_impl(&client, &x, &y, &z, 1, 1, &BSplineBoundary::NotAKnot)
.expect("fit failed");
let xi = Tensor::<CpuRuntime>::from_slice(&[0.5, 1.5, 2.5], &[3], &device);
let yi = Tensor::<CpuRuntime>::from_slice(&[0.5, 1.0, 1.5], &[3], &device);
let result = rect_bivariate_spline_evaluate_impl(&client, &spline, &xi, &yi).unwrap();
let vals: Vec<f64> = result.to_vec();
let expected = [0.5 + 1.0, 1.5 + 2.0, 2.5 + 3.0];
for (i, (&v, &e)) in vals.iter().zip(expected.iter()).enumerate() {
assert!(
(v - e).abs() < 1e-8,
"point {}: got {} expected {}",
i,
v,
e
);
}
}
#[test]
fn test_grid_evaluation() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0, 3.0], &[4], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0, 3.0], &[4], &device);
let mut z_data = vec![0.0f64; 16];
for i in 0..4 {
for j in 0..4 {
z_data[i * 4 + j] = i as f64 * j as f64;
}
}
let z = Tensor::<CpuRuntime>::from_slice(&z_data, &[4, 4], &device);
let spline =
rect_bivariate_spline_fit_impl(&client, &x, &y, &z, 3, 3, &BSplineBoundary::NotAKnot)
.expect("fit failed");
let xi = Tensor::<CpuRuntime>::from_slice(&[0.5, 1.5], &[2], &device);
let yi = Tensor::<CpuRuntime>::from_slice(&[0.5, 1.5, 2.5], &[3], &device);
let grid = rect_bivariate_spline_evaluate_grid_impl(&client, &spline, &xi, &yi).unwrap();
assert_eq!(grid.shape(), &[2, 3]);
let vals: Vec<f64> = grid.to_vec();
assert!(
(vals[0] - 0.25).abs() < 0.1,
"grid[0,0]: {} vs 0.25",
vals[0]
);
assert!(
(vals[5] - 3.75).abs() < 0.1,
"grid[1,2]: {} vs 3.75",
vals[5]
);
}
#[test]
fn test_partial_derivative() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0], &[5], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0], &[5], &device);
let mut z_data = vec![0.0f64; 25];
for i in 0..5 {
for j in 0..5 {
z_data[i * 5 + j] = (i * i + j * j) as f64;
}
}
let z = Tensor::<CpuRuntime>::from_slice(&z_data, &[5, 5], &device);
let spline =
rect_bivariate_spline_fit_impl(&client, &x, &y, &z, 3, 3, &BSplineBoundary::NotAKnot)
.expect("fit failed");
let xi = Tensor::<CpuRuntime>::from_slice(&[2.0], &[1], &device);
let yi = Tensor::<CpuRuntime>::from_slice(&[1.0], &[1], &device);
let dzdx = rect_bivariate_spline_partial_derivative_impl(&client, &spline, &xi, &yi, 1, 0)
.unwrap();
let val: Vec<f64> = dzdx.to_vec();
assert!(
(val[0] - 4.0).abs() < 0.5,
"dz/dx at (2,1): {} vs 4.0",
val[0]
);
let xi2 = Tensor::<CpuRuntime>::from_slice(&[1.0], &[1], &device);
let yi2 = Tensor::<CpuRuntime>::from_slice(&[2.0], &[1], &device);
let dzdy =
rect_bivariate_spline_partial_derivative_impl(&client, &spline, &xi2, &yi2, 0, 1)
.unwrap();
let val2: Vec<f64> = dzdy.to_vec();
assert!(
(val2[0] - 4.0).abs() < 0.5,
"dz/dy at (1,2): {} vs 4.0",
val2[0]
);
}
#[test]
fn test_integrate_constant() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0, 3.0], &[4], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0], &[3], &device);
let z = Tensor::<CpuRuntime>::from_slice(&[1.0f64; 12], &[4, 3], &device);
let spline =
rect_bivariate_spline_fit_impl(&client, &x, &y, &z, 1, 1, &BSplineBoundary::NotAKnot)
.expect("fit failed");
let result =
rect_bivariate_spline_integrate_impl(&client, &spline, 0.0, 3.0, 0.0, 2.0).unwrap();
let val: Vec<f64> = result.to_vec();
assert!((val[0] - 6.0).abs() < 1e-6, "integral: {} vs 6.0", val[0]);
}
}