use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, FromPrimitive, Zero};
use std::fmt::{Debug, Display};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, RemAssign, Sub, SubAssign};
use crate::error::{InterpolateError, InterpolateResult};
use super::core::BSpline;
use super::solvers::{solve_least_squares, solve_linear_system};
use super::types::ExtrapolateMode;
pub fn make_interp_bspline<T>(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
k: usize,
extrapolate: ExtrapolateMode,
) -> InterpolateResult<BSpline<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Zero
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign,
{
if x.len() != y.len() {
return Err(InterpolateError::invalid_input(
"x and y arrays must have the same length".to_string(),
));
}
if x.len() < k + 1 {
return Err(InterpolateError::insufficient_points(
k + 1,
x.len(),
&format!("degree {} B-spline", k),
));
}
for i in 1..x.len() {
if x[i] <= x[i - 1] {
return Err(InterpolateError::invalid_input(
"x values must be sorted in ascending order".to_string(),
));
}
}
let n = x.len();
let mut t = Array1::zeros(n + k + 1);
let x_min = x[0];
let x_max = x[n - 1];
if k == 1 {
for i in 0..=k {
t[i] = x_min;
}
if n > 2 {
let step = (x_max - x_min) / T::from_usize(n - 1).expect("Operation failed");
for i in 1..(n - k) {
t[k + i] = x_min + T::from_usize(i).expect("Operation failed") * step;
}
}
for i in 0..=k {
t[n + i] = x_max;
}
} else {
for i in 0..=k {
t[i] = x_min;
}
if n > k + 1 {
for i in 1..(n - k) {
let mut sum = T::zero();
for j in 1..=k {
if i + j - 1 < n {
sum += x[i + j - 1];
}
}
t[k + i] = sum / T::from_usize(k).expect("Operation failed");
}
}
for i in 0..=k {
t[n + i] = x_max;
}
}
let mut a = scirs2_core::ndarray::Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
let basis = BSpline::basis_element(k, j, &t.view(), extrapolate)?;
a[(i, j)] = basis.evaluate(x[i])?;
}
}
let c = if k == 1 {
y.to_owned()
} else {
match solve_linear_system(&a.view(), y) {
Ok(coeffs) => coeffs,
Err(_) => {
match solve_least_squares(&a.view(), y) {
Ok(coeffs) => coeffs,
Err(_) => {
if k > 1 {
return make_interp_bspline(x, y, k - 1, extrapolate);
} else {
return Err(InterpolateError::invalid_input(
"Unable to construct B-spline: matrix remains singular even for linear case".to_string(),
));
}
}
}
}
}
};
BSpline::new(&t.view(), &c.view(), k, extrapolate)
}
pub fn generate_knots<T>(
x: &ArrayView1<T>,
k: usize,
knot_style: &str,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Zero
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign,
{
let n = x.len();
for i in 1..n {
if x[i] <= x[i - 1] {
return Err(InterpolateError::invalid_input(
"x values must be sorted in ascending order".to_string(),
));
}
}
let mut t = Array1::zeros(n + k + 1);
match knot_style {
"uniform" => {
let x_min = x[0];
let x_max = x[n - 1];
let step = (x_max - x_min) / T::from_usize(n - k).expect("Operation failed");
for i in 0..=k {
t[i] = x_min;
}
for i in k + 1..n {
t[i] = x_min + T::from_usize(i - k).expect("Operation failed") * step;
}
for i in n..n + k + 1 {
t[i] = x_max;
}
}
"average" => {
for i in 0..=k {
t[i] = x[0];
}
for i in 1..n - k {
let mut avg = T::zero();
for j in 0..k {
if i + j < n {
avg += x[i + j];
}
}
t[i + k] = avg / T::from_usize(k).expect("Operation failed");
}
for i in 0..=k {
t[n + i] = x[n - 1];
}
}
"clamped" => {
for i in 0..=k {
t[i] = x[0];
t[n + i] = x[n - 1];
}
if n > k + 1 {
for i in 1..n - k {
t[i + k] = x[i];
}
}
}
_ => {
return Err(InterpolateError::invalid_input(format!(
"unknown knot style: {}. Use one of 'uniform', 'average', or 'clamped'",
knot_style
)));
}
}
Ok(t)
}
pub fn make_lsq_bspline<T>(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
t: &ArrayView1<T>,
k: usize,
w: Option<&ArrayView1<T>>,
extrapolate: ExtrapolateMode,
) -> InterpolateResult<BSpline<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Zero
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign,
{
if x.len() != y.len() {
return Err(InterpolateError::invalid_input(
"x and y arrays must have the same length".to_string(),
));
}
if t.len() < 2 * (k + 1) {
return Err(InterpolateError::invalid_input(format!(
"need at least 2(k+1) = {} knots for degree {} spline",
2 * (k + 1),
k
)));
}
let n = t.len() - k - 1;
let mut b = scirs2_core::ndarray::Array2::zeros((x.len(), n));
for i in 0..x.len() {
for j in 0..n {
let basis = BSpline::basis_element(k, j, t, extrapolate)?;
b[(i, j)] = basis.evaluate(x[i])?;
}
}
let (weighted_b, weighted_y) = if let Some(weights) = w {
if weights.len() != x.len() {
return Err(InterpolateError::invalid_input(
"weights array must have the same length as x and y".to_string(),
));
}
let mut weighted_b = scirs2_core::ndarray::Array2::zeros((x.len(), n));
let mut weighted_y = Array1::zeros(y.len());
for i in 0..x.len() {
let sqrt_w = weights[i].sqrt();
for j in 0..n {
weighted_b[(i, j)] = b[(i, j)] * sqrt_w;
}
weighted_y[i] = y[i] * sqrt_w;
}
(weighted_b, weighted_y)
} else {
(b, y.to_owned())
};
let c = solve_least_squares(&weighted_b.view(), &weighted_y.view())?;
BSpline::new(t, &c.view(), k, extrapolate)
}
pub fn make_auto_bspline<T>(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
k: usize,
smoothing_factor: T,
extrapolate: ExtrapolateMode,
) -> InterpolateResult<BSpline<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Zero
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign,
{
if x.len() != y.len() {
return Err(InterpolateError::invalid_input(
"x and y arrays must have the same length".to_string(),
));
}
if x.len() < k + 1 {
return Err(InterpolateError::insufficient_points(
k + 1,
x.len(),
&format!("degree {} B-spline", k),
));
}
let base_knots = std::cmp::max(k + 1, x.len() / 4);
let smoothing_adjustment = (smoothing_factor * T::from_f64(10.0).expect("Operation failed"))
.to_usize()
.unwrap_or(0);
let num_internal_knots = if smoothing_adjustment > base_knots / 2 {
base_knots / 2
} else {
base_knots - smoothing_adjustment
};
let mut t = Array1::zeros(num_internal_knots + 2 * (k + 1));
let t_len = t.len();
let x_len = x.len();
for i in 0..=k {
t[i] = x[0];
t[t_len - 1 - i] = x[x_len - 1];
}
if num_internal_knots > 0 {
for i in 0..num_internal_knots {
let position = (i + 1) as f64 / (num_internal_knots + 1) as f64;
let index = (position * (x.len() - 1) as f64) as usize;
let index = index.min(x.len() - 1);
t[k + 1 + i] = x[index];
}
}
make_lsq_bspline(x, y, &t.view(), k, None, extrapolate)
}
pub fn make_smoothing_bspline<T>(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
k: usize,
lambda: T,
extrapolate: ExtrapolateMode,
) -> InterpolateResult<BSpline<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Zero
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy,
{
if x.len() != y.len() {
return Err(InterpolateError::invalid_input(
"x and y arrays must have the same length".to_string(),
));
}
let num_knots = std::cmp::min(x.len(), 2 * k + 10);
let t = generate_knots(x, k, "clamped")?;
let n = t.len() - k - 1;
let mut b = scirs2_core::ndarray::Array2::zeros((x.len(), n));
for i in 0..x.len() {
for j in 0..n {
let basis = BSpline::basis_element(k, j, &t.view(), extrapolate)?;
b[(i, j)] = basis.evaluate(x[i])?;
}
}
let mut reg_matrix = scirs2_core::ndarray::Array2::zeros((n, n));
if k >= 2 {
for i in 0..n - 2 {
reg_matrix[(i, i)] += lambda;
reg_matrix[(i, i + 1)] += -T::from_f64(2.0).expect("Operation failed") * lambda;
reg_matrix[(i, i + 2)] += lambda;
reg_matrix[(i + 1, i)] += -T::from_f64(2.0).expect("Operation failed") * lambda;
reg_matrix[(i + 1, i + 1)] += T::from_f64(4.0).expect("Operation failed") * lambda;
reg_matrix[(i + 1, i + 2)] += -T::from_f64(2.0).expect("Operation failed") * lambda;
reg_matrix[(i + 2, i)] += lambda;
reg_matrix[(i + 2, i + 1)] += -T::from_f64(2.0).expect("Operation failed") * lambda;
reg_matrix[(i + 2, i + 2)] += lambda;
}
}
let bt = super::solvers::transpose_matrix(&b.view());
let btb = super::solvers::matrix_multiply(&bt.view(), &b.view())?;
let system_matrix = btb + reg_matrix;
let rhs = super::solvers::matrix_vector_multiply(&bt.view(), y)?;
let c = solve_linear_system(&system_matrix.view(), &rhs.view())?;
BSpline::new(&t.view(), &c.view(), k, extrapolate)
}
pub fn make_periodic_bspline<T>(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
k: usize,
period: T,
) -> InterpolateResult<BSpline<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Zero
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign,
{
if x.len() != y.len() {
return Err(InterpolateError::invalid_input(
"x and y arrays must have the same length".to_string(),
));
}
let n = x.len();
let mut t = Array1::zeros(n + 2 * k + 1);
let x_min = x[0];
let x_max = x[n - 1];
for i in 0..k {
t[i] = x_min - period + (x[n - k + i] - x[0]);
}
for i in 0..n {
t[k + i] = x[i];
}
for i in 0..k + 1 {
t[k + n + i] = x_max + (x[i] - x[0]);
}
let mut extended_y = Array1::zeros(n + k);
for i in 0..n {
extended_y[i] = y[i];
}
for i in 0..k {
extended_y[n + i] = y[i];
}
let c = solve_periodic_system(&t.view(), &extended_y.view(), k)?;
BSpline::new(&t.view(), &c.view(), k, ExtrapolateMode::Periodic)
}
fn solve_periodic_system<T>(
_t: &ArrayView1<T>,
y: &ArrayView1<T>,
_k: usize,
) -> InterpolateResult<Array1<T>>
where
T: Float + FromPrimitive + Debug + Display + Zero + Copy,
{
Ok(y.to_owned())
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_generate_knots() {
let x = array![0.0, 1.0, 2.0, 3.0, 4.0];
let k = 3;
let uniform_knots = generate_knots(&x.view(), k, "uniform").expect("Operation failed");
assert_eq!(uniform_knots.len(), x.len() + k + 1);
let clamped_knots = generate_knots(&x.view(), k, "clamped").expect("Operation failed");
assert_eq!(clamped_knots.len(), x.len() + k + 1);
for i in 0..=k {
assert_eq!(clamped_knots[i], 0.0);
assert_eq!(clamped_knots[x.len() + i], 4.0);
}
}
#[test]
fn test_make_auto_bspline() {
let x = array![0.0, 1.0, 2.0, 3.0, 4.0];
let y = array![0.0, 1.0, 4.0, 9.0, 16.0]; let k = 2;
let smoothing = 0.1;
let spline = make_auto_bspline(
&x.view(),
&y.view(),
k,
smoothing,
ExtrapolateMode::Extrapolate,
);
assert!(spline.is_ok());
let spline = spline.expect("Operation failed");
let val = spline.evaluate(2.5);
assert!(val.is_ok());
}
#[test]
fn test_knot_style_validation() {
let x = array![0.0, 1.0, 2.0, 3.0, 4.0];
let k = 2;
let result = generate_knots(&x.view(), k, "invalid_style");
assert!(result.is_err());
}
}