use scirs2_core::ndarray::{Array, Array1, Axis, Dimension};
use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
use std::fmt::Debug;
use crate::error::{NdimageError, NdimageResult};
#[allow(dead_code)]
fn get_spline_poles<T: Float + FromPrimitive>(order: usize) -> Vec<T> {
match order {
0 | 1 => vec![], 2 => {
let sqrt8 = T::from_f64(8.0).expect("Operation failed").sqrt();
let three = T::from_f64(3.0).expect("Operation failed");
vec![sqrt8 - three]
}
3 => {
let sqrt3 = T::from_f64(3.0).expect("Operation failed").sqrt();
let two = T::from_f64(2.0).expect("Operation failed");
vec![sqrt3 - two]
}
4 => {
let val1 = T::from_f64(0.361341225285).expect("Operation failed"); let val2 = T::from_f64(0.013725429297).expect("Operation failed"); vec![val1, val2]
}
5 => {
let val1 = T::from_f64(0.430575347099).expect("Operation failed");
let val2 = T::from_f64(0.043096288203).expect("Operation failed");
vec![val1, val2]
}
_ => vec![], }
}
#[allow(dead_code)]
fn get_initial_causal_coefficient<T: Float + FromPrimitive>(
coeffs: &[T],
pole: T,
tolerance: T,
) -> T {
let mut sum = T::zero();
let mut z_power = T::one();
let _abs_pole = pole.abs();
for &coeff in coeffs {
sum = sum + coeff * z_power;
z_power = z_power * pole;
if z_power.abs() < tolerance {
break;
}
}
sum
}
#[allow(dead_code)]
fn get_initial_anti_causal_coefficient<T: Float + FromPrimitive>(coeffs: &[T], pole: T) -> T {
let n = coeffs.len();
if n < 2 {
return T::zero();
}
let last_idx = n - 1;
(pole / (pole * pole - T::one())) * (pole * coeffs[last_idx] + coeffs[last_idx - 1])
}
#[allow(dead_code)]
fn apply_causal_filter<T: Float + FromPrimitive>(coeffs: &mut [T], pole: T, initialcoeff: T) {
if coeffs.is_empty() {
return;
}
coeffs[0] = initialcoeff;
for i in 1..coeffs.len() {
coeffs[i] = coeffs[i] + pole * coeffs[i - 1];
}
}
#[allow(dead_code)]
fn apply_anti_causal_filter<T: Float + FromPrimitive>(coeffs: &mut [T], pole: T, initialcoeff: T) {
if coeffs.is_empty() {
return;
}
let last_idx = coeffs.len() - 1;
coeffs[last_idx] = initialcoeff;
for i in (0..last_idx).rev() {
coeffs[i] = pole * (coeffs[i + 1] - coeffs[i]);
}
}
#[allow(dead_code)]
pub fn spline_filter<T, D>(input: &Array<T, D>, order: Option<usize>) -> NdimageResult<Array<T, D>>
where
T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
D: Dimension + scirs2_core::ndarray::RemoveAxis + 'static,
usize: scirs2_core::ndarray::NdIndex<<D as scirs2_core::ndarray::Dimension>::Smaller>,
{
if input.ndim() == 0 {
return Err(NdimageError::InvalidInput(
"Input array cannot be 0-dimensional".into(),
));
}
let spline_order = order.unwrap_or(3);
if spline_order == 0 || spline_order > 5 {
return Err(NdimageError::InvalidInput(format!(
"Spline order must be between 1 and 5, got {}",
spline_order
)));
}
if spline_order <= 1 {
return Ok(input.to_owned());
}
let mut output = input.to_owned();
for axis in 0..input.ndim() {
spline_filter_axis(&mut output, spline_order, axis)?;
}
Ok(output)
}
#[allow(dead_code)]
pub fn spline_filter1d<T, D>(
input: &Array<T, D>,
order: Option<usize>,
axis: Option<usize>,
) -> NdimageResult<Array<T, D>>
where
T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
D: Dimension + scirs2_core::ndarray::RemoveAxis + 'static,
usize: scirs2_core::ndarray::NdIndex<<D as scirs2_core::ndarray::Dimension>::Smaller>,
{
if input.ndim() == 0 {
return Err(NdimageError::InvalidInput(
"Input array cannot be 0-dimensional".into(),
));
}
let spline_order = order.unwrap_or(3);
let axis_val = axis.unwrap_or(0);
if spline_order == 0 || spline_order > 5 {
return Err(NdimageError::InvalidInput(format!(
"Spline order must be between 1 and 5, got {}",
spline_order
)));
}
if axis_val >= input.ndim() {
return Err(NdimageError::InvalidInput(format!(
"Axis {} is out of bounds for array of dimension {}",
axis_val,
input.ndim()
)));
}
if spline_order <= 1 {
return Ok(input.to_owned());
}
let mut output = input.to_owned();
spline_filter_axis(&mut output, spline_order, axis_val)?;
Ok(output)
}
#[allow(dead_code)]
pub fn bspline<T>(
positions: &Array<T, scirs2_core::ndarray::Ix1>,
order: Option<usize>,
derivative: Option<usize>,
) -> NdimageResult<Array<T, scirs2_core::ndarray::Ix1>>
where
T: Float + FromPrimitive + Debug,
{
let spline_order = order.unwrap_or(3);
let deriv = derivative.unwrap_or(0);
if spline_order == 0 || spline_order > 5 {
return Err(NdimageError::InvalidInput(format!(
"Spline order must be between 1 and 5, got {}",
spline_order
)));
}
if deriv > spline_order {
return Err(NdimageError::InvalidInput(format!(
"Derivative order must be less than or equal to spline order (got {} for order {})",
deriv, spline_order
)));
}
let mut result = Array1::<T>::zeros(positions.len());
for (i, &pos) in positions.iter().enumerate() {
result[i] = evaluate_bspline_basis(pos, spline_order, deriv);
}
Ok(result)
}
#[allow(dead_code)]
fn spline_filter_axis<T, D>(data: &mut Array<T, D>, order: usize, axis: usize) -> NdimageResult<()>
where
T: Float + FromPrimitive + Clone,
D: Dimension + scirs2_core::ndarray::RemoveAxis,
usize: scirs2_core::ndarray::NdIndex<<D as scirs2_core::ndarray::Dimension>::Smaller>,
{
let poles = get_spline_poles::<T>(order);
if poles.is_empty() {
return Ok(());
}
let tolerance = T::from_f64(1e-10).expect("Operation failed");
let axis_len = data.shape()[axis];
for mut lane in data.axis_iter_mut(Axis(axis)) {
let mut coeffs: Vec<T> = lane.iter().cloned().collect();
for &pole in &poles {
let initial_causal = get_initial_causal_coefficient(&coeffs, pole, tolerance);
apply_causal_filter(&mut coeffs, pole, initial_causal);
let initial_anti_causal = get_initial_anti_causal_coefficient(&coeffs, pole);
apply_anti_causal_filter(&mut coeffs, pole, initial_anti_causal);
}
for (i, &coeff) in coeffs.iter().enumerate() {
lane[i] = coeff;
}
}
Ok(())
}
#[allow(dead_code)]
fn evaluate_bspline_basis<T: Float + FromPrimitive>(x: T, order: usize, derivative: usize) -> T {
if derivative > order {
return T::zero();
}
match order {
0 => {
if derivative == 0 {
if x >= T::zero() && x < T::one() {
T::one()
} else {
T::zero()
}
} else {
T::zero()
}
}
1 => {
if derivative == 0 {
let abs_x = x.abs();
if abs_x < T::one() {
T::one() - abs_x
} else {
T::zero()
}
} else if derivative == 1 {
if x > T::zero() && x < T::one() {
-T::one()
} else if x > -T::one() && x < T::zero() {
T::one()
} else {
T::zero()
}
} else {
T::zero()
}
}
2 => {
let abs_x = x.abs();
if derivative == 0 {
if abs_x < T::from_f64(0.5).expect("Operation failed") {
let _half = T::from_f64(0.5).expect("Operation failed");
let three_quarters = T::from_f64(0.75).expect("Operation failed");
three_quarters - x * x
} else if abs_x < T::from_f64(1.5).expect("Operation failed") {
let half = T::from_f64(0.5).expect("Operation failed");
let val = abs_x - T::from_f64(1.5).expect("Operation failed");
half * val * val
} else {
T::zero()
}
} else {
T::zero()
}
}
3 => {
let abs_x = x.abs();
if derivative == 0 {
if abs_x < T::one() {
let two_thirds = T::from_f64(2.0 / 3.0).expect("Operation failed");
let half = T::from_f64(0.5).expect("Operation failed");
two_thirds - abs_x * abs_x + half * abs_x * abs_x * abs_x
} else if abs_x < T::from_f64(2.0).expect("Operation failed") {
let one_sixth = T::from_f64(1.0 / 6.0).expect("Operation failed");
let val = T::from_f64(2.0).expect("Operation failed") - abs_x;
one_sixth * val * val * val
} else {
T::zero()
}
} else {
T::zero()
}
}
_ => T::zero(), }
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array1, Array2};
#[test]
fn test_spline_filter() {
let input: Array2<f64> = Array2::eye(3);
let result = spline_filter(&input, None).expect("Operation failed");
assert_eq!(result.shape(), input.shape());
}
#[test]
fn test_spline_filter1d() {
let input: Array2<f64> = Array2::eye(3);
let result = spline_filter1d(&input, None, None).expect("Operation failed");
assert_eq!(result.shape(), input.shape());
}
#[test]
fn test_bspline() {
let positions = Array1::linspace(0.0, 2.0, 5);
let result = bspline(&positions, None, None).expect("Operation failed");
assert_eq!(result.len(), positions.len());
}
}