use crate::CFft1D;
use num_complex::Complex;
use num_traits::float::{Float, FloatConst};
use num_traits::identities::{one, zero};
use num_traits::{cast, NumAssign};
#[derive(Debug)]
pub struct CFft2D<T> {
len_m: usize,
len_n: usize,
scaler_n: T,
scaler_u: T,
fft_m: CFft1D<T>,
fft_n: CFft1D<T>,
work: Vec<Vec<Complex<T>>>,
}
impl<T: Float + FloatConst + NumAssign> CFft2D<T> {
pub fn new() -> Self {
Self {
len_m: 0,
len_n: 0,
scaler_n: zero(),
scaler_u: zero(),
fft_m: CFft1D::new(),
fft_n: CFft1D::new(),
work: Vec::new(),
}
}
pub fn with_len(len_m: usize, len_n: usize) -> Self {
Self {
len_m,
len_n,
scaler_n: T::one() / cast(len_m * len_n).unwrap(),
scaler_u: T::one() / cast::<_, T>(len_m * len_n).unwrap().sqrt(),
fft_m: CFft1D::with_len(len_m),
fft_n: CFft1D::with_len(len_n),
work: vec![vec![zero(); len_m]; len_n],
}
}
pub fn setup(&mut self, len_m: usize, len_n: usize) {
self.len_m = len_m;
self.len_n = len_n;
self.scaler_n = T::one() / cast(len_m * len_n).unwrap();
self.scaler_u = self.scaler_n.sqrt();
self.fft_m.setup(len_m);
self.fft_n.setup(len_n);
if self.work.len() != len_n || (!self.work.is_empty() && self.work[0].len() != len_m) {
self.work = vec![vec![zero(); len_m]; len_n];
}
}
pub fn forward(&mut self, source: &[Vec<Complex<T>>]) -> Vec<Vec<Complex<T>>> {
self.convert(source, false, one())
}
pub fn forward0(&mut self, source: &[Vec<Complex<T>>]) -> Vec<Vec<Complex<T>>> {
self.convert(source, false, one())
}
pub fn forwardu(&mut self, source: &[Vec<Complex<T>>]) -> Vec<Vec<Complex<T>>> {
let scaler = self.scaler_u;
self.convert(source, false, scaler)
}
pub fn forwardn(&mut self, source: &[Vec<Complex<T>>]) -> Vec<Vec<Complex<T>>> {
let scaler = self.scaler_n;
self.convert(source, false, scaler)
}
pub fn backward(&mut self, source: &[Vec<Complex<T>>]) -> Vec<Vec<Complex<T>>> {
let scaler = self.scaler_n;
self.convert(source, true, scaler)
}
pub fn backward0(&mut self, source: &[Vec<Complex<T>>]) -> Vec<Vec<Complex<T>>> {
self.convert(source, true, one())
}
pub fn backwardu(&mut self, source: &[Vec<Complex<T>>]) -> Vec<Vec<Complex<T>>> {
let scaler = self.scaler_u;
self.convert(source, true, scaler)
}
#[inline]
fn convert(
&mut self,
source: &[Vec<Complex<T>>],
is_back: bool,
scaler: T,
) -> Vec<Vec<Complex<T>>> {
if source.is_empty() {
return Vec::new();
}
if source.len() != self.len_m || source[0].len() != self.len_n {
self.setup(source.len(), source[0].len());
}
for (i, si) in source.iter().enumerate() {
let work = if is_back {
self.fft_m.backward0(si)
} else {
self.fft_m.forward0(si)
};
for (j, &wi) in work.iter().enumerate() {
self.work[j][i] = wi;
}
}
let mut ret = vec![Vec::with_capacity(self.len_n); self.len_m];
for i in 0..self.work.len() {
let work = if is_back {
self.fft_n.backward0(&self.work[i])
} else {
self.fft_n.forward0(&self.work[i])
};
for j in 0..work.len() {
ret[j].push(work[j] * scaler);
}
}
ret
}
}
impl<T: Float + FloatConst + NumAssign> Default for CFft2D<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assert_appro_eq;
use crate::FloatEps;
use appro_eq::AbsError;
use rand::distributions::{Distribution, Standard};
use rand::{Rng, SeedableRng};
use rand_xorshift::XorShiftRng;
use std::fmt::Debug;
fn convert<T: Float + FloatConst>(
source: &[Vec<Complex<T>>],
scalar: T,
) -> Vec<Vec<Complex<T>>> {
(0..source.len())
.map(|i| {
(0..source[0].len())
.map(|k| {
(0..source.len()).fold(zero(), |x: Complex<T>, j| {
x + (0..source[0].len()).fold(zero(), |y: Complex<T>, l| {
y + source[j][l]
* Complex::<T>::from_polar(
one(),
-cast::<_, T>(2).unwrap()
* T::PI()
* ((cast::<_, T>(i * j).unwrap()
/ cast(source.len()).unwrap())
+ cast::<_, T>(k * l).unwrap()
/ cast(source[0].len()).unwrap()),
)
})
}) * scalar
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
fn test_with_source<T: Float + FloatConst + NumAssign + Debug + AbsError + FloatEps>(
fft: &mut CFft2D<T>,
source: &[Vec<Complex<T>>],
) {
let expected = convert(source, one());
let actual = fft.forward(source);
assert_appro_eq(&expected, &actual);
let actual_source = fft.backward(&actual);
assert_appro_eq(source, &actual_source);
}
fn test_with_len<T: Float + FloatConst + NumAssign + Debug + AbsError + FloatEps>(
fft: &mut CFft2D<T>,
len_m: usize,
len_n: usize,
) where
Standard: Distribution<T>,
{
let mut rng = XorShiftRng::from_seed([
0xDA, 0xE1, 0x4B, 0x0B, 0xFF, 0xC2, 0xFE, 0x64, 0x23, 0xFE, 0x3F, 0x51, 0x6D, 0x3E,
0xA2, 0xF3,
]);
for _ in 0..10 {
let arr = (0..len_m)
.map(|_| {
(0..len_n)
.map(|_| Complex::new(rng.gen::<T>(), rng.gen::<T>()))
.collect::<Vec<Complex<T>>>()
})
.collect::<Vec<Vec<Complex<T>>>>();
test_with_source(fft, &arr);
}
}
#[test]
fn f64_new() {
for i in 1..10 {
for j in 1..10 {
test_with_len(&mut CFft2D::<f64>::new(), i, j);
}
}
}
#[test]
fn f32_new() {
for i in 1..10 {
for j in 1..10 {
test_with_len(&mut CFft2D::<f32>::new(), i, j);
}
}
}
#[test]
fn f64_with_len() {
for i in 1..10 {
for j in 1..10 {
test_with_len(&mut CFft2D::<f64>::with_len(i, j), i, j);
}
}
}
#[test]
fn f32_with_len() {
for i in 1..10 {
for j in 1..10 {
test_with_len(&mut CFft2D::<f32>::with_len(i, j), i, j);
}
}
}
}