use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DCTType {
DCT1,
DCT2,
DCT3,
DCT4,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DSTType {
DST1,
DST2,
DST3,
DST4,
}
pub fn dct(x: &[f64], dct_type: DCTType) -> FFTResult<Vec<f64>> {
if x.is_empty() {
return Err(FFTError::ValueError("dct: input must not be empty".to_string()));
}
match dct_type {
DCTType::DCT1 => dct1(x),
DCTType::DCT2 => dct2(x),
DCTType::DCT3 => dct3(x),
DCTType::DCT4 => dct4(x),
}
}
pub fn idct(x: &[f64], dct_type: DCTType) -> FFTResult<Vec<f64>> {
if x.is_empty() {
return Err(FFTError::ValueError("idct: input must not be empty".to_string()));
}
let n = x.len();
match dct_type {
DCTType::DCT1 => {
let y = dct1(x)?;
let scale = 2.0 * (n - 1) as f64;
Ok(y.iter().map(|&v| v / scale).collect())
}
DCTType::DCT2 => {
let mut y = dct3(x)?;
let scale = 2.0 * n as f64;
for v in y.iter_mut() {
*v /= scale;
}
Ok(y)
}
DCTType::DCT3 => {
let mut y = dct2(x)?;
let scale = 2.0 * n as f64;
for v in y.iter_mut() {
*v /= scale;
}
Ok(y)
}
DCTType::DCT4 => {
let y = dct4(x)?;
let scale = 2.0 * n as f64;
Ok(y.iter().map(|&v| v / scale).collect())
}
}
}
fn dct1(x: &[f64]) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 1 {
return Ok(x.to_vec());
}
let m = 2 * (n - 1);
let mut ext = Vec::with_capacity(m);
ext.extend_from_slice(x);
for i in (1..n - 1).rev() {
ext.push(x[i]);
}
let ext_c: Vec<Complex64> = ext.iter().map(|&v| Complex64::new(v, 0.0)).collect();
let fft_out = fft(&ext_c, Some(m))?;
Ok(fft_out[..n].iter().map(|c| c.re).collect())
}
fn dct2(x: &[f64]) -> FFTResult<Vec<f64>> {
let n = x.len();
let mut ext = Vec::with_capacity(2 * n);
ext.extend_from_slice(x);
for i in (0..n).rev() {
ext.push(x[i]);
}
let ext_c: Vec<Complex64> = ext.iter().map(|&v| Complex64::new(v, 0.0)).collect();
let fft_out = fft(&ext_c, Some(2 * n))?;
let scale = 1.0 / (2.0 * n as f64).sqrt(); let _ = scale;
let mut result = Vec::with_capacity(n);
for k in 0..n {
let angle = -PI * k as f64 / (2.0 * n as f64);
let twiddle = Complex64::new(angle.cos(), angle.sin());
result.push((fft_out[k] * twiddle).re);
}
Ok(result)
}
fn dct3(x: &[f64]) -> FFTResult<Vec<f64>> {
let n = x.len();
let mut x_twiddle: Vec<Complex64> = x
.iter()
.enumerate()
.map(|(k, &v)| {
let angle = PI * k as f64 / (2.0 * n as f64);
Complex64::new(v * angle.cos(), v * angle.sin())
})
.collect();
x_twiddle.resize(2 * n, Complex64::new(0.0, 0.0));
for k in 1..n {
x_twiddle[2 * n - k] = x_twiddle[k].conj();
}
let y = ifft(&x_twiddle[..2 * n].to_vec(), Some(2 * n))?;
Ok(y[..n].iter().map(|c| c.re * 2.0 * n as f64).collect())
}
fn dct4(x: &[f64]) -> FFTResult<Vec<f64>> {
let n = x.len();
let mut z: Vec<Complex64> = Vec::with_capacity(n);
for m in 0..n {
let angle = -PI * (2 * m + 1) as f64 / (4.0 * n as f64);
z.push(Complex64::new(x[m] * angle.cos(), x[m] * angle.sin()));
}
let z_fft = fft(&z, Some(n))?;
let mut result = Vec::with_capacity(n);
for k in 0..n {
let angle = -PI * k as f64 / (2.0 * n as f64);
let twiddle = Complex64::new(angle.cos(), angle.sin());
result.push(2.0 * (z_fft[k] * twiddle).re);
}
Ok(result)
}
pub fn mdct(x: &[f64], n: usize) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("mdct: n must be > 0".to_string()));
}
if x.len() != 2 * n {
return Err(FFTError::ValueError(format!(
"mdct: input length {} must equal 2*n = {}",
x.len(),
2 * n
)));
}
let windowed: Vec<f64> = x
.iter()
.enumerate()
.map(|(i, &xi)| {
let w = (PI / (2.0 * n as f64) * (i as f64 + 0.5)).sin();
xi * w
})
.collect();
let mut rotated = Vec::with_capacity(n);
for k in 0..n {
let idx = k + n / 2;
let v = if idx < 2 * n {
windowed[idx]
} else {
0.0
};
rotated.push(if k < n / 2 { -windowed[n / 2 + n - 1 - k] - windowed[n - 1 - k] } else { windowed[k - n / 2] - windowed[3 * n / 2 - 1 - k] });
}
let result = dct4(&rotated)?;
let scale = 1.0 / (2.0 * n as f64).sqrt();
Ok(result.iter().map(|&v| v * scale).collect())
}
pub fn imdct(x: &[f64], n: usize) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("imdct: n must be > 0".to_string()));
}
if x.len() != n {
return Err(FFTError::ValueError(format!(
"imdct: input length {} must equal n = {}",
x.len(),
n
)));
}
let scale = 2.0 / n as f64;
let mut out = vec![0.0_f64; 2 * n];
for i in 0..2 * n {
let mut sum = 0.0;
for (k, &xk) in x.iter().enumerate() {
let angle = PI / n as f64 * (i as f64 + 0.5 + n as f64 / 2.0) * (k as f64 + 0.5);
sum += xk * angle.cos();
}
out[i] = scale * sum;
}
for (i, v) in out.iter_mut().enumerate() {
let w = (PI / (2.0 * n as f64) * (i as f64 + 0.5)).sin();
*v *= w;
}
Ok(out)
}
pub fn dst(x: &[f64], dst_type: DSTType) -> FFTResult<Vec<f64>> {
if x.is_empty() {
return Err(FFTError::ValueError("dst: input must not be empty".to_string()));
}
match dst_type {
DSTType::DST1 => dst1(x),
DSTType::DST2 => dst2(x),
DSTType::DST3 => dst3(x),
DSTType::DST4 => dst4(x),
}
}
fn dst1(x: &[f64]) -> FFTResult<Vec<f64>> {
let n = x.len();
let m = 2 * (n + 1);
let mut ext = vec![0.0_f64; m];
for i in 0..n {
ext[i + 1] = x[i];
ext[m - 1 - i] = -x[i];
}
let ext_c: Vec<Complex64> = ext.iter().map(|&v| Complex64::new(v, 0.0)).collect();
let fft_out = fft(&ext_c, Some(m))?;
Ok(fft_out[1..=n].iter().map(|c| -c.im).collect())
}
fn dst2(x: &[f64]) -> FFTResult<Vec<f64>> {
let n = x.len();
let x_rev: Vec<f64> = x.iter().rev().copied().collect();
let dct = dct2(&x_rev)?;
Ok(dct
.iter()
.enumerate()
.map(|(k, &v)| if k % 2 == 0 { v } else { -v })
.collect())
}
fn dst3(x: &[f64]) -> FFTResult<Vec<f64>> {
let n = x.len();
let x_neg: Vec<f64> = x
.iter()
.enumerate()
.map(|(k, &v)| if k % 2 == 0 { v } else { -v })
.collect();
let dct = dct3(&x_neg)?;
Ok(dct.iter().rev().copied().collect())
}
fn dst4(x: &[f64]) -> FFTResult<Vec<f64>> {
let n = x.len();
let x_rev_neg: Vec<f64> = x
.iter()
.rev()
.enumerate()
.map(|(k, &v)| if k % 2 == 0 { v } else { -v })
.collect();
let result = dct4(&x_rev_neg)?;
Ok(result
.iter()
.enumerate()
.map(|(k, &v)| if k % 2 == 0 { v } else { -v })
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_dct2_idct2_roundtrip() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = dct(&x, DCTType::DCT2).expect("failed to create y");
let z = idct(&y, DCTType::DCT2).expect("failed to create z");
for (a, b) in x.iter().zip(z.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}
#[test]
fn test_dct1_known() {
let x = vec![1.0, 1.0, 1.0];
let y = dct(&x, DCTType::DCT1).expect("failed to create y");
assert_eq!(y.len(), 3);
}
#[test]
fn test_dct1_roundtrip() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = dct(&x, DCTType::DCT1).expect("failed to create y");
let z = idct(&y, DCTType::DCT1).expect("failed to create z");
for (a, b) in x.iter().zip(z.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-9);
}
}
#[test]
fn test_dct3_roundtrip_via_dct2() {
let x = vec![3.0, -1.0, 2.0, 0.5];
let y = dct(&x, DCTType::DCT2).expect("failed to create y");
let z = idct(&y, DCTType::DCT2).expect("failed to create z");
for (a, b) in x.iter().zip(z.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-9);
}
}
#[test]
fn test_dct4_involution() {
let x = vec![1.0, -2.0, 3.0, -4.0];
let y = dct(&x, DCTType::DCT4).expect("failed to create y");
let z = idct(&y, DCTType::DCT4).expect("failed to create z");
for (a, b) in x.iter().zip(z.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-9);
}
}
#[test]
fn test_dct2_dc_term() {
let n = 4;
let val = 3.0_f64;
let x = vec![val; n];
let y = dct(&x, DCTType::DCT2).expect("failed to create y");
assert_relative_eq!(y[0], 2.0 * n as f64 * val, epsilon = 1e-10);
for &yk in &y[1..] {
assert_relative_eq!(yk, 0.0, epsilon = 1e-10);
}
}
#[test]
fn test_dst1_roundtrip() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = dst(&x, DSTType::DST1).expect("failed to create y");
assert_eq!(y.len(), 4);
let z = dst(&y, DSTType::DST1).expect("failed to create z");
let scale = 2.0 * (x.len() + 1) as f64;
for (a, b) in x.iter().zip(z.iter()) {
assert_relative_eq!(*a, b / scale, epsilon = 1e-9);
}
}
#[test]
fn test_dst2_length() {
let x = vec![1.0, -1.0, 2.0, 0.5];
let y = dst(&x, DSTType::DST2).expect("failed to create y");
assert_eq!(y.len(), 4);
}
#[test]
fn test_dst4_length() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = dst(&x, DSTType::DST4).expect("failed to create y");
assert_eq!(y.len(), 4);
}
#[test]
fn test_mdct_output_length() {
let n = 8;
let x: Vec<f64> = (0..2 * n).map(|i| (i as f64 * 0.3).sin()).collect();
let y = mdct(&x, n).expect("failed to create y");
assert_eq!(y.len(), n);
}
#[test]
fn test_imdct_output_length() {
let n = 8;
let x: Vec<f64> = (0..n).map(|i| (i as f64 * 0.4).cos()).collect();
let y = imdct(&x, n).expect("failed to create y");
assert_eq!(y.len(), 2 * n);
}
#[test]
fn test_dct_empty_error() {
assert!(dct(&[], DCTType::DCT2).is_err());
}
#[test]
fn test_dst_empty_error() {
assert!(dst(&[], DSTType::DST2).is_err());
}
#[test]
fn test_mdct_wrong_length() {
assert!(mdct(&[1.0, 2.0, 3.0], 4).is_err());
}
}