use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct AkimaSpline<F: Float + FromPrimitive> {
x: Array1<F>,
y: Array1<F>,
coeffs: Array2<F>,
}
impl<F: Float + FromPrimitive + Debug> AkimaSpline<F> {
pub fn new(x: &ArrayView1<F>, y: &ArrayView1<F>) -> InterpolateResult<Self> {
if x.len() != y.len() {
return Err(InterpolateError::invalid_input(
"x and y arrays must have the same length".to_string(),
));
}
if x.len() < 5 {
return Err(InterpolateError::invalid_input(
"at least 5 points are required for Akima spline".to_string(),
));
}
for i in 1..x.len() {
if x[i] <= x[i - 1] {
return Err(InterpolateError::invalid_input(
"x array must be strictly increasing".to_string(),
));
}
}
let n = x.len();
let mut slopes = Array1::zeros(n + 3);
for i in 0..n - 1 {
let m_i = (y[i + 1] - y[i]) / (x[i + 1] - x[i]);
slopes[i + 2] = m_i;
}
slopes[0] = F::from_f64(3.0).expect("Operation failed") * slopes[2]
- F::from_f64(2.0).expect("Operation failed") * slopes[3];
slopes[1] = F::from_f64(2.0).expect("Operation failed") * slopes[2] - slopes[3];
slopes[n + 1] = F::from_f64(2.0).expect("Operation failed") * slopes[n] - slopes[n - 1];
slopes[n + 2] = F::from_f64(3.0).expect("Operation failed") * slopes[n]
- F::from_f64(2.0).expect("Operation failed") * slopes[n - 1];
let mut derivatives = Array1::zeros(n);
for i in 0..n {
let w1 = (slopes[i + 3] - slopes[i + 2]).abs();
let w2 = (slopes[i + 1] - slopes[i]).abs();
if w1 + w2 == F::zero() {
derivatives[i] =
(slopes[i + 1] + slopes[i + 2]) / F::from_f64(2.0).expect("Operation failed");
} else {
derivatives[i] = (w1 * slopes[i + 1] + w2 * slopes[i + 2]) / (w1 + w2);
}
}
let mut coeffs = Array2::zeros((n - 1, 4));
for i in 0..n - 1 {
let dx = x[i + 1] - x[i];
let dy = y[i + 1] - y[i];
let a = y[i];
let b = derivatives[i];
let c = (F::from_f64(3.0).expect("Operation failed") * dy / dx
- F::from_f64(2.0).expect("Operation failed") * derivatives[i]
- derivatives[i + 1])
/ dx;
let d = (derivatives[i] + derivatives[i + 1]
- F::from_f64(2.0).expect("Operation failed") * dy / dx)
/ (dx * dx);
coeffs[[i, 0]] = a;
coeffs[[i, 1]] = b;
coeffs[[i, 2]] = c;
coeffs[[i, 3]] = d;
}
Ok(Self {
x: x.to_owned(),
y: y.to_owned(),
coeffs,
})
}
pub fn evaluate(&self, xnew: F) -> InterpolateResult<F> {
if xnew < self.x[0] || xnew > self.x[self.x.len() - 1] {
return Err(InterpolateError::OutOfBounds(
"xnew is outside the interpolation range".to_string(),
));
}
let mut idx = 0;
for i in 0..self.x.len() - 1 {
if xnew >= self.x[i] && xnew <= self.x[i + 1] {
idx = i;
break;
}
}
if xnew == self.x[self.x.len() - 1] {
return Ok(self.y[self.y.len() - 1]);
}
let dx = xnew - self.x[idx];
let a = self.coeffs[[idx, 0]];
let b = self.coeffs[[idx, 1]];
let c = self.coeffs[[idx, 2]];
let d = self.coeffs[[idx, 3]];
let result = a + b * dx + c * dx * dx + d * dx * dx * dx;
Ok(result)
}
pub fn evaluate_array(&self, xnew: &ArrayView1<F>) -> InterpolateResult<Array1<F>> {
let mut result = Array1::zeros(xnew.len());
for (i, &x) in xnew.iter().enumerate() {
result[i] = self.evaluate(x)?;
}
Ok(result)
}
pub fn derivative(&self, xnew: F) -> InterpolateResult<F> {
if xnew < self.x[0] || xnew > self.x[self.x.len() - 1] {
return Err(InterpolateError::OutOfBounds(
"xnew is outside the interpolation range".to_string(),
));
}
let mut idx = 0;
for i in 0..self.x.len() - 1 {
if xnew >= self.x[i] && xnew <= self.x[i + 1] {
idx = i;
break;
}
}
if xnew == self.x[self.x.len() - 1] {
idx = self.x.len() - 2;
let dx = self.x[idx + 1] - self.x[idx];
let b = self.coeffs[[idx, 1]];
let c = self.coeffs[[idx, 2]];
let d = self.coeffs[[idx, 3]];
return Ok(b
+ F::from_f64(2.0).expect("Operation failed") * c * dx
+ F::from_f64(3.0).expect("Operation failed") * d * dx * dx);
}
let dx = xnew - self.x[idx];
let b = self.coeffs[[idx, 1]];
let c = self.coeffs[[idx, 2]];
let d = self.coeffs[[idx, 3]];
let result = b
+ F::from_f64(2.0).expect("Operation failed") * c * dx
+ F::from_f64(3.0).expect("Operation failed") * d * dx * dx;
Ok(result)
}
}
#[allow(dead_code)]
pub fn make_akima_spline<F: crate::traits::InterpolationFloat>(
x: &ArrayView1<F>,
y: &ArrayView1<F>,
) -> InterpolateResult<AkimaSpline<F>> {
AkimaSpline::new(x, y)
}
#[allow(dead_code)]
pub fn akima_interpolate<F: crate::traits::InterpolationFloat>(
x: &ArrayView1<F>,
y: &ArrayView1<F>,
xnew: &ArrayView1<F>,
) -> InterpolateResult<Array1<F>> {
let spline = AkimaSpline::new(x, y)?;
spline.evaluate_array(xnew)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_akima_spline() {
let x = array![0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let y = array![0.0, 1.0, 4.0, 20.0, 16.0, 25.0];
let spline = AkimaSpline::new(&x.view(), &y.view()).expect("Operation failed");
assert_abs_diff_eq!(
spline.evaluate(0.0).expect("Operation failed"),
0.0,
epsilon = 1e-10
);
assert_abs_diff_eq!(
spline.evaluate(1.0).expect("Operation failed"),
1.0,
epsilon = 1e-10
);
assert_abs_diff_eq!(
spline.evaluate(2.0).expect("Operation failed"),
4.0,
epsilon = 1e-10
);
assert_abs_diff_eq!(
spline.evaluate(3.0).expect("Operation failed"),
20.0,
epsilon = 1e-10
);
assert_abs_diff_eq!(
spline.evaluate(4.0).expect("Operation failed"),
16.0,
epsilon = 1e-10
);
assert_abs_diff_eq!(
spline.evaluate(5.0).expect("Operation failed"),
25.0,
epsilon = 1e-10
);
let y_2_5 = spline.evaluate(2.5).expect("Operation failed");
let y_3_5 = spline.evaluate(3.5).expect("Operation failed");
assert!(y_2_5 > 4.0);
assert!(y_2_5 < 20.0);
assert!(y_3_5 < 20.0);
assert!(y_3_5 > 16.0);
assert!(spline.evaluate(-1.0).is_err());
assert!(spline.evaluate(6.0).is_err());
}
#[test]
fn test_akima_spline_derivative() {
let x = array![0.0, 1.0, 2.0, 3.0, 4.0];
let y = array![0.0, 1.0, 4.0, 9.0, 16.0];
let spline = AkimaSpline::new(&x.view(), &y.view()).expect("Operation failed");
let d_1 = spline.derivative(1.0).expect("Operation failed");
let d_2 = spline.derivative(2.0).expect("Operation failed");
let d_3 = spline.derivative(3.0).expect("Operation failed");
assert!((d_1 - 2.0).abs() < 0.3);
assert!((d_2 - 4.0).abs() < 0.3);
assert!((d_3 - 6.0).abs() < 0.3);
}
#[test]
fn test_make_akima_spline() {
let x = array![0.0, 1.0, 2.0, 3.0, 4.0];
let y = array![0.0, 1.0, 4.0, 9.0, 16.0];
let spline = make_akima_spline(&x.view(), &y.view()).expect("Operation failed");
assert_abs_diff_eq!(
spline.evaluate(2.5).expect("Operation failed"),
6.25,
epsilon = 0.5
);
}
}