use super::Interpolator;
use crate::error::Result;
use crate::interpolation::common;
pub struct BicubicInterpolator;
impl Interpolator for BicubicInterpolator {
fn interpolate(&self, data: &[f32], shape: &[usize], indices: &[f64]) -> Result<f32> {
if indices.len() != shape.len() {
return Err(crate::error::RossbyError::Interpolation {
message: format!(
"Dimension mismatch: indices has {} dimensions but shape has {} dimensions",
indices.len(),
shape.len()
),
});
}
if indices.is_empty() {
if data.len() != 1 {
return Err(crate::error::RossbyError::Interpolation {
message: "Expected scalar data (length 1) for 0D interpolation".to_string(),
});
}
return Ok(data[0]);
}
for (i, &size) in shape.iter().enumerate() {
if size < 4 {
return Err(crate::error::RossbyError::Interpolation {
message: format!(
"Dimension {} has size {}, but bicubic interpolation requires at least 4 points per dimension. Consider using bilinear interpolation instead.",
i, size
),
});
}
}
interpolate_nd(data, shape, indices, 0)
}
fn name(&self) -> &str {
"bicubic"
}
}
fn interpolate_nd(data: &[f32], shape: &[usize], indices: &[f64], dim: usize) -> Result<f32> {
if dim == indices.len() {
let mut idx_array = Vec::with_capacity(indices.len());
for &index in indices {
idx_array.push(index.floor() as usize);
}
let flat_idx = common::flat_index(&idx_array, shape)?;
if flat_idx >= data.len() {
return Err(crate::error::RossbyError::Interpolation {
message: format!(
"Index out of bounds: calculated index {} exceeds data length {}",
flat_idx,
data.len()
),
});
}
return Ok(data[flat_idx]);
}
let idx = common::clamp_index(indices[dim], shape[dim]);
let i = idx.floor() as usize;
let frac = idx - i as f64;
let mut positions = [0; 4];
positions[0] = if i > 0 { i - 1 } else { 0 };
positions[1] = i;
positions[2] = (i + 1).min(shape[dim] - 1);
positions[3] = (i + 2).min(shape[dim] - 1);
let mut new_indices = indices.to_vec();
let mut values = [0.0; 4];
for j in 0..4 {
new_indices[dim] = positions[j] as f64;
values[j] = interpolate_nd(data, shape, &new_indices, dim + 1)?;
}
let weights = common::cubic_weights(frac);
let mut result = 0.0;
for j in 0..4 {
result += values[j] as f64 * weights[j];
}
Ok(result as f32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bicubic_interpolation_1d() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = vec![5];
let interpolator = BicubicInterpolator;
assert_eq!(
interpolator.interpolate(&data, &shape, &[1.0]).unwrap(),
2.0
);
assert_eq!(
interpolator.interpolate(&data, &shape, &[2.0]).unwrap(),
3.0
);
assert_eq!(
interpolator.interpolate(&data, &shape, &[3.0]).unwrap(),
4.0
);
assert!((interpolator.interpolate(&data, &shape, &[1.5]).unwrap() - 2.5).abs() < 1e-5);
assert_eq!(
interpolator.interpolate(&data, &shape, &[0.0]).unwrap(),
1.0
);
assert_eq!(
interpolator.interpolate(&data, &shape, &[4.0]).unwrap(),
5.0
);
}
#[test]
fn test_bicubic_interpolation_2d() {
let data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
];
let shape = vec![4, 4];
let interpolator = BicubicInterpolator;
assert_eq!(
interpolator
.interpolate(&data, &shape, &[1.0, 1.0])
.unwrap(),
6.0
);
assert_eq!(
interpolator
.interpolate(&data, &shape, &[2.0, 2.0])
.unwrap(),
11.0
);
let center_value = interpolator
.interpolate(&data, &shape, &[1.5, 1.5])
.unwrap();
assert!((center_value - 8.5).abs() < 1e-5);
let v1 = interpolator
.interpolate(&data, &shape, &[1.5, 1.0])
.unwrap();
let v2 = interpolator
.interpolate(&data, &shape, &[1.5, 1.25])
.unwrap();
let v3 = interpolator
.interpolate(&data, &shape, &[1.5, 1.5])
.unwrap();
let v4 = interpolator
.interpolate(&data, &shape, &[1.5, 1.75])
.unwrap();
let v5 = interpolator
.interpolate(&data, &shape, &[1.5, 2.0])
.unwrap();
assert!(v1 < v2);
assert!(v2 < v3);
assert!(v3 < v4);
assert!(v4 < v5);
}
#[test]
fn test_bicubic_error_cases() {
let data = vec![1.0, 2.0, 3.0];
let shape = vec![3];
let interpolator = BicubicInterpolator;
let result = interpolator.interpolate(&data, &shape, &[1.0]);
assert!(result.is_err());
let data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
];
let shape = vec![4, 4];
let result = interpolator.interpolate(&data, &shape, &[1.0]);
assert!(result.is_err());
let result = interpolator.interpolate(&data, &shape, &[1.0, 1.0, 1.0]);
assert!(result.is_err());
}
}