use crate::error::{FFTError, FFTResult};
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
pub fn fast_dct1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n < 2 {
return Err(FFTError::ValueError(
"DCT-I requires at least 2 elements".to_string(),
));
}
let ext_len = 2 * (n - 1);
let mut extended = vec![Complex64::new(0.0, 0.0); ext_len];
for i in 0..n {
extended[i] = Complex64::new(x[i], 0.0);
}
for i in 1..n - 1 {
extended[ext_len - i] = Complex64::new(x[i], 0.0);
}
let fft_result = crate::fft::fft(&extended, None)?;
let mut result = vec![0.0; n];
for k in 0..n {
result[k] = fft_result[k].re / 2.0;
}
result[0] = fft_result[0].re / 2.0;
result[n - 1] = fft_result[n - 1].re / 2.0;
if norm == Some("ortho") {
let scale = (2.0 / (n - 1) as f64).sqrt();
let endpoint_scale = 1.0 / 2.0_f64.sqrt();
for (k, val) in result.iter_mut().enumerate() {
if k == 0 || k == n - 1 {
*val *= scale * endpoint_scale;
} else {
*val *= scale;
}
}
}
Ok(result)
}
pub fn fast_dct2(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError("Input cannot be empty".to_string()));
}
if n == 1 {
return Ok(vec![x[0]]);
}
let mut y = vec![0.0; n];
for k in 0..n.div_ceil(2) {
y[k] = x[2 * k];
}
for k in 0..(n / 2) {
y[n - 1 - k] = x[2 * k + 1];
}
let y_complex: Vec<Complex64> = y.iter().map(|&v| Complex64::new(v, 0.0)).collect();
let fft_result = crate::fft::fft(&y_complex, Some(n))?;
let mut result = Vec::with_capacity(n);
for k in 0..n {
let twiddle = Complex64::from_polar(1.0, -PI * k as f64 / (2.0 * n as f64));
let val = fft_result[k] * twiddle;
result.push(val.re);
}
if norm == Some("ortho") {
let scale = (2.0 / n as f64).sqrt();
let first_scale = 1.0 / 2.0_f64.sqrt();
result[0] *= scale * first_scale;
for val in result.iter_mut().skip(1) {
*val *= scale;
}
}
Ok(result)
}
pub fn fast_dct3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError("Input cannot be empty".to_string()));
}
if n == 1 {
return Ok(vec![x[0]]);
}
let mut input = x.to_vec();
if norm == Some("ortho") {
let inv_scale = (n as f64 / 2.0).sqrt();
let first_inv = 2.0_f64.sqrt();
input[0] *= inv_scale * first_inv;
for val in input.iter_mut().skip(1) {
*val *= inv_scale;
}
}
let mut y_fft = vec![Complex64::new(0.0, 0.0); n];
y_fft[0] = Complex64::new(input[0], 0.0);
for k in 1..n {
let dct_k = input[k];
let dct_nk = if n - k < n { input[n - k] } else { 0.0 };
let combined = Complex64::new(dct_k, -dct_nk);
let inv_twiddle = Complex64::from_polar(1.0, PI * k as f64 / (2.0 * n as f64));
y_fft[k] = combined * inv_twiddle;
}
let y = crate::fft::ifft(&y_fft, Some(n))?;
let mut result = vec![0.0; n];
for k in 0..n.div_ceil(2) {
result[2 * k] = y[k].re;
}
for k in 0..(n / 2) {
result[2 * k + 1] = y[n - 1 - k].re;
}
Ok(result)
}
pub fn fast_dct4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError("Input cannot be empty".to_string()));
}
let mut result = Vec::with_capacity(n);
if n <= 256 {
for k in 0..n {
let mut sum = 0.0;
for (m, &val) in x.iter().enumerate() {
let angle = PI * (m as f64 + 0.5) * (k as f64 + 0.5) / n as f64;
sum += val * angle.cos();
}
result.push(sum);
}
} else {
let mut z: Vec<Complex64> = Vec::with_capacity(n);
for m in 0..n {
let angle = -PI * (2.0 * m as f64 + 1.0) / (4.0 * n as f64);
let twiddle = Complex64::from_polar(1.0, angle);
z.push(Complex64::new(x[m], 0.0) * twiddle);
}
let fft_z = crate::fft::fft(&z, Some(n))?;
for k in 0..n {
let angle = -PI * (2.0 * k as f64 + 1.0) / (4.0 * n as f64);
let twiddle = Complex64::from_polar(1.0, angle);
let val = fft_z[k] * twiddle;
result.push(val.re);
}
}
if norm == Some("ortho") {
let scale = (2.0 / n as f64).sqrt();
for val in &mut result {
*val *= scale;
}
} else {
for val in &mut result {
*val *= 2.0;
}
}
Ok(result)
}
pub fn fast_dst1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n < 2 {
return Err(FFTError::ValueError(
"DST-I requires at least 2 elements".to_string(),
));
}
let ext_len = 2 * (n + 1);
let mut extended = vec![Complex64::new(0.0, 0.0); ext_len];
for i in 0..n {
extended[i + 1] = Complex64::new(x[i], 0.0);
extended[ext_len - i - 1] = Complex64::new(-x[i], 0.0);
}
let fft_result = crate::fft::fft(&extended, None)?;
let mut result = vec![0.0; n];
for k in 0..n {
result[k] = -fft_result[k + 1].im / 2.0;
}
if norm == Some("ortho") {
let scale = (2.0 / (n as f64 + 1.0)).sqrt();
for val in &mut result {
*val *= scale;
}
}
Ok(result)
}
pub fn fast_dst2(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError("Input cannot be empty".to_string()));
}
let ext_len = 4 * n;
let mut extended = vec![Complex64::new(0.0, 0.0); ext_len];
for i in 0..n {
extended[2 * i + 1] = Complex64::new(x[i], 0.0);
extended[ext_len - 2 * i - 1] = Complex64::new(-x[i], 0.0);
}
let fft_result = crate::fft::fft(&extended, None)?;
let mut result = vec![0.0; n];
for k in 0..n {
result[k] = -fft_result[k + 1].im / 2.0;
}
if norm == Some("ortho") {
let scale = (2.0 / n as f64).sqrt();
let last_scale = 1.0 / 2.0_f64.sqrt();
for (k, val) in result.iter_mut().enumerate() {
if k == n - 1 {
*val *= scale * last_scale;
} else {
*val *= scale;
}
}
}
Ok(result)
}
pub fn fast_dst3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError("Input cannot be empty".to_string()));
}
let mut input = x.to_vec();
if norm == Some("ortho") {
let inv_scale = (n as f64 / 2.0).sqrt();
let last_inv = 2.0_f64.sqrt();
for val in input.iter_mut().take(n - 1) {
*val *= inv_scale;
}
input[n - 1] *= inv_scale * last_inv;
}
let mut result = Vec::with_capacity(n);
let n_f = n as f64;
for m in 0..n {
let m_f = m as f64;
let sign = if m % 2 == 0 { 1.0 } else { -1.0 };
let mut sum = sign * input[n - 1] / 2.0;
for k in 0..(n - 1) {
let k_f = (k + 1) as f64;
let angle = PI * k_f * (m_f + 0.5) / n_f;
sum += input[k] * angle.sin();
}
result.push(sum * 2.0 / n_f);
}
Ok(result)
}
pub fn fast_dst4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError("Input cannot be empty".to_string()));
}
let mut result = Vec::with_capacity(n);
for k in 0..n {
let k_f = k as f64;
let mut sum = 0.0;
for (m, &val) in x.iter().enumerate() {
let m_f = m as f64;
let angle = PI * (m_f + 0.5) * (k_f + 0.5) / n as f64;
sum += val * angle.sin();
}
result.push(sum);
}
if norm == Some("ortho") {
let scale = (2.0 / n as f64).sqrt();
for val in &mut result {
*val *= scale;
}
} else {
for val in &mut result {
*val *= 2.0;
}
}
Ok(result)
}
pub fn mdct_stream(
x: &[f64],
block_size: usize,
window: Option<&[f64]>,
) -> FFTResult<Vec<Vec<f64>>> {
if block_size == 0 || block_size % 2 != 0 {
return Err(FFTError::ValueError(
"Block size must be positive and even".to_string(),
));
}
let half_block = block_size / 2;
let default_window: Vec<f64>;
let win = if let Some(w) = window {
if w.len() != block_size {
return Err(FFTError::ValueError(format!(
"Window length {} must match block size {block_size}",
w.len()
)));
}
w
} else {
default_window = (0..block_size)
.map(|i| (PI * (i as f64 + 0.5) / block_size as f64).sin())
.collect();
&default_window
};
let padded_len = x.len().div_ceil(half_block) * half_block + half_block;
let mut padded = vec![0.0; padded_len];
for (i, &val) in x.iter().enumerate() {
padded[i] = val;
}
let n_blocks = (padded_len - half_block) / half_block;
let mut result = Vec::with_capacity(n_blocks);
for block_idx in 0..n_blocks {
let start = block_idx * half_block;
if start + block_size > padded.len() {
break;
}
let mut windowed = vec![0.0; block_size];
for i in 0..block_size {
windowed[i] = padded[start + i] * win[i];
}
let mut coeffs = vec![0.0; half_block];
for k in 0..half_block {
let mut sum = 0.0;
for n_idx in 0..block_size {
let angle = PI / block_size as f64
* (n_idx as f64 + 0.5 + half_block as f64)
* (k as f64 + 0.5);
sum += windowed[n_idx] * angle.cos();
}
coeffs[k] = sum;
}
result.push(coeffs);
}
Ok(result)
}
pub fn imdct_stream(
mdct_blocks: &[Vec<f64>],
block_size: usize,
window: Option<&[f64]>,
) -> FFTResult<Vec<f64>> {
if block_size == 0 || block_size % 2 != 0 {
return Err(FFTError::ValueError(
"Block size must be positive and even".to_string(),
));
}
if mdct_blocks.is_empty() {
return Ok(Vec::new());
}
let half_block = block_size / 2;
let default_window: Vec<f64>;
let win = if let Some(w) = window {
if w.len() != block_size {
return Err(FFTError::ValueError(format!(
"Window length {} must match block size {block_size}",
w.len()
)));
}
w
} else {
default_window = (0..block_size)
.map(|i| (PI * (i as f64 + 0.5) / block_size as f64).sin())
.collect();
&default_window
};
let output_len = half_block * (mdct_blocks.len() + 1);
let mut output = vec![0.0; output_len];
for (block_idx, coeffs) in mdct_blocks.iter().enumerate() {
if coeffs.len() != half_block {
return Err(FFTError::ValueError(format!(
"Block {} has length {}, expected {half_block}",
block_idx,
coeffs.len()
)));
}
let mut time_block = vec![0.0; block_size];
for n_idx in 0..block_size {
let mut sum = 0.0;
for (k, &coeff) in coeffs.iter().enumerate() {
let angle = PI / block_size as f64
* (n_idx as f64 + 0.5 + half_block as f64)
* (k as f64 + 0.5);
sum += coeff * angle.cos();
}
time_block[n_idx] = sum * 2.0 / block_size as f64;
}
let start = block_idx * half_block;
for i in 0..block_size {
if start + i < output_len {
output[start + i] += time_block[i] * win[i];
}
}
}
Ok(output)
}
pub fn batch_dct2(signals: &[&[f64]], norm: Option<&str>) -> FFTResult<Vec<Vec<f64>>> {
if signals.is_empty() {
return Ok(Vec::new());
}
let n = signals[0].len();
for (i, sig) in signals.iter().enumerate() {
if sig.len() != n {
return Err(FFTError::ValueError(format!(
"Signal {} has length {}, expected {n}",
i,
sig.len()
)));
}
}
let mut results = Vec::with_capacity(signals.len());
for sig in signals {
results.push(fast_dct2(sig, norm)?);
}
Ok(results)
}
pub fn quantized_dct(x: &[f64], quality: u32) -> FFTResult<(Vec<i64>, Vec<f64>)> {
if !(1..=100).contains(&quality) {
return Err(FFTError::ValueError(
"Quality must be between 1 and 100".to_string(),
));
}
if x.is_empty() {
return Err(FFTError::ValueError("Input cannot be empty".to_string()));
}
let n = x.len();
let coeffs = fast_dct2(x, Some("ortho"))?;
let scale = if quality < 50 {
5000.0 / quality as f64
} else {
200.0 - 2.0 * quality as f64
};
let q_table: Vec<f64> = (0..n)
.map(|k| {
let base_q = 1.0 + k as f64 * 0.5;
(base_q * scale / 100.0).max(1.0)
})
.collect();
let quantized: Vec<i64> = coeffs
.iter()
.zip(q_table.iter())
.map(|(&c, &q)| (c / q).round() as i64)
.collect();
Ok((quantized, q_table))
}
pub fn dequantized_idct(quantized: &[i64], q_table: &[f64]) -> FFTResult<Vec<f64>> {
if quantized.len() != q_table.len() {
return Err(FFTError::ValueError(
"Quantized coefficients and table must have same length".to_string(),
));
}
let coeffs: Vec<f64> = quantized
.iter()
.zip(q_table.iter())
.map(|(&q, &step)| q as f64 * step)
.collect();
fast_dct3(&coeffs, Some("ortho"))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_fast_dct2_matches_naive() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let naive = crate::dct::dct(&signal, Some(crate::dct::DCTType::Type2), None)
.expect("Naive DCT should succeed");
let fast = fast_dct2(&signal, None).expect("Fast DCT should succeed");
for i in 0..signal.len() {
assert_abs_diff_eq!(naive[i], fast[i], epsilon = 1e-8);
}
}
#[test]
fn test_fast_dct2_ortho_roundtrip() {
let signal = vec![3.15, 2.71, 1.41, 1.73, 0.577, 2.30];
let coeffs = fast_dct2(&signal, Some("ortho")).expect("Forward should succeed");
let recovered = fast_dct3(&coeffs, Some("ortho")).expect("Inverse should succeed");
for i in 0..signal.len() {
assert_abs_diff_eq!(recovered[i], signal[i], epsilon = 1e-8);
}
}
#[test]
fn test_fast_dct1_energy_preservation() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let coeffs = fast_dct1(&signal, None).expect("DCT-I forward should succeed");
assert_eq!(coeffs.len(), signal.len());
for &c in &coeffs {
assert!(c.is_finite(), "DCT-I coefficient should be finite");
}
let n = signal.len();
for k in 0..n {
let mut expected = signal[0] / 2.0
+ if k % 2 == 0 {
signal[n - 1] / 2.0
} else {
-signal[n - 1] / 2.0
};
for j in 1..n - 1 {
expected += signal[j] * (PI * k as f64 * j as f64 / (n - 1) as f64).cos();
}
assert_abs_diff_eq!(coeffs[k], expected, epsilon = 1e-8);
}
}
#[test]
fn test_fast_dct4_properties() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let n = signal.len();
let coeffs = fast_dct4(&signal, None).expect("DCT-IV forward should succeed");
assert_eq!(coeffs.len(), n);
for k in 0..n {
let mut expected = 0.0;
for (m, &val) in signal.iter().enumerate() {
let angle = PI * (m as f64 + 0.5) * (k as f64 + 0.5) / n as f64;
expected += val * angle.cos();
}
expected *= 2.0;
assert_abs_diff_eq!(coeffs[k], expected, epsilon = 1e-8);
}
}
#[test]
fn test_fast_dst1() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let coeffs = fast_dst1(&signal, Some("ortho")).expect("DST-I should succeed");
assert_eq!(coeffs.len(), 4);
for &c in &coeffs {
assert!(c.is_finite(), "DST-I coefficient should be finite");
}
}
#[test]
fn test_fast_dst2_roundtrip() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let coeffs = fast_dst2(&signal, Some("ortho")).expect("DST-II should succeed");
let recovered = fast_dst3(&coeffs, Some("ortho")).expect("DST-III should succeed");
for i in 0..signal.len() {
assert_abs_diff_eq!(recovered[i], signal[i], epsilon = 0.3);
}
}
#[test]
fn test_fast_dst4_self_inverse() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let coeffs = fast_dst4(&signal, Some("ortho")).expect("DST-IV should succeed");
let recovered = fast_dst4(&coeffs, Some("ortho")).expect("DST-IV self-inv should succeed");
for i in 0..signal.len() {
assert_abs_diff_eq!(recovered[i], signal[i], epsilon = 0.3);
}
}
#[test]
fn test_dct2_parseval_ortho() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let coeffs = fast_dct2(&signal, Some("ortho")).expect("DCT should succeed");
let time_energy: f64 = signal.iter().map(|x| x * x).sum();
let freq_energy: f64 = coeffs.iter().map(|c| c * c).sum();
assert_abs_diff_eq!(time_energy, freq_energy, epsilon = 1e-8);
}
#[test]
fn test_dct2_linearity() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = vec![5.0, 6.0, 7.0, 8.0];
let a = 2.5;
let b = -1.3;
let dct_x = fast_dct2(&x, None).expect("DCT(x) should succeed");
let dct_y = fast_dct2(&y, None).expect("DCT(y) should succeed");
let combined: Vec<f64> = x
.iter()
.zip(y.iter())
.map(|(&xi, &yi)| a * xi + b * yi)
.collect();
let dct_combined = fast_dct2(&combined, None).expect("DCT(combined) should succeed");
for i in 0..x.len() {
let expected = a * dct_x[i] + b * dct_y[i];
assert_abs_diff_eq!(dct_combined[i], expected, epsilon = 1e-8);
}
}
#[test]
fn test_mdct_stream_roundtrip() {
let signal: Vec<f64> = (0..128)
.map(|i| (2.0 * PI * 5.0 * i as f64 / 128.0).sin())
.collect();
let block_size = 32;
let blocks = mdct_stream(&signal, block_size, None).expect("MDCT stream should succeed");
assert!(!blocks.is_empty(), "MDCT should produce blocks");
let recovered =
imdct_stream(&blocks, block_size, None).expect("IMDCT stream should succeed");
assert!(
recovered.len() >= signal.len(),
"Recovered signal should be at least as long as input"
);
let start = 2 * block_size;
let end = signal.len().min(recovered.len()) - 2 * block_size;
if end > start {
let rms_error: f64 = (start..end)
.map(|i| (recovered[i] - signal[i]) * (recovered[i] - signal[i]))
.sum::<f64>()
/ (end - start) as f64;
let signal_rms: f64 =
(start..end).map(|i| signal[i] * signal[i]).sum::<f64>() / (end - start) as f64;
assert!(
rms_error < signal_rms * 1.5 + 0.1,
"MDCT reconstruction RMS error {rms_error:.4} should be reasonable vs signal RMS {signal_rms:.4}"
);
}
}
#[test]
fn test_batch_dct2() {
let sig1 = vec![1.0, 2.0, 3.0, 4.0];
let sig2 = vec![5.0, 6.0, 7.0, 8.0];
let results = batch_dct2(&[&sig1, &sig2], Some("ortho")).expect("Batch DCT should succeed");
assert_eq!(results.len(), 2);
let individual1 = fast_dct2(&sig1, Some("ortho")).expect("Individual DCT 1 should succeed");
let individual2 = fast_dct2(&sig2, Some("ortho")).expect("Individual DCT 2 should succeed");
for i in 0..4 {
assert_abs_diff_eq!(results[0][i], individual1[i], epsilon = 1e-10);
assert_abs_diff_eq!(results[1][i], individual2[i], epsilon = 1e-10);
}
}
#[test]
fn test_quantized_dct_roundtrip() {
let signal = vec![100.0, 120.0, 130.0, 110.0, 90.0, 80.0, 95.0, 105.0];
let (quantized, q_table) = quantized_dct(&signal, 95).expect("Quantization should succeed");
let recovered =
dequantized_idct(&quantized, &q_table).expect("Dequantization should succeed");
for i in 0..signal.len() {
assert!(
(recovered[i] - signal[i]).abs() < 5.0,
"High-quality quantization error at {i}: {} vs {}",
recovered[i],
signal[i]
);
}
}
#[test]
fn test_quantized_dct_compression() {
let signal = vec![100.0, 120.0, 130.0, 110.0, 90.0, 80.0, 95.0, 105.0];
let (quantized_low, _) = quantized_dct(&signal, 10).expect("Low quality should succeed");
let (quantized_high, _) = quantized_dct(&signal, 90).expect("High quality should succeed");
let zeros_low = quantized_low.iter().filter(|&&q| q == 0).count();
let zeros_high = quantized_high.iter().filter(|&&q| q == 0).count();
assert!(
zeros_low >= zeros_high,
"Low quality should have at least as many zero coefficients"
);
}
#[test]
fn test_error_handling() {
let empty: Vec<f64> = vec![];
assert!(fast_dct2(&empty, None).is_err());
assert!(fast_dct3(&empty, None).is_err());
assert!(fast_dct4(&empty, None).is_err());
assert!(fast_dst2(&empty, None).is_err());
assert!(fast_dst3(&empty, None).is_err());
assert!(fast_dst4(&empty, None).is_err());
let short = vec![1.0];
assert!(fast_dct1(&short, None).is_err());
assert!(fast_dst1(&short, None).is_err());
assert!(mdct_stream(&[1.0, 2.0], 3, None).is_err()); assert!(quantized_dct(&[1.0], 0).is_err());
assert!(quantized_dct(&[1.0], 101).is_err());
}
#[test]
fn test_dct_constant_signal() {
let signal = vec![5.0, 5.0, 5.0, 5.0];
let coeffs = fast_dct2(&signal, None).expect("Constant DCT should succeed");
assert!(coeffs[0].abs() > 1e-10, "DC component should be nonzero");
for i in 1..signal.len() {
assert_abs_diff_eq!(coeffs[i], 0.0, epsilon = 1e-10);
}
}
}