use crate::kernel::{Complex, Float};
use crate::prelude::*;
use super::super::problem::Sign;
pub struct DirectSolver<T: Float> {
_marker: core::marker::PhantomData<T>,
}
impl<T: Float> Default for DirectSolver<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Float> DirectSolver<T> {
#[must_use]
pub fn new() -> Self {
Self {
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn name(&self) -> &'static str {
"dft-direct"
}
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let n = input.len();
debug_assert_eq!(n, output.len(), "Input and output must have same length");
if n == 0 {
return;
}
let sign_val = T::from_isize(sign.value() as isize);
let two_pi_over_n = T::TWO_PI / T::from_usize(n);
for k in 0..n {
let mut sum = Complex::zero();
let k_t = T::from_usize(k);
for (j, &x_j) in input.iter().enumerate() {
let angle = sign_val * two_pi_over_n * T::from_usize(j) * k_t;
let twiddle = Complex::cis(angle);
sum = sum + x_j * twiddle;
}
output[k] = sum;
}
}
pub fn execute_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
let n = data.len();
if n == 0 {
return;
}
let input: Vec<Complex<T>> = data.to_vec();
self.execute(&input, data, sign);
}
pub unsafe fn execute_ptr(
&self,
input: *const Complex<T>,
output: *mut Complex<T>,
n: usize,
sign: Sign,
) {
if n == 0 {
return;
}
let sign_val = T::from_isize(sign.value() as isize);
let two_pi_over_n = T::TWO_PI / T::from_usize(n);
if input as *const () == output as *const () {
let input_copy: Vec<Complex<T>> = (0..n).map(|i| unsafe { *input.add(i) }).collect();
for k in 0..n {
let mut sum = Complex::zero();
let k_t = T::from_usize(k);
for (j, &x_j) in input_copy.iter().enumerate() {
let angle = sign_val * two_pi_over_n * T::from_usize(j) * k_t;
let twiddle = Complex::cis(angle);
sum = sum + x_j * twiddle;
}
unsafe { *output.add(k) = sum };
}
} else {
for k in 0..n {
let mut sum = Complex::zero();
let k_t = T::from_usize(k);
for j in 0..n {
let x_j = unsafe { *input.add(j) };
let angle = sign_val * two_pi_over_n * T::from_usize(j) * k_t;
let twiddle = Complex::cis(angle);
sum = sum + x_j * twiddle;
}
unsafe { *output.add(k) = sum };
}
}
}
}
pub fn dft_direct<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
DirectSolver::new().execute(input, output, Sign::Forward);
}
pub fn idft_direct<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
DirectSolver::new().execute(input, output, Sign::Backward);
}
pub fn idft_direct_normalized<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
DirectSolver::new().execute(input, output, Sign::Backward);
let n = T::from_usize(output.len());
for x in output.iter_mut() {
*x = *x / n;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
fn complex_approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
approx_eq(a.re, b.re, eps) && approx_eq(a.im, b.im, eps)
}
#[test]
fn test_dft_size_1() {
let input = [Complex::new(3.0_f64, 4.0)];
let mut output = [Complex::zero()];
dft_direct(&input, &mut output);
assert!(complex_approx_eq(output[0], input[0], 1e-10));
}
#[test]
fn test_dft_size_2() {
let input = [Complex::new(1.0_f64, 0.0), Complex::new(2.0, 0.0)];
let mut output = [Complex::zero(); 2];
dft_direct(&input, &mut output);
assert!(complex_approx_eq(output[0], Complex::new(3.0, 0.0), 1e-10));
assert!(complex_approx_eq(output[1], Complex::new(-1.0, 0.0), 1e-10));
}
#[test]
fn test_dft_size_4() {
let input = [
Complex::new(1.0_f64, 0.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 0.0),
];
let mut output = [Complex::zero(); 4];
dft_direct(&input, &mut output);
assert!(complex_approx_eq(output[0], Complex::new(10.0, 0.0), 1e-10));
assert!(complex_approx_eq(output[2], Complex::new(-2.0, 0.0), 1e-10));
}
#[test]
fn test_inverse_recovers_input() {
let original = [
Complex::new(1.0_f64, 2.0),
Complex::new(3.0, 4.0),
Complex::new(5.0, 6.0),
Complex::new(7.0, 8.0),
];
let mut transformed = [Complex::zero(); 4];
let mut recovered = [Complex::zero(); 4];
dft_direct(&original, &mut transformed);
idft_direct_normalized(&transformed, &mut recovered);
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_parseval_theorem() {
let input = [
Complex::new(1.0_f64, 2.0),
Complex::new(3.0, 4.0),
Complex::new(5.0, 6.0),
Complex::new(7.0, 8.0),
];
let mut output = [Complex::zero(); 4];
dft_direct(&input, &mut output);
let time_energy: f64 = input.iter().map(|x| x.norm_sqr()).sum();
let freq_energy: f64 = output.iter().map(|x| x.norm_sqr()).sum();
let n = input.len() as f64;
assert!(approx_eq(time_energy, freq_energy / n, 1e-10));
}
#[test]
fn test_linearity() {
let x = [
Complex::new(1.0_f64, 0.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 0.0),
];
let y = [
Complex::new(5.0_f64, 0.0),
Complex::new(6.0, 0.0),
Complex::new(7.0, 0.0),
Complex::new(8.0, 0.0),
];
let a = 2.0;
let ax_plus_y: Vec<_> = x
.iter()
.zip(y.iter())
.map(|(&xi, &yi)| xi * a + yi)
.collect();
let mut dft_combined = [Complex::zero(); 4];
dft_direct(&ax_plus_y, &mut dft_combined);
let mut dft_x = [Complex::zero(); 4];
let mut dft_y = [Complex::zero(); 4];
dft_direct(&x, &mut dft_x);
dft_direct(&y, &mut dft_y);
for i in 0..4 {
let expected = dft_x[i] * a + dft_y[i];
assert!(complex_approx_eq(dft_combined[i], expected, 1e-10));
}
}
#[test]
fn test_inplace() {
let original = [
Complex::new(1.0_f64, 2.0),
Complex::new(3.0, 4.0),
Complex::new(5.0, 6.0),
Complex::new(7.0, 8.0),
];
let mut out_of_place = [Complex::zero(); 4];
dft_direct(&original, &mut out_of_place);
let mut in_place = original;
DirectSolver::new().execute_inplace(&mut in_place, Sign::Forward);
for (a, b) in out_of_place.iter().zip(in_place.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
}