use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum PchipExtrapolateMode {
#[default]
Linear,
Polynomial,
}
#[derive(Debug, Clone)]
pub struct PchipInterpolator<F: Float> {
x: Array1<F>,
y: Array1<F>,
derivatives: Array1<F>,
extrapolate: bool,
extrapolate_mode: PchipExtrapolateMode,
}
impl<F: Float + FromPrimitive + Debug> PchipInterpolator<F> {
pub fn new(x: &ArrayView1<F>, y: &ArrayView1<F>, extrapolate: bool) -> InterpolateResult<Self> {
if x.len() != y.len() {
return Err(InterpolateError::invalid_input(
"x and y arrays must have the same length".to_string(),
));
}
if x.len() < 2 {
return Err(InterpolateError::insufficient_points(
2,
x.len(),
"PCHIP interpolation",
));
}
for i in 1..x.len() {
if x[i] <= x[i - 1] {
return Err(InterpolateError::invalid_input(
"x values must be sorted in ascending order".to_string(),
));
}
}
let x_arr = x.to_owned();
let y_arr = y.to_owned();
let derivatives = Self::find_derivatives(&x_arr, &y_arr)?;
Ok(PchipInterpolator {
x: x_arr,
y: y_arr,
derivatives,
extrapolate,
extrapolate_mode: PchipExtrapolateMode::Linear,
})
}
pub fn with_extrapolate_mode(mut self, mode: PchipExtrapolateMode) -> Self {
self.extrapolate_mode = mode;
self
}
pub fn evaluate(&self, xnew: F) -> InterpolateResult<F> {
let n = self.x.len();
let is_extrapolating = xnew < self.x[0] || xnew > self.x[n - 1];
if is_extrapolating && !self.extrapolate {
return Err(InterpolateError::OutOfBounds(
"xnew is outside the interpolation range".to_string(),
));
}
if is_extrapolating && self.extrapolate_mode == PchipExtrapolateMode::Linear {
if xnew < self.x[0] {
let dx = xnew - self.x[0];
return Ok(self.y[0] + self.derivatives[0] * dx);
} else {
let dx = xnew - self.x[n - 1];
return Ok(self.y[n - 1] + self.derivatives[n - 1] * dx);
}
}
if !is_extrapolating && xnew == self.x[n - 1] {
return Ok(self.y[n - 1]);
}
let idx = if is_extrapolating {
if xnew < self.x[0] {
0
} else {
n - 2
}
} else {
let mut seg = 0;
for i in 0..n - 1 {
if xnew >= self.x[i] && xnew <= self.x[i + 1] {
seg = i;
break;
}
}
seg
};
let x1 = self.x[idx];
let x2 = self.x[idx + 1];
let y1 = self.y[idx];
let y2 = self.y[idx + 1];
let d1 = self.derivatives[idx];
let d2 = self.derivatives[idx + 1];
let h = x2 - x1;
let t = (xnew - x1) / h;
let h00 = Self::h00(t);
let h10 = Self::h10(t);
let h01 = Self::h01(t);
let h11 = Self::h11(t);
let result = h00 * y1 + h10 * h * d1 + h01 * y2 + h11 * h * d2;
Ok(result)
}
pub fn evaluate_array(&self, xnew: &ArrayView1<F>) -> InterpolateResult<Array1<F>> {
let mut result = Array1::zeros(xnew.len());
for (i, &x) in xnew.iter().enumerate() {
result[i] = self.evaluate(x)?;
}
Ok(result)
}
fn h00(t: F) -> F {
let two = F::from_f64(2.0).expect("Operation failed");
let three = F::from_f64(3.0).expect("Operation failed");
(two * t * t * t) - (three * t * t) + F::one()
}
fn h10(t: F) -> F {
let two = F::from_f64(2.0).expect("Operation failed");
(t * t * t) - (two * t * t) + t
}
fn h01(t: F) -> F {
let two = F::from_f64(2.0).expect("Operation failed");
let three = F::from_f64(3.0).expect("Operation failed");
-(two * t * t * t) + (three * t * t)
}
fn h11(t: F) -> F {
(t * t * t) - (t * t)
}
fn edge_case(h0: F, h1: F, m0: F, m1: F) -> F {
let two = F::from_f64(2.0).expect("Operation failed");
let three = F::from_f64(3.0).expect("Operation failed");
let d = ((two * h0 + h1) * m0 - h0 * m1) / (h0 + h1);
let sign_d = if d >= F::zero() { F::one() } else { -F::one() };
let sign_m0 = if m0 >= F::zero() { F::one() } else { -F::one() };
let sign_m1 = if m1 >= F::zero() { F::one() } else { -F::one() };
if sign_d != sign_m0 {
F::zero()
} else if (sign_m0 != sign_m1) && (d.abs() > three * m0.abs()) {
three * m0
} else {
d
}
}
fn find_derivatives(x: &Array1<F>, y: &Array1<F>) -> InterpolateResult<Array1<F>> {
let n = x.len();
let mut derivatives = Array1::zeros(n);
if n == 2 {
let slope = (y[1] - y[0]) / (x[1] - x[0]);
derivatives[0] = slope;
derivatives[1] = slope;
return Ok(derivatives);
}
let mut slopes = Array1::zeros(n - 1);
for i in 0..n - 1 {
slopes[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i]);
}
let mut h = Array1::zeros(n - 1);
for i in 0..n - 1 {
h[i] = x[i + 1] - x[i];
}
let two = F::from_f64(2.0).expect("Operation failed");
for i in 1..n - 1 {
let prev_slope = slopes[i - 1];
let curr_slope = slopes[i];
let sign_prev = if prev_slope > F::zero() {
F::one()
} else if prev_slope < F::zero() {
-F::one()
} else {
F::zero()
};
let sign_curr = if curr_slope > F::zero() {
F::one()
} else if curr_slope < F::zero() {
-F::one()
} else {
F::zero()
};
if sign_prev * sign_curr <= F::zero() {
derivatives[i] = F::zero();
} else {
let w1 = two * h[i] + h[i - 1];
let w2 = h[i] + two * h[i - 1];
if prev_slope.abs() < F::epsilon() || curr_slope.abs() < F::epsilon() {
derivatives[i] = F::zero();
} else {
let whmean_inv = (w1 / prev_slope + w2 / curr_slope) / (w1 + w2);
derivatives[i] = F::one() / whmean_inv;
}
}
}
derivatives[0] = Self::edge_case(h[0], h[1], slopes[0], slopes[1]);
derivatives[n - 1] = Self::edge_case(h[n - 2], h[n - 3], slopes[n - 2], slopes[n - 3]);
Ok(derivatives)
}
}
#[allow(dead_code)]
pub fn pchip_interpolate<F: crate::traits::InterpolationFloat>(
x: &ArrayView1<F>,
y: &ArrayView1<F>,
xnew: &ArrayView1<F>,
extrapolate: bool,
) -> InterpolateResult<Array1<F>> {
let interp = PchipInterpolator::new(x, y, extrapolate)?;
interp.evaluate_array(xnew)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_pchip_interpolation_basic() {
let x = array![0.0, 1.0, 2.0, 3.0];
let y = array![0.0, 1.0, 4.0, 9.0];
let interp = PchipInterpolator::new(&x.view(), &y.view(), false).expect("Operation failed");
assert_relative_eq!(interp.evaluate(0.0).expect("Operation failed"), 0.0);
assert_relative_eq!(interp.evaluate(1.0).expect("Operation failed"), 1.0);
assert_relative_eq!(interp.evaluate(2.0).expect("Operation failed"), 4.0);
assert_relative_eq!(interp.evaluate(3.0).expect("Operation failed"), 9.0);
let y_interp_0_5 = interp.evaluate(0.5).expect("Operation failed");
let y_interp_1_5 = interp.evaluate(1.5).expect("Operation failed");
let y_interp_2_5 = interp.evaluate(2.5).expect("Operation failed");
assert!(y_interp_0_5 > 0.0 && y_interp_0_5 < 1.0);
assert!(y_interp_1_5 > 1.0 && y_interp_1_5 < 4.0);
assert!(y_interp_2_5 > 4.0 && y_interp_2_5 < 9.0);
}
#[test]
fn test_pchip_monotonicity_preservation() {
let x = array![0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let y = array![0.0, 1.0, 0.5, 0.0, 0.5, 2.0];
let interp = PchipInterpolator::new(&x.view(), &y.view(), false).expect("Operation failed");
let y_0_25 = interp.evaluate(0.25).expect("Operation failed");
let y_0_50 = interp.evaluate(0.50).expect("Operation failed");
let y_0_75 = interp.evaluate(0.75).expect("Operation failed");
assert!(y_0_25 <= y_0_50 && y_0_50 <= y_0_75);
let y_1_25 = interp.evaluate(1.25).expect("Operation failed");
let y_1_50 = interp.evaluate(1.50).expect("Operation failed");
let y_1_75 = interp.evaluate(1.75).expect("Operation failed");
assert!(y_1_25 >= y_1_50 && y_1_50 >= y_1_75);
let y_4_25 = interp.evaluate(4.25).expect("Operation failed");
let y_4_50 = interp.evaluate(4.50).expect("Operation failed");
let y_4_75 = interp.evaluate(4.75).expect("Operation failed");
assert!(y_4_25 <= y_4_50 && y_4_50 <= y_4_75);
}
#[test]
fn test_pchip_extrapolation() {
let x = array![0.0, 1.0, 2.0, 3.0];
let y = array![0.0, 1.0, 4.0, 9.0];
let interp_extrap =
PchipInterpolator::new(&x.view(), &y.view(), true).expect("Operation failed");
let y_minus_1 = interp_extrap.evaluate(-1.0).expect("Operation failed");
let y_plus_4 = interp_extrap.evaluate(4.0).expect("Operation failed");
assert!(
y_plus_4 > 9.0,
"Extrapolation above should be greater than last point"
);
assert!(
y_minus_1.is_finite(),
"Extrapolation should produce finite values"
);
assert!(
y_plus_4.is_finite(),
"Extrapolation should produce finite values"
);
let interp_no_extrap =
PchipInterpolator::new(&x.view(), &y.view(), false).expect("Operation failed");
assert!(interp_no_extrap.evaluate(-1.0).is_err());
assert!(interp_no_extrap.evaluate(4.0).is_err());
}
#[test]
fn test_pchip_extrapolation_far_beyond_range() {
let x = array![0.0, 1.0, 2.0, 3.0];
let y = array![0.0, 1.0, 4.0, 9.0];
let interp = PchipInterpolator::new(&x.view(), &y.view(), true).expect("Operation failed");
let y_50 = interp.evaluate(50.0).expect("Operation failed");
assert!(
y_50 > 9.0,
"Extrapolation at x=50 should be > 9.0, got {}",
y_50
);
assert!(
y_50 < 1000.0,
"Extrapolation should be reasonable (linear), got {}",
y_50
);
assert!(
y_50.is_finite(),
"Extrapolation should produce finite values"
);
let y_minus_50 = interp.evaluate(-50.0).expect("Operation failed");
assert!(
y_minus_50.is_finite(),
"Extrapolation should produce finite values"
);
assert!(
y_minus_50.abs() < 1000.0,
"Extrapolation should be reasonable (linear), got {}",
y_minus_50
);
}
#[test]
fn test_pchip_extrapolation_linear_behavior() {
let x = array![0.0, 1.0, 2.0, 3.0];
let y = array![0.0, 1.0, 4.0, 9.0];
let interp = PchipInterpolator::new(&x.view(), &y.view(), true).expect("Operation failed");
let y_4 = interp.evaluate(4.0).expect("Operation failed");
let y_5 = interp.evaluate(5.0).expect("Operation failed");
let extrap_slope = y_5 - y_4;
let last_derivative = interp.derivatives[interp.derivatives.len() - 1];
assert_relative_eq!(extrap_slope, last_derivative, epsilon = 1e-10);
}
#[test]
fn test_pchip_interpolate_function() {
let x = array![0.0, 1.0, 2.0, 3.0];
let y = array![0.0, 1.0, 4.0, 9.0];
let xnew = array![0.5, 1.5, 2.5];
let y_interp =
pchip_interpolate(&x.view(), &y.view(), &xnew.view(), false).expect("Operation failed");
assert_eq!(y_interp.len(), 3);
assert!(y_interp[0] > 0.0 && y_interp[0] < 1.0);
assert!(y_interp[1] > 1.0 && y_interp[1] < 4.0);
assert!(y_interp[2] > 4.0 && y_interp[2] < 9.0);
}
#[test]
fn test_pchip_error_conditions() {
let x = array![0.0, 1.0, 2.0, 3.0];
let y = array![0.0, 1.0, 4.0];
assert!(PchipInterpolator::new(&x.view(), &y.view(), false).is_err());
let x_unsorted = array![0.0, 2.0, 1.0, 3.0];
let y_valid = array![0.0, 1.0, 4.0, 9.0];
assert!(PchipInterpolator::new(&x_unsorted.view(), &y_valid.view(), false).is_err());
let x_short = array![0.0];
let y_short = array![0.0];
assert!(PchipInterpolator::new(&x_short.view(), &y_short.view(), false).is_err());
}
}