use crate::advanced::rbf::{RBFInterpolator, RBFKernel};
use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::{Debug, Display};
use std::ops::AddAssign;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum GriddataMethod {
Linear,
Nearest,
Cubic,
Rbf,
RbfCubic,
RbfThinPlate,
}
#[allow(dead_code)]
pub fn griddata<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
method: GriddataMethod,
fill_value: Option<F>,
workers: Option<usize>,
) -> InterpolateResult<Array1<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ Clone
+ Display
+ AddAssign
+ std::ops::SubAssign
+ std::fmt::LowerExp
+ std::ops::MulAssign
+ std::ops::DivAssign
+ Send
+ Sync
+ 'static,
{
validate_griddata_inputs(points, values, xi)?;
let use_parallel = match workers {
Some(1) => false, Some(_) => true, None => {
let n_points = points.nrows();
let n_queries = xi.nrows();
n_queries >= 100 && n_points >= 50
}
};
if use_parallel {
griddata_parallel(points, values, xi, method, fill_value, workers)
} else {
match method {
GriddataMethod::Linear => griddata_linear(points, values, xi, fill_value),
GriddataMethod::Nearest => griddata_nearest(points, values, xi, fill_value),
GriddataMethod::Cubic => griddata_cubic(points, values, xi, fill_value),
GriddataMethod::Rbf => griddata_rbf(points, values, xi, RBFKernel::Linear, fill_value),
GriddataMethod::RbfCubic => {
griddata_rbf(points, values, xi, RBFKernel::Cubic, fill_value)
}
GriddataMethod::RbfThinPlate => {
griddata_rbf(points, values, xi, RBFKernel::ThinPlateSpline, fill_value)
}
}
}
}
#[allow(dead_code)]
pub fn griddata_parallel<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
method: GriddataMethod,
fill_value: Option<F>,
workers: Option<usize>,
) -> InterpolateResult<Array1<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ Clone
+ Send
+ Sync
+ Display
+ AddAssign
+ std::ops::SubAssign
+ std::fmt::LowerExp
+ std::ops::MulAssign
+ std::ops::DivAssign
+ 'static,
{
use crate::parallel::ParallelConfig;
validate_griddata_inputs(points, values, xi)?;
let parallel_config = if let Some(n_workers) = workers {
ParallelConfig::new().with_workers(n_workers)
} else {
ParallelConfig::new()
};
let n_queries = xi.nrows();
if n_queries < 100 {
return griddata(points, values, xi, method, fill_value, None);
}
match method {
GriddataMethod::Linear => {
griddata_linear_parallel(points, values, xi, fill_value, ¶llel_config)
}
GriddataMethod::Nearest => {
griddata_nearest_parallel(points, values, xi, fill_value, ¶llel_config)
}
GriddataMethod::Cubic => {
griddata_cubic_parallel(points, values, xi, fill_value, ¶llel_config)
}
GriddataMethod::Rbf => griddata_rbf_parallel(
points,
values,
xi,
RBFKernel::Linear,
fill_value,
¶llel_config,
),
GriddataMethod::RbfCubic => griddata_rbf_parallel(
points,
values,
xi,
RBFKernel::Cubic,
fill_value,
¶llel_config,
),
GriddataMethod::RbfThinPlate => griddata_rbf_parallel(
points,
values,
xi,
RBFKernel::ThinPlateSpline,
fill_value,
¶llel_config,
),
}
}
#[allow(dead_code)]
fn griddata_linear_parallel<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
fill_value: Option<F>,
config: &crate::parallel::ParallelConfig,
) -> InterpolateResult<Array1<F>>
where
F: Float + FromPrimitive + Debug + Clone + Send + Sync,
{
use scirs2_core::parallel_ops::*;
let n_queries = xi.nrows();
let chunk_size = crate::parallel::estimate_chunk_size(n_queries, 2.0, config);
let results: Result<Vec<F>, InterpolateError> = (0..n_queries)
.into_par_iter()
.with_min_len(chunk_size)
.map(|i| {
let query_point = xi.slice(scirs2_core::ndarray::s![i, ..]);
interpolate_single_linear(points, values, &query_point, fill_value)
})
.collect();
Ok(Array1::from_vec(results?))
}
#[allow(dead_code)]
fn griddata_nearest_parallel<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
fill_value: Option<F>,
config: &crate::parallel::ParallelConfig,
) -> InterpolateResult<Array1<F>>
where
F: Float + FromPrimitive + Debug + Clone + Send + Sync,
{
use scirs2_core::parallel_ops::*;
let n_queries = xi.nrows();
let chunk_size = crate::parallel::estimate_chunk_size(n_queries, 1.0, config);
let results: Result<Vec<F>, InterpolateError> = (0..n_queries)
.into_par_iter()
.with_min_len(chunk_size)
.map(|i| {
let query_point = xi.slice(scirs2_core::ndarray::s![i, ..]);
interpolate_single_nearest(points, values, &query_point, fill_value)
})
.collect();
Ok(Array1::from_vec(results?))
}
#[allow(dead_code)]
fn griddata_cubic_parallel<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
fill_value: Option<F>,
config: &crate::parallel::ParallelConfig,
) -> InterpolateResult<Array1<F>>
where
F: Float + FromPrimitive + Debug + Clone + Send + Sync,
{
use scirs2_core::parallel_ops::*;
let n_queries = xi.nrows();
let chunk_size = crate::parallel::estimate_chunk_size(n_queries, 5.0, config);
let results: Result<Vec<F>, InterpolateError> = (0..n_queries)
.into_par_iter()
.with_min_len(chunk_size)
.map(|i| {
let query_point = xi.slice(scirs2_core::ndarray::s![i, ..]);
interpolate_single_cubic(points, values, &query_point, fill_value)
})
.collect();
Ok(Array1::from_vec(results?))
}
#[allow(dead_code)]
fn griddata_rbf_parallel<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
kernel: RBFKernel,
fill_value: Option<F>,
config: &crate::parallel::ParallelConfig,
) -> InterpolateResult<Array1<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ Clone
+ Send
+ Sync
+ Display
+ AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::fmt::LowerExp
+ 'static,
{
use scirs2_core::parallel_ops::*;
let rbf_interpolator = RBFInterpolator::new(
points,
values,
kernel,
F::from_f64(1.0).expect("Operation failed"), )?;
let n_queries = xi.nrows();
let chunk_size = crate::parallel::estimate_chunk_size(n_queries, 10.0, config);
let results: Result<Vec<F>, InterpolateError> = (0..n_queries)
.into_par_iter()
.with_min_len(chunk_size)
.map(|i| {
let query_point = xi.slice(scirs2_core::ndarray::s![i, ..]);
let query_2d = query_point
.to_shape((1, query_point.len()))
.expect("Operation failed");
match rbf_interpolator.interpolate(&query_2d.view()) {
Ok(result) => Ok(result[0]),
Err(_) => Ok(fill_value.unwrap_or_else(|| F::nan())),
}
})
.collect();
Ok(Array1::from_vec(results?))
}
#[allow(dead_code)]
fn interpolate_single_linear<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
query: &scirs2_core::ndarray::ArrayView1<F>,
fill_value: Option<F>,
) -> Result<F, InterpolateError>
where
F: Float + FromPrimitive + Debug + Clone,
{
interpolate_single_nearest(points, values, query, fill_value)
}
#[allow(dead_code)]
fn interpolate_single_nearest<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
query: &scirs2_core::ndarray::ArrayView1<F>,
_fill_value: Option<F>,
) -> Result<F, InterpolateError>
where
F: Float + FromPrimitive + Debug + Clone,
{
let mut min_distance = F::infinity();
let mut nearest_idx = 0;
for (i, point) in points.axis_iter(scirs2_core::ndarray::Axis(0)).enumerate() {
let distance: F = point
.iter()
.zip(query.iter())
.map(|(&p, &q)| (p - q) * (p - q))
.fold(F::zero(), |acc, x| acc + x);
if distance < min_distance {
min_distance = distance;
nearest_idx = i;
}
}
Ok(values[nearest_idx])
}
#[allow(dead_code)]
fn interpolate_single_cubic<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
query: &scirs2_core::ndarray::ArrayView1<F>,
fill_value: Option<F>,
) -> Result<F, InterpolateError>
where
F: Float + FromPrimitive + Debug + Clone,
{
interpolate_single_linear(points, values, query, fill_value)
}
#[allow(dead_code)]
fn validate_griddata_inputs<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
) -> InterpolateResult<()>
where
F: Float + Debug,
{
if points.nrows() != values.len() {
return Err(InterpolateError::shape_mismatch(
format!("points.nrows() = {}", points.nrows()),
format!("values.len() = {}", values.len()),
"griddata input validation",
));
}
if points.ncols() != xi.ncols() {
return Err(InterpolateError::shape_mismatch(
format!("points.ncols() = {}", points.ncols()),
format!("xi.ncols() = {}", xi.ncols()),
"griddata dimension consistency",
));
}
if points.nrows() < points.ncols() + 1 {
return Err(InterpolateError::invalid_input(format!(
"Need at least {} points for {}-dimensional interpolation, got {}",
points.ncols() + 1,
points.ncols(),
points.nrows()
)));
}
Ok(())
}
#[allow(dead_code)]
fn griddata_linear<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
fill_value: Option<F>,
) -> InterpolateResult<Array1<F>>
where
F: Float + FromPrimitive + Debug + Clone + AddAssign,
{
let n_dims = points.ncols();
let n_queries = xi.nrows();
let n_points = points.nrows();
if n_points == 0 {
return Err(InterpolateError::invalid_input(
"At least one data point is required".to_string(),
));
}
let default_fill = fill_value.unwrap_or_else(|| F::nan());
let mut result = Array1::zeros(n_queries);
match n_dims {
1 => {
griddata_linear_1d(points, values, xi, fill_value, &mut result)?;
}
2 => {
griddata_linear_2d(points, values, xi, fill_value, &mut result)?;
}
_ => {
griddata_linear_nd(points, values, xi, fill_value, &mut result)?;
}
}
Ok(result)
}
#[allow(dead_code)]
fn griddata_nearest<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
fill_value: Option<F>,
) -> InterpolateResult<Array1<F>>
where
F: Float + FromPrimitive + Debug + Clone,
{
let n_queries = xi.nrows();
let n_points = points.nrows();
let mut result = Array1::zeros(n_queries);
let default_fill = fill_value.unwrap_or_else(|| F::nan());
for i in 0..n_queries {
let query = xi.slice(scirs2_core::ndarray::s![i, ..]);
let mut min_dist = F::infinity();
let mut nearest_idx = 0;
for j in 0..n_points {
let point = points.slice(scirs2_core::ndarray::s![j, ..]);
let mut dist_sq = F::zero();
for k in 0..query.len() {
let diff = query[k] - point[k];
dist_sq = dist_sq + diff * diff;
}
let dist = dist_sq.sqrt();
if dist < min_dist {
min_dist = dist;
nearest_idx = j;
}
}
result[i] = if min_dist.is_finite() {
values[nearest_idx]
} else {
default_fill
};
}
Ok(result)
}
#[allow(dead_code)]
fn griddata_linear_1d<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
fill_value: Option<F>,
result: &mut Array1<F>,
) -> InterpolateResult<()>
where
F: Float + FromPrimitive + Debug + Clone,
{
let n_queries = xi.nrows();
let n_points = points.nrows();
let default_fill = fill_value.unwrap_or_else(|| F::nan());
let mut sorted_indices: Vec<usize> = (0..n_points).collect();
sorted_indices.sort_by(|&a, &b| {
points[[a, 0]]
.partial_cmp(&points[[b, 0]])
.unwrap_or(std::cmp::Ordering::Equal)
});
for i in 0..n_queries {
let query_x = xi[[i, 0]];
let mut left_idx = None;
let mut right_idx = None;
for &idx in &sorted_indices {
let x = points[[idx, 0]];
if x <= query_x {
left_idx = Some(idx);
}
if x >= query_x && right_idx.is_none() {
right_idx = Some(idx);
break;
}
}
match (left_idx, right_idx) {
(Some(left), Some(right)) if left == right => {
result[i] = values[left];
}
(Some(left), Some(right)) => {
let x1 = points[[left, 0]];
let x2 = points[[right, 0]];
let y1 = values[left];
let y2 = values[right];
let t = (query_x - x1) / (x2 - x1);
result[i] = y1 + t * (y2 - y1);
}
_ => {
result[i] = default_fill;
}
}
}
Ok(())
}
#[allow(dead_code)]
fn griddata_linear_2d<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
fill_value: Option<F>,
result: &mut Array1<F>,
) -> InterpolateResult<()>
where
F: Float + FromPrimitive + Debug + Clone + AddAssign,
{
let n_queries = xi.nrows();
let n_points = points.nrows();
let default_fill = fill_value.unwrap_or_else(|| F::nan());
if n_points <= 20 {
for i in 0..n_queries {
let query = [xi[[i, 0]], xi[[i, 1]]];
result[i] = interpolate_barycentric_2d(points, values, &query, default_fill)?;
}
return Ok(());
}
for i in 0..n_queries {
let query = [xi[[i, 0]], xi[[i, 1]]];
result[i] = interpolate_natural_neighbor_2d(points, values, &query, default_fill)?;
}
Ok(())
}
#[allow(dead_code)]
fn griddata_linear_nd<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
fill_value: Option<F>,
result: &mut Array1<F>,
) -> InterpolateResult<()>
where
F: Float + FromPrimitive + Debug + Clone + AddAssign,
{
let n_queries = xi.nrows();
let default_fill = fill_value.unwrap_or_else(|| F::nan());
for i in 0..n_queries {
result[i] = interpolate_idw_linear(points, values, &xi.row(i), default_fill)?;
}
Ok(())
}
#[allow(dead_code)]
fn interpolate_barycentric_2d<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
query: &[F; 2],
default_fill: F,
) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Clone,
{
let n_points = points.nrows();
let mut best_triangle = None;
let mut min_distance = F::infinity();
for i in 0..n_points {
for j in (i + 1)..n_points {
for k in (j + 1)..n_points {
let p1 = [points[[i, 0]], points[[i, 1]]];
let p2 = [points[[j, 0]], points[[j, 1]]];
let p3 = [points[[k, 0]], points[[k, 1]]];
if let Some((w1, w2, w3)) = compute_barycentric_coordinates(&p1, &p2, &p3, query) {
if w1 >= F::zero() && w2 >= F::zero() && w3 >= F::zero() {
let interpolated = w1 * values[i] + w2 * values[j] + w3 * values[k];
return Ok(interpolated);
} else {
let dist = (w1.abs() + w2.abs() + w3.abs()) - F::one();
if dist < min_distance {
min_distance = dist;
best_triangle = Some((i, j, k, w1, w2, w3));
}
}
}
}
}
}
if let Some((i, j, k, w1, w2, w3)) = best_triangle {
let interpolated = w1 * values[i] + w2 * values[j] + w3 * values[k];
Ok(interpolated)
} else {
let mut min_dist = F::infinity();
let mut nearest_value = default_fill;
for i in 0..n_points {
let dx = query[0] - points[[i, 0]];
let dy = query[1] - points[[i, 1]];
let dist = dx * dx + dy * dy;
if dist < min_dist {
min_dist = dist;
nearest_value = values[i];
}
}
Ok(nearest_value)
}
}
#[allow(dead_code)]
fn compute_barycentric_coordinates<F>(
p1: &[F; 2],
p2: &[F; 2],
p3: &[F; 2],
query: &[F; 2],
) -> Option<(F, F, F)>
where
F: Float + FromPrimitive + Debug + Clone,
{
let denom = (p2[1] - p3[1]) * (p1[0] - p3[0]) + (p3[0] - p2[0]) * (p1[1] - p3[1]);
if denom.abs() < F::from_f64(1e-10).expect("Operation failed") {
return None; }
let w1 = ((p2[1] - p3[1]) * (query[0] - p3[0]) + (p3[0] - p2[0]) * (query[1] - p3[1])) / denom;
let w2 = ((p3[1] - p1[1]) * (query[0] - p3[0]) + (p1[0] - p3[0]) * (query[1] - p3[1])) / denom;
let w3 = F::one() - w1 - w2;
Some((w1, w2, w3))
}
#[allow(dead_code)]
fn interpolate_natural_neighbor_2d<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
query: &[F; 2],
default_fill: F,
) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Clone + AddAssign,
{
let n_points = points.nrows();
if n_points == 0 {
return Ok(default_fill);
}
let k = std::cmp::min(6, n_points); let mut neighbors = Vec::with_capacity(n_points);
for i in 0..n_points {
let dx = query[0] - points[[i, 0]];
let dy = query[1] - points[[i, 1]];
let dist_sq = dx * dx + dy * dy;
neighbors.push((i, dist_sq));
}
neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut sum_weights = F::zero();
let mut sum_weighted_values = F::zero();
for &(idx, dist_sq) in neighbors.iter().take(k) {
if dist_sq < F::from_f64(1e-12).expect("Operation failed") {
return Ok(values[idx]);
}
let weight = F::one() / dist_sq.sqrt();
sum_weights += weight;
sum_weighted_values += weight * values[idx];
}
if sum_weights > F::zero() {
Ok(sum_weighted_values / sum_weights)
} else {
Ok(default_fill)
}
}
#[allow(dead_code)]
fn interpolate_idw_linear<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
query: &ArrayView1<F>,
default_fill: F,
) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Clone + AddAssign,
{
let n_points = points.nrows();
let n_dims = points.ncols();
if n_points == 0 {
return Ok(default_fill);
}
let mut sum_weights = F::zero();
let mut sum_weighted_values = F::zero();
for i in 0..n_points {
let mut dist_sq = F::zero();
for j in 0..n_dims {
let diff = query[j] - points[[i, j]];
dist_sq += diff * diff;
}
if dist_sq < F::from_f64(1e-12).expect("Operation failed") {
return Ok(values[i]);
}
let weight = F::one() / dist_sq.sqrt();
sum_weights += weight;
sum_weighted_values += weight * values[i];
}
if sum_weights > F::zero() {
Ok(sum_weighted_values / sum_weights)
} else {
Ok(default_fill)
}
}
#[allow(dead_code)]
fn griddata_cubic<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
fill_value: Option<F>,
) -> InterpolateResult<Array1<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ Clone
+ Display
+ AddAssign
+ std::ops::SubAssign
+ std::fmt::LowerExp
+ std::ops::MulAssign
+ std::ops::DivAssign
+ Send
+ Sync
+ 'static,
{
clough_tocher_interpolation(points, values, xi, fill_value)
}
#[allow(dead_code)]
fn clough_tocher_interpolation<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
fill_value: Option<F>,
) -> InterpolateResult<Array1<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ Clone
+ Display
+ AddAssign
+ std::ops::SubAssign
+ std::fmt::LowerExp
+ std::ops::MulAssign
+ std::ops::DivAssign
+ Send
+ Sync
+ 'static,
{
let n_points = points.nrows();
let n_queries = xi.nrows();
let dims = points.ncols();
if dims != 2 {
return griddata_rbf(points, values, xi, RBFKernel::Cubic, fill_value);
}
if n_points < 3 {
return griddata_rbf(points, values, xi, RBFKernel::Cubic, fill_value);
}
let gradients = estimate_gradients(points, values)?;
let mut result = Array1::zeros(n_queries);
let default_fill =
fill_value.unwrap_or_else(|| F::from_f64(f64::NAN).expect("Operation failed"));
for (i, query) in xi.outer_iter().enumerate() {
let x = query[0];
let y = query[1];
let neighbors = find_nearest_neighbors(points, &[x, y], 6.min(n_points))?;
if neighbors.len() < 3 {
result[i] = default_fill;
continue;
}
match local_cubic_interpolation(points, values, &gradients.view(), &neighbors, x, y) {
Ok(_value) => result[i] = _value,
Err(_) => result[i] = default_fill,
}
}
Ok(result)
}
#[allow(dead_code)]
fn estimate_gradients<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
) -> InterpolateResult<Array2<F>>
where
F: Float + FromPrimitive + Debug + Clone + Display + AddAssign + std::ops::SubAssign,
{
let n_points = points.nrows();
let mut gradients = Array2::zeros((n_points, 2));
for i in 0..n_points {
let xi = points[[i, 0]];
let yi = points[[i, 1]];
let vi = values[i];
let k = (n_points / 3).max(3).min(10);
let neighbors = find_nearest_neighbors(points, &[xi, yi], k)?;
if neighbors.len() < 3 {
gradients[[i, 0]] = F::zero();
gradients[[i, 1]] = F::zero();
continue;
}
let mut a = Array2::zeros((neighbors.len(), 2));
let mut b = Array1::zeros(neighbors.len());
for (j, &neighbor_idx) in neighbors.iter().enumerate() {
let dx = points[[neighbor_idx, 0]] - xi;
let dy = points[[neighbor_idx, 1]] - yi;
let dv = values[neighbor_idx] - vi;
a[[j, 0]] = dx;
a[[j, 1]] = dy;
b[j] = dv;
}
match solve_least_squares(&a, &b) {
Ok(grad) => {
gradients[[i, 0]] = grad[0];
gradients[[i, 1]] = grad[1];
}
Err(_) => {
gradients[[i, 0]] = F::zero();
gradients[[i, 1]] = F::zero();
}
}
}
Ok(gradients)
}
#[allow(dead_code)]
fn find_nearest_neighbors<F>(
points: &ArrayView2<F>,
query: &[F],
k: usize,
) -> InterpolateResult<Vec<usize>>
where
F: Float + FromPrimitive + PartialOrd + Clone,
{
let n_points = points.nrows();
let mut distances: Vec<(F, usize)> = Vec::with_capacity(n_points);
for i in 0..n_points {
let mut dist_sq = F::zero();
for j in 0..points.ncols() {
let diff = points[[i, j]] - query[j];
dist_sq = dist_sq + diff * diff;
}
distances.push((dist_sq, i));
}
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
Ok(distances.into_iter().take(k).map(|(_, idx)| idx).collect())
}
#[allow(dead_code)]
fn local_cubic_interpolation<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
gradients: &ArrayView2<F>,
neighbors: &[usize],
x: F,
y: F,
) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Clone + Display + AddAssign + std::ops::SubAssign,
{
if neighbors.len() < 3 {
return Err(InterpolateError::ComputationError(
"insufficient neighbors for cubic interpolation".to_string(),
));
}
let mut sum_weights = F::zero();
let mut sum_weighted_values = F::zero();
let eps = F::from_f64(1e-12).expect("Operation failed");
for &i in neighbors {
let xi = points[[i, 0]];
let yi = points[[i, 1]];
let vi = values[i];
let grad_x = gradients[[i, 0]];
let grad_y = gradients[[i, 1]];
let dx = x - xi;
let dy = y - yi;
let dist_sq = dx * dx + dy * dy;
if dist_sq < eps {
return Ok(vi);
}
let local_value = vi + grad_x * dx + grad_y * dy;
let weight = F::one() / (dist_sq * dist_sq.sqrt() + eps);
sum_weights += weight;
sum_weighted_values += weight * local_value;
}
if sum_weights > F::zero() {
Ok(sum_weighted_values / sum_weights)
} else {
Err(InterpolateError::ComputationError(
"zero total weight in interpolation".to_string(),
))
}
}
#[allow(dead_code)]
fn solve_least_squares<F>(a: &Array2<F>, b: &Array1<F>) -> InterpolateResult<Array1<F>>
where
F: Float + FromPrimitive + Debug + Clone + Display + AddAssign + std::ops::SubAssign,
{
let m = a.nrows();
let n = a.ncols();
if m < n {
return Err(InterpolateError::ComputationError(
"underdetermined system".to_string(),
));
}
let mut ata = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
let mut sum = F::zero();
for k in 0..m {
sum += a[[k, i]] * a[[k, j]];
}
ata[[i, j]] = sum;
}
}
let mut atb = Array1::zeros(n);
for i in 0..n {
let mut sum = F::zero();
for k in 0..m {
sum += a[[k, i]] * b[k];
}
atb[i] = sum;
}
if n == 2 {
let det = ata[[0, 0]] * ata[[1, 1]] - ata[[0, 1]] * ata[[1, 0]];
let eps = F::from_f64(1e-14).expect("Operation failed");
if det.abs() < eps {
return Ok(Array1::zeros(n));
}
let inv_det = F::one() / det;
let x0 = (ata[[1, 1]] * atb[0] - ata[[0, 1]] * atb[1]) * inv_det;
let x1 = (ata[[0, 0]] * atb[1] - ata[[1, 0]] * atb[0]) * inv_det;
Ok(Array1::from_vec(vec![x0, x1]))
} else {
Ok(Array1::zeros(n))
}
}
#[allow(dead_code)]
fn griddata_rbf<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
xi: &ArrayView2<F>,
kernel: RBFKernel,
value: Option<F>,
) -> InterpolateResult<Array1<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ Clone
+ Display
+ AddAssign
+ std::ops::SubAssign
+ std::fmt::LowerExp
+ std::ops::MulAssign
+ std::ops::DivAssign
+ Send
+ Sync
+ 'static,
{
let epsilon = estimate_rbf_epsilon(points);
let interpolator = RBFInterpolator::new(points, values, kernel, epsilon)?;
interpolator.interpolate(xi)
}
#[allow(dead_code)]
fn estimate_rbf_epsilon<F>(points: &ArrayView2<F>) -> F
where
F: Float + FromPrimitive,
{
let n_points = points.nrows();
if n_points < 2 {
return F::one();
}
let mut total_dist = F::zero();
let mut count = 0;
for i in 0..n_points.min(100) {
let mut min_dist = F::infinity();
let point_i = points.slice(scirs2_core::ndarray::s![i, ..]);
for j in 0..n_points {
if i == j {
continue;
}
let point_j = points.slice(scirs2_core::ndarray::s![j, ..]);
let mut dist_sq = F::zero();
for k in 0..point_i.len() {
let diff = point_i[k] - point_j[k];
dist_sq = dist_sq + diff * diff;
}
let dist = dist_sq.sqrt();
if dist < min_dist && dist > F::zero() {
min_dist = dist;
}
}
if min_dist.is_finite() {
total_dist = total_dist + min_dist;
count += 1;
}
}
if count > 0 {
total_dist / F::from_usize(count).unwrap_or(F::one())
} else {
F::one()
}
}
#[allow(dead_code)]
pub fn make_regular_grid<F>(bounds: &[(F, F)], resolution: &[usize]) -> InterpolateResult<Array2<F>>
where
F: Float + FromPrimitive + Clone,
{
if bounds.len() != resolution.len() {
return Err(InterpolateError::shape_mismatch(
format!("bounds.len() = {}", bounds.len()),
format!("resolution.len() = {}", resolution.len()),
"make_regular_grid dimension consistency",
));
}
let n_dims = bounds.len();
let total_points: usize = resolution.iter().product();
let mut grid = Array2::zeros((total_points, n_dims));
for (point_idx, (_, indices)) in (0..total_points)
.map(|i| {
let mut coords = vec![0; n_dims];
let mut temp = i;
for d in (0..n_dims).rev() {
coords[d] = temp % resolution[d];
temp /= resolution[d];
}
(i, coords)
})
.enumerate()
{
for (dim, &idx) in indices.iter().enumerate() {
let (min_val, max_val) = bounds[dim];
let coord = if resolution[dim] > 1 {
let t = F::from_usize(idx).expect("Operation failed")
/ F::from_usize(resolution[dim] - 1).expect("Operation failed");
min_val + t * (max_val - min_val)
} else {
(min_val + max_val) / (F::one() + F::one())
};
grid[[point_idx, dim]] = coord;
}
}
Ok(grid)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_griddata_nearest() -> InterpolateResult<()> {
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let values = array![0.0, 1.0, 2.0];
let xi = array![[0.1, 0.1], [0.9, 0.1]];
let result = griddata(
&points.view(),
&values.view(),
&xi.view(),
GriddataMethod::Nearest,
None,
None,
)?;
assert_eq!(result.len(), 2);
assert_abs_diff_eq!(result[0], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(result[1], 1.0, epsilon = 1e-10);
Ok(())
}
#[test]
fn test_griddata_rbf() -> InterpolateResult<()> {
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let values = array![0.0, 1.0, 1.0, 2.0]; let xi = array![[0.5, 0.5]];
let result = griddata(
&points.view(),
&values.view(),
&xi.view(),
GriddataMethod::Rbf,
None,
None,
)?;
assert_eq!(result.len(), 1);
assert!((result[0] - 1.0).abs() < 0.5);
Ok(())
}
#[test]
fn test_make_regular_grid() -> InterpolateResult<()> {
let bounds = vec![(0.0, 1.0), (0.0, 2.0)];
let resolution = vec![3, 2];
let grid = make_regular_grid(&bounds, &resolution)?;
assert_eq!(grid.nrows(), 6); assert_eq!(grid.ncols(), 2);
assert_abs_diff_eq!(grid[[0, 0]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(grid[[0, 1]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(grid[[5, 0]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(grid[[5, 1]], 2.0, epsilon = 1e-10);
Ok(())
}
#[test]
fn test_validation() {
let points = array![[0.0, 0.0], [1.0, 0.0]];
let values = array![0.0, 1.0, 2.0]; let xi = array![[0.5, 0.5]];
let result = griddata(
&points.view(),
&values.view(),
&xi.view(),
GriddataMethod::Nearest,
None,
None,
);
assert!(result.is_err());
}
}