use crate::dft::problem::Sign;
use crate::kernel::{Complex, Float, IoDim};
use crate::prelude::*;
pub struct VrankGeq1Solver<T: Float> {
n: usize,
howmany: usize,
istride: isize,
ostride: isize,
idist: isize,
odist: isize,
_marker: core::marker::PhantomData<T>,
}
impl<T: Float> Default for VrankGeq1Solver<T> {
fn default() -> Self {
Self::new_contiguous(1, 1)
}
}
impl<T: Float> VrankGeq1Solver<T> {
#[must_use]
pub fn new(
n: usize,
howmany: usize,
istride: isize,
ostride: isize,
idist: isize,
odist: isize,
) -> Self {
Self {
n,
howmany,
istride,
ostride,
idist,
odist,
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn new_contiguous(n: usize, howmany: usize) -> Self {
Self::new(n, howmany, 1, 1, n as isize, n as isize)
}
#[must_use]
pub fn from_dims(transform_dim: &IoDim, batch_dim: &IoDim) -> Self {
Self::new(
transform_dim.n,
batch_dim.n,
transform_dim.is,
transform_dim.os,
batch_dim.is,
batch_dim.os,
)
}
#[must_use]
pub fn name(&self) -> &'static str {
"dft-vrank-geq1"
}
#[must_use]
pub fn n(&self) -> usize {
self.n
}
#[must_use]
pub fn howmany(&self) -> usize {
self.howmany
}
pub fn execute_with<F>(&self, input: &[Complex<T>], output: &mut [Complex<T>], fft_1d: F)
where
F: Fn(&[Complex<T>], &mut [Complex<T>]),
{
if self.n == 0 || self.howmany == 0 {
return;
}
if self.istride == 1
&& self.ostride == 1
&& self.idist == self.n as isize
&& self.odist == self.n as isize
{
for batch in 0..self.howmany {
let start = batch * self.n;
let end = start + self.n;
fft_1d(&input[start..end], &mut output[start..end]);
}
return;
}
let mut in_buf = vec![Complex::zero(); self.n];
let mut out_buf = vec![Complex::zero(); self.n];
for batch in 0..self.howmany {
let in_base = (batch as isize * self.idist) as usize;
let out_base = (batch as isize * self.odist) as usize;
for i in 0..self.n {
let idx = in_base as isize + i as isize * self.istride;
in_buf[i] = input[idx as usize];
}
fft_1d(&in_buf, &mut out_buf);
for i in 0..self.n {
let idx = out_base as isize + i as isize * self.ostride;
output[idx as usize] = out_buf[i];
}
}
}
pub fn execute_inplace_with<F>(&self, data: &mut [Complex<T>], fft_1d: F)
where
F: Fn(&mut [Complex<T>]),
{
if self.n == 0 || self.howmany == 0 {
return;
}
if self.istride == 1
&& self.ostride == 1
&& self.idist == self.odist
&& self.idist == self.n as isize
{
for batch in 0..self.howmany {
let start = batch * self.n;
let end = start + self.n;
fft_1d(&mut data[start..end]);
}
return;
}
let mut buf = vec![Complex::zero(); self.n];
for batch in 0..self.howmany {
let base = (batch as isize * self.idist) as usize;
for i in 0..self.n {
let idx = base as isize + i as isize * self.istride;
buf[i] = data[idx as usize];
}
fft_1d(&mut buf);
let out_base = (batch as isize * self.odist) as usize;
for i in 0..self.n {
let idx = out_base as isize + i as isize * self.ostride;
data[idx as usize] = buf[i];
}
}
}
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
use crate::dft::solvers::{
BluesteinSolver, CooleyTukeySolver, CtVariant, DirectSolver, GenericSolver, NopSolver,
};
if self.n <= 1 {
self.execute_with(input, output, |i, o| NopSolver::new().execute(i, o));
} else if CooleyTukeySolver::<T>::applicable(self.n) {
let solver = CooleyTukeySolver::new(CtVariant::Dit);
self.execute_with(input, output, |i, o| solver.execute(i, o, sign));
} else if self.n <= 16 {
let solver = DirectSolver::new();
self.execute_with(input, output, |i, o| solver.execute(i, o, sign));
} else if GenericSolver::<T>::applicable(self.n) {
let solver = GenericSolver::new(self.n);
self.execute_with(input, output, |i, o| solver.execute(i, o, sign));
} else {
let solver = BluesteinSolver::new(self.n);
self.execute_with(input, output, |i, o| solver.execute(i, o, sign));
}
}
pub fn execute_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
use crate::dft::solvers::{
BluesteinSolver, CooleyTukeySolver, CtVariant, DirectSolver, GenericSolver, NopSolver,
};
if self.n <= 1 {
self.execute_inplace_with(data, |d| NopSolver::new().execute_inplace(d));
} else if CooleyTukeySolver::<T>::applicable(self.n) {
let solver = CooleyTukeySolver::new(CtVariant::Dit);
self.execute_inplace_with(data, |d| solver.execute_inplace(d, sign));
} else if self.n <= 16 {
let solver = DirectSolver::new();
self.execute_inplace_with(data, |d| solver.execute_inplace(d, sign));
} else if GenericSolver::<T>::applicable(self.n) {
let solver = GenericSolver::new(self.n);
self.execute_inplace_with(data, |d| solver.execute_inplace(d, sign));
} else {
let solver = BluesteinSolver::new(self.n);
self.execute_inplace_with(data, |d| solver.execute_inplace(d, sign));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
fn complex_approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
approx_eq(a.re, b.re, eps) && approx_eq(a.im, b.im, eps)
}
#[test]
fn test_batch_contiguous_power_of_2() {
let n = 4;
let howmany = 3;
let solver = VrankGeq1Solver::<f64>::new_contiguous(n, howmany);
let input: Vec<Complex<f64>> = (0..(n * howmany))
.map(|i| Complex::new(i as f64, 0.0))
.collect();
let mut output = vec![Complex::zero(); n * howmany];
solver.execute(&input, &mut output, Sign::Forward);
assert!(complex_approx_eq(output[0], Complex::new(6.0, 0.0), 1e-10));
assert!(complex_approx_eq(output[4], Complex::new(22.0, 0.0), 1e-10));
assert!(complex_approx_eq(output[8], Complex::new(38.0, 0.0), 1e-10));
}
#[test]
fn test_batch_roundtrip() {
let n = 8;
let howmany = 4;
let solver = VrankGeq1Solver::<f64>::new_contiguous(n, howmany);
let original: Vec<Complex<f64>> = (0..(n * howmany))
.map(|i| Complex::new((i as f64).sin(), (i as f64).cos()))
.collect();
let mut transformed = vec![Complex::zero(); n * howmany];
let mut recovered = vec![Complex::zero(); n * howmany];
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let scale = n as f64;
for (a, b) in original.iter().zip(recovered.iter()) {
let normalized = Complex::new(b.re / scale, b.im / scale);
assert!(complex_approx_eq(*a, normalized, 1e-10));
}
}
#[test]
fn test_batch_strided_column_access() {
let rows = 4;
let cols = 4;
let solver = VrankGeq1Solver::<f64>::new(
rows, cols, cols as isize, cols as isize, 1, 1, );
let input: Vec<Complex<f64>> = (0..16).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut output = vec![Complex::zero(); 16];
solver.execute(&input, &mut output, Sign::Forward);
assert!(complex_approx_eq(output[0], Complex::new(24.0, 0.0), 1e-10));
assert!(complex_approx_eq(output[1], Complex::new(28.0, 0.0), 1e-10));
assert!(complex_approx_eq(output[2], Complex::new(32.0, 0.0), 1e-10));
assert!(complex_approx_eq(output[3], Complex::new(36.0, 0.0), 1e-10));
}
#[test]
fn test_batch_inplace() {
let n = 8;
let howmany = 3;
let solver = VrankGeq1Solver::<f64>::new_contiguous(n, howmany);
let input: Vec<Complex<f64>> = (0..(n * howmany))
.map(|i| Complex::new(i as f64, 0.0))
.collect();
let mut out_of_place = vec![Complex::zero(); n * howmany];
solver.execute(&input, &mut out_of_place, Sign::Forward);
let mut in_place = input;
solver.execute_inplace(&mut in_place, Sign::Forward);
for (a, b) in out_of_place.iter().zip(in_place.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_batch_non_power_of_2() {
let n = 5; let howmany = 2;
let solver = VrankGeq1Solver::<f64>::new_contiguous(n, howmany);
let original: Vec<Complex<f64>> = (0..(n * howmany))
.map(|i| Complex::new((i as f64).sin(), 0.0))
.collect();
let mut transformed = vec![Complex::zero(); n * howmany];
let mut recovered = vec![Complex::zero(); n * howmany];
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let scale = n as f64;
for (a, b) in original.iter().zip(recovered.iter()) {
let normalized = Complex::new(b.re / scale, b.im / scale);
assert!(complex_approx_eq(*a, normalized, 1e-9));
}
}
#[test]
fn test_batch_from_dims() {
let transform_dim = IoDim::new(8, 1, 1); let batch_dim = IoDim::new(4, 8, 8);
let solver = VrankGeq1Solver::<f64>::from_dims(&transform_dim, &batch_dim);
assert_eq!(solver.n(), 8);
assert_eq!(solver.howmany(), 4);
}
}