use scirs2_core::ndarray::ArrayStatCompat;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use std::marker::PhantomData;
use super::mls::{PolynomialBasis, WeightFunction};
use crate::error::{InterpolateError, InterpolateResult};
use statrs::statistics::Statistics;
#[derive(Debug, Clone)]
pub struct RegressionResult<F: Float> {
pub value: F,
pub std_error: F,
pub confidence_interval: Option<(F, F)>,
pub coefficients: Array1<F>,
pub effective_df: F,
pub r_squared: F,
}
#[derive(Debug, Clone)]
pub struct LocalPolynomialConfig<F: Float> {
pub bandwidth: F,
pub weight_fn: WeightFunction,
pub basis: PolynomialBasis,
pub confidence_level: Option<F>,
pub robust_se: bool,
pub max_points: Option<usize>,
pub epsilon: F,
}
impl<F: Float + FromPrimitive> Default for LocalPolynomialConfig<F> {
fn default() -> Self {
Self {
bandwidth: F::from_f64(0.2).expect("Operation failed"),
weight_fn: WeightFunction::Gaussian,
basis: PolynomialBasis::Linear,
confidence_level: None,
robust_se: false,
max_points: None,
epsilon: F::from_f64(1e-10).expect("Operation failed"),
}
}
}
#[derive(Debug, Clone)]
pub struct LocalPolynomialRegression<F>
where
F: Float + FromPrimitive + Debug,
{
points: Array2<F>,
values: Array1<F>,
config: LocalPolynomialConfig<F>,
response_sd: F,
_phantom: PhantomData<F>,
}
impl<F> LocalPolynomialRegression<F>
where
F: Float + FromPrimitive + Debug,
{
pub fn response_sd(&self) -> F {
self.response_sd
}
}
impl<F> LocalPolynomialRegression<F>
where
F: Float + FromPrimitive + Debug + 'static,
{
pub fn new(points: Array2<F>, values: Array1<F>, bandwidth: F) -> InterpolateResult<Self> {
let config = LocalPolynomialConfig {
bandwidth,
..LocalPolynomialConfig::default()
};
Self::with_config(points, values, config)
}
pub fn with_config(
points: Array2<F>,
values: Array1<F>,
config: LocalPolynomialConfig<F>,
) -> InterpolateResult<Self> {
if points.shape()[0] != values.len() {
return Err(InterpolateError::DimensionMismatch(
"Number of points must match number of values".to_string(),
));
}
if points.shape()[0] < 2 {
return Err(InterpolateError::InvalidValue(
"At least 2 points are required for local polynomial regression".to_string(),
));
}
if config.bandwidth <= F::zero() {
return Err(InterpolateError::InvalidValue(
"Bandwidth parameter must be positive".to_string(),
));
}
let mean = values.sum() / F::from_usize(values.len()).expect("Operation failed");
let sum_squared_dev = values.fold(F::zero(), |acc, &v| acc + (v - mean).powi(2));
let variance = sum_squared_dev / F::from_usize(values.len() - 1).expect("Operation failed");
let response_sd = variance.sqrt();
Ok(Self {
points,
values,
config,
response_sd,
_phantom: PhantomData,
})
}
pub fn fit_at_point(&self, x: &ArrayView1<F>) -> InterpolateResult<RegressionResult<F>> {
if x.len() != self.points.shape()[1] {
return Err(InterpolateError::DimensionMismatch(
"Query point dimension must match training points".to_string(),
));
}
let (indices, distances) = self.find_relevant_points(x)?;
if indices.is_empty() {
return Err(InterpolateError::invalid_input(
"No points found within effective range".to_string(),
));
}
let weights = self.compute_weights(&distances)?;
let local_points = self.extract_local_points(&indices);
let basis_functions = self.create_basis_functions(&local_points, x)?;
let local_values = self.extract_local_values(&indices);
let result = self.fit_weighted_least_squares(
&local_points,
&local_values,
x,
&weights,
&basis_functions,
)?;
Ok(result)
}
pub fn fit_multiple(&self, points: &ArrayView2<F>) -> InterpolateResult<Array1<F>> {
if points.shape()[1] != self.points.shape()[1] {
return Err(InterpolateError::DimensionMismatch(
"Query points dimension must match training points".to_string(),
));
}
let n_points = points.shape()[0];
let mut results = Array1::zeros(n_points);
for i in 0..n_points {
let point = points.slice(scirs2_core::ndarray::s![i, ..]);
let result = self.fit_at_point(&point)?;
results[i] = result.value;
}
Ok(results)
}
fn find_relevant_points(&self, x: &ArrayView1<F>) -> InterpolateResult<(Vec<usize>, Vec<F>)> {
let n_points = self.points.shape()[0];
let n_dims = self.points.shape()[1];
let mut distances = Vec::with_capacity(n_points);
for i in 0..n_points {
let mut d_squared = F::zero();
for j in 0..n_dims {
let diff = x[j] - self.points[[i, j]];
d_squared = d_squared + diff * diff;
}
let dist = d_squared.sqrt();
distances.push((i, dist));
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let limit = match self.config.max_points {
Some(limit) => std::cmp::min(limit, n_points),
None => n_points,
};
let effective_radius = match self.config.weight_fn {
WeightFunction::WendlandC2 | WeightFunction::CubicSpline => self.config.bandwidth,
_ => F::infinity(),
};
let mut indices = Vec::new();
let mut dist_values = Vec::new();
for &(idx, dist) in distances.iter().take(limit) {
if dist <= effective_radius {
indices.push(idx);
dist_values.push(dist);
}
}
let n_dims = self.points.shape()[1];
let min_points = match self.config.basis {
PolynomialBasis::Constant => 1,
PolynomialBasis::Linear => n_dims + 1,
PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
};
if indices.len() < min_points {
indices = distances
.iter()
.take(min_points)
.map(|&(idx, _)| idx)
.collect();
dist_values = distances
.iter()
.take(min_points)
.map(|&(_, dist)| dist)
.collect();
}
Ok((indices, dist_values))
}
fn compute_weights(&self, distances: &[F]) -> InterpolateResult<Array1<F>> {
let n = distances.len();
let mut weights = Array1::zeros(n);
for (i, &d) in distances.iter().enumerate() {
let r = d / self.config.bandwidth;
let weight = match self.config.weight_fn {
WeightFunction::Gaussian => (-r * r).exp(),
WeightFunction::WendlandC2 => {
if r < F::one() {
let t = F::one() - r;
let factor = F::from_f64(4.0).expect("Operation failed") * r + F::one();
t.powi(4) * factor
} else {
F::zero()
}
}
WeightFunction::InverseDistance => F::one() / (self.config.epsilon + r * r),
WeightFunction::CubicSpline => {
if r < F::from_f64(1.0 / 3.0).expect("Operation failed") {
let r2 = r * r;
let r3 = r2 * r;
F::from_f64(2.0 / 3.0).expect("Operation failed")
- F::from_f64(9.0).expect("Operation failed") * r2
+ F::from_f64(19.0).expect("Operation failed") * r3
} else if r < F::one() {
let t = F::from_f64(2.0).expect("Operation failed")
- F::from_f64(3.0).expect("Operation failed") * r;
F::from_f64(1.0 / 3.0).expect("Operation failed") * t.powi(3)
} else {
F::zero()
}
}
};
weights[i] = weight;
}
let sum = weights.sum();
if sum > F::zero() {
weights.mapv_inplace(|w| w / sum);
} else {
weights.fill(F::from_f64(1.0 / n as f64).expect("Operation failed"));
}
Ok(weights)
}
fn extract_local_points(&self, indices: &[usize]) -> Array2<F> {
let n_points = indices.len();
let n_dims = self.points.shape()[1];
let mut local_points = Array2::zeros((n_points, n_dims));
for (i, &idx) in indices.iter().enumerate() {
let row = self.points.row(idx);
local_points.row_mut(i).assign(&row);
}
local_points
}
fn extract_local_values(&self, indices: &[usize]) -> Array1<F> {
let mut local_values = Array1::zeros(indices.len());
for (i, &idx) in indices.iter().enumerate() {
local_values[i] = self.values[idx];
}
local_values
}
fn create_basis_functions(
&self,
local_points: &Array2<F>,
x: &ArrayView1<F>,
) -> InterpolateResult<Array2<F>> {
let n_points = local_points.shape()[0];
let n_dims = local_points.shape()[1];
let n_basis = match self.config.basis {
PolynomialBasis::Constant => 1,
PolynomialBasis::Linear => n_dims + 1,
PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
};
let mut basis = Array2::zeros((n_points, n_basis));
for i in 0..n_points {
let point = local_points.row(i);
let mut col = 0;
basis[[i, col]] = F::one();
col += 1;
if self.config.basis == PolynomialBasis::Linear
|| self.config.basis == PolynomialBasis::Quadratic
{
for j in 0..n_dims {
basis[[i, col]] = point[j] - x[j]; col += 1;
}
}
if self.config.basis == PolynomialBasis::Quadratic {
for j in 0..n_dims {
for k in j..n_dims {
let term_j = point[j] - x[j];
let term_k = point[k] - x[k];
basis[[i, col]] = term_j * term_k;
col += 1;
}
}
}
}
Ok(basis)
}
#[allow(clippy::too_many_lines)]
fn fit_weighted_least_squares(
&self,
local_points: &Array2<F>,
local_values: &Array1<F>,
_x: &ArrayView1<F>,
weights: &Array1<F>,
basis: &Array2<F>,
) -> InterpolateResult<RegressionResult<F>> {
let n_points = local_points.shape()[0];
let n_basis = basis.shape()[1];
let mut w_basis = Array2::zeros((n_points, n_basis));
let mut w_values = Array1::zeros(n_points);
for i in 0..n_points {
let sqrt_w = weights[i].sqrt();
for j in 0..n_basis {
w_basis[[i, j]] = basis[[i, j]] * sqrt_w;
}
w_values[i] = local_values[i] * sqrt_w;
}
#[cfg(feature = "linalg")]
let xtx = w_basis.t().dot(&w_basis);
#[cfg(not(feature = "linalg"))]
let _xtx = w_basis.t().dot(&w_basis);
let xty = w_basis.t().dot(&w_values);
#[cfg(feature = "linalg")]
let coefficients = {
use scirs2_linalg::solve;
let xtx_f64 = xtx.mapv(|x| x.to_f64().expect("Operation failed"));
let xty_f64 = xty.mapv(|x| x.to_f64().expect("Operation failed"));
match solve(&xtx_f64.view(), &xty_f64.view(), None) {
Ok(c) => c.mapv(|_x| F::from_f64(_x).expect("Operation failed")),
Err(_) => {
let mut mean = F::zero();
let mut sum_weights = F::zero();
for i in 0..n_points {
mean = mean + weights[i] * local_values[i];
sum_weights = sum_weights + weights[i];
}
if sum_weights > F::zero() {
mean = mean / sum_weights;
} else {
mean = local_values.mean_or(F::zero());
}
let mut result = Array1::zeros(n_basis);
result[0] = mean;
result
}
}
};
#[cfg(not(feature = "linalg"))]
let coefficients = {
let mut result = Array1::zeros(xty.len());
let mut mean = F::zero();
let mut sum_weights = F::zero();
for i in 0..n_points {
mean = mean + weights[i] * local_values[i];
sum_weights = sum_weights + weights[i];
}
if sum_weights > F::zero() {
mean = mean / sum_weights;
} else {
mean = local_values.mean_or(F::zero());
}
result[0] = mean;
result
};
let fitted_value = coefficients[0];
#[cfg(feature = "linalg")]
let xtx_inv = {
use scirs2_linalg::inv;
let xtx_f64 = xtx.mapv(|x| x.to_f64().expect("Operation failed"));
match inv(&xtx_f64.view(), None) {
Ok(inv) => inv.mapv(|_x| F::from_f64(_x).expect("Operation failed")),
Err(_) => {
return Ok(RegressionResult {
value: fitted_value,
std_error: F::zero(),
confidence_interval: None,
coefficients,
effective_df: F::from_f64(1.0).expect("Operation failed"),
r_squared: F::zero(),
});
}
}
};
#[cfg(not(feature = "linalg"))]
{
Ok(RegressionResult {
value: fitted_value,
std_error: F::zero(),
confidence_interval: None,
coefficients,
effective_df: F::from_f64(1.0).expect("Operation failed"),
r_squared: F::zero(),
})
}
#[cfg(feature = "linalg")]
{
let fitted_local = basis.dot(&coefficients);
let residuals = local_values - &fitted_local;
let ssr = residuals
.iter()
.zip(weights.iter())
.fold(F::zero(), |acc, (&r, &w)| acc + w * r * r);
let mean = local_values
.iter()
.zip(weights.iter())
.fold(F::zero(), |acc, (&y, &w)| acc + w * y)
/ weights.sum();
let sst = local_values
.iter()
.zip(weights.iter())
.fold(F::zero(), |acc, (&y, &w)| acc + w * (y - mean) * (y - mean));
let r_squared = if sst > F::zero() {
F::one() - (ssr / sst)
} else {
F::zero()
};
let mut leverage = Array1::zeros(n_points);
for i in 0..n_points {
let w_row = w_basis.row(i);
let h_ii = w_row.dot(&xtx_inv.dot(&w_row));
leverage[i] = h_ii;
}
let effective_df = leverage.sum();
let xtx_inv_row1 = xtx_inv.row(0);
let n_effective = F::from_usize(n_points).expect("Operation failed") - effective_df;
let mse = if n_effective > F::zero() {
ssr / n_effective
} else {
F::zero()
};
let std_error = if self.config.robust_se {
let mut sum_squared_weighted_residuals = F::zero();
for i in 0..n_points {
let adjusted_residual = if leverage[i] < F::one() {
residuals[i] / (F::one() - leverage[i])
} else {
residuals[i]
};
let weighted_residual = basis[[i, 0]] * adjusted_residual;
sum_squared_weighted_residuals =
sum_squared_weighted_residuals + weighted_residual * weighted_residual;
}
(xtx_inv_row1[0] * sum_squared_weighted_residuals).sqrt()
} else {
(xtx_inv_row1[0] * mse).sqrt()
};
let confidence_interval = self.config.confidence_level.map(|level| {
let alpha = F::one() - level;
let half_alpha = alpha / F::from_f64(2.0).expect("Operation failed");
let z_critical = if half_alpha <= F::from_f64(0.001).expect("Operation failed") {
F::from_f64(3.09).expect("Operation failed") } else if half_alpha <= F::from_f64(0.005).expect("Operation failed") {
F::from_f64(2.81).expect("Operation failed") } else if half_alpha <= F::from_f64(0.01).expect("Operation failed") {
F::from_f64(2.58).expect("Operation failed") } else if half_alpha <= F::from_f64(0.025).expect("Operation failed") {
F::from_f64(1.96).expect("Operation failed") } else if half_alpha <= F::from_f64(0.05).expect("Operation failed") {
F::from_f64(1.645).expect("Operation failed") } else {
F::from_f64(1.28).expect("Operation failed") };
let margin = z_critical * std_error;
(fitted_value - margin, fitted_value + margin)
});
Ok(RegressionResult {
value: fitted_value,
std_error,
confidence_interval,
coefficients,
effective_df,
r_squared,
})
}
}
pub fn config(&self) -> &LocalPolynomialConfig<F> {
&self.config
}
pub fn points(&self) -> &Array2<F> {
&self.points
}
pub fn values(&self) -> &Array1<F> {
&self.values
}
pub fn select_bandwidth(&self, bandwidths: &[F]) -> InterpolateResult<F> {
if bandwidths.is_empty() {
return Err(InterpolateError::InvalidValue(
"Bandwidths array cannot be empty".to_string(),
));
}
let n_points = self.points.shape()[0];
let mut best_bandwidth = bandwidths[0];
let mut min_error = F::infinity();
for &bandwidth in bandwidths {
if bandwidth <= F::zero() {
continue;
}
let config = LocalPolynomialConfig {
bandwidth,
..self.config.clone()
};
let model = Self::with_config(self.points.clone(), self.values.clone(), config)?;
let mut total_squared_error = F::zero();
for i in 0..n_points {
let point = self.points.row(i).to_owned();
let result = model.fit_at_point(&point.view())?;
let error = result.value - self.values[i];
total_squared_error = total_squared_error + error * error;
}
let mse = total_squared_error / F::from_usize(n_points).expect("Operation failed");
if mse < min_error {
min_error = mse;
best_bandwidth = bandwidth;
}
}
Ok(best_bandwidth)
}
}
#[allow(dead_code)]
pub fn make_loess<F>(
points: Array2<F>,
values: Array1<F>,
bandwidth: F,
) -> InterpolateResult<LocalPolynomialRegression<F>>
where
F: Float + FromPrimitive + Debug + 'static,
{
LocalPolynomialRegression::new(points, values, bandwidth)
}
#[allow(dead_code)]
pub fn make_robust_loess<F>(
points: Array2<F>,
values: Array1<F>,
bandwidth: F,
confidence_level: F,
) -> InterpolateResult<LocalPolynomialRegression<F>>
where
F: Float + FromPrimitive + Debug + 'static,
{
let config = LocalPolynomialConfig {
bandwidth,
weight_fn: WeightFunction::Gaussian,
basis: PolynomialBasis::Linear,
confidence_level: Some(confidence_level),
robust_se: true,
max_points: None,
epsilon: F::from_f64(1e-10).expect("Operation failed"),
};
LocalPolynomialRegression::with_config(points, values, config)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::{array, Axis};
#[test]
fn test_local_polynomial_regression() {
let x = array![0.0, 1.0, 2.0, 3.0, 4.0, 5.0].insert_axis(Axis(1));
let y = array![0.0, 1.0, 4.0, 9.0, 16.0, 25.0];
let loess = LocalPolynomialRegression::new(
x.clone(),
y.clone(),
1.5, )
.expect("Operation failed");
let query = array![2.5];
let result = loess.fit_at_point(&query.view()).expect("Operation failed");
assert_abs_diff_eq!(result.value, 6.25, epsilon = 1.5);
}
#[test]
fn test_confidence_intervals() {
let x = array![0.0, 1.0, 2.0, 3.0, 4.0, 5.0].insert_axis(Axis(1));
let y = array![0.0, 1.0, 4.0, 9.0, 16.0, 25.0];
let config = LocalPolynomialConfig {
bandwidth: 1.5,
confidence_level: Some(0.95),
..LocalPolynomialConfig::default()
};
let loess = LocalPolynomialRegression::with_config(x.clone(), y.clone(), config)
.expect("Operation failed");
let query = array![2.5];
let result = loess.fit_at_point(&query.view()).expect("Operation failed");
#[cfg(feature = "linalg")]
{
let (lower, upper) = result.confidence_interval.expect("Operation failed");
assert!(lower < 6.25);
assert!(upper > 6.25);
}
#[cfg(not(feature = "linalg"))]
{
assert!(result.confidence_interval.is_none());
}
}
}