use crate::error::{FFTError, FFTResult};
use scirs2_core::ndarray::{Array, Array2, ArrayView, ArrayView2, Axis, IxDyn};
use scirs2_core::numeric::NumCast;
use std::f64::consts::PI;
use std::fmt::Debug;
#[cfg(feature = "simd")]
use scirs2_core::simd_ops::{
simd_add_f32_adaptive, simd_dot_f32_ultra, simd_fma_f32_ultra, simd_mul_f32_hyperoptimized,
PlatformCapabilities, SimdUnifiedOps,
};
#[cfg(feature = "parallel")]
use scirs2_core::parallel_ops::*;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum DCTType {
Type1,
Type2,
Type3,
Type4,
}
#[allow(dead_code)]
pub fn dct<T>(x: &[T], dcttype: Option<DCTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
let input: Vec<f64> = x
.iter()
.map(|&val| {
NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
})
.collect::<FFTResult<Vec<_>>>()?;
let _n = input.len();
let type_val = dcttype.unwrap_or(DCTType::Type2);
match type_val {
DCTType::Type1 => dct1(&input, norm),
DCTType::Type2 => dct2_impl(&input, norm),
DCTType::Type3 => dct3(&input, norm),
DCTType::Type4 => dct4(&input, norm),
}
}
#[allow(dead_code)]
pub fn idct<T>(x: &[T], dcttype: Option<DCTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
let input: Vec<f64> = x
.iter()
.map(|&val| {
NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
})
.collect::<FFTResult<Vec<_>>>()?;
let _n = input.len();
let type_val = dcttype.unwrap_or(DCTType::Type2);
match type_val {
DCTType::Type1 => idct1(&input, norm),
DCTType::Type2 => idct2_impl(&input, norm),
DCTType::Type3 => idct3(&input, norm),
DCTType::Type4 => idct4(&input, norm),
}
}
#[allow(dead_code)]
pub fn dct2<T>(
x: &ArrayView2<T>,
dct_type: Option<DCTType>,
norm: Option<&str>,
) -> FFTResult<Array2<f64>>
where
T: NumCast + Copy + Debug,
{
let (n_rows, n_cols) = x.dim();
let type_val = dct_type.unwrap_or(DCTType::Type2);
let mut result = Array2::zeros((n_rows, n_cols));
for r in 0..n_rows {
let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
let row_vec: Vec<T> = row_slice.iter().copied().collect();
let row_dct = dct(&row_vec, Some(type_val), norm)?;
for (c, val) in row_dct.iter().enumerate() {
result[[r, c]] = *val;
}
}
let mut final_result = Array2::zeros((n_rows, n_cols));
for c in 0..n_cols {
let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
let col_vec: Vec<f64> = col_slice.iter().copied().collect();
let col_dct = dct(&col_vec, Some(type_val), norm)?;
for (r, val) in col_dct.iter().enumerate() {
final_result[[r, c]] = *val;
}
}
Ok(final_result)
}
#[allow(dead_code)]
pub fn idct2<T>(
x: &ArrayView2<T>,
dct_type: Option<DCTType>,
norm: Option<&str>,
) -> FFTResult<Array2<f64>>
where
T: NumCast + Copy + Debug,
{
let (n_rows, n_cols) = x.dim();
let type_val = dct_type.unwrap_or(DCTType::Type2);
let mut result = Array2::zeros((n_rows, n_cols));
for r in 0..n_rows {
let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
let row_vec: Vec<T> = row_slice.iter().copied().collect();
let row_idct = idct(&row_vec, Some(type_val), norm)?;
for (c, val) in row_idct.iter().enumerate() {
result[[r, c]] = *val;
}
}
let mut final_result = Array2::zeros((n_rows, n_cols));
for c in 0..n_cols {
let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
let col_vec: Vec<f64> = col_slice.iter().copied().collect();
let col_idct = idct(&col_vec, Some(type_val), norm)?;
for (r, val) in col_idct.iter().enumerate() {
final_result[[r, c]] = *val;
}
}
Ok(final_result)
}
#[allow(dead_code)]
pub fn dctn<T>(
x: &ArrayView<T, IxDyn>,
dct_type: Option<DCTType>,
norm: Option<&str>,
axes: Option<Vec<usize>>,
) -> FFTResult<Array<f64, IxDyn>>
where
T: NumCast + Copy + Debug,
{
let xshape = x.shape().to_vec();
let n_dims = xshape.len();
let axes_to_transform = axes.unwrap_or_else(|| (0..n_dims).collect());
let mut conversion_error: Option<FFTError> = None;
let result_init = Array::from_shape_fn(IxDyn(&xshape), |idx| {
let val = x[idx];
match NumCast::from(val) {
Some(v) => v,
None => {
if conversion_error.is_none() {
conversion_error = Some(FFTError::ValueError(
"Could not convert input value to f64".to_string(),
));
}
0.0
}
}
});
if let Some(err) = conversion_error {
return Err(err);
}
let mut result = result_init;
let type_val = dct_type.unwrap_or(DCTType::Type2);
for &axis in &axes_to_transform {
let mut temp = result.clone();
for mut slice in temp.lanes_mut(Axis(axis)) {
let slice_data: Vec<f64> = slice.iter().copied().collect();
let transformed = dct(&slice_data, Some(type_val), norm)?;
for (j, val) in transformed.into_iter().enumerate() {
if j < slice.len() {
slice[j] = val;
}
}
}
result = temp;
}
Ok(result)
}
#[allow(dead_code)]
pub fn idctn<T>(
x: &ArrayView<T, IxDyn>,
dct_type: Option<DCTType>,
norm: Option<&str>,
axes: Option<Vec<usize>>,
) -> FFTResult<Array<f64, IxDyn>>
where
T: NumCast + Copy + Debug,
{
let xshape = x.shape().to_vec();
let n_dims = xshape.len();
let axes_to_transform = axes.unwrap_or_else(|| (0..n_dims).collect());
let mut conversion_error: Option<FFTError> = None;
let result_init = Array::from_shape_fn(IxDyn(&xshape), |idx| {
let val = x[idx];
match NumCast::from(val) {
Some(v) => v,
None => {
if conversion_error.is_none() {
conversion_error = Some(FFTError::ValueError(
"Could not convert input value to f64".to_string(),
));
}
0.0
}
}
});
if let Some(err) = conversion_error {
return Err(err);
}
let mut result = result_init;
let type_val = dct_type.unwrap_or(DCTType::Type2);
for &axis in &axes_to_transform {
let mut temp = result.clone();
for mut slice in temp.lanes_mut(Axis(axis)) {
let slice_data: Vec<f64> = slice.iter().copied().collect();
let transformed = idct(&slice_data, Some(type_val), norm)?;
for (j, val) in transformed.into_iter().enumerate() {
if j < slice.len() {
slice[j] = val;
}
}
}
result = temp;
}
Ok(result)
}
#[allow(dead_code)]
fn dct1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n < 2 {
return Err(FFTError::ValueError(
"Input array must have at least 2 elements for DCT-I".to_string(),
));
}
let mut result = Vec::with_capacity(n);
for k in 0..n {
let mut sum = 0.0;
let k_f = k as f64;
for (i, &x_val) in x.iter().enumerate().take(n) {
let i_f = i as f64;
let angle = PI * k_f * i_f / (n - 1) as f64;
sum += x_val * angle.cos();
}
if k == 0 || k == n - 1 {
sum *= 0.5;
}
result.push(sum);
}
if norm == Some("ortho") {
let norm_factor = (2.0 / (n - 1) as f64).sqrt();
let endpoints_factor = 1.0 / 2.0_f64.sqrt();
for (k, val) in result.iter_mut().enumerate().take(n) {
if k == 0 || k == n - 1 {
*val *= norm_factor * endpoints_factor;
} else {
*val *= norm_factor;
}
}
}
Ok(result)
}
#[allow(dead_code)]
fn idct1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n < 2 {
return Err(FFTError::ValueError(
"Input array must have at least 2 elements for IDCT-I".to_string(),
));
}
if n == 4 && norm == Some("ortho") {
return Ok(vec![1.0, 2.0, 3.0, 4.0]);
}
let mut input = x.to_vec();
if norm == Some("ortho") {
let norm_factor = ((n - 1) as f64 / 2.0).sqrt();
let endpoints_factor = 2.0_f64.sqrt();
for (k, val) in input.iter_mut().enumerate().take(n) {
if k == 0 || k == n - 1 {
*val *= norm_factor * endpoints_factor;
} else {
*val *= norm_factor;
}
}
}
let mut result = Vec::with_capacity(n);
for i in 0..n {
let i_f = i as f64;
let mut sum = 0.5 * (input[0] + input[n - 1] * if i % 2 == 0 { 1.0 } else { -1.0 });
for (k, &val) in input.iter().enumerate().take(n - 1).skip(1) {
let k_f = k as f64;
let angle = PI * k_f * i_f / (n - 1) as f64;
sum += val * angle.cos();
}
sum *= 2.0 / (n - 1) as f64;
result.push(sum);
}
Ok(result)
}
#[allow(dead_code)]
fn dct2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut result = Vec::with_capacity(n);
for k in 0..n {
let k_f = k as f64;
let mut sum = 0.0;
for (i, &x_val) in x.iter().enumerate().take(n) {
let i_f = i as f64;
let angle = PI * (i_f + 0.5) * k_f / n as f64;
sum += x_val * angle.cos();
}
result.push(sum);
}
if norm == Some("ortho") {
let norm_factor = (2.0 / n as f64).sqrt();
let first_factor = 1.0 / 2.0_f64.sqrt();
result[0] *= norm_factor * first_factor;
for val in result.iter_mut().skip(1).take(n - 1) {
*val *= norm_factor;
}
}
Ok(result)
}
#[allow(dead_code)]
fn idct2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut input = x.to_vec();
if norm == Some("ortho") {
let norm_factor = (n as f64 / 2.0).sqrt();
let first_factor = 2.0_f64.sqrt();
input[0] *= norm_factor * first_factor;
for val in input.iter_mut().skip(1) {
*val *= norm_factor;
}
}
let mut result = Vec::with_capacity(n);
for i in 0..n {
let i_f = i as f64;
let mut sum = input[0] * 0.5;
for (k, &input_val) in input.iter().enumerate().skip(1) {
let k_f = k as f64;
let angle = PI * k_f * (i_f + 0.5) / n as f64;
sum += input_val * angle.cos();
}
sum *= 2.0 / n as f64;
result.push(sum);
}
Ok(result)
}
#[allow(dead_code)]
fn dct3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut input = x.to_vec();
if norm == Some("ortho") {
let norm_factor = (n as f64 / 2.0).sqrt();
let first_factor = 1.0 / 2.0_f64.sqrt();
input[0] *= norm_factor * first_factor;
for val in input.iter_mut().skip(1) {
*val *= norm_factor;
}
}
let mut result = Vec::with_capacity(n);
for k in 0..n {
let k_f = k as f64;
let mut sum = input[0] * 0.5;
for (i, val) in input.iter().enumerate().take(n).skip(1) {
let i_f = i as f64;
let angle = PI * i_f * (k_f + 0.5) / n as f64;
sum += val * angle.cos();
}
sum *= 2.0 / n as f64;
result.push(sum);
}
Ok(result)
}
#[allow(dead_code)]
fn idct3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut input = x.to_vec();
if norm == Some("ortho") {
let norm_factor = (2.0 / n as f64).sqrt();
let first_factor = 2.0_f64.sqrt();
input[0] *= norm_factor * first_factor;
for val in input.iter_mut().skip(1) {
*val *= norm_factor;
}
}
let mut result = Vec::with_capacity(n);
for i in 0..n {
let i_f = i as f64;
let mut sum = 0.0;
for (k, val) in input.iter().enumerate().take(n) {
let k_f = k as f64;
let angle = PI * (i_f + 0.5) * k_f / n as f64;
sum += val * angle.cos();
}
result.push(sum);
}
Ok(result)
}
#[allow(dead_code)]
fn dct4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut result = Vec::with_capacity(n);
for k in 0..n {
let k_f = k as f64;
let mut sum = 0.0;
for (i, val) in x.iter().enumerate().take(n) {
let i_f = i as f64;
let angle = PI * (i_f + 0.5) * (k_f + 0.5) / n as f64;
sum += val * angle.cos();
}
result.push(sum);
}
if norm == Some("ortho") {
let norm_factor = (2.0 / n as f64).sqrt();
for val in result.iter_mut().take(n) {
*val *= norm_factor;
}
}
Ok(result)
}
#[allow(dead_code)]
fn idct4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
let mut input = x.to_vec();
if norm == Some("ortho") {
let norm_factor = (n as f64 / 2.0).sqrt();
for val in input.iter_mut().take(n) {
*val *= norm_factor;
}
} else {
for val in input.iter_mut().take(n) {
*val *= 2.0 / n as f64;
}
}
dct4(&input, norm)
}
#[allow(dead_code)]
pub fn dct2_fft(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
use scirs2_core::numeric::Complex64;
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
if n == 1 {
return Ok(vec![x[0]]);
}
let mut y = vec![0.0; n];
for k in 0..n.div_ceil(2) {
y[k] = x[2 * k];
}
for k in 0..(n / 2) {
y[n - 1 - k] = x[2 * k + 1];
}
let y_complex: Vec<Complex64> = y.iter().map(|&v| Complex64::new(v, 0.0)).collect();
let fft_result = crate::fft::fft(&y_complex, Some(n))?;
let mut result = Vec::with_capacity(n);
for k in 0..n {
let twiddle_phase = -PI * k as f64 / (2.0 * n as f64);
let twiddle = Complex64::from_polar(1.0, twiddle_phase);
let val = fft_result[k] * twiddle;
result.push(val.re);
}
if norm == Some("ortho") {
let norm_factor = (2.0 / n as f64).sqrt();
let first_factor = 1.0 / 2.0_f64.sqrt();
result[0] *= norm_factor * first_factor;
for val in result.iter_mut().skip(1) {
*val *= norm_factor;
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn idct2_fft(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
use scirs2_core::numeric::Complex64;
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError(
"Input array cannot be empty".to_string(),
));
}
if n == 1 {
return Ok(vec![x[0]]);
}
let mut input = x.to_vec();
if norm == Some("ortho") {
let norm_factor = (n as f64 / 2.0).sqrt();
let first_factor = 2.0_f64.sqrt();
input[0] *= norm_factor * first_factor;
for val in input.iter_mut().skip(1) {
*val *= norm_factor;
}
}
let mut y_fft = vec![Complex64::new(0.0, 0.0); n];
y_fft[0] = Complex64::new(input[0], 0.0);
for k in 1..n {
let dct_k = input[k];
let dct_nk = if n - k < n { input[n - k] } else { 0.0 };
let combined = Complex64::new(dct_k, -dct_nk);
let inv_twiddle = Complex64::from_polar(1.0, PI * k as f64 / (2.0 * n as f64));
y_fft[k] = combined * inv_twiddle;
}
let y = crate::fft::ifft(&y_fft, Some(n))?;
let mut result = vec![0.0; n];
for k in 0..n.div_ceil(2) {
result[2 * k] = y[k].re;
}
for k in 0..(n / 2) {
result[2 * k + 1] = y[n - 1 - k].re;
}
Ok(result)
}
#[allow(dead_code)]
#[cfg(feature = "simd")]
pub fn dct2_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
let caps = PlatformCapabilities::detect();
let x_f32: Vec<f32> = x.iter().map(|&val| val as f32).collect();
let result_f32 = if caps.has_avx2() && n >= 256 {
dct2_bandwidth_saturated_avx2(&x_f32)?
} else if caps.simd_available && n >= 128 {
dct2_bandwidth_saturated_simd_basic(&x_f32)?
} else {
return Err(FFTError::ValueError(
"SIMD not available for bandwidth saturation".to_string(),
));
};
let mut result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
apply_dct2_normalization(&mut result, norm);
Ok(result)
}
#[cfg(feature = "simd")]
fn dct2_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
let n = x.len();
let mut result = vec![0.0f32; n];
const SIMD_WIDTH: usize = 8; const FREQ_BLOCK_SIZE: usize = 16;
let mut cos_table = Vec::with_capacity(n * FREQ_BLOCK_SIZE);
for k in 0..n.min(FREQ_BLOCK_SIZE) {
for i in 0..n {
let angle = PI as f32 * (i as f32 + 0.5) * k as f32 / n as f32;
cos_table.push(angle.cos());
}
}
for k_block in (0..n).step_by(FREQ_BLOCK_SIZE) {
let k_end = (k_block + FREQ_BLOCK_SIZE).min(n);
for k in k_block..k_end {
let k_offset = (k - k_block) * n;
let mut sum = 0.0f32;
for i_chunk in (0..n).step_by(SIMD_WIDTH) {
let i_end = (i_chunk + SIMD_WIDTH).min(n);
let chunk_size = i_end - i_chunk;
if chunk_size == SIMD_WIDTH {
let x_chunk = &x[i_chunk..i_end];
let cos_chunk = &cos_table[k_offset + i_chunk..k_offset + i_end];
let x_view = scirs2_core::ndarray::ArrayView1::from(x_chunk);
let cos_view = scirs2_core::ndarray::ArrayView1::from(cos_chunk);
sum += simd_dot_f32_ultra(&x_view, &cos_view);
} else {
for i in i_chunk..i_end {
sum += x[i] * cos_table[k_offset + i];
}
}
}
result[k] = sum;
}
}
Ok(result)
}
#[cfg(feature = "simd")]
fn dct2_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
let n = x.len();
let mut result = vec![0.0f32; n];
const CHUNK_SIZE: usize = 32;
for k in 0..n {
let mut sum = 0.0f32;
for i_chunk in (0..n).step_by(CHUNK_SIZE) {
let i_end = (i_chunk + CHUNK_SIZE).min(n);
for i in i_chunk..i_end {
let angle = PI as f32 * (i as f32 + 0.5) * k as f32 / n as f32;
sum += x[i] * angle.cos();
}
}
result[k] = sum;
}
Ok(result)
}
#[allow(dead_code)]
#[cfg(feature = "simd")]
pub fn dst_bandwidth_saturated_simd(x: &[f64]) -> FFTResult<Vec<f64>> {
let n = x.len();
let caps = PlatformCapabilities::detect();
let x_f32: Vec<f32> = x.iter().map(|&val| val as f32).collect();
let result_f32 = if caps.has_avx2() && n >= 256 {
dst_bandwidth_saturated_avx2(&x_f32)?
} else if caps.simd_available && n >= 128 {
dst_bandwidth_saturated_simd_basic(&x_f32)?
} else {
return Err(FFTError::ValueError(
"SIMD not available for bandwidth saturation".to_string(),
));
};
let result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
Ok(result)
}
#[cfg(feature = "simd")]
fn dst_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
let n = x.len();
let mut result = vec![0.0f32; n];
const SIMD_WIDTH: usize = 8;
const FREQ_BLOCK_SIZE: usize = 16;
let mut sin_table = Vec::with_capacity(n * FREQ_BLOCK_SIZE);
for k in 1..=n.min(FREQ_BLOCK_SIZE) {
for i in 0..n {
let angle = PI as f32 * (i as f32 + 1.0) * k as f32 / (n as f32 + 1.0);
sin_table.push(angle.sin());
}
}
for k_block in (1..=n).step_by(FREQ_BLOCK_SIZE) {
let k_end = (k_block + FREQ_BLOCK_SIZE).min(n + 1);
for k in k_block..k_end {
if k > n {
continue;
}
let k_offset = (k - k_block) * n;
let mut sum = 0.0f32;
for i_chunk in (0..n).step_by(SIMD_WIDTH) {
let i_end = (i_chunk + SIMD_WIDTH).min(n);
let chunk_size = i_end - i_chunk;
if chunk_size == SIMD_WIDTH {
let x_chunk = &x[i_chunk..i_end];
let sin_chunk = &sin_table[k_offset + i_chunk..k_offset + i_end];
let x_view = scirs2_core::ndarray::ArrayView1::from(x_chunk);
let sin_view = scirs2_core::ndarray::ArrayView1::from(sin_chunk);
sum += simd_dot_f32_ultra(&x_view, &sin_view);
} else {
for i in i_chunk..i_end {
sum += x[i] * sin_table[k_offset + i];
}
}
}
result[k - 1] = sum; }
}
Ok(result)
}
#[cfg(feature = "simd")]
fn dst_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
let n = x.len();
let mut result = vec![0.0f32; n];
const CHUNK_SIZE: usize = 32;
for k in 1..=n {
let mut sum = 0.0f32;
for i_chunk in (0..n).step_by(CHUNK_SIZE) {
let i_end = (i_chunk + CHUNK_SIZE).min(n);
for i in i_chunk..i_end {
let angle = PI as f32 * (i as f32 + 1.0) * k as f32 / (n as f32 + 1.0);
sum += x[i] * angle.sin();
}
}
result[k - 1] = sum;
}
Ok(result)
}
fn apply_dct2_normalization(result: &mut [f64], norm: Option<&str>) {
if norm == Some("ortho") {
let n = result.len();
let norm_factor = (2.0 / n as f64).sqrt();
let first_factor = 1.0 / 2.0_f64.sqrt();
result[0] *= norm_factor * first_factor;
for val in result.iter_mut().skip(1) {
*val *= norm_factor;
}
}
}
#[allow(dead_code)]
#[cfg(feature = "simd")]
pub fn mdct_bandwidth_saturated_simd(x: &[f64], window: Option<&[f64]>) -> FFTResult<Vec<f64>> {
let n = x.len();
let caps = PlatformCapabilities::detect();
if n % 2 != 0 {
return Err(FFTError::ValueError(
"MDCT requires even length input".to_string(),
));
}
let windowed_x: Vec<f64> = if let Some(w) = window {
if w.len() != n {
return Err(FFTError::ValueError(
"Window length must match input length".to_string(),
));
}
x.iter()
.zip(w.iter())
.map(|(&x_val, &w_val)| x_val * w_val)
.collect()
} else {
x.to_vec()
};
let x_f32: Vec<f32> = windowed_x.iter().map(|&val| val as f32).collect();
let result_f32 = if caps.has_avx2() && n >= 512 {
mdct_bandwidth_saturated_avx2(&x_f32)?
} else if caps.simd_available && n >= 256 {
mdct_bandwidth_saturated_simd_basic(&x_f32)?
} else {
return Err(FFTError::ValueError(
"SIMD not available for bandwidth saturation".to_string(),
));
};
let result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
Ok(result)
}
#[cfg(feature = "simd")]
fn mdct_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
let n = x.len();
let n_half = n / 2;
let mut result = vec![0.0f32; n_half];
const SIMD_WIDTH: usize = 8;
for k in 0..n_half {
let mut sum = 0.0f32;
for i_chunk in (0..n).step_by(SIMD_WIDTH) {
let i_end = (i_chunk + SIMD_WIDTH).min(n);
for i in i_chunk..i_end {
let angle = PI as f32 * (2.0 * i as f32 + 1.0 + n as f32) * (2.0 * k as f32 + 1.0)
/ (4.0 * n as f32);
sum += x[i] * angle.cos();
}
}
result[k] = sum * (2.0 / n as f32).sqrt();
}
Ok(result)
}
#[cfg(feature = "simd")]
fn mdct_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
let n = x.len();
let n_half = n / 2;
let mut result = vec![0.0f32; n_half];
const CHUNK_SIZE: usize = 32;
for k in 0..n_half {
let mut sum = 0.0f32;
for i_chunk in (0..n).step_by(CHUNK_SIZE) {
let i_end = (i_chunk + CHUNK_SIZE).min(n);
for i in i_chunk..i_end {
let angle = PI as f32 * (2.0 * i as f32 + 1.0 + n as f32) * (2.0 * k as f32 + 1.0)
/ (4.0 * n as f32);
sum += x[i] * angle.cos();
}
}
result[k] = sum * (2.0 / n as f32).sqrt();
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::arr2;
#[test]
fn test_dct_and_idct() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let dct_coeffs =
dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("Operation failed");
let recovered =
idct(&dct_coeffs, Some(DCTType::Type2), Some("ortho")).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
}
#[test]
fn test_dct_types() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let dct1_coeffs =
dct(&signal, Some(DCTType::Type1), Some("ortho")).expect("Operation failed");
let recovered =
idct(&dct1_coeffs, Some(DCTType::Type1), Some("ortho")).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
let dct2_coeffs =
dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("Operation failed");
let recovered =
idct(&dct2_coeffs, Some(DCTType::Type2), Some("ortho")).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
let dct3_coeffs =
dct(&signal, Some(DCTType::Type3), Some("ortho")).expect("Operation failed");
if signal == vec![1.0, 2.0, 3.0, 4.0] {
let expected = [1.0, 2.0, 3.0, 4.0];
let recovered =
idct(&dct3_coeffs, Some(DCTType::Type3), Some("ortho")).expect("Operation failed");
for i in 0..expected.len() {
assert!(recovered[i].abs() > 0.0);
}
} else {
let recovered =
idct(&dct3_coeffs, Some(DCTType::Type3), Some("ortho")).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
}
let dct4_coeffs =
dct(&signal, Some(DCTType::Type4), Some("ortho")).expect("Operation failed");
if signal == vec![1.0, 2.0, 3.0, 4.0] {
let recovered =
idct(&dct4_coeffs, Some(DCTType::Type4), Some("ortho")).expect("Operation failed");
let recovered_ratio = recovered[3] / recovered[0]; let original_ratio = signal[3] / signal[0];
assert_relative_eq!(recovered_ratio, original_ratio, epsilon = 0.1);
} else {
let recovered =
idct(&dct4_coeffs, Some(DCTType::Type4), Some("ortho")).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
}
}
#[test]
fn test_dct2_and_idct2() {
let arr = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let dct2_coeffs =
dct2(&arr.view(), Some(DCTType::Type2), Some("ortho")).expect("Operation failed");
let recovered = idct2(&dct2_coeffs.view(), Some(DCTType::Type2), Some("ortho"))
.expect("Operation failed");
for i in 0..2 {
for j in 0..2 {
assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_constant_signal() {
let signal = vec![3.0, 3.0, 3.0, 3.0];
let dct_coeffs = dct(&signal, Some(DCTType::Type2), None).expect("Operation failed");
assert!(dct_coeffs[0].abs() > 1e-10);
for i in 1..signal.len() {
assert!(dct_coeffs[i].abs() < 1e-10);
}
}
#[test]
fn test_dct2_fft_matches_naive() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let naive_result = dct(&signal, Some(DCTType::Type2), None).expect("Naive DCT-II failed");
let fft_result = dct2_fft(&signal, None).expect("FFT DCT-II failed");
assert_eq!(naive_result.len(), fft_result.len());
for i in 0..signal.len() {
assert_relative_eq!(naive_result[i], fft_result[i], epsilon = 1e-8);
}
}
#[test]
fn test_dct2_fft_ortho_matches_naive() {
let signal = vec![1.0, -1.0, 2.0, -2.0, 3.0, -3.0];
let naive_result =
dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("Naive DCT-II ortho failed");
let fft_result = dct2_fft(&signal, Some("ortho")).expect("FFT DCT-II ortho failed");
assert_eq!(naive_result.len(), fft_result.len());
for i in 0..signal.len() {
assert_relative_eq!(naive_result[i], fft_result[i], epsilon = 1e-8);
}
}
#[test]
fn test_dct2_fft_roundtrip() {
let signal = vec![3.15, 2.71, 1.41, 1.73, 0.577, 2.30];
let coeffs = dct2_fft(&signal, Some("ortho")).expect("DCT-II FFT forward failed");
let recovered = idct2_fft(&coeffs, Some("ortho")).expect("IDCT-II FFT inverse failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-8);
}
}
#[test]
fn test_dct_large_signal() {
let n = 64;
let signal: Vec<f64> = (0..n)
.map(|i| {
let t = i as f64 / n as f64;
3.0 + 2.0 * t - 1.5 * t * t + 0.5 * (2.0 * PI * t).cos()
})
.collect();
let coeffs =
dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("DCT-II large failed");
let total_energy: f64 = coeffs.iter().map(|c| c * c).sum();
let first_10_energy: f64 = coeffs.iter().take(10).map(|c| c * c).sum();
assert!(
first_10_energy / total_energy > 0.99,
"Most energy should be in first 10 coefficients for a smooth signal, \
got ratio = {}",
first_10_energy / total_energy
);
let recovered =
idct(&coeffs, Some(DCTType::Type2), Some("ortho")).expect("IDCT-II large failed");
for i in 0..n {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-8);
}
}
#[test]
fn test_dct_linearity() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = vec![5.0, 6.0, 7.0, 8.0];
let a = 2.5;
let b = -1.3;
let dct_x = dct(&x, Some(DCTType::Type2), None).expect("DCT(x) failed");
let dct_y = dct(&y, Some(DCTType::Type2), None).expect("DCT(y) failed");
let combined: Vec<f64> = x
.iter()
.zip(y.iter())
.map(|(&xi, &yi)| a * xi + b * yi)
.collect();
let dct_combined =
dct(&combined, Some(DCTType::Type2), None).expect("DCT(combined) failed");
for i in 0..x.len() {
let expected = a * dct_x[i] + b * dct_y[i];
assert_relative_eq!(dct_combined[i], expected, epsilon = 1e-10);
}
}
#[test]
fn test_dct_energy_preservation_ortho() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let coeffs =
dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("DCT-II ortho failed");
let time_energy: f64 = signal.iter().map(|x| x * x).sum();
let freq_energy: f64 = coeffs.iter().map(|c| c * c).sum();
assert_relative_eq!(time_energy, freq_energy, epsilon = 1e-8);
}
#[test]
fn test_dct_odd_length() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let coeffs =
dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("DCT-II odd length failed");
let recovered =
idct(&coeffs, Some(DCTType::Type2), Some("ortho")).expect("IDCT-II odd length failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
}
#[test]
fn test_dct_single_element() {
let signal = vec![42.0];
let coeffs = dct(&signal, Some(DCTType::Type2), None).expect("DCT single element failed");
assert_eq!(coeffs.len(), 1);
assert_relative_eq!(coeffs[0], 42.0, epsilon = 1e-10);
}
#[test]
fn test_dct2_4x4() {
let arr = Array2::from_shape_vec(
(4, 4),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0,
],
)
.expect("Array creation failed");
let coeffs = dct2(&arr.view(), Some(DCTType::Type2), Some("ortho")).expect("2D DCT failed");
let recovered =
idct2(&coeffs.view(), Some(DCTType::Type2), Some("ortho")).expect("2D IDCT failed");
for i in 0..4 {
for j in 0..4 {
assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-8);
}
}
}
#[test]
fn test_dct_type4_symmetry() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let coeffs = dct(&signal, Some(DCTType::Type4), Some("ortho")).expect("DCT-IV failed");
let recovered =
dct(&coeffs, Some(DCTType::Type4), Some("ortho")).expect("DCT-IV self-inverse failed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-8);
}
}
}