use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
#[allow(dead_code)]
pub fn find_span<F: crate::traits::InterpolationFloat>(
x: F,
knots: &ArrayView1<F>,
k: usize,
) -> usize {
let n = knots.len() - k - 1;
if x >= knots[n] {
return n - 1;
}
if x <= knots[k] {
return k;
}
let mut low = k;
let mut high = n;
let mut mid = (low + high) / 2;
while x < knots[mid] || x >= knots[mid + 1] {
if x < knots[mid] {
high = mid;
} else {
low = mid;
}
mid = (low + high) / 2;
}
mid
}
#[allow(dead_code)]
pub fn basis_funs<F: crate::traits::InterpolationFloat>(
x: F,
span: usize,
knots: &ArrayView1<F>,
k: usize,
) -> Array1<F> {
let mut basis = Array1::zeros(k + 1);
let mut left = Array1::zeros(k + 1);
let mut right = Array1::zeros(k + 1);
basis[0] = F::one();
for j in 1..=k {
left[j] = x - knots[span + 1 - j];
right[j] = knots[span + j] - x;
let mut saved = F::zero();
for r in 0..j {
let temp = basis[r] / (right[r + 1] + left[j - r]);
basis[r] = saved + right[r + 1] * temp;
saved = left[j - r] * temp;
}
basis[j] = saved;
}
basis
}
#[allow(dead_code)]
pub fn basis_funs_derivatives<F: crate::traits::InterpolationFloat>(
x: F,
span: usize,
knots: &ArrayView1<F>,
k: usize,
n: usize,
) -> Array2<F> {
let n = n.min(k);
let mut derivs = Array2::zeros((n + 1, k + 1));
let mut ndu = Array2::zeros((k + 1, k + 1)); let mut a = Array2::zeros((2, k + 1)); let mut left = Array1::zeros(k + 1);
let mut right = Array1::zeros(k + 1);
ndu[[0, 0]] = F::one();
for j in 1..=k {
left[j] = x - knots[span + 1 - j];
right[j] = knots[span + j] - x;
let mut saved = F::zero();
for r in 0..j {
ndu[[j, r]] = right[r + 1] + left[j - r];
let temp = ndu[[r, j - 1]] / ndu[[j, r]];
ndu[[r, j]] = saved + right[r + 1] * temp;
saved = left[j - r] * temp;
}
ndu[[j, j]] = saved;
}
for j in 0..=k {
derivs[[0, j]] = ndu[[j, k]];
}
for r in 0..=k {
let mut s1 = 0;
let mut s2 = 1;
a[[0, 0]] = F::one();
for m in 1..=n {
let mut d = F::zero();
let rk = r as isize - m as isize;
let pk = k as isize - m as isize;
if r >= m {
a[[s2, 0]] = a[[s1, 0]] / ndu[[(pk + 1) as usize, rk as usize]];
d = a[[s2, 0]] * ndu[[rk as usize, pk as usize]];
}
let j1 = if rk >= -1 { 1 } else { (-rk) as usize };
let j2 = if r as isize - 1 <= pk { m - 1 } else { k - r };
for j in j1..=j2 {
a[[s2, j]] = (a[[s1, j]] - a[[s1, j - 1]])
/ ndu[[(pk + 1) as usize, (rk + j as isize) as usize]];
d += a[[s2, j]] * ndu[[(rk + j as isize) as usize, pk as usize]];
}
if r as isize <= pk {
a[[s2, m]] = -a[[s1, m - 1]] / ndu[[(pk + 1) as usize, r]];
d += a[[s2, m]] * ndu[[r, pk as usize]];
}
derivs[[m, r]] = d;
std::mem::swap(&mut s1, &mut s2);
}
}
let mut fac = F::from_usize(k).expect("Operation failed");
for j in 1..=n {
for i in 0..=k {
derivs[[j, i]] *= fac;
}
fac *= F::from_usize(k - j).expect("Operation failed");
}
derivs
}
#[allow(dead_code)]
pub fn evaluate_bispline<F: crate::traits::InterpolationFloat>(
x: F,
y: F,
knots_x: &ArrayView1<F>,
knots_y: &ArrayView1<F>,
coeffs: &ArrayView1<F>,
kx: usize,
ky: usize,
) -> F {
let span_x = find_span(x, knots_x, kx);
let span_y = find_span(y, knots_y, ky);
let basis_x = basis_funs(x, span_x, knots_x, kx);
let basis_y = basis_funs(y, span_y, knots_y, ky);
#[allow(unused_variables)]
let n_x = knots_x.len() - kx - 1;
let n_y = knots_y.len() - ky - 1;
let mut sum = F::zero();
for i in 0..=kx {
for j in 0..=ky {
let idx = (span_x - kx + i) * n_y + (span_y - ky + j);
if idx < coeffs.len() {
sum += basis_x[i] * basis_y[j] * coeffs[idx];
}
}
}
sum
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn evaluate_bispline_derivative<F: crate::traits::InterpolationFloat>(
x: F,
y: F,
knots_x: &ArrayView1<F>,
knots_y: &ArrayView1<F>,
coeffs: &ArrayView1<F>,
kx: usize,
ky: usize,
dx: usize,
dy: usize,
) -> F {
if dx == 0 && dy == 0 {
return evaluate_bispline(x, y, knots_x, knots_y, coeffs, kx, ky);
}
if dx > kx || dy > ky {
return F::zero(); }
let span_x = find_span(x, knots_x, kx);
let span_y = find_span(y, knots_y, ky);
let derivs_x = basis_funs_derivatives(x, span_x, knots_x, kx, dx);
let derivs_y = basis_funs_derivatives(y, span_y, knots_y, ky, dy);
#[allow(unused_variables)]
let n_x = knots_x.len() - kx - 1;
let n_y = knots_y.len() - ky - 1;
let mut sum = F::zero();
for i in 0..=kx {
if span_x - kx + i >= n_x {
continue;
}
for j in 0..=ky {
if span_y - ky + j >= n_y {
continue;
}
let idx = (span_x - kx + i) * n_y + (span_y - ky + j);
if idx < coeffs.len() {
sum += derivs_x[[dx, i]] * derivs_y[[dy, j]] * coeffs[idx];
}
}
}
sum
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn integrate_bispline<F: crate::traits::InterpolationFloat>(
xa: F,
xb: F,
ya: F,
yb: F,
knots_x: &ArrayView1<F>,
knots_y: &ArrayView1<F>,
coeffs: &ArrayView1<F>,
kx: usize,
ky: usize,
n_quad: Option<usize>,
) -> F {
let n = n_quad.unwrap_or(10);
let (points, weights) = gauss_legendre_quadrature(n);
let mut sum = F::zero();
let half_width_x = (xb - xa) / F::from_f64(2.0).expect("Operation failed");
let half_width_y = (yb - ya) / F::from_f64(2.0).expect("Operation failed");
let mid_x = (xa + xb) / F::from_f64(2.0).expect("Operation failed");
let mid_y = (ya + yb) / F::from_f64(2.0).expect("Operation failed");
for i in 0..n {
let _x = mid_x + half_width_x * points[i];
for j in 0..n {
let _y = mid_y + half_width_y * points[j];
let value = evaluate_bispline(_x, _y, knots_x, knots_y, coeffs, kx, ky);
sum += value * weights[i] * weights[j];
}
}
sum * half_width_x * half_width_y * F::from_f64(4.0).expect("Operation failed")
}
#[allow(dead_code)]
fn gauss_legendre_quadrature<F: Float + FromPrimitive + Debug>(n: usize) -> (Vec<F>, Vec<F>) {
let mut points = Vec::with_capacity(n);
let mut weights = Vec::with_capacity(n);
match n {
1 => {
points.push(F::zero());
weights.push(F::from_f64(2.0).expect("Operation failed"));
}
2 => {
let p = F::from_f64(1.0 / 3.0_f64.sqrt()).expect("Operation failed");
points.push(-p);
points.push(p);
weights.push(F::one());
weights.push(F::one());
}
3 => {
let p = F::from_f64((3.0 / 5.0_f64).sqrt()).expect("Operation failed");
points.push(-p);
points.push(F::zero());
points.push(p);
weights.push(F::from_f64(5.0 / 9.0).expect("Operation failed"));
weights.push(F::from_f64(8.0 / 9.0).expect("Operation failed"));
weights.push(F::from_f64(5.0 / 9.0).expect("Operation failed"));
}
4 => {
let p1 = F::from_f64((3.0 - 2.0 * 6.0_f64.sqrt()) / 7.0)
.expect("Operation failed")
.sqrt();
let p2 = F::from_f64((3.0 + 2.0 * 6.0_f64.sqrt()) / 7.0)
.expect("Operation failed")
.sqrt();
points.push(-p2);
points.push(-p1);
points.push(p1);
points.push(p2);
weights.push(F::from_f64((18.0 - 6.0_f64.sqrt()) / 36.0).expect("Operation failed"));
weights.push(F::from_f64((18.0 + 6.0_f64.sqrt()) / 36.0).expect("Operation failed"));
weights.push(F::from_f64((18.0 + 6.0_f64.sqrt()) / 36.0).expect("Operation failed"));
weights.push(F::from_f64((18.0 - 6.0_f64.sqrt()) / 36.0).expect("Operation failed"));
}
_ => {
let dx = F::from_f64(2.0 / (n as f64)).expect("Operation failed");
let mut x = F::from_f64(-1.0).expect("Operation failed")
+ dx / F::from_f64(2.0).expect("Operation failed");
for _ in 0..n {
points.push(x);
weights.push(dx);
x = x + dx;
}
}
}
(points, weights)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_find_span() {
let knots = array![0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 4.0, 4.0, 4.0];
let k = 3;
assert_eq!(find_span(0.0, &knots.view(), k), 3);
assert_eq!(find_span(0.5, &knots.view(), k), 3);
assert_eq!(find_span(1.0, &knots.view(), k), 4);
assert_eq!(find_span(2.0, &knots.view(), k), 5);
assert_eq!(find_span(3.0, &knots.view(), k), 6);
assert_eq!(find_span(4.0, &knots.view(), k), 6);
}
#[test]
fn test_basis_funs() {
let knots = array![0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 4.0, 4.0, 4.0];
let k = 3;
let span = find_span(0.5, &knots.view(), k);
let basis = basis_funs(0.5, span, &knots.view(), k);
let sum: f64 = basis.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
}
#[test]
fn test_evaluate_bispline() {
let knots_x = array![0.0, 0.0, 1.0, 1.0];
let knots_y = array![0.0, 0.0, 1.0, 1.0];
let coeffs = array![0.0, 0.0, 0.0, 1.0];
let val_00 = evaluate_bispline(
0.0,
0.0,
&knots_x.view(),
&knots_y.view(),
&coeffs.view(),
1,
1,
);
let val_01 = evaluate_bispline(
0.0,
1.0,
&knots_x.view(),
&knots_y.view(),
&coeffs.view(),
1,
1,
);
let val_10 = evaluate_bispline(
1.0,
0.0,
&knots_x.view(),
&knots_y.view(),
&coeffs.view(),
1,
1,
);
let val_11 = evaluate_bispline(
1.0,
1.0,
&knots_x.view(),
&knots_y.view(),
&coeffs.view(),
1,
1,
);
assert_relative_eq!(val_00, 0.0, epsilon = 1e-10);
assert_relative_eq!(val_01, 0.0, epsilon = 1e-10);
assert_relative_eq!(val_10, 0.0, epsilon = 1e-10);
assert_relative_eq!(val_11, 1.0, epsilon = 1e-10);
let val_mid = evaluate_bispline(
0.5,
0.5,
&knots_x.view(),
&knots_y.view(),
&coeffs.view(),
1,
1,
);
assert_relative_eq!(val_mid, 0.25, epsilon = 1e-10);
}
#[test]
fn test_integrate_bispline() {
let knots_x = array![0.0, 0.0, 1.0, 1.0];
let knots_y = array![0.0, 0.0, 1.0, 1.0];
let coeffs = array![1.0, 1.0, 1.0, 1.0];
let integral = integrate_bispline(
0.0,
1.0,
0.0,
1.0,
&knots_x.view(),
&knots_y.view(),
&coeffs.view(),
1,
1,
Some(5),
);
assert!(
integral > 0.0 && integral < 10.0,
"Integral should be positive and reasonable: {}",
integral
);
assert!(integral.is_finite());
let integral_half = integrate_bispline(
0.0,
0.5,
0.0,
0.5,
&knots_x.view(),
&knots_y.view(),
&coeffs.view(),
1,
1,
Some(5),
);
assert!(
integral_half > 0.0 && integral_half < integral,
"Quarter domain integral should be positive and less than full: {} vs {}",
integral_half,
integral
);
assert!(integral_half.is_finite()); }
}