use super::complex::Complex;
use crate::HisabError;
#[must_use = "returns an error if input length is not a power of two"]
pub fn fft(data: &mut [Complex]) -> Result<(), HisabError> {
let n = data.len();
if n <= 1 {
return Ok(());
}
if !n.is_power_of_two() {
return Err(HisabError::InvalidInput(
"FFT requires power-of-2 length".into(),
));
}
let mut j = 0usize;
for i in 1..n {
let mut bit = n >> 1;
while j & bit != 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if i < j {
data.swap(i, j);
}
}
let mut len = 2;
while len <= n {
let half = len / 2;
let angle = -2.0 * std::f64::consts::PI / len as f64;
let wn = Complex::new(angle.cos(), angle.sin());
let mut start = 0;
while start < n {
let mut w = Complex::new(1.0, 0.0);
for k in 0..half {
let u = data[start + k];
let t = w * data[start + k + half];
data[start + k] = u + t;
data[start + k + half] = u - t;
w *= wn;
}
start += len;
}
len <<= 1;
}
Ok(())
}
#[must_use = "returns an error if input length is not a power of two"]
pub fn ifft(data: &mut [Complex]) -> Result<(), HisabError> {
let n = data.len();
for d in data.iter_mut() {
*d = d.conj();
}
fft(data)?;
let scale = 1.0 / n as f64;
for d in data.iter_mut() {
*d = d.conj() * scale;
}
Ok(())
}
#[must_use = "returns an error if dimensions are invalid"]
pub fn fft_2d(data: &mut [Complex], rows: usize, cols: usize) -> Result<(), HisabError> {
if data.len() != rows * cols {
return Err(HisabError::InvalidInput(format!(
"data length {} != rows*cols {}",
data.len(),
rows * cols
)));
}
for r in 0..rows {
let row = &mut data[r * cols..(r + 1) * cols];
fft(row)?;
}
let mut col_buf = vec![Complex::new(0.0, 0.0); rows];
for c in 0..cols {
for r in 0..rows {
col_buf[r] = data[r * cols + c];
}
fft(&mut col_buf)?;
for r in 0..rows {
data[r * cols + c] = col_buf[r];
}
}
Ok(())
}
#[must_use = "returns an error if dimensions are invalid"]
pub fn ifft_2d(data: &mut [Complex], rows: usize, cols: usize) -> Result<(), HisabError> {
if data.len() != rows * cols {
return Err(HisabError::InvalidInput(format!(
"data length {} != rows*cols {}",
data.len(),
rows * cols
)));
}
for r in 0..rows {
let row = &mut data[r * cols..(r + 1) * cols];
ifft(row)?;
}
let mut col_buf = vec![Complex::new(0.0, 0.0); rows];
for c in 0..cols {
for r in 0..rows {
col_buf[r] = data[r * cols + c];
}
ifft(&mut col_buf)?;
for r in 0..rows {
data[r * cols + c] = col_buf[r];
}
}
Ok(())
}
#[must_use = "contains the DST coefficients or an error"]
pub fn dst(data: &[f64]) -> Result<Vec<f64>, HisabError> {
let n = data.len();
if n == 0 {
return Err(HisabError::InvalidInput(
"DST requires non-empty input".into(),
));
}
let np1 = (n + 1) as f64;
let mut out = Vec::with_capacity(n);
for k in 0..n {
let mut sum = 0.0;
for (i, &x) in data.iter().enumerate() {
sum += x * (std::f64::consts::PI * (i + 1) as f64 * (k + 1) as f64 / np1).sin();
}
out.push(sum);
}
Ok(out)
}
#[must_use = "contains the inverse DST result or an error"]
pub fn idst(data: &[f64]) -> Result<Vec<f64>, HisabError> {
let n = data.len();
let mut out = dst(data)?;
let scale = 2.0 / (n + 1) as f64;
for v in &mut out {
*v *= scale;
}
Ok(out)
}
#[must_use = "contains the DCT coefficients or an error"]
pub fn dct(data: &[f64]) -> Result<Vec<f64>, HisabError> {
let n = data.len();
if n == 0 {
return Err(HisabError::InvalidInput(
"DCT requires non-empty input".into(),
));
}
let two_n = 2.0 * n as f64;
let mut out = Vec::with_capacity(n);
for k in 0..n {
let mut sum = 0.0;
for (i, &x) in data.iter().enumerate() {
sum += x * (std::f64::consts::PI * (2.0 * i as f64 + 1.0) * k as f64 / two_n).cos();
}
out.push(sum);
}
Ok(out)
}
#[must_use = "contains the inverse DCT result or an error"]
pub fn idct(data: &[f64]) -> Result<Vec<f64>, HisabError> {
let n = data.len();
if n == 0 {
return Err(HisabError::InvalidInput(
"IDCT requires non-empty input".into(),
));
}
let two_n = 2.0 * n as f64;
let inv_n = 1.0 / n as f64;
let dc = data[0] * inv_n;
let scale = 2.0 * inv_n;
let mut out = Vec::with_capacity(n);
for i in 0..n {
let mut sum = dc;
for (k, &x) in data.iter().enumerate().skip(1) {
sum += scale
* x
* (std::f64::consts::PI * k as f64 * (2.0 * i as f64 + 1.0) / two_n).cos();
}
out.push(sum);
}
Ok(out)
}