use crate::dft::problem::Sign;
use crate::kernel::{Complex, Float};
use crate::prelude::*;
const BUFFER_ALIGN: usize = 64;
pub struct BufferedSolver<T: Float> {
n: usize,
input_stride: isize,
output_stride: isize,
_marker: core::marker::PhantomData<T>,
}
impl<T: Float> Default for BufferedSolver<T> {
fn default() -> Self {
Self::new_contiguous(1)
}
}
impl<T: Float> BufferedSolver<T> {
#[must_use]
pub fn new_contiguous(n: usize) -> Self {
Self {
n,
input_stride: 1,
output_stride: 1,
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn new(n: usize, input_stride: isize, output_stride: isize) -> Self {
Self {
n,
input_stride,
output_stride,
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn name(&self) -> &'static str {
"dft-buffered"
}
#[must_use]
pub fn n(&self) -> usize {
self.n
}
#[must_use]
pub fn needs_buffering(&self) -> bool {
self.input_stride != 1 || self.output_stride != 1
}
pub fn execute<F>(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign, fft_fn: F)
where
F: FnOnce(&[Complex<T>], &mut [Complex<T>], Sign),
{
if self.n == 0 {
return;
}
if self.input_stride == 1 && self.output_stride == 1 {
fft_fn(input, output, sign);
return;
}
let mut in_buf = aligned_buffer(self.n);
let mut out_buf = aligned_buffer(self.n);
for i in 0..self.n {
let idx = (i as isize * self.input_stride) as usize;
in_buf[i] = input[idx];
}
fft_fn(&in_buf, &mut out_buf, sign);
for i in 0..self.n {
let idx = (i as isize * self.output_stride) as usize;
output[idx] = out_buf[i];
}
}
pub fn execute_inplace<F>(&self, data: &mut [Complex<T>], sign: Sign, fft_fn: F)
where
F: FnOnce(&mut [Complex<T>], Sign),
{
if self.n == 0 {
return;
}
if self.input_stride == 1 && self.output_stride == 1 {
fft_fn(data, sign);
return;
}
let mut buf = aligned_buffer(self.n);
for i in 0..self.n {
let idx = (i as isize * self.input_stride) as usize;
buf[i] = data[idx];
}
fft_fn(&mut buf, sign);
for i in 0..self.n {
let idx = (i as isize * self.output_stride) as usize;
data[idx] = buf[i];
}
}
pub fn execute_ct(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
use super::{CooleyTukeySolver, CtVariant};
if !CooleyTukeySolver::<T>::applicable(self.n) {
panic!("BufferedSolver::execute_ct requires power-of-2 size");
}
let solver = CooleyTukeySolver::new(CtVariant::Dit);
self.execute(input, output, sign, |i, o, s| solver.execute(i, o, s));
}
pub fn execute_ct_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
use super::{CooleyTukeySolver, CtVariant};
if !CooleyTukeySolver::<T>::applicable(self.n) {
panic!("BufferedSolver::execute_ct_inplace requires power-of-2 size");
}
let solver = CooleyTukeySolver::new(CtVariant::Dit);
self.execute_inplace(data, sign, |d, s| solver.execute_inplace(d, s));
}
}
fn aligned_buffer<T: Float>(n: usize) -> Vec<Complex<T>> {
let _ = BUFFER_ALIGN; vec![Complex::zero(); n]
}
#[cfg(test)]
mod tests {
use super::super::{CooleyTukeySolver, CtVariant};
use super::*;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
fn complex_approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
approx_eq(a.re, b.re, eps) && approx_eq(a.im, b.im, eps)
}
#[test]
fn test_buffered_contiguous() {
let n = 8;
let solver = BufferedSolver::<f64>::new_contiguous(n);
assert!(!solver.needs_buffering());
let input: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64, 0.0)).collect();
let mut output = vec![Complex::zero(); n];
solver.execute_ct(&input, &mut output, Sign::Forward);
assert!(complex_approx_eq(output[0], Complex::new(28.0, 0.0), 1e-10));
}
#[test]
fn test_buffered_strided() {
let n = 4;
let solver = BufferedSolver::<f64>::new(n, 2, 2);
assert!(solver.needs_buffering());
let input: Vec<Complex<f64>> = vec![
Complex::new(0.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(1.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(2.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(3.0, 0.0),
Complex::new(99.0, 99.0),
];
let mut output = vec![Complex::zero(); 8];
solver.execute_ct(&input, &mut output, Sign::Forward);
assert!(complex_approx_eq(output[0], Complex::new(6.0, 0.0), 1e-10));
}
#[test]
fn test_buffered_roundtrip() {
let n = 8;
let solver = BufferedSolver::<f64>::new_contiguous(n);
let original: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new((i as f64).sin(), (i as f64).cos()))
.collect();
let mut transformed = vec![Complex::zero(); n];
let mut recovered = vec![Complex::zero(); n];
solver.execute_ct(&original, &mut transformed, Sign::Forward);
solver.execute_ct(&transformed, &mut recovered, Sign::Backward);
let scale = n as f64;
for (a, b) in original.iter().zip(recovered.iter()) {
let normalized = Complex::new(b.re / scale, b.im / scale);
assert!(complex_approx_eq(*a, normalized, 1e-10));
}
}
#[test]
fn test_buffered_inplace() {
let n = 8;
let solver = BufferedSolver::<f64>::new_contiguous(n);
let input: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64, 0.0)).collect();
let mut out_of_place = vec![Complex::zero(); n];
solver.execute_ct(&input, &mut out_of_place, Sign::Forward);
let mut in_place = input;
solver.execute_ct_inplace(&mut in_place, Sign::Forward);
for (a, b) in out_of_place.iter().zip(in_place.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_buffered_generic_callback() {
let n = 8;
let solver = BufferedSolver::<f64>::new(n, 2, 2);
let ct_solver = CooleyTukeySolver::<f64>::new(CtVariant::Dit);
let input: Vec<Complex<f64>> = vec![
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(2.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(4.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(5.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(6.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(7.0, 0.0),
Complex::new(0.0, 0.0),
];
let mut output = vec![Complex::zero(); 16];
solver.execute(&input, &mut output, Sign::Forward, |i, o, s| {
ct_solver.execute(i, o, s);
});
assert!(complex_approx_eq(output[0], Complex::new(28.0, 0.0), 1e-10));
}
}