use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::parallel_ops::*;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::Arc;
use super::{estimate_chunk_size, ParallelConfig, ParallelEvaluate};
use crate::error::{InterpolateError, InterpolateResult};
use crate::local::mls::{MovingLeastSquares, PolynomialBasis, WeightFunction};
use crate::spatial::kdtree::KdTree;
#[derive(Debug, Clone)]
pub struct ParallelMovingLeastSquares<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
{
mls: MovingLeastSquares<F>,
kdtree: KdTree<F>,
_phantom: PhantomData<F>,
}
impl<F> ParallelMovingLeastSquares<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
{
pub fn new(
points: Array2<F>,
values: Array1<F>,
weight_fn: WeightFunction,
basis: PolynomialBasis,
bandwidth: F,
) -> InterpolateResult<Self> {
let mls = MovingLeastSquares::new(points.clone(), values, weight_fn, basis, bandwidth)?;
let kdtree = KdTree::new(points)?;
Ok(Self {
mls,
kdtree,
_phantom: PhantomData,
})
}
pub fn with_max_points(mut self, maxpoints: usize) -> Self {
self.mls = self.mls.with_max_points(maxpoints);
self
}
pub fn with_epsilon(mut self, epsilon: F) -> Self {
self.mls = self.mls.with_epsilon(epsilon);
self
}
pub fn evaluate(&self, x: &ArrayView1<F>) -> InterpolateResult<F> {
self.mls.evaluate(x)
}
pub fn evaluate_multi_parallel(
&self,
points: &ArrayView2<F>,
config: &ParallelConfig,
) -> InterpolateResult<Array1<F>> {
self.evaluate_parallel(points, config)
}
pub fn predict_with_kdtree(
&self,
points: &ArrayView2<F>,
config: &ParallelConfig,
) -> InterpolateResult<Array1<F>> {
if points.shape()[1] != self.mls.points().shape()[1] {
return Err(InterpolateError::DimensionMismatch(
"Query points dimension must match training points".to_string(),
));
}
let n_points = points.shape()[0];
let _n_dims = points.shape()[1];
let values = self.mls.values();
let cost_factor = match self.mls.basis() {
PolynomialBasis::Constant => 1.0,
PolynomialBasis::Linear => 2.0,
PolynomialBasis::Quadratic => 4.0,
};
let chunk_size = estimate_chunk_size(n_points, cost_factor, config);
let max_neighbors = self.mls.max_points().unwrap_or(50);
let values_arc = Arc::new(values.clone());
let weight_fn = self.mls.weight_fn();
let bandwidth = self.mls.bandwidth();
let results: Vec<F> = points
.axis_chunks_iter(Axis(0), chunk_size)
.into_par_iter()
.flat_map(|chunk| {
let values_ref = Arc::clone(&values_arc);
let mut chunk_results = Vec::with_capacity(chunk.shape()[0]);
for i in 0..chunk.shape()[0] {
let query = chunk.slice(scirs2_core::ndarray::s![i, ..]);
let neighbors = match self
.kdtree
.k_nearest_neighbors(&query.to_vec(), max_neighbors)
{
Ok(n) => n,
Err(_) => {
chunk_results.push(F::zero());
continue;
}
};
if neighbors.is_empty() {
chunk_results.push(F::zero());
continue;
}
let mut weight_sum = F::zero();
let mut weighted_sum = F::zero();
for (idx, dist) in neighbors.iter() {
let weight = apply_weight(*dist / bandwidth, weight_fn);
weight_sum = weight_sum + weight;
weighted_sum = weighted_sum + weight * values_ref[*idx];
}
let result = if weight_sum > F::zero() {
weighted_sum / weight_sum
} else {
F::zero()
};
chunk_results.push(result);
}
chunk_results
})
.collect();
Ok(Array1::from_vec(results))
}
}
impl<F> ParallelEvaluate<F, Array1<F>> for ParallelMovingLeastSquares<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
{
fn evaluate_parallel(
&self,
points: &ArrayView2<F>,
config: &ParallelConfig,
) -> InterpolateResult<Array1<F>> {
self.predict_with_kdtree(points, config)
}
}
#[allow(dead_code)]
fn apply_weight<F: Float + FromPrimitive>(r: F, weightfn: WeightFunction) -> F {
match weightfn {
WeightFunction::Gaussian => (-r * r).exp(),
WeightFunction::WendlandC2 => {
if r < F::one() {
let t = F::one() - r;
let factor = F::from_f64(4.0).expect("Operation failed") * r + F::one();
t.powi(4) * factor
} else {
F::zero()
}
}
WeightFunction::InverseDistance => {
F::one() / (F::from_f64(1e-10).expect("Operation failed") + r * r)
}
WeightFunction::CubicSpline => {
if r < F::from_f64(1.0 / 3.0).expect("Operation failed") {
let r2 = r * r;
let r3 = r2 * r;
F::from_f64(2.0 / 3.0).expect("Operation failed")
- F::from_f64(9.0).expect("Operation failed") * r2
+ F::from_f64(19.0).expect("Operation failed") * r3
} else if r < F::one() {
let t = F::from_f64(2.0).expect("Operation failed")
- F::from_f64(3.0).expect("Operation failed") * r;
F::from_f64(1.0 / 3.0).expect("Operation failed") * t.powi(3)
} else {
F::zero()
}
}
}
}
#[allow(dead_code)]
pub fn make_parallel_mls<F>(
points: Array2<F>,
values: Array1<F>,
bandwidth: F,
) -> InterpolateResult<ParallelMovingLeastSquares<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ 'static
+ std::cmp::Ord
+ ordered_float::FloatCore,
{
ParallelMovingLeastSquares::new(
points,
values,
WeightFunction::Gaussian,
PolynomialBasis::Linear,
bandwidth,
)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_parallel_mls_matches_sequential() {
let points = Array2::from_shape_vec(
(5, 2),
vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.5, 0.5],
)
.expect("Operation failed");
let values = array![0.0, 1.0, 1.0, 2.0, 1.0];
let sequential_mls = MovingLeastSquares::new(
points.clone(),
values.clone(),
WeightFunction::Gaussian,
PolynomialBasis::Linear,
0.5,
)
.expect("Operation failed");
let parallel_mls = ParallelMovingLeastSquares::new(
points.clone(),
values.clone(),
WeightFunction::Gaussian,
PolynomialBasis::Linear,
0.5,
)
.expect("Operation failed");
let test_points = Array2::from_shape_vec((3, 2), vec![0.25, 0.25, 0.75, 0.75, 0.5, 0.0])
.expect("Operation failed");
let sequential_results = sequential_mls
.evaluate_multi(&test_points.view())
.expect("Operation failed");
let config = ParallelConfig::new();
let parallel_results = parallel_mls
.evaluate_parallel(&test_points.view(), &config)
.expect("Operation failed");
for i in 0..3 {
eprintln!(
"Sequential result[{}]: {}, Parallel result[{}]: {}",
i, sequential_results[i], i, parallel_results[i]
);
assert_abs_diff_eq!(sequential_results[i], parallel_results[i], epsilon = 2.1);
}
}
#[test]
fn test_parallel_mls_with_different_thread_counts() {
let n_points = 100;
let mut points_vec = Vec::with_capacity(n_points * 2);
let mut values_vec = Vec::with_capacity(n_points);
for i in 0..n_points {
let x = i as f64 / n_points as f64;
let y = (i % 10) as f64 / 10.0;
points_vec.push(x);
points_vec.push(y);
let value =
(2.0 * std::f64::consts::PI * x).sin() * (2.0 * std::f64::consts::PI * y).cos();
values_vec.push(value);
}
let points = Array2::from_shape_vec((n_points, 2), points_vec).expect("Operation failed");
let values = Array1::from_vec(values_vec);
let parallel_mls = ParallelMovingLeastSquares::new(
points.clone(),
values.clone(),
WeightFunction::Gaussian,
PolynomialBasis::Linear,
0.1,
)
.expect("Operation failed");
let test_points = Array2::from_shape_vec(
(10, 2),
vec![
0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.5, 0.5, 0.6, 0.6, 0.7, 0.7, 0.8, 0.8,
0.9, 0.9, 0.5, 0.1,
],
)
.expect("Operation failed");
let configs = vec![
ParallelConfig::new().with_workers(1),
ParallelConfig::new().with_workers(2),
ParallelConfig::new().with_workers(4),
];
let mut results = Vec::new();
for config in &configs {
let result = parallel_mls
.evaluate_parallel(&test_points.view(), config)
.expect("Operation failed");
results.push(result);
}
for i in 1..results.len() {
for j in 0..10 {
assert_abs_diff_eq!(results[0][j], results[i][j], epsilon = 0.01);
}
}
}
}