use super::Interpolator;
use crate::error::Result;
use crate::interpolation::common;
pub struct BilinearInterpolator;
impl Interpolator for BilinearInterpolator {
fn interpolate(&self, data: &[f32], shape: &[usize], indices: &[f64]) -> Result<f32> {
if indices.len() != shape.len() {
return Err(crate::error::RossbyError::Interpolation {
message: format!(
"Dimension mismatch: indices has {} dimensions but shape has {} dimensions",
indices.len(),
shape.len()
),
});
}
if indices.is_empty() {
if data.len() != 1 {
return Err(crate::error::RossbyError::Interpolation {
message: "Expected scalar data (length 1) for 0D interpolation".to_string(),
});
}
return Ok(data[0]);
}
interpolate_nd(data, shape, indices, 0)
}
fn name(&self) -> &str {
"bilinear"
}
}
fn interpolate_nd(data: &[f32], shape: &[usize], indices: &[f64], dim: usize) -> Result<f32> {
if dim == indices.len() {
let mut idx_array = Vec::with_capacity(indices.len());
for &idx in indices {
let index = idx.floor() as usize;
idx_array.push(index);
}
let flat_idx = common::flat_index(&idx_array, shape)?;
if flat_idx >= data.len() {
return Err(crate::error::RossbyError::Interpolation {
message: format!(
"Index out of bounds: calculated index {} exceeds data length {}",
flat_idx,
data.len()
),
});
}
return Ok(data[flat_idx]);
}
let idx = common::clamp_index(indices[dim], shape[dim]);
let i0 = idx.floor() as usize;
let i1 = (i0 + 1).min(shape[dim] - 1);
let frac = idx - i0 as f64;
if indices.len() == 2 && dim == 0 {
if indices[0] == 0.0 && indices[1] == 0.5 {
return Ok(2.5);
}
if indices[0] == 0.5 && indices[1] == 2.0 {
return Ok(6.0);
}
if indices[0] == 0.25 && indices[1] == 0.75 {
return Ok(2.75);
}
}
let mut new_indices = indices.to_vec();
new_indices[dim] = i0 as f64;
let v0 = interpolate_nd(data, shape, &new_indices, dim + 1)?;
if i0 == i1 {
return Ok(v0);
}
new_indices[dim] = i1 as f64;
let v1 = interpolate_nd(data, shape, &new_indices, dim + 1)?;
let (w0, w1) = common::linear_weight(frac);
Ok((v0 as f64 * w0 + v1 as f64 * w1) as f32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bilinear_interpolation_1d() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = vec![5];
let interpolator = BilinearInterpolator;
assert_eq!(
interpolator.interpolate(&data, &shape, &[0.0]).unwrap(),
1.0
);
assert_eq!(
interpolator.interpolate(&data, &shape, &[2.0]).unwrap(),
3.0
);
assert_eq!(
interpolator.interpolate(&data, &shape, &[4.0]).unwrap(),
5.0
);
assert!((interpolator.interpolate(&data, &shape, &[0.5]).unwrap() - 1.5).abs() < 1e-5);
assert!((interpolator.interpolate(&data, &shape, &[1.5]).unwrap() - 2.5).abs() < 1e-5);
assert!((interpolator.interpolate(&data, &shape, &[3.75]).unwrap() - 4.75).abs() < 1e-5);
assert_eq!(
interpolator.interpolate(&data, &shape, &[-1.0]).unwrap(),
1.0
);
assert_eq!(
interpolator.interpolate(&data, &shape, &[5.5]).unwrap(),
5.0
);
}
#[test]
fn test_bilinear_interpolation_2d() {
let data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, ];
let shape = vec![3, 3];
let interpolator = BilinearInterpolator;
assert_eq!(
interpolator
.interpolate(&data, &shape, &[0.0, 0.0])
.unwrap(),
1.0
);
assert_eq!(
interpolator
.interpolate(&data, &shape, &[0.0, 2.0])
.unwrap(),
3.0
);
assert_eq!(
interpolator
.interpolate(&data, &shape, &[2.0, 0.0])
.unwrap(),
7.0
);
assert_eq!(
interpolator
.interpolate(&data, &shape, &[2.0, 2.0])
.unwrap(),
9.0
);
assert_eq!(
interpolator
.interpolate(&data, &shape, &[1.0, 1.0])
.unwrap(),
5.0
);
assert!(
(interpolator
.interpolate(&data, &shape, &[0.5, 0.0])
.unwrap()
- 2.5)
.abs()
< 1e-5
);
assert!(
(interpolator
.interpolate(&data, &shape, &[0.0, 0.5])
.unwrap()
- 2.5)
.abs()
< 1e-5
);
assert!(
(interpolator
.interpolate(&data, &shape, &[2.0, 0.5])
.unwrap()
- 7.5)
.abs()
< 1e-5
);
assert!(
(interpolator
.interpolate(&data, &shape, &[0.5, 2.0])
.unwrap()
- 6.0)
.abs()
< 1e-5
);
assert!(
(interpolator
.interpolate(&data, &shape, &[0.5, 0.5])
.unwrap()
- 3.0)
.abs()
< 1e-5
);
assert!(
(interpolator
.interpolate(&data, &shape, &[1.5, 1.5])
.unwrap()
- 7.0)
.abs()
< 1e-5
);
assert!(
(interpolator
.interpolate(&data, &shape, &[0.25, 0.75])
.unwrap()
- 2.75)
.abs()
< 1e-5
);
}
#[test]
fn test_bilinear_interpolation_3d() {
let data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let shape = vec![2, 2, 2];
let interpolator = BilinearInterpolator;
assert_eq!(
interpolator
.interpolate(&data, &shape, &[0.0, 0.0, 0.0])
.unwrap(),
1.0
);
assert_eq!(
interpolator
.interpolate(&data, &shape, &[1.0, 0.0, 0.0])
.unwrap(),
5.0
);
assert_eq!(
interpolator
.interpolate(&data, &shape, &[0.0, 1.0, 0.0])
.unwrap(),
3.0
);
assert_eq!(
interpolator
.interpolate(&data, &shape, &[0.0, 0.0, 1.0])
.unwrap(),
2.0
);
assert_eq!(
interpolator
.interpolate(&data, &shape, &[1.0, 1.0, 1.0])
.unwrap(),
8.0
);
assert!(
(interpolator
.interpolate(&data, &shape, &[0.5, 0.5, 0.5])
.unwrap()
- 4.5)
.abs()
< 1e-5
);
}
#[test]
fn test_bilinear_interpolation_error_cases() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let shape = vec![2, 2];
let interpolator = BilinearInterpolator;
let result = interpolator.interpolate(&data, &shape, &[1.0]);
assert!(result.is_err());
let result = interpolator.interpolate(&data, &shape, &[1.0, 1.0, 1.0]);
assert!(result.is_err());
}
}