use crate::error::{InterpolateError, InterpolateResult};
#[derive(Debug, Clone)]
pub struct BicubicInterp {
pub x_grid: Vec<f64>,
pub y_grid: Vec<f64>,
coefficients: Vec<Vec<[[f64; 4]; 4]>>,
nx: usize,
ny: usize,
}
impl BicubicInterp {
pub fn new(
x_grid: Vec<f64>,
y_grid: Vec<f64>,
values: Vec<Vec<f64>>,
) -> InterpolateResult<Self> {
let nx = x_grid.len();
let ny = y_grid.len();
if nx < 4 {
return Err(InterpolateError::insufficient_points(
4,
nx,
"BicubicInterp x_grid",
));
}
if ny < 4 {
return Err(InterpolateError::insufficient_points(
4,
ny,
"BicubicInterp y_grid",
));
}
for i in 1..nx {
if x_grid[i] <= x_grid[i - 1] {
return Err(InterpolateError::invalid_input(format!(
"BicubicInterp: x_grid not strictly increasing at index {}: {} <= {}",
i,
x_grid[i],
x_grid[i - 1]
)));
}
}
for j in 1..ny {
if y_grid[j] <= y_grid[j - 1] {
return Err(InterpolateError::invalid_input(format!(
"BicubicInterp: y_grid not strictly increasing at index {}: {} <= {}",
j,
y_grid[j],
y_grid[j - 1]
)));
}
}
if values.len() != nx {
return Err(InterpolateError::dimension_mismatch(
nx,
values.len(),
"BicubicInterp: values row count vs nx",
));
}
for (i, row) in values.iter().enumerate() {
if row.len() != ny {
return Err(InterpolateError::dimension_mismatch(
ny,
row.len(),
&format!("BicubicInterp: values row {} length vs ny", i),
));
}
}
let coefficients = Self::build_coefficients(&x_grid, &y_grid, &values, nx, ny);
Ok(Self {
x_grid,
y_grid,
coefficients,
nx,
ny,
})
}
fn build_coefficients(
x: &[f64],
y: &[f64],
f: &[Vec<f64>],
nx: usize,
ny: usize,
) -> Vec<Vec<[[f64; 4]; 4]>> {
let fx = Self::x_derivatives(x, f, nx, ny);
let fy = Self::y_derivatives(y, f, nx, ny);
let fxy = Self::xy_derivatives(x, y, f, nx, ny);
let mut coeffs = vec![vec![[[0.0f64; 4]; 4]; ny - 1]; nx - 1];
for i in 0..nx - 1 {
let hx = x[i + 1] - x[i]; for j in 0..ny - 1 {
let hy = y[j + 1] - y[j];
let fv = [f[i][j], f[i + 1][j], f[i][j + 1], f[i + 1][j + 1]];
let fxv = [
fx[i][j] * hx,
fx[i + 1][j] * hx,
fx[i][j + 1] * hx,
fx[i + 1][j + 1] * hx,
];
let fyv = [
fy[i][j] * hy,
fy[i + 1][j] * hy,
fy[i][j + 1] * hy,
fy[i + 1][j + 1] * hy,
];
let fxyv = [
fxy[i][j] * hx * hy,
fxy[i + 1][j] * hx * hy,
fxy[i][j + 1] * hx * hy,
fxy[i + 1][j + 1] * hx * hy,
];
coeffs[i][j] = Self::hermite_coefficients(&fv, &fxv, &fyv, &fxyv);
}
}
coeffs
}
fn hermite_coefficients(
fv: &[f64; 4],
fxv: &[f64; 4],
fyv: &[f64; 4],
fxyv: &[f64; 4],
) -> [[f64; 4]; 4] {
let d: [[f64; 4]; 4] = [
[fv[0], fv[2], fyv[0], fyv[2]],
[fv[1], fv[3], fyv[1], fyv[3]],
[fxv[0], fxv[2], fxyv[0], fxyv[2]],
[fxv[1], fxv[3], fxyv[1], fxyv[3]],
];
let wt: [[f64; 4]; 4] = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[-3.0, 3.0, -2.0, -1.0],
[2.0, -2.0, 1.0, 1.0],
];
let mut tmp = [[0.0f64; 4]; 4];
for k in 0..4 {
for s in 0..4 {
for r in 0..4 {
tmp[k][s] += wt[k][r] * d[r][s];
}
}
}
let mut c = [[0.0f64; 4]; 4];
for k in 0..4 {
for l in 0..4 {
for s in 0..4 {
c[k][l] += tmp[k][s] * wt[l][s];
}
}
}
c
}
fn lagrange4_deriv(xs: &[f64; 4], fs: &[f64; 4], at: usize) -> f64 {
let xp = xs[at];
let mut result = 0.0;
for j in 0..4 {
let mut denom = 1.0;
for k in 0..4 {
if k != j {
denom *= xs[j] - xs[k];
}
}
let mut numer_sum = 0.0;
for m in 0..4 {
if m == j {
continue;
}
let mut prod = 1.0;
for k in 0..4 {
if k != j && k != m {
prod *= xp - xs[k];
}
}
numer_sum += prod;
}
result += fs[j] * numer_sum / denom;
}
result
}
fn stencil_start(i: usize, n: usize) -> usize {
if i <= 1 {
0
} else if i >= n - 2 {
n - 4
} else {
i - 1
}
}
fn x_derivatives(x: &[f64], f: &[Vec<f64>], nx: usize, ny: usize) -> Vec<Vec<f64>> {
let mut dx = vec![vec![0.0f64; ny]; nx];
for i in 0..nx {
let s = Self::stencil_start(i, nx);
let xs = [x[s], x[s + 1], x[s + 2], x[s + 3]];
let at = i - s;
for j in 0..ny {
let fs = [f[s][j], f[s + 1][j], f[s + 2][j], f[s + 3][j]];
dx[i][j] = Self::lagrange4_deriv(&xs, &fs, at);
}
}
dx
}
fn y_derivatives(y: &[f64], f: &[Vec<f64>], nx: usize, ny: usize) -> Vec<Vec<f64>> {
let mut dy = vec![vec![0.0f64; ny]; nx];
for i in 0..nx {
for j in 0..ny {
let s = Self::stencil_start(j, ny);
let ys = [y[s], y[s + 1], y[s + 2], y[s + 3]];
let fs = [f[i][s], f[i][s + 1], f[i][s + 2], f[i][s + 3]];
dy[i][j] = Self::lagrange4_deriv(&ys, &fs, j - s);
}
}
dy
}
fn xy_derivatives(x: &[f64], y: &[f64], f: &[Vec<f64>], nx: usize, ny: usize) -> Vec<Vec<f64>> {
let fy = Self::y_derivatives(y, f, nx, ny);
Self::x_derivatives(x, &fy, nx, ny)
}
pub fn interpolate(&self, x: f64, y: f64) -> InterpolateResult<f64> {
let x = x.max(self.x_grid[0]).min(self.x_grid[self.nx - 1]);
let y = y.max(self.y_grid[0]).min(self.y_grid[self.ny - 1]);
let ix = Self::find_index(&self.x_grid, x);
let iy = Self::find_index(&self.y_grid, y);
let tx = (x - self.x_grid[ix]) / (self.x_grid[ix + 1] - self.x_grid[ix]);
let ty = (y - self.y_grid[iy]) / (self.y_grid[iy + 1] - self.y_grid[iy]);
let c = &self.coefficients[ix][iy];
let mut val = 0.0;
for k in 0..4 {
for l in 0..4 {
val += c[k][l] * tx.powi(k as i32) * ty.powi(l as i32);
}
}
Ok(val)
}
pub fn interpolate_grid(
&self,
x_pts: &[f64],
y_pts: &[f64],
) -> InterpolateResult<Vec<Vec<f64>>> {
x_pts
.iter()
.map(|&x| y_pts.iter().map(|&y| self.interpolate(x, y)).collect())
.collect()
}
fn find_index(grid: &[f64], x: f64) -> usize {
let n = grid.len();
let x = x.max(grid[0]).min(grid[n - 1]);
let mut lo = 0usize;
let mut hi = n - 2;
while lo < hi {
let mid = (lo + hi + 1) / 2;
if grid[mid] <= x {
lo = mid;
} else {
hi = mid - 1;
}
}
lo
}
}
#[cfg(test)]
mod tests {
use super::*;
fn linspace(a: f64, b: f64, n: usize) -> Vec<f64> {
(0..n)
.map(|i| a + (b - a) * (i as f64) / ((n - 1) as f64))
.collect()
}
fn make_grid(nx: usize, ny: usize) -> (Vec<f64>, Vec<f64>, Vec<Vec<f64>>) {
let x = linspace(0.0, 3.0, nx);
let y = linspace(0.0, 3.0, ny);
let values: Vec<Vec<f64>> = x
.iter()
.map(|&xi| y.iter().map(|&yj| xi * yj).collect())
.collect();
(x, y, values)
}
#[test]
fn test_bicubic_exact_nodes() {
let (x, y, values) = make_grid(5, 5);
let interp = BicubicInterp::new(x.clone(), y.clone(), values).expect("valid");
for &xi in &x {
for &yj in &y {
let v = interp.interpolate(xi, yj).expect("valid");
let expected = xi * yj;
assert!(
(v - expected).abs() < 1e-10,
"At ({},{}) expected {} got {}",
xi,
yj,
expected,
v
);
}
}
}
#[test]
fn test_bicubic_linear_exact() {
let x = linspace(0.0, 4.0, 6);
let y = linspace(0.0, 4.0, 6);
let values: Vec<Vec<f64>> = x
.iter()
.map(|&xi| y.iter().map(|&yj| 2.0 * xi + 3.0 * yj + 1.0).collect())
.collect();
let interp = BicubicInterp::new(x, y, values).expect("valid");
let test_pts = [(0.7, 1.3), (1.5, 2.5), (2.1, 0.9), (3.3, 3.7)];
for (xi, yj) in test_pts {
let v = interp.interpolate(xi, yj).expect("valid");
let expected = 2.0 * xi + 3.0 * yj + 1.0;
assert!(
(v - expected).abs() < 1e-8,
"linear at ({},{}): expected {}, got {}",
xi,
yj,
expected,
v
);
}
}
#[test]
fn test_bicubic_insufficient_points() {
let x = vec![0.0, 1.0, 2.0]; let y = linspace(0.0, 2.0, 5);
let values: Vec<Vec<f64>> = x
.iter()
.map(|&xi| y.iter().map(|&yj| xi + yj).collect())
.collect();
assert!(BicubicInterp::new(x, y, values).is_err());
}
#[test]
fn test_bicubic_cubic_polynomial() {
let x = linspace(0.0, 2.0, 8);
let y = linspace(0.0, 2.0, 8);
let values: Vec<Vec<f64>> = x
.iter()
.map(|&xi| y.iter().map(|_yj| xi * xi * xi).collect())
.collect();
let interp = BicubicInterp::new(x, y, values).expect("valid");
for &xi in &[0.25, 0.75, 1.25, 1.75f64] {
let v = interp.interpolate(xi, 1.0).expect("valid");
let expected = xi * xi * xi;
assert!(
(v - expected).abs() < 1e-4,
"x^3 at ({},1): expected {}, got {}",
xi,
expected,
v
);
}
}
#[test]
fn test_bicubic_clamping() {
let (x, y, values) = make_grid(5, 5);
let interp = BicubicInterp::new(x, y, values).expect("valid");
let _v = interp.interpolate(-1.0, -1.0).expect("valid");
let _v2 = interp.interpolate(10.0, 10.0).expect("valid");
}
}