use std::cmp::max;
use std::sync::Arc;
use num_complex::Complex;
use num_integer::Integer;
use strength_reduce::StrengthReducedUsize;
use transpose;
use crate::array_utils;
use crate::common::{fft_error_inplace, fft_error_outofplace};
use crate::{common::FftNum, FftDirection};
use crate::{Direction, Fft, Length};
pub struct GoodThomasAlgorithm<T> {
width: usize,
width_size_fft: Arc<dyn Fft<T>>,
height: usize,
height_size_fft: Arc<dyn Fft<T>>,
reduced_width: StrengthReducedUsize,
reduced_width_plus_one: StrengthReducedUsize,
inplace_scratch_len: usize,
outofplace_scratch_len: usize,
len: usize,
direction: FftDirection,
}
impl<T: FftNum> GoodThomasAlgorithm<T> {
pub fn new(mut width_fft: Arc<dyn Fft<T>>, mut height_fft: Arc<dyn Fft<T>>) -> Self {
assert_eq!(
width_fft.fft_direction(), height_fft.fft_direction(),
"width_fft and height_fft must have the same direction. got width direction={}, height direction={}",
width_fft.fft_direction(), height_fft.fft_direction());
let mut width = width_fft.len();
let mut height = height_fft.len();
let direction = width_fft.fft_direction();
let gcd = num_integer::gcd(width as i64, height as i64);
assert!(gcd == 1,
"Invalid width and height for Good-Thomas Algorithm (width={}, height={}): Inputs must be coprime",
width,
height);
if width > height {
std::mem::swap(&mut width, &mut height);
std::mem::swap(&mut width_fft, &mut height_fft);
}
let len = width * height;
let width_inplace_scratch = width_fft.get_inplace_scratch_len();
let height_inplace_scratch = height_fft.get_inplace_scratch_len();
let height_outofplace_scratch = height_fft.get_outofplace_scratch_len();
let max_inner_inplace_scratch = max(height_inplace_scratch, width_inplace_scratch);
let outofplace_scratch_len = if max_inner_inplace_scratch > len {
max_inner_inplace_scratch
} else {
0
};
let inplace_scratch_len = len
+ max(
if width_inplace_scratch > len {
width_inplace_scratch
} else {
0
},
height_outofplace_scratch,
);
Self {
width,
width_size_fft: width_fft,
height,
height_size_fft: height_fft,
reduced_width: StrengthReducedUsize::new(width),
reduced_width_plus_one: StrengthReducedUsize::new(width + 1),
inplace_scratch_len,
outofplace_scratch_len,
len,
direction,
}
}
fn reindex_input(&self, source: &[Complex<T>], destination: &mut [Complex<T>]) {
let mut destination_index = 0;
for mut source_row in source.chunks_exact(self.width) {
let increments_until_cycle =
1 + (self.len() - destination_index) / self.reduced_width_plus_one;
if increments_until_cycle < self.width {
let (pre_cycle_row, post_cycle_row) = source_row.split_at(increments_until_cycle);
for input_element in pre_cycle_row {
destination[destination_index] = *input_element;
destination_index += self.reduced_width_plus_one.get();
}
source_row = post_cycle_row;
destination_index -= self.len();
}
for input_element in source_row {
destination[destination_index] = *input_element;
destination_index += self.reduced_width_plus_one.get();
}
destination_index -= self.width;
}
}
fn reindex_output(&self, source: &[Complex<T>], destination: &mut [Complex<T>]) {
for (y, source_chunk) in source.chunks_exact(self.height).enumerate() {
let (quotient, remainder) =
StrengthReducedUsize::div_rem(y * self.height, self.reduced_width);
let mut destination_index = remainder;
let start_x = self.height - quotient;
for x in start_x..self.height {
destination[destination_index] = source_chunk[x];
destination_index += self.width;
}
for x in 0..start_x {
destination[destination_index] = source_chunk[x];
destination_index += self.width;
}
}
}
fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
let (scratch, inner_scratch) = scratch.split_at_mut(self.len());
self.reindex_input(buffer, scratch);
let width_scratch = if inner_scratch.len() > buffer.len() {
&mut inner_scratch[..]
} else {
&mut buffer[..]
};
self.width_size_fft
.process_with_scratch(scratch, width_scratch);
transpose::transpose(scratch, buffer, self.width, self.height);
self.height_size_fft
.process_outofplace_with_scratch(buffer, scratch, inner_scratch);
self.reindex_output(scratch, buffer);
}
fn perform_fft_out_of_place(
&self,
input: &mut [Complex<T>],
output: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) {
self.reindex_input(input, output);
let width_scratch = if scratch.len() > input.len() {
&mut scratch[..]
} else {
&mut input[..]
};
self.width_size_fft
.process_with_scratch(output, width_scratch);
transpose::transpose(output, input, self.width, self.height);
let height_scratch = if scratch.len() > output.len() {
&mut scratch[..]
} else {
&mut output[..]
};
self.height_size_fft
.process_with_scratch(input, height_scratch);
self.reindex_output(input, output);
}
}
boilerplate_fft!(
GoodThomasAlgorithm,
|this: &GoodThomasAlgorithm<_>| this.len,
|this: &GoodThomasAlgorithm<_>| this.inplace_scratch_len,
|this: &GoodThomasAlgorithm<_>| this.outofplace_scratch_len
);
pub struct GoodThomasAlgorithmSmall<T> {
width: usize,
width_size_fft: Arc<dyn Fft<T>>,
height: usize,
height_size_fft: Arc<dyn Fft<T>>,
input_output_map: Box<[usize]>,
direction: FftDirection,
}
impl<T: FftNum> GoodThomasAlgorithmSmall<T> {
pub fn new(width_fft: Arc<dyn Fft<T>>, height_fft: Arc<dyn Fft<T>>) -> Self {
assert_eq!(
width_fft.fft_direction(), height_fft.fft_direction(),
"n1_fft and height_fft must have the same direction. got width direction={}, height direction={}",
width_fft.fft_direction(), height_fft.fft_direction());
let width = width_fft.len();
let height = height_fft.len();
let len = width * height;
assert_eq!(width_fft.get_outofplace_scratch_len(), 0, "GoodThomasAlgorithmSmall should only be used with algorithms that require 0 out-of-place scratch. Width FFT (len={}) requires {}, should require 0", width, width_fft.get_outofplace_scratch_len());
assert_eq!(height_fft.get_outofplace_scratch_len(), 0, "GoodThomasAlgorithmSmall should only be used with algorithms that require 0 out-of-place scratch. Height FFT (len={}) requires {}, should require 0", height, height_fft.get_outofplace_scratch_len());
assert!(width_fft.get_inplace_scratch_len() <= width, "GoodThomasAlgorithmSmall should only be used with algorithms that require little inplace scratch. Width FFT (len={}) requires {}, should require {} or less", width, width_fft.get_inplace_scratch_len(), width);
assert!(height_fft.get_inplace_scratch_len() <= height, "GoodThomasAlgorithmSmall should only be used with algorithms that require little inplace scratch. Height FFT (len={}) requires {}, should require {} or less", height, height_fft.get_inplace_scratch_len(), height);
let gcd_data = i64::extended_gcd(&(width as i64), &(height as i64));
assert!(gcd_data.gcd == 1,
"Invalid input width and height to Good-Thomas Algorithm: ({},{}): Inputs must be coprime",
width,
height);
let width_inverse = if gcd_data.x >= 0 {
gcd_data.x
} else {
gcd_data.x + height as i64
} as usize;
let height_inverse = if gcd_data.y >= 0 {
gcd_data.y
} else {
gcd_data.y + width as i64
} as usize;
let input_iter = (0..len)
.map(|i| (i % width, i / width))
.map(|(x, y)| (x * height + y * width) % len);
let output_iter = (0..len).map(|i| (i % height, i / height)).map(|(y, x)| {
(x * height * height_inverse as usize + y * width * width_inverse as usize) % len
});
let input_output_map: Vec<usize> = input_iter.chain(output_iter).collect();
Self {
direction: width_fft.fft_direction(),
width,
width_size_fft: width_fft,
height,
height_size_fft: height_fft,
input_output_map: input_output_map.into_boxed_slice(),
}
}
fn perform_fft_out_of_place(
&self,
input: &mut [Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
assert_eq!(self.len(), input.len());
assert_eq!(self.len(), output.len());
let (input_map, output_map) = self.input_output_map.split_at(self.len());
for (output_element, &input_index) in output.iter_mut().zip(input_map.iter()) {
*output_element = input[input_index];
}
self.width_size_fft.process_with_scratch(output, input);
unsafe { array_utils::transpose_small(self.width, self.height, output, input) };
self.height_size_fft.process_with_scratch(input, output);
for (input_element, &output_index) in input.iter().zip(output_map.iter()) {
output[output_index] = *input_element;
}
}
fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
assert_eq!(self.len(), buffer.len());
assert_eq!(self.len(), scratch.len());
let (input_map, output_map) = self.input_output_map.split_at(self.len());
for (output_element, &input_index) in scratch.iter_mut().zip(input_map.iter()) {
*output_element = buffer[input_index];
}
self.width_size_fft.process_with_scratch(scratch, buffer);
unsafe { array_utils::transpose_small(self.width, self.height, scratch, buffer) };
self.height_size_fft
.process_outofplace_with_scratch(buffer, scratch, &mut []);
for (input_element, &output_index) in scratch.iter().zip(output_map.iter()) {
buffer[output_index] = *input_element;
}
}
}
boilerplate_fft!(
GoodThomasAlgorithmSmall,
|this: &GoodThomasAlgorithmSmall<_>| this.width * this.height,
|this: &GoodThomasAlgorithmSmall<_>| this.len(),
|_| 0
);
#[cfg(test)]
mod unit_tests {
use super::*;
use crate::test_utils::check_fft_algorithm;
use crate::{algorithm::Dft, test_utils::BigScratchAlgorithm};
use num_integer::gcd;
use num_traits::Zero;
use std::sync::Arc;
#[test]
fn test_good_thomas() {
for width in 1..12 {
for height in 1..12 {
if gcd(width, height) == 1 {
test_good_thomas_with_lengths(width, height, FftDirection::Forward);
test_good_thomas_with_lengths(width, height, FftDirection::Inverse);
}
}
}
}
#[test]
fn test_good_thomas_small() {
let butterfly_sizes = [2, 3, 4, 5, 6, 7, 8, 16];
for width in &butterfly_sizes {
for height in &butterfly_sizes {
if gcd(*width, *height) == 1 {
test_good_thomas_small_with_lengths(*width, *height, FftDirection::Forward);
test_good_thomas_small_with_lengths(*width, *height, FftDirection::Inverse);
}
}
}
}
fn test_good_thomas_with_lengths(width: usize, height: usize, direction: FftDirection) {
let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
let fft = GoodThomasAlgorithm::new(width_fft, height_fft);
check_fft_algorithm(&fft, width * height, direction);
}
fn test_good_thomas_small_with_lengths(width: usize, height: usize, direction: FftDirection) {
let width_fft = Arc::new(Dft::new(width, direction)) as Arc<dyn Fft<f32>>;
let height_fft = Arc::new(Dft::new(height, direction)) as Arc<dyn Fft<f32>>;
let fft = GoodThomasAlgorithmSmall::new(width_fft, height_fft);
check_fft_algorithm(&fft, width * height, direction);
}
#[test]
fn test_output_mapping() {
let width = 15;
for height in 3..width {
if gcd(width, height) == 1 {
let width_fft =
Arc::new(Dft::new(width, FftDirection::Forward)) as Arc<dyn Fft<f32>>;
let height_fft =
Arc::new(Dft::new(height, FftDirection::Forward)) as Arc<dyn Fft<f32>>;
let fft = GoodThomasAlgorithm::new(width_fft, height_fft);
let mut buffer = vec![Complex { re: 0.0, im: 0.0 }; fft.len()];
fft.process(&mut buffer);
}
}
}
#[test]
fn test_good_thomas_inner_scratch() {
let scratch_lengths = [1, 5, 24];
let mut inner_ffts = Vec::new();
for &len in &scratch_lengths {
for &inplace_scratch in &scratch_lengths {
for &outofplace_scratch in &scratch_lengths {
inner_ffts.push(Arc::new(BigScratchAlgorithm {
len,
inplace_scratch,
outofplace_scratch,
direction: FftDirection::Forward,
}) as Arc<dyn Fft<f32>>);
}
}
}
for width_fft in inner_ffts.iter() {
for height_fft in inner_ffts.iter() {
if width_fft.len() == height_fft.len() {
continue;
}
let fft = GoodThomasAlgorithm::new(Arc::clone(width_fft), Arc::clone(height_fft));
let mut inplace_buffer = vec![Complex::zero(); fft.len()];
let mut inplace_scratch = vec![Complex::zero(); fft.get_inplace_scratch_len()];
fft.process_with_scratch(&mut inplace_buffer, &mut inplace_scratch);
let mut outofplace_input = vec![Complex::zero(); fft.len()];
let mut outofplace_output = vec![Complex::zero(); fft.len()];
let mut outofplace_scratch =
vec![Complex::zero(); fft.get_outofplace_scratch_len()];
fft.process_outofplace_with_scratch(
&mut outofplace_input,
&mut outofplace_output,
&mut outofplace_scratch,
);
}
}
}
}