use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use scirs2_core::numeric::{Complex64, NumCast, Zero};
use std::f64::consts::PI;
use std::fmt::Debug;
fn cast_f64<T: NumCast + Copy + Debug>(x: &[T]) -> FFTResult<Vec<f64>> {
x.iter()
.map(|&v| {
NumCast::from(v).ok_or_else(|| {
FFTError::ValueError(format!("Cannot cast {v:?} to f64"))
})
})
.collect()
}
pub fn dct2<T>(x: &[T], norm: Option<&str>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
if x.is_empty() {
return Err(FFTError::ValueError("dct2: input is empty".into()));
}
let input = cast_f64(x)?;
dct2_f64(&input, norm)
}
fn dct2_f64(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
let ext_len = 2 * n;
let mut extended = vec![Complex64::zero(); ext_len];
for k in 0..n {
extended[k] = Complex64::new(x[k], 0.0);
extended[ext_len - 1 - k] = Complex64::new(x[k], 0.0);
}
let y_fft = fft(&extended, None)?;
let mut result = Vec::with_capacity(n);
for k in 0..n {
let phase = -PI * k as f64 / ext_len as f64;
let twiddle = Complex64::new(phase.cos(), phase.sin());
let val = y_fft[k] * twiddle;
result.push(val.re);
}
if norm == Some("ortho") {
let scale0 = 1.0 / (4.0 * n as f64).sqrt();
let scale_k = 1.0 / (2.0 * n as f64).sqrt();
result[0] *= scale0;
for v in result.iter_mut().skip(1) {
*v *= scale_k;
}
}
Ok(result)
}
pub fn idct2<T>(x: &[T], norm: Option<&str>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
if x.is_empty() {
return Err(FFTError::ValueError("idct2: input is empty".into()));
}
let input = cast_f64(x)?;
idct2_f64(&input, norm)
}
fn idct2_f64(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
dct3_f64(x, norm)
}
pub fn dct3<T>(x: &[T], norm: Option<&str>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
if x.is_empty() {
return Err(FFTError::ValueError("dct3: input is empty".into()));
}
let input = cast_f64(x)?;
dct3_f64(&input, norm)
}
fn dct3_f64(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
let mut v = vec![Complex64::zero(); n];
if norm == Some("ortho") {
let scale0 = (4.0 * n as f64).sqrt();
let scale_k = (2.0 * n as f64).sqrt();
let phase0 = PI / (4.0 * n as f64); let twiddle0 = Complex64::new(phase0.cos(), phase0.sin());
v[0] = Complex64::new(x[0], 0.0) / scale0 * twiddle0 * Complex64::new(2.0 * n as f64, 0.0);
let mut x_scaled = vec![0.0_f64; n];
x_scaled[0] = x[0] * scale0;
for k in 1..n {
x_scaled[k] = x[k] * scale_k;
}
return dct3_core(&x_scaled, n);
}
for k in 0..n {
let xk = if k == 0 { x[0] * 0.5 } else { x[k] };
let phase = PI * k as f64 / (2.0 * n as f64);
let twiddle = Complex64::new(phase.cos(), phase.sin());
v[k] = Complex64::new(xk, 0.0) * twiddle;
}
let ifft_result = ifft(&v, None)?;
let scale = 2.0 * n as f64;
Ok(ifft_result.iter().map(|c| c.re * scale).collect())
}
fn dct3_core(x: &[f64], n: usize) -> FFTResult<Vec<f64>> {
let mut v = vec![Complex64::zero(); n];
for k in 0..n {
let xk = if k == 0 { x[0] * 0.5 } else { x[k] };
let phase = PI * k as f64 / (2.0 * n as f64);
let twiddle = Complex64::new(phase.cos(), phase.sin());
v[k] = Complex64::new(xk, 0.0) * twiddle;
}
let ifft_result = ifft(&v, None)?;
let scale = 2.0 * n as f64;
Ok(ifft_result.iter().map(|c| c.re * scale).collect())
}
pub fn dct4<T>(x: &[T], norm: Option<&str>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
if x.is_empty() {
return Err(FFTError::ValueError("dct4: input is empty".into()));
}
let input = cast_f64(x)?;
dct4_f64(&input, norm)
}
fn dct4_f64(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
let mut z = vec![Complex64::zero(); n];
for k in 0..n {
let phase = -PI * (2 * k + 1) as f64 / (4.0 * n as f64);
let twiddle = Complex64::new(phase.cos(), phase.sin());
z[k] = Complex64::new(x[k], 0.0) * twiddle;
}
let ifft_result = ifft(&z, None)?;
let out_twiddle_phase = PI / (4.0 * n as f64);
let out_twiddle = Complex64::new(out_twiddle_phase.cos(), out_twiddle_phase.sin());
let scale = 2.0 * n as f64;
let mut result: Vec<f64> = ifft_result
.iter()
.map(|&c| (out_twiddle * c).re * scale)
.collect();
if norm == Some("ortho") {
let s = 1.0 / (2.0 * n as f64).sqrt();
for v in &mut result {
*v *= s;
}
}
Ok(result)
}
pub fn dst2<T>(x: &[T], norm: Option<&str>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
if x.is_empty() {
return Err(FFTError::ValueError("dst2: input is empty".into()));
}
let input = cast_f64(x)?;
dst2_f64(&input, norm)
}
fn dst2_f64(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
let ext_len = 2 * n;
let mut extended = vec![Complex64::zero(); ext_len];
for k in 0..n {
extended[k] = Complex64::new(x[k], 0.0);
extended[ext_len - 1 - k] = Complex64::new(-x[k], 0.0);
}
let y_fft = fft(&extended, None)?;
let mut result = Vec::with_capacity(n);
for k in 0..n {
let phase = -PI * (k + 1) as f64 / ext_len as f64;
let twiddle = Complex64::new(phase.cos(), phase.sin());
let val = y_fft[k + 1] * twiddle;
result.push(-val.im);
}
if norm == Some("ortho") {
let scale_k = 1.0 / (2.0 * n as f64).sqrt();
let scale_n = 1.0 / (4.0 * n as f64).sqrt();
for (k, v) in result.iter_mut().enumerate() {
*v *= if k == n - 1 { scale_n } else { scale_k };
}
}
Ok(result)
}
pub fn idst2<T>(x: &[T], norm: Option<&str>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
if x.is_empty() {
return Err(FFTError::ValueError("idst2: input is empty".into()));
}
let input = cast_f64(x)?;
idst2_f64(&input, norm)
}
fn idst2_f64(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
let x_work: Vec<f64> = if norm == Some("ortho") {
let scale_k = (2.0 * n as f64).sqrt();
let scale_n = (4.0 * n as f64).sqrt();
x.iter()
.enumerate()
.map(|(k, &v)| v * if k == n - 1 { scale_n } else { scale_k })
.collect()
} else {
x.to_vec()
};
dst3_f64(&x_work)
}
fn dst3_f64(x: &[f64]) -> FFTResult<Vec<f64>> {
let n = x.len();
let mut v = vec![Complex64::zero(); n];
for k in 0..n {
let phase = PI * (k + 1) as f64 / (2.0 * n as f64);
let twiddle = Complex64::new(phase.cos(), phase.sin());
let xk = if k == n - 1 { x[k] * 0.5 } else { x[k] };
v[k] = Complex64::new(xk, 0.0) * twiddle;
}
let ifft_result = ifft(&v, None)?;
let scale = 2.0 * n as f64;
Ok(ifft_result.iter().map(|c| c.im * scale).collect())
}
pub fn mdct(x: &[f64], n: usize) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("mdct: n must be > 0".into()));
}
if x.len() != 2 * n {
return Err(FFTError::ValueError(format!(
"mdct: input length {} must be 2*n = {}",
x.len(),
2 * n
)));
}
let half = n / 2;
let mut y = vec![0.0_f64; n];
for k in 0..n {
y[k] = if k < half {
-x[half + k] - x[half - 1 - k]
} else {
let m = k - half;
x[m + n] - x[n + half - 1 - m] };
let _ = m_safe_fold(x, n, k); }
for k in 0..n {
y[k] = fold_mdct(x, n, k);
}
dct4_f64(&y, None)
}
pub fn imdct(x: &[f64], n: usize) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("imdct: n must be > 0".into()));
}
if x.len() != n {
return Err(FFTError::ValueError(format!(
"imdct: input length {} must equal n = {n}",
x.len()
)));
}
let dct4_result = dct4_f64(x, None)?;
let scale = 1.0 / (2.0 * n as f64);
let half = n / 2;
let mut output = vec![0.0_f64; 2 * n];
for k in 0..n {
let v = dct4_result[k] * scale;
let unfolded = unfold_imdct(&dct4_result, n, k, scale);
output[k] = unfolded.0;
output[k + n] = unfolded.1;
let _ = v;
let _ = half;
}
let z: Vec<f64> = dct4_result.iter().map(|&v| v * scale).collect();
let mut out = vec![0.0_f64; 2 * n];
for k in 0..n {
let unf = imdct_unfold(&z, n, k);
out[k] = unf;
}
for k in 0..n {
out[2 * n - 1 - k] = out[k];
}
direct_imdct(x, n)
}
fn direct_imdct(x: &[f64], n: usize) -> FFTResult<Vec<f64>> {
let mut out = vec![0.0_f64; 2 * n];
let scale = 1.0 / n as f64;
for m in 0..2 * n {
let mut sum = 0.0_f64;
for k in 0..n {
let phase =
PI / (2.0 * n as f64) * (2.0 * m as f64 + (n as f64) * 0.5 + 1.0) * (2.0 * k as f64 + 1.0);
sum += x[k] * phase.cos();
}
out[m] = sum * scale;
}
Ok(out)
}
#[inline]
fn fold_mdct(x: &[f64], n: usize, k: usize) -> f64 {
let half = n / 2; if k < half {
let i0 = half + k; let i1 = half - 1 - k; -x[i0] - x[i1]
} else {
let m = k - half; let i0 = m + n; let i1 = n + half - 1 - m; x[i0] - x[i1]
}
}
#[allow(dead_code)]
#[inline]
fn m_safe_fold(_x: &[f64], _n: usize, _k: usize) -> f64 { 0.0 }
#[allow(dead_code)]
#[inline]
fn unfold_imdct(_z: &[f64], _n: usize, _k: usize, _scale: f64) -> (f64, f64) { (0.0, 0.0) }
#[allow(dead_code)]
#[inline]
fn imdct_unfold(_z: &[f64], _n: usize, _k: usize) -> f64 { 0.0 }
pub fn dct2_2d(block: &[f64], rows: usize, cols: usize, norm: Option<&str>) -> FFTResult<Vec<f64>> {
if rows == 0 || cols == 0 {
return Err(FFTError::ValueError("dct2_2d: dimensions must be > 0".into()));
}
if block.len() != rows * cols {
return Err(FFTError::ValueError(format!(
"dct2_2d: expected {} elements, got {}",
rows * cols,
block.len()
)));
}
let mut buf = block.to_vec();
for i in 0..rows {
let row = buf[i * cols..(i + 1) * cols].to_vec();
let row_dct = dct2_f64(&row, norm)?;
buf[i * cols..(i + 1) * cols].copy_from_slice(&row_dct);
}
for j in 0..cols {
let col: Vec<f64> = (0..rows).map(|i| buf[i * cols + j]).collect();
let col_dct = dct2_f64(&col, norm)?;
for i in 0..rows {
buf[i * cols + j] = col_dct[i];
}
}
Ok(buf)
}
pub fn idct2_2d(block: &[f64], rows: usize, cols: usize, norm: Option<&str>) -> FFTResult<Vec<f64>> {
if rows == 0 || cols == 0 {
return Err(FFTError::ValueError("idct2_2d: dimensions must be > 0".into()));
}
if block.len() != rows * cols {
return Err(FFTError::ValueError(format!(
"idct2_2d: expected {} elements, got {}",
rows * cols,
block.len()
)));
}
let mut buf = block.to_vec();
for j in 0..cols {
let col: Vec<f64> = (0..rows).map(|i| buf[i * cols + j]).collect();
let col_idct = idct2_f64(&col, norm)?;
for i in 0..rows {
buf[i * cols + j] = col_idct[i];
}
}
for i in 0..rows {
let row = buf[i * cols..(i + 1) * cols].to_vec();
let row_idct = idct2_f64(&row, norm)?;
buf[i * cols..(i + 1) * cols].copy_from_slice(&row_idct);
}
Ok(buf)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn assert_f64_close(a: &[f64], b: &[f64], tol: f64, label: &str) {
assert_eq!(a.len(), b.len(), "{label}: length mismatch");
for (i, (&ai, &bi)) in a.iter().zip(b.iter()).enumerate() {
assert_relative_eq!(ai, bi, epsilon = tol, var_name = format!("{label}[{i}]"));
}
}
#[test]
fn test_dct2_dc_component() {
let x = vec![1.0_f64, 1.0, 1.0, 1.0];
let coeffs = dct2(&x, None).expect("dct2");
assert_relative_eq!(coeffs[0], 8.0, epsilon = 1e-9);
for k in 1..4 {
assert_relative_eq!(coeffs[k].abs(), 0.0, epsilon = 1e-9);
}
}
#[test]
fn test_dct2_roundtrip_ortho() {
let n = 8;
let x: Vec<f64> = (0..n).map(|k| k as f64).collect();
let X = dct2(&x, Some("ortho")).expect("dct2");
let x_rec = idct2(&X, Some("ortho")).expect("idct2");
assert_f64_close(&x, &x_rec, 1e-9, "dct2 ortho roundtrip");
}
#[test]
fn test_dct2_length_16() {
let n = 16;
let x: Vec<f64> = (0..n).map(|k| (k as f64 / n as f64 * PI).sin()).collect();
let X = dct2(&x, None).expect("dct2");
let x_rec = idct2(&X, None).expect("idct2");
assert_f64_close(&x, &x_rec, 1e-9, "dct2 unorm roundtrip 16");
}
#[test]
fn test_dct2_empty_error() {
let empty: Vec<f64> = vec![];
assert!(dct2(&empty, None).is_err());
}
#[test]
fn test_dct3_inverse_of_dct2_ortho() {
let n = 8;
let x: Vec<f64> = (0..n).map(|k| k as f64).collect();
let X = dct2(&x, Some("ortho")).expect("dct2");
let x_rec = dct3(&X, Some("ortho")).expect("dct3");
assert_f64_close(&x, &x_rec, 1e-8, "dct3 is inverse of dct2 ortho");
}
#[test]
fn test_dct4_self_inverse() {
let n = 8;
let x: Vec<f64> = (0..n).map(|k| k as f64).collect();
let X = dct4(&x, None).expect("dct4");
let x2 = dct4(&X, None).expect("dct4 double");
let scale = 2.0 * n as f64;
for (a, b) in x.iter().zip(x2.iter()) {
assert_relative_eq!(a * scale, *b, epsilon = 1e-7);
}
}
#[test]
fn test_dct4_ortho_self_inverse() {
let n = 8;
let x: Vec<f64> = (0..n).map(|k| (k as f64).sin()).collect();
let X = dct4(&x, Some("ortho")).expect("dct4 ortho");
let x2 = dct4(&X, Some("ortho")).expect("dct4 ortho double");
assert_f64_close(&x, &x2, 1e-7, "dct4 ortho self-inverse");
}
#[test]
fn test_dst2_length() {
let x = vec![1.0_f64, 2.0, 3.0, 4.0];
let coeffs = dst2(&x, None).expect("dst2");
assert_eq!(coeffs.len(), 4);
}
#[test]
fn test_dst2_idst2_roundtrip_ortho() {
let n = 8;
let x: Vec<f64> = (0..n).map(|k| k as f64).collect();
let X = dst2(&x, Some("ortho")).expect("dst2");
let x_rec = idst2(&X, Some("ortho")).expect("idst2");
assert_f64_close(&x, &x_rec, 1e-8, "dst2 ortho roundtrip");
}
#[test]
fn test_dst2_empty_error() {
let empty: Vec<f64> = vec![];
assert!(dst2(&empty, None).is_err());
}
#[test]
fn test_mdct_output_length() {
let n = 8;
let frame: Vec<f64> = (0..2 * n).map(|k| k as f64).collect();
let coeffs = mdct(&frame, n).expect("mdct");
assert_eq!(coeffs.len(), n);
}
#[test]
fn test_imdct_output_length() {
let n = 8;
let frame: Vec<f64> = (0..2 * n).map(|k| (k as f64 * 0.3).sin()).collect();
let coeffs = mdct(&frame, n).expect("mdct");
let restored = imdct(&coeffs, n).expect("imdct");
assert_eq!(restored.len(), 2 * n);
}
#[test]
fn test_mdct_invalid_length() {
let n = 8;
let frame = vec![0.0_f64; 2 * n + 1]; assert!(mdct(&frame, n).is_err());
}
#[test]
fn test_mdct_zero_input() {
let n = 4;
let frame = vec![0.0_f64; 2 * n];
let coeffs = mdct(&frame, n).expect("mdct zero");
for &v in &coeffs {
assert_relative_eq!(v, 0.0, epsilon = 1e-12);
}
}
#[test]
fn test_dct2_2d_roundtrip() {
let block: Vec<f64> = (0..16).map(|k| k as f64).collect();
let coeffs = dct2_2d(&block, 4, 4, Some("ortho")).expect("dct2_2d");
let recovered = idct2_2d(&coeffs, 4, 4, Some("ortho")).expect("idct2_2d");
assert_f64_close(&block, &recovered, 1e-9, "dct2_2d roundtrip");
}
#[test]
fn test_dct2_2d_dc() {
let rows = 4;
let cols = 4;
let block = vec![1.0_f64; rows * cols];
let coeffs = dct2_2d(&block, rows, cols, None).expect("dct2_2d");
let dc = coeffs[0].abs();
for (k, &v) in coeffs.iter().enumerate().skip(1) {
assert!(v.abs() <= dc + 1e-9, "non-DC bin {k} larger than DC");
}
}
#[test]
fn test_dct2_2d_wrong_size() {
let block = vec![0.0_f64; 10];
assert!(dct2_2d(&block, 4, 4, None).is_err());
}
}