use num_complex::Complex64;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Dft;
impl Dft {
pub fn dft(signal: &[Complex64]) -> Vec<Complex64> {
let n = signal.len();
let mut result = Vec::with_capacity(n);
for k in 0..n {
let mut sum = Complex64::new(0.0, 0.0);
for (j, &x) in signal.iter().enumerate() {
let angle = -2.0 * std::f64::consts::PI * k as f64 * j as f64 / n as f64;
sum += x * Complex64::from_polar(1.0, angle);
}
result.push(sum);
}
result
}
pub fn idft(spectrum: &[Complex64]) -> Vec<Complex64> {
let n = spectrum.len();
let conj: Vec<Complex64> = spectrum.iter().map(|c| c.conj()).collect();
let result = Self::dft(&conj);
result.iter().map(|c| c.conj() / n as f64).collect()
}
pub fn fft(signal: &[Complex64]) -> Vec<Complex64> {
let n = signal.len();
if n == 1 {
return signal.to_vec();
}
if n == 2 {
return vec![
signal[0] + signal[1],
signal[0] - signal[1],
];
}
let mut x = Self::bit_reverse_copy(signal);
let mut m = 1usize;
while m < n {
let wm = Complex64::from_polar(1.0, -std::f64::consts::PI / m as f64);
for k in (0..n).step_by(2 * m) {
let mut w = Complex64::new(1.0, 0.0);
for j in 0..m {
let t = w * x[k + j + m];
let u = x[k + j];
x[k + j] = u + t;
x[k + j + m] = u - t;
w *= wm;
}
}
m *= 2;
}
x
}
pub fn ifft(spectrum: &[Complex64]) -> Vec<Complex64> {
let n = spectrum.len();
let conj: Vec<Complex64> = spectrum.iter().map(|c| c.conj()).collect();
let result = Self::fft(&conj);
result.iter().map(|c| c.conj() / n as f64).collect()
}
fn bit_reverse_copy(signal: &[Complex64]) -> Vec<Complex64> {
let n = signal.len();
let bits = (n as f64).log2() as usize;
let mut result = vec![Complex64::new(0.0, 0.0); n];
for i in 0..n {
let j = Self::reverse_bits(i, bits);
result[j] = signal[i];
}
result
}
fn reverse_bits(mut n: usize, bits: usize) -> usize {
let mut result = 0usize;
for _ in 0..bits {
result = (result << 1) | (n & 1);
n >>= 1;
}
result
}
pub fn power_spectrum(signal: &[Complex64]) -> Vec<f64> {
let spectrum = Self::fft(signal);
spectrum.iter().map(|c| c.norm_sqr()).collect()
}
pub fn frequency_bins(n: usize, fs: f64) -> Vec<f64> {
(0..n).map(|k| k as f64 * fs / n as f64).collect()
}
pub fn is_power_of_2(n: usize) -> bool {
n > 0 && (n & (n - 1)) == 0
}
pub fn zero_pad_to_power_of_2(signal: &[Complex64]) -> Vec<Complex64> {
let n = signal.len();
if Self::is_power_of_2(n) {
return signal.to_vec();
}
let next_pow2 = 1 << ((n as f64).log2().ceil() as usize);
let mut padded = signal.to_vec();
padded.resize(next_pow2, Complex64::new(0.0, 0.0));
padded
}
}
#[cfg(test)]
mod tests {
use super::*;
fn complex_approx_eq(a: Complex64, b: Complex64, tol: f64) -> bool {
(a.re - b.re).abs() < tol && (a.im - b.im).abs() < tol
}
#[test]
fn test_dft_dc_signal() {
let signal = vec![Complex64::new(3.0, 0.0); 8];
let result = Dft::dft(&signal);
assert!(complex_approx_eq(result[0], Complex64::new(24.0, 0.0), 1e-10));
for k in 1..8 {
assert!(result[k].norm() < 1e-10, "Non-DC bin {k} should be ~0");
}
}
#[test]
fn test_dft_single_tone() {
let n = 64;
let signal: Vec<Complex64> = (0..n).map(|j| {
let t = j as f64 / n as f64;
Complex64::new((2.0 * std::f64::consts::PI * 4.0 * t).cos(), 0.0)
}).collect();
let result = Dft::dft(&signal);
assert!(result[4].norm() > 20.0, "Bin 4 magnitude: {}", result[4].norm());
assert!(result[60].norm() > 20.0, "Bin 60 magnitude: {}", result[60].norm());
}
#[test]
fn test_idft_roundtrip() {
let signal: Vec<Complex64> = vec![
Complex64::new(1.0, 0.0),
Complex64::new(2.0, 1.0),
Complex64::new(-1.0, 0.5),
Complex64::new(0.0, -1.0),
Complex64::new(3.0, 0.0),
Complex64::new(-2.0, 0.0),
Complex64::new(1.0, 1.0),
Complex64::new(0.5, -0.5),
];
let spectrum = Dft::dft(&signal);
let recovered = Dft::idft(&spectrum);
for (a, b) in signal.iter().zip(recovered.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10), "Roundtrip mismatch: {a} vs {b}");
}
}
#[test]
fn test_fft_matches_dft() {
let signal: Vec<Complex64> = (0..16).map(|j| {
Complex64::new((j as f64 * 0.5).sin(), (j as f64 * 0.3).cos())
}).collect();
let dft_result = Dft::dft(&signal);
let fft_result = Dft::fft(&signal);
for (a, b) in dft_result.iter().zip(fft_result.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10), "FFT != DFT: {a} vs {b}");
}
}
#[test]
fn test_ifft_roundtrip() {
let signal: Vec<Complex64> = (0..32).map(|j| {
Complex64::new((j as f64).cos(), (j as f64 * 0.7).sin())
}).collect();
let spectrum = Dft::fft(&signal);
let recovered = Dft::ifft(&spectrum);
for (a, b) in signal.iter().zip(recovered.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10), "IFFT roundtrip: {a} vs {b}");
}
}
#[test]
fn test_fft_dc() {
let signal = vec![Complex64::new(5.0, 0.0); 8];
let result = Dft::fft(&signal);
assert!(complex_approx_eq(result[0], Complex64::new(40.0, 0.0), 1e-10));
for k in 1..8 {
assert!(result[k].norm() < 1e-10);
}
}
#[test]
fn test_power_spectrum() {
let signal = vec![Complex64::new(1.0, 0.0); 4];
let ps = Dft::power_spectrum(&signal);
assert!((ps[0] - 16.0).abs() < 1e-10);
assert!(ps[1] < 1e-10);
}
#[test]
fn test_frequency_bins() {
let bins = Dft::frequency_bins(8, 1000.0);
assert_eq!(bins.len(), 8);
assert!((bins[0] - 0.0).abs() < 1e-10);
assert!((bins[1] - 125.0).abs() < 1e-10);
assert!((bins[4] - 500.0).abs() < 1e-10);
}
#[test]
fn test_bit_reverse() {
assert_eq!(Dft::reverse_bits(0, 3), 0);
assert_eq!(Dft::reverse_bits(1, 3), 4);
assert_eq!(Dft::reverse_bits(3, 3), 6);
assert_eq!(Dft::reverse_bits(6, 3), 3);
}
#[test]
fn test_is_power_of_2() {
assert!(Dft::is_power_of_2(1));
assert!(Dft::is_power_of_2(256));
assert!(!Dft::is_power_of_2(3));
assert!(!Dft::is_power_of_2(0));
}
#[test]
fn test_zero_pad() {
let signal = vec![Complex64::new(1.0, 0.0); 5];
let padded = Dft::zero_pad_to_power_of_2(&signal);
assert_eq!(padded.len(), 8);
assert!(Dft::is_power_of_2(padded.len()));
}
}