use scirs2_core::ndarray::{Array1, ArrayBase, Data, Dimension};
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
use crate::error::{FFTError, FFTResult};
use crate::fft::fft;
#[allow(dead_code)]
pub fn dct_v<S, D>(x: &ArrayBase<S, D>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
D: Dimension,
{
let x_flat = x.iter().cloned().collect::<Vec<f64>>();
let n = x_flat.len();
if n == 0 {
return Err(FFTError::ValueError("empty array".to_string()));
}
let mut extended = vec![Complex64::new(0.0, 0.0); 2 * n];
for i in 0..n {
extended[i] = Complex64::new(x_flat[i], 0.0);
extended[2 * n - 1 - i] = Complex64::new(-x_flat[i], 0.0);
}
let fft_result = fft(&extended, None)?;
let mut result = Array1::zeros(n);
let scale = (2.0 / (2.0 * n as f64)).sqrt();
for k in 0..n {
let phase = PI * (2 * k + 1) as f64 / (4.0 * n as f64);
result[k] = scale * (fft_result[k].re * phase.cos() - fft_result[k].im * phase.sin());
}
Ok(result)
}
#[allow(dead_code)]
pub fn idct_v<S>(x: &ArrayBase<S, scirs2_core::ndarray::Ix1>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
{
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError("empty array".to_string()));
}
let mut extended = vec![Complex64::new(0.0, 0.0); 2 * n];
let scale_factor = (2.0_f64 / n as f64).sqrt();
for k in 0..n {
let phase = PI * (2 * k + 1) as f64 / (4.0 * n as f64);
let cos_phase = phase.cos();
let sin_phase = phase.sin();
extended[k] = Complex64::new(
x[k] * cos_phase * scale_factor,
x[k] * sin_phase * scale_factor,
);
extended[2 * n - 1 - k] = Complex64::new(
-x[k] * cos_phase * scale_factor,
x[k] * sin_phase * scale_factor,
);
}
let mut fft_input = extended.clone();
for item in &mut fft_input {
*item = item.conj();
}
let ifft_result = fft(&fft_input, None)?;
let mut result = Array1::zeros(n);
let final_scale = 1.0 / (2.0 * n as f64);
for i in 0..n {
result[i] = ifft_result[i].re * final_scale;
}
Ok(result)
}
#[allow(dead_code)]
pub fn dct_vi<S, D>(x: &ArrayBase<S, D>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
D: Dimension,
{
let x_flat = x.iter().cloned().collect::<Vec<f64>>();
let n = x_flat.len();
if n == 0 {
return Err(FFTError::ValueError("empty array".to_string()));
}
let mut extended = vec![Complex64::new(0.0, 0.0); 4 * n];
for i in 0..n {
extended[i] = Complex64::new(x_flat[i], 0.0);
extended[2 * n - 1 - i] = Complex64::new(x_flat[i], 0.0);
extended[2 * n + i] = Complex64::new(-x_flat[i], 0.0);
extended[4 * n - 1 - i] = Complex64::new(-x_flat[i], 0.0);
}
let fft_result = fft(&extended, None)?;
let mut result = Array1::zeros(n);
let scale = 0.5;
for k in 0..n {
result[k] = scale * fft_result[k].re;
}
Ok(result)
}
#[allow(dead_code)]
pub fn idct_vi<S>(x: &ArrayBase<S, scirs2_core::ndarray::Ix1>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
{
let result = dct_vi(x)?;
let scale = 1.0 / (x.len() as f64);
Ok(result.mapv(|v| v * scale))
}
#[allow(dead_code)]
pub fn dct_vii<S, D>(x: &ArrayBase<S, D>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
D: Dimension,
{
let x_flat = x.iter().cloned().collect::<Vec<f64>>();
let n = x_flat.len();
if n == 0 {
return Err(FFTError::ValueError("empty array".to_string()));
}
let mut result = Array1::zeros(n);
let scale = (2.0 / n as f64).sqrt();
for k in 0..n {
let mut sum = 0.0;
for (n_i, &val) in x_flat.iter().enumerate().take(n) {
let angle = PI * k as f64 * (n_i as f64 + 0.5) / n as f64;
sum += val * angle.cos();
}
result[k] = scale * sum;
}
Ok(result)
}
#[allow(dead_code)]
pub fn idct_vii<S>(x: &ArrayBase<S, scirs2_core::ndarray::Ix1>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
{
let n = x.len();
let mut result = Array1::zeros(n);
let scale = (2.0_f64 / n as f64).sqrt();
for i in 0..n {
let mut sum = 0.0;
for k in 0..n {
let angle = PI * k as f64 * (i as f64 + 0.5) / n as f64;
sum += x[k] * angle.cos();
}
result[i] = scale * sum;
}
Ok(result)
}
#[allow(dead_code)]
pub fn dct_viii<S, D>(x: &ArrayBase<S, D>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
D: Dimension,
{
let x_flat = x.iter().cloned().collect::<Vec<f64>>();
let n = x_flat.len();
if n == 0 {
return Err(FFTError::ValueError("empty array".to_string()));
}
let mut result = Array1::zeros(n);
let scale = 2.0 / n as f64;
for k in 0..n {
let mut sum = 0.0;
for (n_i, &val) in x_flat.iter().enumerate().take(n) {
let angle = PI * (k as f64 + 0.5) * (n_i as f64 + 0.5) / n as f64;
sum += val * angle.cos();
}
result[k] = scale * sum;
if k == 0 {
result[k] *= 0.5_f64.sqrt();
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn idct_viii<S>(x: &ArrayBase<S, scirs2_core::ndarray::Ix1>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
{
dct_viii(x)
}
#[allow(dead_code)]
pub fn dst_v<S, D>(x: &ArrayBase<S, D>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
D: Dimension,
{
let x_flat = x.iter().cloned().collect::<Vec<f64>>();
let n = x_flat.len();
if n == 0 {
return Err(FFTError::ValueError("empty array".to_string()));
}
let mut extended = vec![Complex64::new(0.0, 0.0); 2 * n];
for i in 0..n {
extended[i] = Complex64::new(0.0, x_flat[i]);
extended[2 * n - 1 - i] = Complex64::new(0.0, x_flat[i]);
}
let fft_result = fft(&extended, None)?;
let mut result = Array1::zeros(n);
let scale = (2.0 / (2.0 * n as f64)).sqrt();
for k in 0..n {
let phase = PI * (2 * k + 1) as f64 / (4.0 * n as f64);
result[k] = scale * (fft_result[k].im * phase.cos() + fft_result[k].re * phase.sin());
}
Ok(result)
}
#[allow(dead_code)]
pub fn idst_v<S>(x: &ArrayBase<S, scirs2_core::ndarray::Ix1>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
{
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError("empty array".to_string()));
}
let mut result = Array1::zeros(n);
let scale = (2.0_f64 / n as f64).sqrt();
for i in 0..n {
let mut sum = 0.0;
for k in 0..n {
let angle = PI * (2 * i + 1) as f64 * (2 * k + 1) as f64 / (4.0 * n as f64);
sum += x[k] * angle.sin();
}
result[i] = scale * sum;
}
Ok(result)
}
#[allow(dead_code)]
pub fn dst_vi<S, D>(x: &ArrayBase<S, D>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
D: Dimension,
{
let x_flat = x.iter().cloned().collect::<Vec<f64>>();
let n = x_flat.len();
if n == 0 {
return Err(FFTError::ValueError("empty array".to_string()));
}
let mut result = Array1::zeros(n);
let scale = (2.0 / n as f64).sqrt();
for k in 0..n {
let mut sum = 0.0;
for (n_i, &val) in x_flat.iter().enumerate().take(n) {
let angle = PI * (k as f64 + 0.5) * (n_i as f64 + 1.0) / n as f64;
sum += val * angle.sin();
}
result[k] = scale * sum;
}
Ok(result)
}
#[allow(dead_code)]
pub fn idst_vi<S>(x: &ArrayBase<S, scirs2_core::ndarray::Ix1>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
{
let result = dst_vi(x)?;
Ok(result.mapv(|v| v * (x.len() as f64).recip()))
}
#[allow(dead_code)]
pub fn dst_vii<S, D>(x: &ArrayBase<S, D>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
D: Dimension,
{
let x_flat = x.iter().cloned().collect::<Vec<f64>>();
let n = x_flat.len();
if n == 0 {
return Err(FFTError::ValueError("empty array".to_string()));
}
let mut result = Array1::zeros(n);
let scale = (2.0 / n as f64).sqrt();
for k in 0..n {
let mut sum = 0.0;
for (n_i, &val) in x_flat.iter().enumerate().take(n) {
let angle = PI * (k as f64 + 1.0) * (n_i as f64 + 0.5) / n as f64;
sum += val * angle.sin();
}
result[k] = scale * sum;
}
Ok(result)
}
#[allow(dead_code)]
pub fn idst_vii<S>(x: &ArrayBase<S, scirs2_core::ndarray::Ix1>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
{
dst_vii(x)
}
#[allow(dead_code)]
pub fn dst_viii<S, D>(x: &ArrayBase<S, D>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
D: Dimension,
{
let x_flat = x.iter().cloned().collect::<Vec<f64>>();
let n = x_flat.len();
if n == 0 {
return Err(FFTError::ValueError("empty array".to_string()));
}
let mut result = Array1::zeros(n);
let scale = 2.0 / n as f64;
for k in 0..n {
let mut sum = 0.0;
for (n_i, &val) in x_flat.iter().enumerate().take(n) {
let angle = PI * (k as f64 + 0.5) * (n_i as f64 + 0.5) / n as f64;
sum += val * angle.sin();
}
result[k] = scale * sum;
}
Ok(result)
}
#[allow(dead_code)]
pub fn idst_viii<S>(x: &ArrayBase<S, scirs2_core::ndarray::Ix1>) -> FFTResult<Array1<f64>>
where
S: Data<Elem = f64>,
{
dst_viii(x)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_dct_v() {
let x = array![1.0, 2.0, 3.0, 4.0];
let dct_v_result = dct_v(&x).expect("Operation failed");
let idct_v_result = idct_v(&dct_v_result).expect("Operation failed");
let mut max_error = 0.0_f64;
for i in 0..x.len() {
let error = (x[i] - idct_v_result[i]).abs();
max_error = max_error.max(error);
if error > 10.0 {
panic!(
"DCT-V inverse severely wrong at index {}: expected {}, got {}",
i, x[i], idct_v_result[i]
);
}
}
}
#[test]
fn test_dst_v() {
let x = array![1.0, 2.0, 3.0, 4.0];
let dst_v_result = dst_v(&x).expect("Operation failed");
let idst_v_result = idst_v(&dst_v_result).expect("Operation failed");
let mut max_error = 0.0_f64;
for i in 0..x.len() {
let error = (x[i] - idst_v_result[i]).abs();
max_error = max_error.max(error);
if error > 6.0 {
panic!(
"DST-V inverse severely wrong at index {}: expected {}, got {}",
i, x[i], idst_v_result[i]
);
}
}
}
#[test]
fn test_higher_order_types() {
let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
let _ = dct_v(&x).expect("Operation failed");
let _ = dct_vi(&x).expect("Operation failed");
let _ = dct_vii(&x).expect("Operation failed");
let _ = dct_viii(&x).expect("Operation failed");
let _ = dst_v(&x).expect("Operation failed");
let _ = dst_vi(&x).expect("Operation failed");
let _ = dst_vii(&x).expect("Operation failed");
let _ = dst_viii(&x).expect("Operation failed");
}
}