use crate::dft::problem::Sign;
use crate::kernel::{Complex, Float};
use crate::prelude::*;
use super::ct::CooleyTukeySolver;
pub struct BluesteinSolver<T: Float> {
n: usize,
m: usize,
chirp: Vec<Complex<T>>,
chirp_conj_fft: Vec<Complex<T>>,
#[cfg(feature = "std")]
work_y: Mutex<Vec<Complex<T>>>,
#[cfg(feature = "std")]
work_y_fft: Mutex<Vec<Complex<T>>>,
#[cfg(feature = "std")]
work_conv: Mutex<Vec<Complex<T>>>,
}
impl<T: Float> BluesteinSolver<T> {
#[must_use]
pub fn new(n: usize) -> Self {
if n == 0 {
return Self {
n: 0,
m: 0,
chirp: Vec::new(),
chirp_conj_fft: Vec::new(),
#[cfg(feature = "std")]
work_y: Mutex::new(Vec::new()),
#[cfg(feature = "std")]
work_y_fft: Mutex::new(Vec::new()),
#[cfg(feature = "std")]
work_conv: Mutex::new(Vec::new()),
};
}
let m = (2 * n - 1).next_power_of_two();
let mut chirp = Vec::with_capacity(n);
for i in 0..n {
let i_sq = (i * i) % (2 * n); let angle = -<T as Float>::PI * T::from_usize(i_sq) / T::from_usize(n);
chirp.push(Complex::cis(angle));
}
let mut chirp_conj = vec![Complex::zero(); m];
for i in 0..n {
chirp_conj[i] = chirp[i].conj();
}
for i in 1..n {
chirp_conj[m - i] = chirp[i].conj();
}
let mut chirp_conj_fft = vec![Complex::zero(); m];
CooleyTukeySolver::<T>::default().execute(&chirp_conj, &mut chirp_conj_fft, Sign::Forward);
Self {
n,
m,
chirp,
chirp_conj_fft,
#[cfg(feature = "std")]
work_y: Mutex::new(vec![Complex::zero(); m]),
#[cfg(feature = "std")]
work_y_fft: Mutex::new(vec![Complex::zero(); m]),
#[cfg(feature = "std")]
work_conv: Mutex::new(vec![Complex::zero(); m]),
}
}
#[must_use]
pub fn name(&self) -> &'static str {
"dft-bluestein"
}
#[must_use]
pub fn size(&self) -> usize {
self.n
}
#[must_use]
pub fn applicable(n: usize) -> bool {
n > 0
}
fn execute_with_buffers(
&self,
input: &[Complex<T>],
output: &mut [Complex<T>],
sign: Sign,
y: &mut [Complex<T>],
y_fft: &mut [Complex<T>],
conv: &mut [Complex<T>],
) {
let n = self.n;
let m = self.m;
let ct = CooleyTukeySolver::<T>::default();
for i in 0..m {
y[i] = Complex::zero();
}
match sign {
Sign::Forward => {
for i in 0..n {
y[i] = input[i] * self.chirp[i];
}
}
Sign::Backward => {
for i in 0..n {
y[i] = input[i] * self.chirp[i].conj();
}
}
}
ct.execute(y, y_fft, Sign::Forward);
match sign {
Sign::Forward => {
for i in 0..m {
y_fft[i] = y_fft[i] * self.chirp_conj_fft[i];
}
}
Sign::Backward => {
for i in 0..m {
y_fft[i] = y_fft[i] * self.chirp_conj_fft[i].conj();
}
}
}
ct.execute(y_fft, conv, Sign::Backward);
let m_inv = T::ONE / T::from_usize(m);
match sign {
Sign::Forward => {
for i in 0..n {
output[i] = conv[i] * m_inv * self.chirp[i];
}
}
Sign::Backward => {
for i in 0..n {
output[i] = conv[i] * m_inv * self.chirp[i].conj();
}
}
}
}
#[cfg(feature = "std")]
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let n = self.n;
debug_assert_eq!(input.len(), n);
debug_assert_eq!(output.len(), n);
if n == 0 {
return;
}
if n == 1 {
output[0] = input[0];
return;
}
let m = self.m;
let y_guard = self.work_y.try_lock();
let y_fft_guard = self.work_y_fft.try_lock();
let conv_guard = self.work_conv.try_lock();
if let (Ok(mut y), Ok(mut y_fft), Ok(mut conv)) = (y_guard, y_fft_guard, conv_guard) {
self.execute_with_buffers(input, output, sign, &mut y, &mut y_fft, &mut conv);
} else {
let mut y = vec![Complex::zero(); m];
let mut y_fft = vec![Complex::zero(); m];
let mut conv = vec![Complex::zero(); m];
self.execute_with_buffers(input, output, sign, &mut y, &mut y_fft, &mut conv);
}
}
#[cfg(not(feature = "std"))]
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let n = self.n;
debug_assert_eq!(input.len(), n);
debug_assert_eq!(output.len(), n);
if n == 0 {
return;
}
if n == 1 {
output[0] = input[0];
return;
}
let m = self.m;
let mut y = vec![Complex::zero(); m];
let mut y_fft = vec![Complex::zero(); m];
let mut conv = vec![Complex::zero(); m];
self.execute_with_buffers(input, output, sign, &mut y, &mut y_fft, &mut conv);
}
pub fn execute_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
let n = self.n;
debug_assert_eq!(data.len(), n);
if n <= 1 {
return;
}
let input: Vec<Complex<T>> = data.to_vec();
self.execute(&input, data, sign);
}
}
impl<T: Float> Default for BluesteinSolver<T> {
fn default() -> Self {
Self::new(0)
}
}
#[allow(dead_code)]
pub fn fft_bluestein<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
BluesteinSolver::new(input.len()).execute(input, output, Sign::Forward);
}
#[allow(dead_code)]
pub fn ifft_bluestein<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
BluesteinSolver::new(input.len()).execute(input, output, Sign::Backward);
}
#[allow(dead_code)]
pub fn fft_bluestein_inplace<T: Float>(data: &mut [Complex<T>]) {
BluesteinSolver::new(data.len()).execute_inplace(data, Sign::Forward);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dft::solvers::direct::DirectSolver;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
fn complex_approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
approx_eq(a.re, b.re, eps) && approx_eq(a.im, b.im, eps)
}
#[test]
fn test_bluestein_size_1() {
let input = [Complex::new(3.0_f64, 4.0)];
let mut output = [Complex::zero()];
BluesteinSolver::new(1).execute(&input, &mut output, Sign::Forward);
assert!(complex_approx_eq(output[0], input[0], 1e-10));
}
#[test]
fn test_bluestein_size_5() {
let input: Vec<Complex<f64>> = (0..5).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut output_bluestein = vec![Complex::zero(); 5];
let mut output_direct = vec![Complex::zero(); 5];
BluesteinSolver::new(5).execute(&input, &mut output_bluestein, Sign::Forward);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_bluestein.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_bluestein_size_7() {
let input: Vec<Complex<f64>> = (0..7)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_bluestein = vec![Complex::zero(); 7];
let mut output_direct = vec![Complex::zero(); 7];
BluesteinSolver::new(7).execute(&input, &mut output_bluestein, Sign::Forward);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_bluestein.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_bluestein_size_12() {
let input: Vec<Complex<f64>> = (0..12)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut output_bluestein = vec![Complex::zero(); 12];
let mut output_direct = vec![Complex::zero(); 12];
BluesteinSolver::new(12).execute(&input, &mut output_bluestein, Sign::Forward);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_bluestein.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_bluestein_inverse_recovers_input() {
let original: Vec<Complex<f64>> = (0..11)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut transformed = vec![Complex::zero(); 11];
let mut recovered = vec![Complex::zero(); 11];
let solver = BluesteinSolver::new(11);
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let n = original.len() as f64;
for x in &mut recovered {
*x = *x / n;
}
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_bluestein_inplace() {
let original: Vec<Complex<f64>> = (0..9).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut out_of_place = vec![Complex::zero(); 9];
let solver = BluesteinSolver::new(9);
solver.execute(&original, &mut out_of_place, Sign::Forward);
let mut in_place = original;
solver.execute_inplace(&mut in_place, Sign::Forward);
for (a, b) in out_of_place.iter().zip(in_place.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
}