use crate::bspline::{BSpline, BSplineWorkspace};
#[cfg(test)]
use crate::bspline::ExtrapolateMode;
use crate::error::InterpolateResult;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, FromPrimitive, Zero};
use std::fmt::{Debug, Display};
#[cfg(feature = "simd")]
use scirs2_core::simd_ops::SimdUnifiedOps;
pub struct SimdBSplineEvaluator<T>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Zero
+ Copy
+ std::ops::AddAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::SubAssign
+ std::ops::RemAssign
+ 'static,
{
spline: BSpline<T>,
workspace: BSplineWorkspace<T>,
}
impl<T> SimdBSplineEvaluator<T>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Zero
+ Copy
+ std::ops::AddAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::SubAssign
+ std::ops::RemAssign
+ 'static,
{
pub fn new(spline: BSpline<T>) -> Self {
let workspace = BSplineWorkspace::new();
Self { spline, workspace }
}
pub fn eval_batch(&mut self, points: &[T]) -> InterpolateResult<Vec<T>> {
points
.iter()
.map(|&x| self.spline.evaluate_with_workspace(x, &mut self.workspace))
.collect()
}
pub fn eval_deriv_batch(&mut self, points: &[T], nu: usize) -> InterpolateResult<Vec<Vec<T>>> {
points
.iter()
.map(|&x| {
let mut derivs = Vec::with_capacity(nu + 1);
for i in 0..=nu {
derivs.push(self.spline.derivative(x, i)?);
}
Ok(derivs)
})
.collect()
}
pub fn spline(&self) -> &BSpline<T> {
&self.spline
}
pub fn spline_mut(&mut self) -> &mut BSpline<T> {
&mut self.spline
}
}
pub struct SimdCubicBSpline<T>
where
T: Float + FromPrimitive + Debug + Display + Zero + Copy + 'static,
{
knots: Array1<T>,
coefficients: Array1<T>,
}
impl<T> SimdCubicBSpline<T>
where
T: Float + FromPrimitive + Debug + Display + Zero + Copy + 'static,
{
pub fn new(knots: Array1<T>, coefficients: Array1<T>) -> InterpolateResult<Self> {
if knots.len() != coefficients.len() + 4 {
return Err(crate::error::InterpolateError::InvalidInput {
message: "For cubic B-spline, knots.len() must equal coefficients.len() + 4"
.to_string(),
});
}
Ok(Self {
knots,
coefficients,
})
}
pub fn eval(&self, x: T) -> InterpolateResult<T> {
let n = self.coefficients.len();
let degree = 3;
let m = self.knots.len() - 1;
let mut k;
if x <= self.knots[degree] {
k = degree;
} else if x >= self.knots[m - degree] {
k = m - degree - 1;
} else {
let mut low = degree;
let mut high = m - degree;
k = (low + high) / 2;
while x < self.knots[k] || x >= self.knots[k + 1] {
if x < self.knots[k] {
high = k;
} else {
low = k;
}
k = (low + high) / 2;
}
}
k = k.max(degree).min(n - 1);
let mut basis = vec![T::zero(); degree + 1];
basis[0] = T::one();
for p in 1..=degree {
let mut saved = T::zero();
for r in 0..p {
let left = self.knots[k + 1 - r] - self.knots[k + 1 - p];
let right = self.knots[k + 1 + p - r] - self.knots[k + 1 - r];
if right != T::zero() {
let temp = basis[r] / right;
basis[r] = saved + (self.knots[k + 1 + p - r] - x) * temp;
saved = (x - self.knots[k + 1 - r]) * temp;
} else {
basis[r] = saved;
saved = T::zero();
}
}
basis[p] = saved;
}
let mut result = T::zero();
for i in 0..=degree {
let idx = k - degree + i;
if idx < n {
result = result + self.coefficients[idx] * basis[i];
}
}
Ok(result)
}
pub fn eval_batch(&self, points: &[T]) -> InterpolateResult<Vec<T>> {
points.iter().map(|&x| self.eval(x)).collect()
}
}
#[derive(Debug, Clone)]
pub struct BatchEvalResult<T> {
pub values: Vec<T>,
pub derivatives: Option<Vec<Vec<T>>>,
}
pub struct SimdBSplineOps;
impl SimdBSplineOps {
#[cfg(feature = "simd")]
pub fn squared_distances<T>(points: &ArrayView1<T>, centers: &ArrayView1<T>) -> Array1<T>
where
T: Float + SimdUnifiedOps,
{
if T::simd_available() {
let diff = T::simd_sub(points, centers);
T::simd_mul(&diff.view(), &diff.view())
} else {
let mut result = Array1::zeros(points.len());
for i in 0..points.len() {
let diff = points[i] - centers[i];
result[i] = diff * diff;
}
result
}
}
#[cfg(feature = "simd")]
pub fn weighted_sum<T>(values: &ArrayView1<T>, weights: &ArrayView1<T>) -> T
where
T: Float + SimdUnifiedOps,
{
values
.iter()
.zip(weights.iter())
.map(|(&v, &w)| v * w)
.fold(T::zero(), |acc, x| acc + x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_simd_cubic_bspline_eval() {
let knots = array![0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 1.0, 1.0];
let coefficients = array![1.0, 2.0, 3.0, 2.0, 1.0];
let spline = SimdCubicBSpline::new(knots, coefficients).expect("Operation failed");
let result = spline.eval(0.25).expect("Operation failed");
assert!(result.is_finite());
let result = spline.eval(0.75).expect("Operation failed");
assert!(result.is_finite());
}
#[test]
fn test_simd_bspline_batch_eval() {
let knots = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
let coefficients = array![1.0, 2.0, 3.0, 4.0];
let spline = BSpline::new(
&knots.view(),
&coefficients.view(),
3,
ExtrapolateMode::Extrapolate,
)
.expect("Operation failed");
let mut evaluator = SimdBSplineEvaluator::new(spline);
let points = vec![0.0, 0.25, 0.5, 0.75, 1.0];
let results = evaluator.eval_batch(&points).expect("Operation failed");
assert_eq!(results.len(), points.len());
assert_relative_eq!(results[0], 1.0, epsilon = 1e-10);
assert_relative_eq!(results[4], 4.0, epsilon = 1e-10);
}
#[cfg(feature = "simd")]
#[test]
fn test_simd_ops_squared_distances() {
let points = array![1.0, 2.0, 3.0, 4.0];
let centers = array![0.5, 1.5, 2.5, 3.5];
let distances = SimdBSplineOps::squared_distances(&points.view(), ¢ers.view());
assert_eq!(distances.len(), 4);
for i in 0..4 {
assert_relative_eq!(distances[i], 0.25, epsilon = 1e-10);
}
}
#[cfg(feature = "simd")]
#[test]
fn test_simd_ops_weighted_sum() {
let values = array![1.0, 2.0, 3.0, 4.0];
let weights = array![0.1, 0.2, 0.3, 0.4];
let result = SimdBSplineOps::weighted_sum(&values.view(), &weights.view());
assert_relative_eq!(result, 3.0, epsilon = 1e-10);
}
}