use crate::bspline::{generate_knots, BSpline, ExtrapolateMode};
use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::ArrayView1;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::{Debug, Display};
use std::ops::{Add, Div, Mul, Sub};
use super::solver::{solve_constrained_system, solve_penalized_system};
use super::types::{ConstrainedSpline, Constraint};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FittingMethod {
LeastSquares,
Interpolation,
Penalized,
}
impl<T> ConstrainedSpline<T>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static
+ std::fmt::LowerExp,
{
pub fn interpolate(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
constraints: Vec<Constraint<T>>,
degree: usize,
extrapolate: ExtrapolateMode,
) -> InterpolateResult<Self> {
Self::fit_internal(
x,
y,
constraints,
degree,
None,
FittingMethod::Interpolation,
extrapolate,
)
}
pub fn least_squares(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
constraints: Vec<Constraint<T>>,
num_knots: usize,
degree: usize,
extrapolate: ExtrapolateMode,
) -> InterpolateResult<Self> {
Self::fit_internal(
x,
y,
constraints,
degree,
Some(num_knots),
FittingMethod::LeastSquares,
extrapolate,
)
}
pub fn penalized(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
constraints: Vec<Constraint<T>>,
num_knots: usize,
degree: usize,
lambda: T,
extrapolate: ExtrapolateMode,
) -> InterpolateResult<Self> {
Self::fit_internal(
x,
y,
constraints,
degree,
Some(num_knots),
FittingMethod::Penalized,
extrapolate,
)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn fit_internal(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
constraints: Vec<Constraint<T>>,
degree: usize,
num_knots: Option<usize>,
method: FittingMethod,
extrapolate: ExtrapolateMode,
) -> InterpolateResult<Self> {
if x.len() != y.len() {
return Err(InterpolateError::IndexError(format!(
"x and y arrays must have the same length: {} vs {}",
x.len(),
y.len()
)));
}
if x.len() < degree + 1 {
return Err(InterpolateError::IndexError(format!(
"Need at least {} points for degree {} spline",
degree + 1,
degree
)));
}
for i in 1..x.len() {
if x[i] <= x[i - 1] {
return Err(InterpolateError::IndexError(
"x values must be strictly increasing".to_string(),
));
}
}
let _n_internal = num_knots.unwrap_or(x.len() - degree - 1);
let _knots = generate_knots(x, degree, "clamped")?;
let coeffs = match method {
FittingMethod::Interpolation => {
solve_constrained_system(x, y, &_knots.view(), degree, &constraints)?
}
FittingMethod::LeastSquares => {
solve_constrained_system(x, y, &_knots.view(), degree, &constraints)?
}
FittingMethod::Penalized => {
solve_penalized_system(
x,
y,
&_knots.view(),
degree,
&constraints,
T::from_f64(0.1).expect("Operation failed"),
)?
}
};
let bspline = BSpline::new(&_knots.view(), &coeffs.view(), degree, extrapolate)?;
Ok(ConstrainedSpline {
bspline,
constraints,
method,
})
}
}