use crate::dft::problem::Sign;
use crate::kernel::{Complex, Float};
use crate::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)] pub enum IndirectStrategy {
Stride,
IndexArray,
}
pub struct IndirectSolver<T: Float> {
n: usize,
input_stride: isize,
output_stride: isize,
strategy: IndirectStrategy,
_marker: core::marker::PhantomData<T>,
}
impl<T: Float> Default for IndirectSolver<T> {
fn default() -> Self {
Self::new_contiguous(1)
}
}
impl<T: Float> IndirectSolver<T> {
#[must_use]
pub fn new_contiguous(n: usize) -> Self {
Self {
n,
input_stride: 1,
output_stride: 1,
strategy: IndirectStrategy::Stride,
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn new_uniform(n: usize, stride: isize) -> Self {
Self {
n,
input_stride: stride,
output_stride: stride,
strategy: IndirectStrategy::Stride,
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn new(n: usize, input_stride: isize, output_stride: isize) -> Self {
Self {
n,
input_stride,
output_stride,
strategy: IndirectStrategy::Stride,
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn name(&self) -> &'static str {
"dft-indirect"
}
#[must_use]
pub fn n(&self) -> usize {
self.n
}
#[must_use]
pub fn input_stride(&self) -> isize {
self.input_stride
}
#[must_use]
pub fn output_stride(&self) -> isize {
self.output_stride
}
#[must_use]
pub fn is_contiguous(&self) -> bool {
self.input_stride == 1 && self.output_stride == 1
}
#[must_use]
pub fn has_uniform_stride(&self) -> bool {
self.input_stride == self.output_stride
}
pub fn execute<F>(
&self,
input: &[Complex<T>],
input_base: usize,
output: &mut [Complex<T>],
output_base: usize,
sign: Sign,
fft_fn: F,
) where
F: FnOnce(&[Complex<T>], &mut [Complex<T>], Sign),
{
if self.n == 0 {
return;
}
if self.is_contiguous() && input_base == 0 && output_base == 0 {
fft_fn(&input[..self.n], &mut output[..self.n], sign);
return;
}
let mut in_buf = vec![Complex::zero(); self.n];
self.gather(input, input_base, &mut in_buf);
let mut out_buf = vec![Complex::zero(); self.n];
fft_fn(&in_buf, &mut out_buf, sign);
self.scatter(&out_buf, output, output_base);
}
pub fn execute_inplace<F>(&self, data: &mut [Complex<T>], base: usize, sign: Sign, fft_fn: F)
where
F: FnOnce(&mut [Complex<T>], Sign),
{
if self.n == 0 {
return;
}
if self.is_contiguous() && base == 0 {
fft_fn(&mut data[..self.n], sign);
return;
}
let mut buf = vec![Complex::zero(); self.n];
self.gather(data, base, &mut buf);
fft_fn(&mut buf, sign);
self.scatter(&buf, data, base);
}
pub fn execute_simple<F>(
&self,
input: &[Complex<T>],
output: &mut [Complex<T>],
sign: Sign,
fft_fn: F,
) where
F: FnOnce(&[Complex<T>], &mut [Complex<T>], Sign),
{
self.execute(input, 0, output, 0, sign, fft_fn);
}
fn gather(&self, src: &[Complex<T>], src_base: usize, dst: &mut [Complex<T>]) {
match self.strategy {
IndirectStrategy::Stride => {
for i in 0..self.n {
let src_idx = (src_base as isize + i as isize * self.input_stride) as usize;
dst[i] = src[src_idx];
}
}
IndirectStrategy::IndexArray => {
unimplemented!("Index array strategy not yet implemented")
}
}
}
fn scatter(&self, src: &[Complex<T>], dst: &mut [Complex<T>], dst_base: usize) {
match self.strategy {
IndirectStrategy::Stride => {
for i in 0..self.n {
let dst_idx = (dst_base as isize + i as isize * self.output_stride) as usize;
dst[dst_idx] = src[i];
}
}
IndirectStrategy::IndexArray => {
unimplemented!("Index array strategy not yet implemented")
}
}
}
pub fn execute_ct(
&self,
input: &[Complex<T>],
input_base: usize,
output: &mut [Complex<T>],
output_base: usize,
sign: Sign,
) {
use super::{CooleyTukeySolver, CtVariant};
if !CooleyTukeySolver::<T>::applicable(self.n) {
panic!("IndirectSolver::execute_ct requires power-of-2 size");
}
let solver = CooleyTukeySolver::new(CtVariant::Dit);
self.execute(input, input_base, output, output_base, sign, |i, o, s| {
solver.execute(i, o, s);
});
}
pub fn execute_bluestein(
&self,
input: &[Complex<T>],
input_base: usize,
output: &mut [Complex<T>],
output_base: usize,
sign: Sign,
) {
use super::BluesteinSolver;
let solver = BluesteinSolver::new(self.n);
self.execute(input, input_base, output, output_base, sign, |i, o, s| {
solver.execute(i, o, s);
});
}
pub fn execute_auto(
&self,
input: &[Complex<T>],
input_base: usize,
output: &mut [Complex<T>],
output_base: usize,
sign: Sign,
) {
use super::{CooleyTukeySolver, GenericSolver, RaderSolver};
use crate::kernel::is_prime;
let n = self.n;
if n <= 1 {
if n == 1 {
let in_idx = input_base;
let out_idx = output_base;
output[out_idx] = input[in_idx];
}
return;
}
if CooleyTukeySolver::<T>::applicable(n) {
self.execute_ct(input, input_base, output, output_base, sign);
} else if is_prime(n) && n <= 1021 {
if let Some(solver) = RaderSolver::new(n) {
self.execute(input, input_base, output, output_base, sign, |i, o, s| {
solver.execute(i, o, s);
});
} else {
self.execute_bluestein(input, input_base, output, output_base, sign);
}
} else if is_smooth(n, 7) {
let solver = GenericSolver::new(n);
self.execute(input, input_base, output, output_base, sign, |i, o, s| {
solver.execute(i, o, s);
});
} else {
self.execute_bluestein(input, input_base, output, output_base, sign);
}
}
}
fn is_smooth(n: usize, b: usize) -> bool {
if n <= 1 {
return true;
}
let mut m = n;
for p in 2..=b {
while m.is_multiple_of(p) {
m /= p;
}
if m == 1 {
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use super::super::{CooleyTukeySolver, CtVariant};
use super::*;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
fn complex_approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
approx_eq(a.re, b.re, eps) && approx_eq(a.im, b.im, eps)
}
#[test]
fn test_indirect_contiguous() {
let n = 8;
let solver = IndirectSolver::<f64>::new_contiguous(n);
assert!(solver.is_contiguous());
let input: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64, 0.0)).collect();
let mut output = vec![Complex::zero(); n];
let ct_solver = CooleyTukeySolver::new(CtVariant::Dit);
solver.execute_simple(&input, &mut output, Sign::Forward, |i, o, s| {
ct_solver.execute(i, o, s);
});
assert!(complex_approx_eq(output[0], Complex::new(28.0, 0.0), 1e-10));
}
#[test]
fn test_indirect_uniform_stride() {
let n = 4;
let solver = IndirectSolver::<f64>::new_uniform(n, 2);
assert!(solver.has_uniform_stride());
let input: Vec<Complex<f64>> = vec![
Complex::new(0.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(1.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(2.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(3.0, 0.0),
Complex::new(99.0, 99.0),
];
let mut output = vec![Complex::zero(); 8];
solver.execute_ct(&input, 0, &mut output, 0, Sign::Forward);
assert!(complex_approx_eq(output[0], Complex::new(6.0, 0.0), 1e-10));
}
#[test]
fn test_indirect_with_offset() {
let n = 4;
let solver = IndirectSolver::<f64>::new_contiguous(n);
let input: Vec<Complex<f64>> = vec![
Complex::new(99.0, 99.0),
Complex::new(99.0, 99.0),
Complex::new(1.0, 0.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 0.0),
];
let mut output = vec![Complex::zero(); 6];
let ct_solver = CooleyTukeySolver::new(CtVariant::Dit);
solver.execute(&input, 2, &mut output, 2, Sign::Forward, |i, o, s| {
ct_solver.execute(i, o, s);
});
assert!(complex_approx_eq(output[2], Complex::new(10.0, 0.0), 1e-10));
}
#[test]
fn test_indirect_roundtrip() {
let n = 8;
let solver = IndirectSolver::<f64>::new_uniform(n, 2);
let mut data: Vec<Complex<f64>> = vec![Complex::zero(); 16];
for i in 0..n {
data[i * 2] = Complex::new((i as f64).sin(), (i as f64).cos());
}
let original: Vec<Complex<f64>> = (0..n).map(|i| data[i * 2]).collect();
let mut transformed = vec![Complex::zero(); 16];
solver.execute_ct(&data, 0, &mut transformed, 0, Sign::Forward);
let mut recovered = vec![Complex::zero(); 16];
solver.execute_ct(&transformed, 0, &mut recovered, 0, Sign::Backward);
let scale = n as f64;
for i in 0..n {
let idx = i * 2;
let normalized = Complex::new(recovered[idx].re / scale, recovered[idx].im / scale);
assert!(complex_approx_eq(original[i], normalized, 1e-10));
}
}
#[test]
fn test_indirect_inplace() {
let n = 4;
let solver = IndirectSolver::<f64>::new_uniform(n, 2);
let mut data: Vec<Complex<f64>> = vec![
Complex::new(0.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(1.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(2.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(3.0, 0.0),
Complex::new(99.0, 99.0),
];
let input = data.clone();
let mut expected = vec![Complex::zero(); 8];
solver.execute_ct(&input, 0, &mut expected, 0, Sign::Forward);
let ct_solver = CooleyTukeySolver::new(CtVariant::Dit);
solver.execute_inplace(&mut data, 0, Sign::Forward, |d, s| {
ct_solver.execute_inplace(d, s);
});
for i in 0..n {
assert!(complex_approx_eq(data[i * 2], expected[i * 2], 1e-10));
}
}
#[test]
fn test_indirect_different_io_strides() {
let n = 4;
let solver = IndirectSolver::<f64>::new(n, 2, 3);
assert!(!solver.has_uniform_stride());
let input: Vec<Complex<f64>> = vec![
Complex::new(0.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(1.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(2.0, 0.0),
Complex::new(99.0, 99.0),
Complex::new(3.0, 0.0),
Complex::new(99.0, 99.0),
];
let mut output = vec![Complex::zero(); 12];
solver.execute_ct(&input, 0, &mut output, 0, Sign::Forward);
assert!(complex_approx_eq(output[0], Complex::new(6.0, 0.0), 1e-10));
}
#[test]
fn test_indirect_execute_auto() {
for &n in &[4, 8, 16, 5, 7, 12, 15] {
let solver = IndirectSolver::<f64>::new_contiguous(n);
let input: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new(i as f64, (i as f64) * 0.5))
.collect();
let mut output = vec![Complex::zero(); n];
solver.execute_auto(&input, 0, &mut output, 0, Sign::Forward);
let expected_dc: f64 = (0..n).map(|i| i as f64).sum();
assert!(
approx_eq(output[0].re, expected_dc, 1e-9),
"Failed for n={}: expected {}, got {}",
n,
expected_dc,
output[0].re
);
}
}
}