use std::sync::Arc;
use num_complex::Complex;
use strength_reduce::StrengthReducedUsize;
use transpose;
use crate::common::{verify_length, verify_length_divisible, FFTnum};
use crate::array_utils;
use crate::math_utils;
use crate::algorithm::butterflies::FFTButterfly;
use crate::{IsInverse, Length, FFT};
pub struct GoodThomasAlgorithm<T> {
width: usize,
width_size_fft: Arc<dyn FFT<T>>,
height: usize,
height_size_fft: Arc<dyn FFT<T>>,
reduced_width: StrengthReducedUsize,
reduced_width_plus_one: StrengthReducedUsize,
inverse: bool,
}
impl<T: FFTnum> GoodThomasAlgorithm<T> {
pub fn new(mut width_fft: Arc<dyn FFT<T>>, mut height_fft: Arc<dyn FFT<T>>) -> Self {
assert_eq!(
width_fft.is_inverse(), height_fft.is_inverse(),
"width_fft and height_fft must both be inverse, or neither. got width inverse={}, height inverse={}",
width_fft.is_inverse(), height_fft.is_inverse());
let mut width = width_fft.len();
let mut height = height_fft.len();
let is_inverse = width_fft.is_inverse();
let gcd = num_integer::gcd(width as i64, height as i64);
assert!(gcd == 1,
"Invalid width and height for Good-Thomas Algorithm (width={}, height={}): Inputs must be coprime",
width,
height);
if width > height {
std::mem::swap(&mut width, &mut height);
std::mem::swap(&mut width_fft, &mut height_fft);
}
Self {
width,
width_size_fft: width_fft,
height,
height_size_fft: height_fft,
reduced_width: StrengthReducedUsize::new(width),
reduced_width_plus_one: StrengthReducedUsize::new(width + 1),
inverse: is_inverse,
}
}
fn perform_fft(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
let mut output_index = 0;
for mut input_row in input.chunks_exact(self.width) {
let increments_until_cycle =
1 + (self.len() - output_index) / self.reduced_width_plus_one;
if increments_until_cycle < self.width {
let (pre_cycle_row, post_cycle_row) = input_row.split_at(increments_until_cycle);
for input_element in pre_cycle_row {
output[output_index] = *input_element;
output_index += self.reduced_width_plus_one.get();
}
input_row = post_cycle_row;
output_index -= self.len();
}
for input_element in input_row {
output[output_index] = *input_element;
output_index += self.reduced_width_plus_one.get();
}
output_index -= self.width;
}
self.width_size_fft.process_multi(output, input);
transpose::transpose(input, output, self.width, self.height);
self.height_size_fft.process_multi(output, input);
for (y, input_chunk) in input.chunks_exact(self.height).enumerate() {
let (quotient, remainder) =
StrengthReducedUsize::div_rem(y * self.height, self.reduced_width);
let mut output_index = remainder;
let start_x = self.height - quotient;
for x in start_x..self.height {
output[output_index] = input_chunk[x];
output_index += self.width;
}
for x in 0..start_x {
output[output_index] = input_chunk[x];
output_index += self.width;
}
}
}
}
impl<T: FFTnum> FFT<T> for GoodThomasAlgorithm<T> {
fn process(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length(input, output, self.len());
self.perform_fft(input, output);
}
fn process_multi(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length_divisible(input, output, self.len());
for (in_chunk, out_chunk) in input
.chunks_exact_mut(self.len())
.zip(output.chunks_exact_mut(self.len()))
{
self.perform_fft(in_chunk, out_chunk);
}
}
}
impl<T> Length for GoodThomasAlgorithm<T> {
#[inline(always)]
fn len(&self) -> usize {
self.width * self.height
}
}
impl<T> IsInverse for GoodThomasAlgorithm<T> {
#[inline(always)]
fn is_inverse(&self) -> bool {
self.inverse
}
}
pub struct GoodThomasAlgorithmDoubleButterfly<T> {
width: usize,
width_size_fft: Arc<dyn FFTButterfly<T>>,
height: usize,
height_size_fft: Arc<dyn FFTButterfly<T>>,
input_output_map: Box<[usize]>,
inverse: bool,
}
impl<T: FFTnum> GoodThomasAlgorithmDoubleButterfly<T> {
pub fn new(width_fft: Arc<dyn FFTButterfly<T>>, height_fft: Arc<dyn FFTButterfly<T>>) -> Self {
assert_eq!(
width_fft.is_inverse(), height_fft.is_inverse(),
"n1_fft and height_fft must both be inverse, or neither. got width inverse={}, height inverse={}",
width_fft.is_inverse(), height_fft.is_inverse());
let width = width_fft.len();
let height = height_fft.len();
let len = width * height;
let (gcd, mut width_inverse, mut height_inverse) =
math_utils::extended_euclidean_algorithm(width as i64, height as i64);
assert!(
gcd == 1,
"Invalid input n1 and height to Good-Thomas Algorithm: ({},{}): Inputs must be coprime",
width,
height
);
if width_inverse < 0 {
width_inverse += height as i64;
}
if height_inverse < 0 {
height_inverse += width as i64;
}
let input_iter = (0..len)
.map(|i| (i % width, i / width))
.map(|(x, y)| (x * height + y * width) % len);
let output_iter = (0..len).map(|i| (i % height, i / height)).map(|(y, x)| {
(x * height * height_inverse as usize + y * width * width_inverse as usize) % len
});
let input_output_map: Vec<usize> = input_iter.chain(output_iter).collect();
Self {
inverse: width_fft.is_inverse(),
width,
width_size_fft: width_fft,
height,
height_size_fft: height_fft,
input_output_map: input_output_map.into_boxed_slice(),
}
}
unsafe fn perform_fft(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
let (input_map, output_map) = self.input_output_map.split_at(self.len());
for (output_element, &input_index) in output.iter_mut().zip(input_map.iter()) {
*output_element = input[input_index];
}
self.width_size_fft.process_multi_inplace(output);
array_utils::transpose_small(self.width, self.height, output, input);
self.height_size_fft.process_multi_inplace(input);
for (input_element, &output_index) in input.iter().zip(output_map.iter()) {
output[output_index] = *input_element;
}
}
}
impl<T: FFTnum> FFT<T> for GoodThomasAlgorithmDoubleButterfly<T> {
fn process(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length(input, output, self.len());
unsafe { self.perform_fft(input, output) };
}
fn process_multi(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length_divisible(input, output, self.len());
for (in_chunk, out_chunk) in input
.chunks_exact_mut(self.len())
.zip(output.chunks_exact_mut(self.len()))
{
unsafe { self.perform_fft(in_chunk, out_chunk) };
}
}
}
impl<T> Length for GoodThomasAlgorithmDoubleButterfly<T> {
#[inline(always)]
fn len(&self) -> usize {
self.width * self.height
}
}
impl<T> IsInverse for GoodThomasAlgorithmDoubleButterfly<T> {
#[inline(always)]
fn is_inverse(&self) -> bool {
self.inverse
}
}
#[cfg(test)]
mod unit_tests {
use super::*;
use crate::algorithm::DFT;
use crate::test_utils::{check_fft_algorithm, make_butterfly};
use num_integer::gcd;
use std::sync::Arc;
#[test]
fn test_good_thomas() {
for width in 1..12 {
for height in 1..12 {
if gcd(width, height) == 1 {
test_good_thomas_with_lengths(width, height, false);
test_good_thomas_with_lengths(width, height, true);
}
}
}
}
#[test]
fn test_good_thomas_double_butterfly() {
let butterfly_sizes = [2, 3, 4, 5, 6, 7, 8, 16];
for width in &butterfly_sizes {
for height in &butterfly_sizes {
if gcd(*width, *height) == 1 {
test_good_thomas_butterfly_with_lengths(*width, *height, false);
test_good_thomas_butterfly_with_lengths(*width, *height, true);
}
}
}
}
fn test_good_thomas_with_lengths(width: usize, height: usize, inverse: bool) {
let width_fft = Arc::new(DFT::new(width, inverse)) as Arc<dyn FFT<f32>>;
let height_fft = Arc::new(DFT::new(height, inverse)) as Arc<dyn FFT<f32>>;
let fft = GoodThomasAlgorithm::new(width_fft, height_fft);
check_fft_algorithm(&fft, width * height, inverse);
}
fn test_good_thomas_butterfly_with_lengths(width: usize, height: usize, inverse: bool) {
let width_fft = make_butterfly(width, inverse);
let height_fft = make_butterfly(height, inverse);
let fft = GoodThomasAlgorithmDoubleButterfly::new(width_fft, height_fft);
check_fft_algorithm(&fft, width * height, inverse);
}
#[test]
fn test_output_mapping() {
let width = 15;
for height in 3..width {
if gcd(width, height) == 1 {
let width_fft = Arc::new(DFT::new(width, false)) as Arc<dyn FFT<f32>>;
let height_fft = Arc::new(DFT::new(height, false)) as Arc<dyn FFT<f32>>;
let fft = GoodThomasAlgorithm::new(width_fft, height_fft);
let mut input = vec![Complex { re: 0.0, im: 0.0 }; fft.len()];
let mut output = vec![Complex { re: 0.0, im: 0.0 }; fft.len()];
fft.process(&mut input, &mut output);
}
}
}
}