use scirs2_core::ndarray::{Array1, Array2, ArrayD, IxDyn};
use crate::error::{InterpolateError, InterpolateResult};
#[derive(Debug, Clone)]
pub struct GridSpec {
pub axes: Vec<Array1<f64>>,
}
impl GridSpec {
pub fn new(axes: Vec<Array1<f64>>) -> InterpolateResult<Self> {
for (dim, ax) in axes.iter().enumerate() {
if ax.is_empty() {
return Err(InterpolateError::InvalidInput {
message: format!("axis {dim} is empty"),
});
}
for i in 1..ax.len() {
if ax[i] <= ax[i - 1] {
return Err(InterpolateError::InvalidInput {
message: format!("axis {dim} is not strictly increasing"),
});
}
}
}
Ok(Self { axes })
}
pub fn uniform(dim: usize, specs: &[(f64, f64, usize)]) -> Self {
assert_eq!(specs.len(), dim, "specs length must equal dim");
let axes: Vec<Array1<f64>> = specs
.iter()
.map(|&(lo, hi, n)| {
let n_pts = n.max(1);
if n_pts == 1 {
Array1::from_vec(vec![lo])
} else {
let step = (hi - lo) / (n_pts as f64 - 1.0);
Array1::from_iter((0..n_pts).map(|i| lo + i as f64 * step))
}
})
.collect();
Self { axes }
}
pub fn ndim(&self) -> usize {
self.axes.len()
}
pub fn shape(&self) -> Vec<usize> {
self.axes.iter().map(|ax| ax.len()).collect()
}
pub fn n_cells(&self) -> usize {
self.axes.iter().map(|ax| ax.len()).product()
}
pub fn nearest_index(&self, dim: usize, val: f64) -> usize {
let ax = &self.axes[dim];
let n = ax.len();
let pos = {
let mut lo_idx = 0_usize;
let mut hi_idx = n;
while lo_idx < hi_idx {
let mid = lo_idx + (hi_idx - lo_idx) / 2;
if ax[mid] < val {
lo_idx = mid + 1;
} else {
hi_idx = mid;
}
}
lo_idx
};
if pos == 0 {
0
} else if pos == n {
n - 1
} else {
let lo = ax[pos - 1];
let hi = ax[pos];
if (val - lo).abs() <= (val - hi).abs() {
pos - 1
} else {
pos
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Aggregator {
Mean,
Median,
Max,
Min,
Count,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ResampleStrategy {
Rasterize(Aggregator),
Conservative,
}
pub fn resample_scattered_to_grid(
points: &Array2<f64>,
values: &Array1<f64>,
grid: &GridSpec,
strategy: ResampleStrategy,
) -> InterpolateResult<ArrayD<f64>> {
let n = points.nrows();
let d = points.ncols();
if d != grid.ndim() {
return Err(InterpolateError::DimensionMismatch(format!(
"points has {d} columns but grid has {} axes",
grid.ndim()
)));
}
if values.len() != n {
return Err(InterpolateError::DimensionMismatch(format!(
"points has {n} rows but values has {} elements",
values.len()
)));
}
let shape = grid.shape();
let total_cells = grid.n_cells();
let mut buckets: Vec<Vec<f64>> = vec![Vec::new(); total_cells];
for row in 0..n {
let mut multi_idx = vec![0_usize; d];
for dim in 0..d {
let coord = points[[row, dim]];
multi_idx[dim] = grid.nearest_index(dim, coord);
}
let mut flat_idx = 0_usize;
for dim in 0..d {
flat_idx += multi_idx[dim] * stride_for(&shape, dim);
}
if flat_idx < total_cells {
buckets[flat_idx].push(values[row]);
}
}
let raw_data: Vec<f64> = buckets
.into_iter()
.map(|mut bucket| aggregate(bucket.as_mut_slice(), strategy))
.collect();
let out = ArrayD::from_shape_vec(IxDyn(&shape), raw_data).map_err(|e| {
InterpolateError::ShapeError(format!("failed to construct output array: {e}"))
})?;
Ok(out)
}
fn stride_for(shape: &[usize], dim: usize) -> usize {
shape[dim + 1..].iter().product()
}
fn aggregate(bucket: &mut [f64], strategy: ResampleStrategy) -> f64 {
if bucket.is_empty() {
return match strategy {
ResampleStrategy::Rasterize(Aggregator::Count) | ResampleStrategy::Conservative => 0.0,
_ => f64::NAN,
};
}
match strategy {
ResampleStrategy::Rasterize(Aggregator::Mean) | ResampleStrategy::Conservative => {
bucket.iter().sum::<f64>() / bucket.len() as f64
}
ResampleStrategy::Rasterize(Aggregator::Median) => {
bucket.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mid = bucket.len() / 2;
if bucket.len() % 2 == 0 {
(bucket[mid - 1] + bucket[mid]) * 0.5
} else {
bucket[mid]
}
}
ResampleStrategy::Rasterize(Aggregator::Max) => {
bucket.iter().copied().fold(f64::NEG_INFINITY, f64::max)
}
ResampleStrategy::Rasterize(Aggregator::Min) => {
bucket.iter().copied().fold(f64::INFINITY, f64::min)
}
ResampleStrategy::Rasterize(Aggregator::Count) => bucket.len() as f64,
}
}