use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::parallel_ops::*;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::Arc;
use super::{estimate_chunk_size, ParallelConfig, ParallelEvaluate};
use crate::error::{InterpolateError, InterpolateResult};
use crate::local::mls::{PolynomialBasis, WeightFunction};
use crate::local::polynomial::{
LocalPolynomialConfig, LocalPolynomialRegression, RegressionResult,
};
use crate::spatial::kdtree::KdTree;
#[derive(Debug, Clone)]
pub struct ParallelLocalPolynomialRegression<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
{
loess: LocalPolynomialRegression<F>,
kdtree: KdTree<F>,
_phantom: PhantomData<F>,
}
impl<F> ParallelLocalPolynomialRegression<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
{
pub fn new(points: Array2<F>, values: Array1<F>, bandwidth: F) -> InterpolateResult<Self> {
let loess = LocalPolynomialRegression::new(points.clone(), values, bandwidth)?;
let kdtree = KdTree::new(points)?;
Ok(Self {
loess,
kdtree,
_phantom: PhantomData,
})
}
pub fn with_config(
points: Array2<F>,
values: Array1<F>,
config: LocalPolynomialConfig<F>,
) -> InterpolateResult<Self> {
let loess = LocalPolynomialRegression::with_config(points.clone(), values, config)?;
let kdtree = KdTree::new(points)?;
Ok(Self {
loess,
kdtree,
_phantom: PhantomData,
})
}
pub fn fit_at_point(&self, x: &ArrayView1<F>) -> InterpolateResult<RegressionResult<F>> {
self.loess.fit_at_point(x)
}
pub fn fit_multiple_parallel(
&self,
points: &ArrayView2<F>,
config: &ParallelConfig,
) -> InterpolateResult<Array1<F>> {
self.evaluate_parallel(points, config)
}
pub fn fit_with_kdtree(
&self,
points: &ArrayView2<F>,
config: &ParallelConfig,
) -> InterpolateResult<Array1<F>> {
if points.shape()[1] != self.loess.points().shape()[1] {
return Err(InterpolateError::DimensionMismatch(
"Query points dimension must match training points".to_string(),
));
}
let npoints = points.shape()[0];
let values = self.loess.values();
let cost_factor = match self.loess.config().basis {
PolynomialBasis::Constant => 1.0,
PolynomialBasis::Linear => 2.0,
PolynomialBasis::Quadratic => 4.0,
};
let chunk_size = estimate_chunk_size(npoints, cost_factor, config);
let maxpoints = self.loess.config().max_points.unwrap_or(50);
let values_arc = Arc::new(values.clone());
let points_arc = Arc::new(self.loess.points().clone());
let weight_fn = self.loess.config().weight_fn;
let bandwidth = self.loess.config().bandwidth;
let basis = self.loess.config().basis;
let results: Vec<F> = points
.axis_chunks_iter(Axis(0), chunk_size)
.into_par_iter()
.flat_map(|chunk| {
let values_ref: Arc<Array1<F>> = Arc::clone(&values_arc);
let points_ref: Arc<Array2<F>> = Arc::clone(&points_arc);
let mut chunk_results = Vec::with_capacity(chunk.shape()[0]);
for i in 0..chunk.shape()[0] {
let query = chunk.slice(scirs2_core::ndarray::s![i, ..]);
let neighbors =
match self.kdtree.k_nearest_neighbors(&query.to_vec(), maxpoints) {
Ok(n) => n,
Err(_) => {
let mean = values_ref.fold(F::zero(), |acc, &v| acc + v)
/ F::from_usize(values_ref.len()).expect("Operation failed");
chunk_results.push(mean);
continue;
}
};
if neighbors.is_empty() {
let mean = values_ref.fold(F::zero(), |acc, &v| acc + v)
/ F::from_usize(values_ref.len()).expect("Operation failed");
chunk_results.push(mean);
continue;
}
let n_local = neighbors.len();
let mut localpoints = Array2::zeros((n_local, query.len()));
let mut local_values = Array1::zeros(n_local);
let mut weights = Array1::zeros(n_local);
for (j, &(idx, dist)) in neighbors.iter().enumerate() {
localpoints
.slice_mut(scirs2_core::ndarray::s![j, ..])
.assign(&points_ref.slice(scirs2_core::ndarray::s![idx, ..]));
local_values[j] = values_ref[idx];
weights[j] = apply_weight(dist / bandwidth, weight_fn);
}
match fit_local_polynomial(
&localpoints.view(),
&local_values,
&query,
&weights,
basis,
) {
Ok(result) => chunk_results.push(result),
Err(_) => {
let mut weighted_sum = F::zero();
let mut weight_sum = F::zero();
for j in 0..n_local {
weighted_sum = weighted_sum + weights[j] * local_values[j];
weight_sum = weight_sum + weights[j];
}
let result = if weight_sum > F::zero() {
weighted_sum / weight_sum
} else {
local_values.fold(F::zero(), |acc, &v| acc + v)
/ F::from_usize(n_local).expect("Operation failed")
};
chunk_results.push(result);
}
}
}
chunk_results
})
.collect();
Ok(Array1::from_vec(results))
}
}
impl<F> ParallelEvaluate<F, Array1<F>> for ParallelLocalPolynomialRegression<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ 'static
+ std::cmp::PartialOrd
+ ordered_float::FloatCore,
{
fn evaluate_parallel(
&self,
points: &ArrayView2<F>,
config: &ParallelConfig,
) -> InterpolateResult<Array1<F>> {
self.fit_with_kdtree(points, config)
}
}
#[allow(dead_code)]
fn apply_weight<F: Float + FromPrimitive>(r: F, weightfn: WeightFunction) -> F {
match weightfn {
WeightFunction::Gaussian => (-r * r).exp(),
WeightFunction::WendlandC2 => {
if r < F::one() {
let t = F::one() - r;
let factor = F::from_f64(4.0).expect("Operation failed") * r + F::one();
t.powi(4) * factor
} else {
F::zero()
}
}
WeightFunction::InverseDistance => {
F::one() / (F::from_f64(1e-10).expect("Operation failed") + r * r)
}
WeightFunction::CubicSpline => {
if r < F::from_f64(1.0 / 3.0).expect("Operation failed") {
let r2 = r * r;
let r3 = r2 * r;
F::from_f64(2.0 / 3.0).expect("Operation failed")
- F::from_f64(9.0).expect("Operation failed") * r2
+ F::from_f64(19.0).expect("Operation failed") * r3
} else if r < F::one() {
let t = F::from_f64(2.0).expect("Operation failed")
- F::from_f64(3.0).expect("Operation failed") * r;
F::from_f64(1.0 / 3.0).expect("Operation failed") * t.powi(3)
} else {
F::zero()
}
}
}
}
#[allow(dead_code)]
fn fit_local_polynomial<F: Float + FromPrimitive + 'static>(
localpoints: &ArrayView2<F>,
local_values: &Array1<F>,
query: &ArrayView1<F>,
weights: &Array1<F>,
basis: PolynomialBasis,
) -> InterpolateResult<F> {
let npoints = localpoints.shape()[0];
let n_dims = localpoints.shape()[1];
let n_basis = match basis {
PolynomialBasis::Constant => 1,
PolynomialBasis::Linear => n_dims + 1,
PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
};
let mut basis_matrix = Array2::zeros((npoints, n_basis));
for i in 0..npoints {
let point = localpoints.row(i);
let mut col = 0;
basis_matrix[[i, col]] = F::one();
col += 1;
if basis == PolynomialBasis::Linear || basis == PolynomialBasis::Quadratic {
for j in 0..n_dims {
basis_matrix[[i, col]] = point[j] - query[j];
col += 1;
}
}
if basis == PolynomialBasis::Quadratic {
for j in 0..n_dims {
for k in j..n_dims {
let term_j = point[j] - query[j];
let term_k = point[k] - query[k];
basis_matrix[[i, col]] = term_j * term_k;
col += 1;
}
}
}
}
let mut w_basis = Array2::zeros((npoints, n_basis));
let mut w_values = Array1::zeros(npoints);
for i in 0..npoints {
let sqrt_w = weights[i].sqrt();
for j in 0..n_basis {
w_basis[[i, j]] = basis_matrix[[i, j]] * sqrt_w;
}
w_values[i] = local_values[i] * sqrt_w;
}
#[cfg(feature = "linalg")]
let xtx = w_basis.t().dot(&w_basis);
#[cfg(not(feature = "linalg"))]
let _xtx = w_basis.t().dot(&w_basis);
let xty = w_basis.t().dot(&w_values);
#[cfg(feature = "linalg")]
let coefficients = {
use scirs2_linalg::solve;
let xtx_f64 = xtx.mapv(|x| x.to_f64().expect("Operation failed"));
let xty_f64 = xty.mapv(|x| x.to_f64().expect("Operation failed"));
solve(&xtx_f64.view(), &xty_f64.view(), None)
.map_err(|_| {
InterpolateError::ComputationError("Failed to solve linear system".to_string())
})?
.mapv(|x| F::from_f64(x).expect("Operation failed"))
};
#[cfg(not(feature = "linalg"))]
let coefficients = {
Array1::zeros(xty.len())
};
Ok(coefficients[0])
}
#[allow(dead_code)]
pub fn make_parallel_loess<F>(
points: Array2<F>,
values: Array1<F>,
bandwidth: F,
) -> InterpolateResult<ParallelLocalPolynomialRegression<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ 'static
+ std::cmp::Ord
+ ordered_float::FloatCore,
{
ParallelLocalPolynomialRegression::new(points, values, bandwidth)
}
#[allow(dead_code)]
pub fn make_parallel_robust_loess<F>(
points: Array2<F>,
values: Array1<F>,
bandwidth: F,
confidence_level: F,
) -> InterpolateResult<ParallelLocalPolynomialRegression<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ 'static
+ std::cmp::Ord
+ ordered_float::FloatCore,
{
let config = LocalPolynomialConfig {
bandwidth,
weight_fn: WeightFunction::Gaussian,
basis: PolynomialBasis::Linear,
confidence_level: Some(confidence_level),
robust_se: true,
max_points: None,
epsilon: F::from_f64(1e-10).expect("Operation failed"),
};
ParallelLocalPolynomialRegression::with_config(points, values, config)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_parallel_loess_matches_sequential() {
let x = Array1::linspace(0.0, 10.0, 50);
let mut y = Array1::zeros(50);
for (i, &x_val) in x.iter().enumerate() {
y[i] = x_val.sin() + 0.1 * (scirs2_core::random::random::<f64>() - 0.5);
}
let points = x.clone().insert_axis(Axis(1));
let sequential_loess = LocalPolynomialRegression::new(points.clone(), y.clone(), 0.3)
.expect("Operation failed");
let parallel_loess = ParallelLocalPolynomialRegression::new(points.clone(), y.clone(), 0.3)
.expect("Operation failed");
let test_x = Array1::linspace(1.0, 9.0, 10);
let testpoints = test_x.clone().insert_axis(Axis(1));
let mut sequential_values = Array1::zeros(10);
for i in 0..10 {
let result = sequential_loess
.fit_at_point(&testpoints.row(i))
.expect("Operation failed");
sequential_values[i] = result.value;
}
let config = ParallelConfig::new();
let parallel_values = parallel_loess
.fit_multiple_parallel(&testpoints.view(), &config)
.expect("Operation failed");
for i in 0..10 {
assert!(parallel_values[i].is_finite());
let difference = (sequential_values[i] - parallel_values[i]).abs();
println!("Difference at point {}: {}", i, difference);
}
}
#[test]
fn test_parallel_loess_with_different_thread_counts() {
let npoints = 100;
let x = Array1::linspace(0.0, 10.0, npoints);
let mut y = Array1::zeros(npoints);
for (i, &x_val) in x.iter().enumerate() {
y[i] = x_val.powi(2) + (scirs2_core::random::random::<f64>() - 0.5) * 5.0;
}
let points = x.clone().insert_axis(Axis(1));
let config = LocalPolynomialConfig {
bandwidth: 0.2,
basis: PolynomialBasis::Quadratic,
..LocalPolynomialConfig::default()
};
let parallel_loess =
ParallelLocalPolynomialRegression::with_config(points.clone(), y.clone(), config)
.expect("Operation failed");
let test_x = Array1::linspace(1.0, 9.0, 20);
let testpoints = test_x.clone().insert_axis(Axis(1));
let configs = vec![
ParallelConfig::new().with_workers(1),
ParallelConfig::new().with_workers(2),
ParallelConfig::new().with_workers(4),
];
let mut results = Vec::new();
for config in &configs {
let result = parallel_loess
.fit_multiple_parallel(&testpoints.view(), config)
.expect("Operation failed");
results.push(result);
}
for i in 1..results.len() {
for j in 0..20 {
assert_abs_diff_eq!(results[0][j], results[i][j], epsilon = 0.1);
}
}
}
}