use crate::error::{FFTError, FFTResult};
use scirs2_core::ndarray::{Array, Array2, ArrayView, ArrayView2, Axis, IxDyn};
use scirs2_core::numeric::NumCast;
use std::f64::consts::PI;
use std::fmt::Debug;
use scirs2_core::simd_ops::{
simd_add_f32_ultra_vec, simd_cos_f32_ultra_vec, simd_div_f32_ultra_vec, simd_exp_f32_ultra_vec,
simd_fma_f32_ultra_vec, simd_mul_f32_ultra_vec, simd_pow_f32_ultra_vec, simd_sin_f32_ultra_vec,
simd_sub_f32_ultra_vec, PlatformCapabilities, SimdUnifiedOps,
};
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum DSTType {
Type1,
Type2,
Type3,
Type4,
}
#[allow(dead_code)]
pub fn dst<T>(x: &[T], dsttype: Option<DSTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
let input: Vec<f64> = x
.iter()
.map(|&val| {
NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
})
.collect::<FFTResult<Vec<_>>>()?;
let _n = input.len();
let type_val = dsttype.unwrap_or(DSTType::Type2);
match type_val {
DSTType::Type1 => dst1(&input, norm),
DSTType::Type2 => dst2_impl(&input, norm),
DSTType::Type3 => dst3(&input, norm),
DSTType::Type4 => dst4(&input, norm),
}
}
#[allow(dead_code)]
pub fn idst<T>(x: &[T], dsttype: Option<DSTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
let input: Vec<f64> = x
.iter()
.map(|&val| {
NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
})
.collect::<FFTResult<Vec<_>>>()?;
let _n = input.len();
let type_val = dsttype.unwrap_or(DSTType::Type2);
match type_val {
DSTType::Type1 => idst1(&input, norm),
DSTType::Type2 => idst2_impl(&input, norm),
DSTType::Type3 => idst3(&input, norm),
DSTType::Type4 => idst4(&input, norm),
}
}
#[allow(dead_code)]
pub fn dst2<T>(
x: &ArrayView2<T>,
dst_type: Option<DSTType>,
norm: Option<&str>,
) -> FFTResult<Array2<f64>>
where
T: NumCast + Copy + Debug,
{
let (n_rows, n_cols) = x.dim();
let type_val = dst_type.unwrap_or(DSTType::Type2);
let mut result = Array2::zeros((n_rows, n_cols));
for r in 0..n_rows {
let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
let row_vec: Vec<T> = row_slice.iter().cloned().collect();
let row_dst = dst(&row_vec, Some(type_val), norm)?;
for (c, val) in row_dst.iter().enumerate() {
result[[r, c]] = *val;
}
}
let mut final_result = Array2::zeros((n_rows, n_cols));
for c in 0..n_cols {
let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
let col_dst = dst(&col_vec, Some(type_val), norm)?;
for (r, val) in col_dst.iter().enumerate() {
final_result[[r, c]] = *val;
}
}
Ok(final_result)
}
#[allow(dead_code)]
pub fn idst2<T>(
x: &ArrayView2<T>,
dst_type: Option<DSTType>,
norm: Option<&str>,
) -> FFTResult<Array2<f64>>
where
T: NumCast + Copy + Debug,
{
let (n_rows, n_cols) = x.dim();
let type_val = dst_type.unwrap_or(DSTType::Type2);
let mut result = Array2::zeros((n_rows, n_cols));
for r in 0..n_rows {
let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
let row_vec: Vec<T> = row_slice.iter().cloned().collect();
let row_idst = idst(&row_vec, Some(type_val), norm)?;
for (c, val) in row_idst.iter().enumerate() {
result[[r, c]] = *val;
}
}
let mut final_result = Array2::zeros((n_rows, n_cols));
for c in 0..n_cols {
let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
let col_idst = idst(&col_vec, Some(type_val), norm)?;
for (r, val) in col_idst.iter().enumerate() {
final_result[[r, c]] = *val;
}
}
Ok(final_result)
}
#[allow(dead_code)]
pub fn dstn<T>(
x: &ArrayView<T, IxDyn>,
dst_type: Option<DSTType>,
norm: Option<&str>,
axes: Option<Vec<usize>>,
) -> FFTResult<Array<f64, IxDyn>>
where
T: NumCast + Copy + Debug,
{
let xshape = x.shape().to_vec();
let n_dims = xshape.len();
let axes_to_transform = match axes {
Some(ax) => ax,
None => (0..n_dims).collect(),
};
let mut conversion_error: Option<FFTError> = None;
let result_init = Array::from_shape_fn(IxDyn(&xshape), |idx| {
let val = x[idx];
match NumCast::from(val) {
Some(v) => v,
None => {
if conversion_error.is_none() {
conversion_error = Some(FFTError::ValueError(
"Could not convert input value to f64".to_string(),
));
}
0.0
}
}
});
if let Some(err) = conversion_error {
return Err(err);
}
let mut result = result_init;
let type_val = dst_type.unwrap_or(DSTType::Type2);
for &axis in &axes_to_transform {
let mut temp = result.clone();
for mut slice in temp.lanes_mut(Axis(axis)).into_iter() {
let slice_data: Vec<f64> = slice.iter().cloned().collect();
let transformed = dst(&slice_data, Some(type_val), norm)?;
for (j, val) in transformed.into_iter().enumerate() {
if j < slice.len() {
slice[j] = val;
}
}
}
result = temp;
}
Ok(result)
}
#[allow(dead_code)]
pub fn idstn<T>(
x: &ArrayView<T, IxDyn>,
dst_type: Option<DSTType>,
norm: Option<&str>,
axes: Option<Vec<usize>>,
) -> FFTResult<Array<f64, IxDyn>>
where
T: NumCast + Copy + Debug,
{
let xshape = x.shape().to_vec();
let n_dims = xshape.len();
let axes_to_transform = match axes {
Some(ax) => ax,
None => (0..n_dims).collect(),
};
let mut conversion_error: Option<FFTError> = None;
let result_init = Array::from_shape_fn(IxDyn(&xshape), |idx| {
let val = x[idx];
match NumCast::from(val) {
Some(v) => v,
None => {
if conversion_error.is_none() {
conversion_error = Some(FFTError::ValueError(
"Could not convert input value to f64".to_string(),
));
}
0.0
}
}
});
if let Some(err) = conversion_error {
return Err(err);
}
let mut result = result_init;
let type_val = dst_type.unwrap_or(DSTType::Type2);
for &axis in &axes_to_transform {
let mut temp = result.clone();
for mut slice in temp.lanes_mut(Axis(axis)).into_iter() {
let slice_data: Vec<f64> = slice.iter().cloned().collect();
let transformed = idst(&slice_data, Some(type_val), norm)?;
for (j, val) in transformed.into_iter().enumerate() {
if j < slice.len() {
slice[j] = val;
}
}
}
result = temp;
}
Ok(result)
}
#[allow(dead_code)]
fn dst1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n < 2 {
return Err(FFTError::ValueError(
"Input array must have at least 2 elements for DST-I".to_string(),
));
}
let mut result = Vec::with_capacity(n);
for k in 0..n {
let mut sum = 0.0;
let k_f = (k + 1) as f64;
for (m, val) in x.iter().enumerate() {
let m_f = (m + 1) as f64;
let angle = PI * k_f * m_f / (n as f64 + 1.0);
sum += val * angle.sin();
}
result.push(sum);
}
if norm == Some("ortho") {
let norm_factor = (2.0 / (n as f64 + 1.0)).sqrt();
for val in result.iter_mut() {
*val *= norm_factor;
}
}
Ok(result)
}
#[allow(dead_code)]
fn idst1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n < 2 {
return Err(FFTError::ValueError(
"Input array must have at least 2 elements for IDST-I".to_string(),
));
}
if norm == Some("ortho") {
return dst1(x, Some("ortho"));
}
let forward = dst1(x, None)?;
let scale = 2.0 / (n as f64 + 1.0);
let result: Vec<f64> = forward.iter().map(|&v| v * scale).collect();
Ok(result)
}
#[allow(dead_code)]
fn dst2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut result = Vec::with_capacity(n);
for k in 0..n {
let mut sum = 0.0;
let k_f = (k + 1) as f64;
for (m, val) in x.iter().enumerate().take(n) {
let m_f = m as f64;
let angle = PI * k_f * (m_f + 0.5) / n as f64;
sum += val * angle.sin();
}
result.push(sum);
}
if norm == Some("ortho") {
let norm_factor = (2.0 / n as f64).sqrt();
let last_factor = 1.0 / 2.0_f64.sqrt();
for val in result.iter_mut().take(n - 1) {
*val *= norm_factor;
}
result[n - 1] *= norm_factor * last_factor;
}
Ok(result)
}
#[allow(dead_code)]
fn idst2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
if n == 1 {
if norm == Some("ortho") {
return Ok(vec![x[0] / (2.0_f64).sqrt()]);
}
return Ok(vec![x[0]]);
}
let mut input = x.to_vec();
if norm == Some("ortho") {
let inv_norm = (n as f64 / 2.0).sqrt();
let last_inv = 2.0_f64.sqrt();
for val in input.iter_mut().take(n - 1) {
*val *= inv_norm;
}
input[n - 1] *= inv_norm * last_inv;
}
let mut result = Vec::with_capacity(n);
let n_f = n as f64;
for m in 0..n {
let m_f = m as f64;
let sign = if m % 2 == 0 { 1.0 } else { -1.0 };
let mut sum = sign * input[n - 1] / 2.0;
for k in 0..(n - 1) {
let k_f = k as f64;
let angle = PI * (k_f + 1.0) * (m_f + 0.5) / n_f;
sum += input[k] * angle.sin();
}
result.push(sum * 2.0 / n_f);
}
Ok(result)
}
#[allow(dead_code)]
fn dst3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut result = Vec::with_capacity(n);
for k in 0..n {
let mut sum = 0.0;
let k_f = k as f64;
sum += x[n - 1] * (if k % 2 == 0 { 1.0 } else { -1.0 });
for (m, val) in x.iter().enumerate().take(n - 1) {
let m_f = (m + 1) as f64;
let angle = PI * m_f * (k_f + 0.5) / n as f64;
sum += 2.0 * val * angle.sin();
}
result.push(sum);
}
if norm == Some("ortho") {
let norm_factor = (2.0 / n as f64).sqrt();
for val in result.iter_mut() {
*val *= norm_factor;
}
}
Ok(result)
}
#[allow(dead_code)]
fn idst3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut input = x.to_vec();
if norm == Some("ortho") {
let inv_norm = (n as f64 / 2.0).sqrt();
for val in input.iter_mut() {
*val *= inv_norm;
}
}
let dst2_result = dst2_impl(&input, None)?;
let scale = 1.0 / n as f64;
let result: Vec<f64> = dst2_result.iter().map(|&v| v * scale).collect();
Ok(result)
}
#[allow(dead_code)]
fn dst4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut result = Vec::with_capacity(n);
for k in 0..n {
let mut sum = 0.0;
let k_f = k as f64;
for (m, val) in x.iter().enumerate().take(n) {
let m_f = m as f64;
let angle = PI * (m_f + 0.5) * (k_f + 0.5) / n as f64;
sum += val * angle.sin();
}
result.push(sum);
}
if let Some("ortho") = norm {
let norm_factor = (2.0 / n as f64).sqrt();
for val in result.iter_mut().take(n) {
*val *= norm_factor;
}
} else {
for val in result.iter_mut().take(n) {
*val *= 2.0;
}
}
Ok(result)
}
#[allow(dead_code)]
fn idst4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
if norm == Some("ortho") {
return dst4(x, Some("ortho"));
}
let forward = dst4(x, None)?;
let scale = 1.0 / (2.0 * n as f64);
let result: Vec<f64> = forward.iter().map(|&v| v * scale).collect();
Ok(result)
}
#[allow(dead_code)]
pub fn dst2_fft(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
use scirs2_core::numeric::Complex64;
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
if n == 1 {
let result = if norm == Some("ortho") {
vec![x[0] * (2.0_f64).sqrt()]
} else {
vec![x[0]]
};
return Ok(result);
}
let mut extended = vec![0.0; 4 * n];
for i in 0..n {
extended[2 * i + 1] = x[i];
extended[4 * n - 2 * i - 1] = -x[i];
}
let ext_complex: Vec<Complex64> = extended.iter().map(|&v| Complex64::new(v, 0.0)).collect();
let ext_fft = crate::fft::fft(&ext_complex, None)?;
let mut result = Vec::with_capacity(n);
for k in 0..n {
let val = -ext_fft[k + 1].im / 2.0;
result.push(val);
}
if norm == Some("ortho") {
let norm_factor = (2.0 / n as f64).sqrt();
let last_factor = 1.0 / 2.0_f64.sqrt();
for val in result.iter_mut().take(n - 1) {
*val *= norm_factor;
}
result[n - 1] *= norm_factor * last_factor;
}
Ok(result)
}
#[allow(dead_code)]
pub fn idst2_fft(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
if n == 1 {
let result = if norm == Some("ortho") {
vec![x[0] / (2.0_f64).sqrt()]
} else {
vec![x[0] * 2.0]
};
return Ok(result);
}
let mut input = x.to_vec();
if norm == Some("ortho") {
let norm_factor = (n as f64 / 2.0).sqrt();
let last_factor = 2.0_f64.sqrt();
for val in input.iter_mut().take(n - 1) {
*val *= norm_factor;
}
input[n - 1] *= norm_factor * last_factor;
}
let mut result = Vec::with_capacity(n);
for i in 0..n {
let i_f = i as f64;
let mut sum = input[n - 1] * if i % 2 == 0 { 0.5 } else { -0.5 };
for (k, &val) in input.iter().enumerate().take(n - 1) {
let k_f = (k + 1) as f64;
let angle = PI * k_f * (i_f + 0.5) / n as f64;
sum += val * angle.sin();
}
sum *= 2.0 / n as f64;
result.push(sum);
}
Ok(result)
}
#[allow(dead_code)]
pub fn dst_bandwidth_saturated_simd<T>(
x: &[T],
dsttype: Option<DSTType>,
norm: Option<&str>,
) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
let input: Vec<f64> = x
.iter()
.map(|&val| {
NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
})
.collect::<FFTResult<Vec<_>>>()?;
let n = input.len();
let type_val = dsttype.unwrap_or(DSTType::Type2);
let caps = PlatformCapabilities::detect();
if n >= 128 && (caps.has_avx2() || caps.has_avx512()) {
match type_val {
DSTType::Type1 => dst1_bandwidth_saturated_simd(&input, norm),
DSTType::Type2 => dst2_bandwidth_saturated_simd_1d(&input, norm),
DSTType::Type3 => dst3_bandwidth_saturated_simd(&input, norm),
DSTType::Type4 => dst4_bandwidth_saturated_simd(&input, norm),
}
} else {
match type_val {
DSTType::Type1 => dst1(&input, norm),
DSTType::Type2 => dst2_impl(&input, norm),
DSTType::Type3 => dst3(&input, norm),
DSTType::Type4 => dst4(&input, norm),
}
}
}
#[allow(dead_code)]
fn dst1_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
use scirs2_core::simd_ops::SimdUnifiedOps;
let n = x.len();
if n < 2 {
return Err(FFTError::ValueError(
"Input array must have at least 2 elements for DST-I".to_string(),
));
}
let mut result = vec![0.0; n];
let chunk_size = 8;
let pi_f32 = PI as f32;
let n_plus_1 = (n + 1) as f32;
for k_chunk in (0..n).step_by(chunk_size) {
let k_chunk_end = (k_chunk + chunk_size).min(n);
let k_chunk_len = k_chunk_end - k_chunk;
let mut k_indices = vec![0.0f32; k_chunk_len];
for (i, k_idx) in k_indices.iter_mut().enumerate() {
*k_idx = (k_chunk + i + 1) as f32; }
for m_chunk in (0..n).step_by(chunk_size) {
let m_chunk_end = (m_chunk + chunk_size).min(n);
let m_chunk_len = m_chunk_end - m_chunk;
if m_chunk_len == k_chunk_len {
let mut m_indices = vec![0.0f32; m_chunk_len];
for (i, m_idx) in m_indices.iter_mut().enumerate() {
*m_idx = (m_chunk + i + 1) as f32; }
let mut x_values = vec![0.0f32; m_chunk_len];
for (i, x_val) in x_values.iter_mut().enumerate() {
*x_val = x[m_chunk + i] as f32;
}
let mut angles = vec![0.0f32; k_chunk_len];
let mut temp_prod = vec![0.0f32; k_chunk_len];
let pi_vec = vec![pi_f32; k_chunk_len];
let n_plus_1_vec = vec![n_plus_1; k_chunk_len];
simd_mul_f32_ultra_vec(&k_indices, &m_indices, &mut temp_prod);
let mut temp_prod2 = vec![0.0f32; k_chunk_len];
simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
simd_div_f32_ultra_vec(&temp_prod2, &n_plus_1_vec, &mut angles);
let mut sin_values = vec![0.0f32; k_chunk_len];
simd_sin_f32_ultra_vec(&angles, &mut sin_values);
let mut products = vec![0.0f32; k_chunk_len];
simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
for (i, &prod) in products.iter().enumerate() {
result[k_chunk + i] += prod as f64;
}
} else {
for (i, k_idx) in (k_chunk..k_chunk_end).enumerate() {
for m_idx in m_chunk..m_chunk_end {
let k_f = (k_idx + 1) as f64;
let m_f = (m_idx + 1) as f64;
let angle = PI * k_f * m_f / (n as f64 + 1.0);
result[k_idx] += x[m_idx] * angle.sin();
}
}
}
}
}
if let Some("ortho") = norm {
let norm_factor = (2.0 / (n as f64 + 1.0)).sqrt() as f32;
let norm_vec = vec![norm_factor; chunk_size];
for chunk_start in (0..n).step_by(chunk_size) {
let chunk_end = (chunk_start + chunk_size).min(n);
let chunk_len = chunk_end - chunk_start;
if chunk_len == chunk_size {
let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
.iter()
.map(|&x| x as f32)
.collect();
let mut normalized = vec![0.0f32; chunk_size];
simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
for (i, &val) in normalized.iter().enumerate() {
result[chunk_start + i] = val as f64;
}
} else {
for i in chunk_start..chunk_end {
result[i] *= norm_factor as f64;
}
}
}
}
Ok(result)
}
#[allow(dead_code)]
fn dst2_bandwidth_saturated_simd_1d(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
use scirs2_core::simd_ops::SimdUnifiedOps;
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut result = vec![0.0; n];
let chunk_size = 8;
let pi_f32 = PI as f32;
let n_f32 = n as f32;
for k_chunk in (0..n).step_by(chunk_size) {
let k_chunk_end = (k_chunk + chunk_size).min(n);
let k_chunk_len = k_chunk_end - k_chunk;
let mut k_indices = vec![0.0f32; k_chunk_len];
for (i, k_idx) in k_indices.iter_mut().enumerate() {
*k_idx = (k_chunk + i + 1) as f32;
}
let mut chunk_sum = vec![0.0f32; k_chunk_len];
for m_chunk in (0..n).step_by(chunk_size) {
let m_chunk_end = (m_chunk + chunk_size).min(n);
let m_chunk_len = m_chunk_end - m_chunk;
if m_chunk_len == k_chunk_len {
let mut m_indices = vec![0.0f32; m_chunk_len];
for (i, m_idx) in m_indices.iter_mut().enumerate() {
*m_idx = (m_chunk + i) as f32;
}
let mut x_values = vec![0.0f32; m_chunk_len];
for (i, x_val) in x_values.iter_mut().enumerate() {
*x_val = x[m_chunk + i] as f32;
}
let mut m_plus_half = vec![0.0f32; m_chunk_len];
let half_vec = vec![0.5f32; m_chunk_len];
simd_add_f32_ultra_vec(&m_indices, &half_vec, &mut m_plus_half);
let mut angles = vec![0.0f32; k_chunk_len];
let mut temp_prod = vec![0.0f32; k_chunk_len];
let pi_vec = vec![pi_f32; k_chunk_len];
let n_vec = vec![n_f32; k_chunk_len];
simd_mul_f32_ultra_vec(&k_indices, &m_plus_half, &mut temp_prod);
let mut temp_prod2 = vec![0.0f32; k_chunk_len];
simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
let mut sin_values = vec![0.0f32; k_chunk_len];
simd_sin_f32_ultra_vec(&angles, &mut sin_values);
let mut products = vec![0.0f32; k_chunk_len];
simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
let mut temp_sum = vec![0.0f32; k_chunk_len];
simd_add_f32_ultra_vec(&chunk_sum, &products, &mut temp_sum);
chunk_sum = temp_sum;
}
}
for (i, &sum) in chunk_sum.iter().enumerate() {
result[k_chunk + i] = sum as f64;
}
}
if let Some("ortho") = norm {
let norm_factor = (2.0 / n as f64).sqrt() as f32;
let norm_vec = vec![norm_factor; chunk_size];
for chunk_start in (0..n).step_by(chunk_size) {
let chunk_end = (chunk_start + chunk_size).min(n);
let chunk_len = chunk_end - chunk_start;
if chunk_len == chunk_size {
let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
.iter()
.map(|&x| x as f32)
.collect();
let mut normalized = vec![0.0f32; chunk_size];
simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
for (i, &val) in normalized.iter().enumerate() {
result[chunk_start + i] = val as f64;
}
}
}
}
Ok(result)
}
#[allow(dead_code)]
fn dst3_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
use scirs2_core::simd_ops::SimdUnifiedOps;
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut result = vec![0.0; n];
let chunk_size = 8;
let pi_f32 = PI as f32;
let n_f32 = n as f32;
for k_chunk in (0..n).step_by(chunk_size) {
let k_chunk_end = (k_chunk + chunk_size).min(n);
let k_chunk_len = k_chunk_end - k_chunk;
let mut k_indices = vec![0.0f32; k_chunk_len];
for (i, k_idx) in k_indices.iter_mut().enumerate() {
*k_idx = (k_chunk + i) as f32;
}
let mut special_terms = vec![0.0f32; k_chunk_len];
let x_last = x[n - 1] as f32;
for (i, &k_val) in k_indices.iter().enumerate() {
let k_int = k_val as usize;
special_terms[i] = x_last * if k_int.is_multiple_of(2) { 1.0 } else { -1.0 };
}
let mut regular_sum = vec![0.0f32; k_chunk_len];
for m_chunk in (0..(n - 1)).step_by(chunk_size) {
let m_chunk_end = (m_chunk + chunk_size).min(n - 1);
let m_chunk_len = m_chunk_end - m_chunk;
if m_chunk_len == k_chunk_len {
let mut m_plus_one = vec![0.0f32; m_chunk_len];
for (i, m_val) in m_plus_one.iter_mut().enumerate() {
*m_val = (m_chunk + i + 1) as f32;
}
let mut x_values = vec![0.0f32; m_chunk_len];
for (i, x_val) in x_values.iter_mut().enumerate() {
*x_val = x[m_chunk + i] as f32;
}
let mut k_plus_half = vec![0.0f32; k_chunk_len];
let half_vec = vec![0.5f32; k_chunk_len];
simd_add_f32_ultra_vec(&k_indices, &half_vec, &mut k_plus_half);
let mut angles = vec![0.0f32; k_chunk_len];
let mut temp_prod = vec![0.0f32; k_chunk_len];
let pi_vec = vec![pi_f32; k_chunk_len];
let n_vec = vec![n_f32; k_chunk_len];
simd_mul_f32_ultra_vec(&m_plus_one, &k_plus_half, &mut temp_prod);
let mut temp_prod2 = vec![0.0f32; k_chunk_len];
simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
let mut sin_values = vec![0.0f32; k_chunk_len];
simd_sin_f32_ultra_vec(&angles, &mut sin_values);
let mut products = vec![0.0f32; k_chunk_len];
simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
let mut temp_sum = vec![0.0f32; k_chunk_len];
simd_add_f32_ultra_vec(®ular_sum, &products, &mut temp_sum);
regular_sum = temp_sum;
}
}
let mut total_sum = vec![0.0f32; k_chunk_len];
simd_add_f32_ultra_vec(&special_terms, ®ular_sum, &mut total_sum);
for (i, &sum) in total_sum.iter().enumerate() {
result[k_chunk + i] = sum as f64;
}
}
if let Some("ortho") = norm {
let norm_factor = ((2.0 / n as f64).sqrt() / 2.0) as f32;
let norm_vec = vec![norm_factor; chunk_size];
for chunk_start in (0..n).step_by(chunk_size) {
let chunk_end = (chunk_start + chunk_size).min(n);
let chunk_len = chunk_end - chunk_start;
if chunk_len == chunk_size {
let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
.iter()
.map(|&x| x as f32)
.collect();
let mut normalized = vec![0.0f32; chunk_size];
simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
for (i, &val) in normalized.iter().enumerate() {
result[chunk_start + i] = val as f64;
}
}
}
} else {
let norm_factor = 0.5f32;
let norm_vec = vec![norm_factor; chunk_size];
for chunk_start in (0..n).step_by(chunk_size) {
let chunk_end = (chunk_start + chunk_size).min(n);
let chunk_len = chunk_end - chunk_start;
if chunk_len == chunk_size {
let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
.iter()
.map(|&x| x as f32)
.collect();
let mut normalized = vec![0.0f32; chunk_size];
simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
for (i, &val) in normalized.iter().enumerate() {
result[chunk_start + i] = val as f64;
}
}
}
}
Ok(result)
}
#[allow(dead_code)]
fn dst4_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
use scirs2_core::simd_ops::SimdUnifiedOps;
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut result = vec![0.0; n];
let chunk_size = 8;
let pi_f32 = PI as f32;
let n_f32 = n as f32;
for k_chunk in (0..n).step_by(chunk_size) {
let k_chunk_end = (k_chunk + chunk_size).min(n);
let k_chunk_len = k_chunk_end - k_chunk;
let mut k_indices = vec![0.0f32; k_chunk_len];
for (i, k_idx) in k_indices.iter_mut().enumerate() {
*k_idx = (k_chunk + i) as f32;
}
let mut k_plus_half = vec![0.0f32; k_chunk_len];
let half_vec = vec![0.5f32; k_chunk_len];
simd_add_f32_ultra_vec(&k_indices, &half_vec, &mut k_plus_half);
let mut chunk_sum = vec![0.0f32; k_chunk_len];
for m_chunk in (0..n).step_by(chunk_size) {
let m_chunk_end = (m_chunk + chunk_size).min(n);
let m_chunk_len = m_chunk_end - m_chunk;
if m_chunk_len == k_chunk_len {
let mut m_indices = vec![0.0f32; m_chunk_len];
for (i, m_idx) in m_indices.iter_mut().enumerate() {
*m_idx = (m_chunk + i) as f32;
}
let mut m_plus_half = vec![0.0f32; m_chunk_len];
simd_add_f32_ultra_vec(&m_indices, &half_vec, &mut m_plus_half);
let mut x_values = vec![0.0f32; m_chunk_len];
for (i, x_val) in x_values.iter_mut().enumerate() {
*x_val = x[m_chunk + i] as f32;
}
let mut angles = vec![0.0f32; k_chunk_len];
let mut temp_prod = vec![0.0f32; k_chunk_len];
let pi_vec = vec![pi_f32; k_chunk_len];
let n_vec = vec![n_f32; k_chunk_len];
simd_mul_f32_ultra_vec(&m_plus_half, &k_plus_half, &mut temp_prod);
let mut temp_prod2 = vec![0.0f32; k_chunk_len];
simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
let mut sin_values = vec![0.0f32; k_chunk_len];
simd_sin_f32_ultra_vec(&angles, &mut sin_values);
let mut products = vec![0.0f32; k_chunk_len];
simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
let mut temp_sum = vec![0.0f32; k_chunk_len];
simd_add_f32_ultra_vec(&chunk_sum, &products, &mut temp_sum);
chunk_sum = temp_sum;
}
}
for (i, &sum) in chunk_sum.iter().enumerate() {
result[k_chunk + i] = sum as f64;
}
}
if let Some("ortho") = norm {
let norm_factor = (2.0 / n as f64).sqrt() as f32;
let norm_vec = vec![norm_factor; chunk_size];
for chunk_start in (0..n).step_by(chunk_size) {
let chunk_end = (chunk_start + chunk_size).min(n);
let chunk_len = chunk_end - chunk_start;
if chunk_len == chunk_size {
let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
.iter()
.map(|&x| x as f32)
.collect();
let mut normalized = vec![0.0f32; chunk_size];
simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
for (i, &val) in normalized.iter().enumerate() {
result[chunk_start + i] = val as f64;
}
}
}
} else {
let norm_factor = 2.0f32;
let norm_vec = vec![norm_factor; chunk_size];
for chunk_start in (0..n).step_by(chunk_size) {
let chunk_end = (chunk_start + chunk_size).min(n);
let chunk_len = chunk_end - chunk_start;
if chunk_len == chunk_size {
let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
.iter()
.map(|&x| x as f32)
.collect();
let mut normalized = vec![0.0f32; chunk_size];
simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
for (i, &val) in normalized.iter().enumerate() {
result[chunk_start + i] = val as f64;
}
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn dst2_bandwidth_saturated_simd<T>(
x: &ArrayView2<T>,
dst_type: Option<DSTType>,
norm: Option<&str>,
) -> FFTResult<Array2<f64>>
where
T: NumCast + Copy + Debug,
{
use scirs2_core::simd_ops::PlatformCapabilities;
let (n_rows, n_cols) = x.dim();
let caps = PlatformCapabilities::detect();
if (n_rows >= 32 && n_cols >= 32) && (caps.has_avx2() || caps.has_avx512()) {
dst2_bandwidth_saturated_simd_impl(x, dst_type, norm)
} else {
dst2(x, dst_type, norm)
}
}
#[allow(dead_code)]
fn dst2_bandwidth_saturated_simd_impl<T>(
x: &ArrayView2<T>,
dst_type: Option<DSTType>,
norm: Option<&str>,
) -> FFTResult<Array2<f64>>
where
T: NumCast + Copy + Debug,
{
let (n_rows, n_cols) = x.dim();
let type_val = dst_type.unwrap_or(DSTType::Type2);
let mut intermediate = Array2::zeros((n_rows, n_cols));
for r in 0..n_rows {
let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
let row_vec: Vec<T> = row_slice.iter().cloned().collect();
let row_dst = dst_bandwidth_saturated_simd(&row_vec, Some(type_val), norm)?;
for (c, val) in row_dst.iter().enumerate() {
intermediate[[r, c]] = *val;
}
}
let mut final_result = Array2::zeros((n_rows, n_cols));
for c in 0..n_cols {
let col_slice = intermediate.slice(scirs2_core::ndarray::s![.., c]);
let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
let col_dst = dst_bandwidth_saturated_simd(&col_vec, Some(type_val), norm)?;
for (r, val) in col_dst.iter().enumerate() {
final_result[[r, c]] = *val;
}
}
Ok(final_result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::arr2;
#[test]
fn test_dst_and_idst() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let dst_coeffs =
dst(&signal, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
let recovered =
idst(&dst_coeffs, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
}
#[test]
fn test_dst_types() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let dst1_coeffs =
dst(&signal, Some(DSTType::Type1), Some("ortho")).expect("Operation failed");
let recovered =
idst(&dst1_coeffs, Some(DSTType::Type1), Some("ortho")).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
let dst2_coeffs =
dst(&signal, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
let recovered =
idst(&dst2_coeffs, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
let dst3_coeffs =
dst(&signal, Some(DSTType::Type3), Some("ortho")).expect("Operation failed");
let recovered =
idst(&dst3_coeffs, Some(DSTType::Type3), Some("ortho")).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
let dst4_coeffs =
dst(&signal, Some(DSTType::Type4), Some("ortho")).expect("Operation failed");
let recovered =
idst(&dst4_coeffs, Some(DSTType::Type4), Some("ortho")).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
}
#[test]
fn test_dst2_and_idst2() {
let arr = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let dst2_coeffs =
dst2(&arr.view(), Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
let recovered = idst2(&dst2_coeffs.view(), Some(DSTType::Type2), Some("ortho"))
.expect("Operation failed");
for i in 0..2 {
for j in 0..2 {
assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_linear_signal() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let dst2_coeffs =
dst(&signal, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
let recovered =
idst(&dst2_coeffs, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
}
#[test]
fn test_dst_linearity() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = vec![5.0, 6.0, 7.0, 8.0];
let a = 2.0;
let b = -0.5;
let dst_x = dst(&x, Some(DSTType::Type2), None).expect("DST(x) failed");
let dst_y = dst(&y, Some(DSTType::Type2), None).expect("DST(y) failed");
let combined: Vec<f64> = x
.iter()
.zip(y.iter())
.map(|(&xi, &yi)| a * xi + b * yi)
.collect();
let dst_combined =
dst(&combined, Some(DSTType::Type2), None).expect("DST(combined) failed");
for i in 0..x.len() {
let expected = a * dst_x[i] + b * dst_y[i];
assert_relative_eq!(dst_combined[i], expected, epsilon = 1e-10);
}
}
#[test]
fn test_dst_energy_preservation() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let coeffs =
dst(&signal, Some(DSTType::Type2), Some("ortho")).expect("DST-II ortho failed");
let time_energy: f64 = signal.iter().map(|x| x * x).sum();
let freq_energy: f64 = coeffs.iter().map(|c| c * c).sum();
assert_relative_eq!(time_energy, freq_energy, epsilon = 1e-8);
}
#[test]
fn test_dst_large_signal() {
let n = 32;
let signal: Vec<f64> = (0..n)
.map(|i| (PI * 3.0 * (i as f64 + 0.5) / n as f64).sin())
.collect();
let coeffs =
dst(&signal, Some(DSTType::Type2), Some("ortho")).expect("DST-II large failed");
let recovered =
idst(&coeffs, Some(DSTType::Type2), Some("ortho")).expect("IDST-II large failed");
for i in 0..n {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-8);
}
}
#[test]
fn test_dst_odd_length() {
let signal = vec![1.0, 3.0, 5.0, 7.0, 9.0];
for dst_type in &[
DSTType::Type1,
DSTType::Type2,
DSTType::Type3,
DSTType::Type4,
] {
let coeffs = dst(&signal, Some(*dst_type), Some("ortho"))
.expect("DST odd length forward failed");
let recovered = idst(&coeffs, Some(*dst_type), Some("ortho"))
.expect("DST odd length inverse failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-8);
}
}
}
#[test]
fn test_dst2_4x4() {
let arr = Array2::from_shape_vec(
(4, 4),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0,
],
)
.expect("Array creation failed");
let coeffs =
dst2(&arr.view(), Some(DSTType::Type2), Some("ortho")).expect("2D DST-II failed");
let recovered =
idst2(&coeffs.view(), Some(DSTType::Type2), Some("ortho")).expect("2D IDST-II failed");
for i in 0..4 {
for j in 0..4 {
assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-8);
}
}
}
#[test]
fn test_dst_type4_self_inverse() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let coeffs = dst(&signal, Some(DSTType::Type4), Some("ortho")).expect("DST-IV failed");
let recovered =
dst(&coeffs, Some(DSTType::Type4), Some("ortho")).expect("DST-IV self-inv failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-8);
}
}
}