use crate::error::{InterpolateError, InterpolateResult};
use crate::traits::InterpolationFloat;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use super::types::SplineBoundaryCondition;
use super::algorithms::*;
#[derive(Debug, Clone)]
pub struct CubicSpline<F: InterpolationFloat> {
x: Array1<F>,
y: Array1<F>,
coeffs: Array2<F>,
}
#[derive(Debug, Clone)]
pub struct CubicSplineBuilder<F: InterpolationFloat> {
x: Option<Array1<F>>,
y: Option<Array1<F>>,
boundary_condition: SplineBoundaryCondition<F>,
}
impl<F: InterpolationFloat> CubicSplineBuilder<F> {
pub fn new() -> Self {
Self {
x: None,
y: None,
boundary_condition: SplineBoundaryCondition::Natural,
}
}
pub fn x(mut self, x: Array1<F>) -> Self {
self.x = Some(x);
self
}
pub fn y(mut self, y: Array1<F>) -> Self {
self.y = Some(y);
self
}
pub fn boundary_condition(mut self, bc: SplineBoundaryCondition<F>) -> Self {
self.boundary_condition = bc;
self
}
pub fn build(self) -> InterpolateResult<CubicSpline<F>> {
let x = self
.x
.ok_or_else(|| InterpolateError::invalid_input("x coordinates not set".to_string()))?;
let y = self
.y
.ok_or_else(|| InterpolateError::invalid_input("y coordinates not set".to_string()))?;
CubicSpline::with_boundary_condition(&x.view(), &y.view(), self.boundary_condition)
}
}
impl<F: InterpolationFloat> Default for CubicSplineBuilder<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: InterpolationFloat + ToString> CubicSpline<F> {
pub fn builder() -> CubicSplineBuilder<F> {
CubicSplineBuilder::new()
}
pub fn new(x: &ArrayView1<F>, y: &ArrayView1<F>) -> InterpolateResult<Self> {
if x.len() != y.len() {
return Err(InterpolateError::invalid_input(
"x and y arrays must have the same length".to_string(),
));
}
if x.len() < 3 {
return Err(InterpolateError::insufficient_points(
3,
x.len(),
"cubic spline",
));
}
for i in 1..x.len() {
if x[i] <= x[i - 1] {
return Err(InterpolateError::invalid_input(
"x values must be sorted in ascending order".to_string(),
));
}
}
let coeffs = compute_natural_cubic_spline(x, y)?;
Ok(CubicSpline {
x: x.to_owned(),
y: y.to_owned(),
coeffs,
})
}
pub fn x(&self) -> &Array1<F> {
&self.x
}
pub fn y(&self) -> &Array1<F> {
&self.y
}
pub fn coeffs(&self) -> &Array2<F> {
&self.coeffs
}
pub fn new_not_a_knot(x: &ArrayView1<F>, y: &ArrayView1<F>) -> InterpolateResult<Self> {
if x.len() != y.len() {
return Err(InterpolateError::invalid_input(
"x and y arrays must have the same length".to_string(),
));
}
if x.len() < 4 {
return Err(InterpolateError::insufficient_points(
4,
x.len(),
"not-a-knot cubic spline",
));
}
for i in 1..x.len() {
if x[i] <= x[i - 1] {
return Err(InterpolateError::invalid_input(
"x values must be sorted in ascending order".to_string(),
));
}
}
let coeffs = compute_not_a_knot_cubic_spline(x, y)?;
Ok(CubicSpline {
x: x.to_owned(),
y: y.to_owned(),
coeffs,
})
}
pub fn with_boundary_condition(
x: &ArrayView1<F>,
y: &ArrayView1<F>,
bc: SplineBoundaryCondition<F>,
) -> InterpolateResult<Self> {
if x.len() != y.len() {
return Err(InterpolateError::invalid_input(
"x and y arrays must have the same length".to_string(),
));
}
let min_points = match bc {
SplineBoundaryCondition::NotAKnot => 4,
_ => 3,
};
if x.len() < min_points {
return Err(InterpolateError::insufficient_points(
min_points,
x.len(),
&format!("cubic spline with {:?} boundary condition", bc),
));
}
for i in 1..x.len() {
if x[i] <= x[i - 1] {
return Err(InterpolateError::invalid_input(
"x values must be sorted in ascending order".to_string(),
));
}
}
if let SplineBoundaryCondition::Periodic = bc {
let tolerance = F::from_f64(1e-10).unwrap_or_else(|| F::epsilon());
if (y[0] - y[y.len() - 1]).abs() > tolerance {
return Err(InterpolateError::invalid_input(
"For periodic boundary conditions, first and last y values must be equal".to_string(),
));
}
}
let coeffs = match bc {
SplineBoundaryCondition::Natural => compute_natural_cubic_spline(x, y)?,
SplineBoundaryCondition::NotAKnot => compute_not_a_knot_cubic_spline(x, y)?,
SplineBoundaryCondition::Clamped(dy0, dyn_) => {
compute_clamped_cubic_spline(x, y, dy0, dyn_)?
}
SplineBoundaryCondition::Periodic => compute_periodic_cubic_spline(x, y)?,
SplineBoundaryCondition::SecondDerivative(d2y0, d2yn) => {
compute_second_derivative_cubic_spline(x, y, d2y0, d2yn)?
}
SplineBoundaryCondition::ParabolicRunout => {
compute_parabolic_runout_cubic_spline(x, y)?
}
};
Ok(CubicSpline {
x: x.to_owned(),
y: y.to_owned(),
coeffs,
})
}
}