use core::sync::atomic::{AtomicU64, Ordering};
use crate::dft::problem::Sign;
use crate::kernel::complex_mul::complex_mul_aos;
use crate::kernel::{Complex, Float};
use crate::prelude::*;
use super::ct::CooleyTukeySolver;
static BLUESTEIN_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
pub struct BluesteinSolver<T: Float> {
n: usize,
m: usize,
chirp_fwd: Vec<Complex<T>>,
chirp_bwd: Vec<Complex<T>>,
chirp_conj_fft_fwd: Vec<Complex<T>>,
chirp_conj_fft_bwd: Vec<Complex<T>>,
pub(crate) solver_id: u64,
#[cfg(feature = "std")]
work_y: Mutex<Vec<Complex<T>>>,
#[cfg(feature = "std")]
work_y_fft: Mutex<Vec<Complex<T>>>,
#[cfg(feature = "std")]
work_conv: Mutex<Vec<Complex<T>>>,
#[cfg(feature = "std")]
work_inplace: Mutex<Vec<Complex<T>>>,
}
impl<T: Float> BluesteinSolver<T> {
#[must_use]
pub fn new(n: usize) -> Self {
let solver_id = BLUESTEIN_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
if n == 0 {
return Self {
n: 0,
m: 0,
chirp_fwd: Vec::new(),
chirp_bwd: Vec::new(),
chirp_conj_fft_fwd: Vec::new(),
chirp_conj_fft_bwd: Vec::new(),
solver_id,
#[cfg(feature = "std")]
work_y: Mutex::new(Vec::new()),
#[cfg(feature = "std")]
work_y_fft: Mutex::new(Vec::new()),
#[cfg(feature = "std")]
work_conv: Mutex::new(Vec::new()),
#[cfg(feature = "std")]
work_inplace: Mutex::new(Vec::new()),
};
}
let m = (2 * n - 1).next_power_of_two();
let mut chirp_fwd = Vec::with_capacity(n);
for i in 0..n {
let i_sq = (i * i) % (2 * n); let angle = -<T as Float>::PI * T::from_usize(i_sq) / T::from_usize(n);
chirp_fwd.push(Complex::cis(angle));
}
let chirp_bwd: Vec<Complex<T>> = chirp_fwd.iter().map(|c| c.conj()).collect();
let mut chirp_conj = vec![Complex::zero(); m];
for i in 0..n {
chirp_conj[i] = chirp_fwd[i].conj();
}
for i in 1..n {
chirp_conj[m - i] = chirp_fwd[i].conj();
}
let mut chirp_conj_fft_fwd = vec![Complex::zero(); m];
CooleyTukeySolver::<T>::default().execute(
&chirp_conj,
&mut chirp_conj_fft_fwd,
Sign::Forward,
);
let chirp_conj_fft_bwd: Vec<Complex<T>> =
chirp_conj_fft_fwd.iter().map(|c| c.conj()).collect();
Self {
n,
m,
chirp_fwd,
chirp_bwd,
chirp_conj_fft_fwd,
chirp_conj_fft_bwd,
solver_id,
#[cfg(feature = "std")]
work_y: Mutex::new(vec![Complex::zero(); m]),
#[cfg(feature = "std")]
work_y_fft: Mutex::new(vec![Complex::zero(); m]),
#[cfg(feature = "std")]
work_conv: Mutex::new(vec![Complex::zero(); m]),
#[cfg(feature = "std")]
work_inplace: Mutex::new(vec![Complex::zero(); n]),
}
}
#[must_use]
pub fn name(&self) -> &'static str {
"dft-bluestein"
}
#[must_use]
pub fn id(&self) -> u64 {
self.solver_id
}
#[must_use]
pub fn size(&self) -> usize {
self.n
}
#[must_use]
pub fn applicable(n: usize) -> bool {
n > 0
}
fn execute_with_buffers(
&self,
input: &[Complex<T>],
output: &mut [Complex<T>],
sign: Sign,
y: &mut [Complex<T>],
y_fft: &mut [Complex<T>],
conv: &mut [Complex<T>],
) {
let n = self.n;
let m = self.m;
let ct = CooleyTukeySolver::<T>::default();
let (chirp, chirp_conv_fft) = match sign {
Sign::Forward => (&self.chirp_fwd, &self.chirp_conj_fft_fwd),
Sign::Backward => (&self.chirp_bwd, &self.chirp_conj_fft_bwd),
};
y[..m].fill(Complex::zero());
complex_mul_aos(&mut y[..n], input, chirp);
ct.execute(y, y_fft, Sign::Forward);
complex_mul_aos(&mut conv[..m], y_fft, chirp_conv_fft);
ct.execute(conv, y_fft, Sign::Backward);
let m_inv = T::ONE / T::from_usize(m);
for v in &mut y_fft[..n] {
*v = *v * m_inv;
}
complex_mul_aos(output, &y_fft[..n], chirp);
}
#[cfg(feature = "std")]
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let n = self.n;
debug_assert_eq!(input.len(), n);
debug_assert_eq!(output.len(), n);
if n == 0 {
return;
}
if n == 1 {
output[0] = input[0];
return;
}
let m = self.m;
let y_guard = self.work_y.try_lock();
let y_fft_guard = self.work_y_fft.try_lock();
let conv_guard = self.work_conv.try_lock();
if let (Ok(mut y), Ok(mut y_fft), Ok(mut conv)) = (y_guard, y_fft_guard, conv_guard) {
self.execute_with_buffers(input, output, sign, &mut y, &mut y_fft, &mut conv);
} else {
let mut y = vec![Complex::zero(); m];
let mut y_fft = vec![Complex::zero(); m];
let mut conv = vec![Complex::zero(); m];
self.execute_with_buffers(input, output, sign, &mut y, &mut y_fft, &mut conv);
}
}
#[cfg(not(feature = "std"))]
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let n = self.n;
debug_assert_eq!(input.len(), n);
debug_assert_eq!(output.len(), n);
if n == 0 {
return;
}
if n == 1 {
output[0] = input[0];
return;
}
let m = self.m;
let mut y = vec![Complex::zero(); m];
let mut y_fft = vec![Complex::zero(); m];
let mut conv = vec![Complex::zero(); m];
self.execute_with_buffers(input, output, sign, &mut y, &mut y_fft, &mut conv);
}
pub fn execute_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
let n = self.n;
debug_assert_eq!(data.len(), n);
if n <= 1 {
return;
}
#[cfg(feature = "std")]
{
if let Ok(mut inplace_buf) = self.work_inplace.try_lock() {
if inplace_buf.len() < n {
inplace_buf.resize(n, Complex::zero());
}
inplace_buf[..n].copy_from_slice(data);
let input: &[Complex<T>] = &inplace_buf[..n];
let input_ptr = input.as_ptr();
let input_slice = unsafe { core::slice::from_raw_parts(input_ptr, n) };
self.execute(input_slice, data, sign);
return;
}
}
let input: Vec<Complex<T>> = data.to_vec();
self.execute(&input, data, sign);
}
}
impl<T: Float> Default for BluesteinSolver<T> {
fn default() -> Self {
Self::new(0)
}
}
pub fn fft_bluestein<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
BluesteinSolver::new(input.len()).execute(input, output, Sign::Forward);
}
pub fn ifft_bluestein<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
BluesteinSolver::new(input.len()).execute(input, output, Sign::Backward);
}
pub fn fft_bluestein_inplace<T: Float>(data: &mut [Complex<T>]) {
BluesteinSolver::new(data.len()).execute_inplace(data, Sign::Forward);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dft::solvers::direct::DirectSolver;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
fn complex_approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
approx_eq(a.re, b.re, eps) && approx_eq(a.im, b.im, eps)
}
#[test]
fn test_bluestein_size_1() {
let input = [Complex::new(3.0_f64, 4.0)];
let mut output = [Complex::zero()];
BluesteinSolver::new(1).execute(&input, &mut output, Sign::Forward);
assert!(complex_approx_eq(output[0], input[0], 1e-10));
}
#[test]
fn test_bluestein_size_5() {
let input: Vec<Complex<f64>> = (0..5).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut output_bluestein = vec![Complex::zero(); 5];
let mut output_direct = vec![Complex::zero(); 5];
BluesteinSolver::new(5).execute(&input, &mut output_bluestein, Sign::Forward);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_bluestein.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_bluestein_size_7() {
let input: Vec<Complex<f64>> = (0..7)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_bluestein = vec![Complex::zero(); 7];
let mut output_direct = vec![Complex::zero(); 7];
BluesteinSolver::new(7).execute(&input, &mut output_bluestein, Sign::Forward);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_bluestein.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_bluestein_size_12() {
let input: Vec<Complex<f64>> = (0..12)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut output_bluestein = vec![Complex::zero(); 12];
let mut output_direct = vec![Complex::zero(); 12];
BluesteinSolver::new(12).execute(&input, &mut output_bluestein, Sign::Forward);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_bluestein.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_bluestein_inverse_recovers_input() {
let original: Vec<Complex<f64>> = (0..11)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut transformed = vec![Complex::zero(); 11];
let mut recovered = vec![Complex::zero(); 11];
let solver = BluesteinSolver::new(11);
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let n = original.len() as f64;
for x in &mut recovered {
*x = *x / n;
}
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_bluestein_inplace() {
let original: Vec<Complex<f64>> = (0..9).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut out_of_place = vec![Complex::zero(); 9];
let solver = BluesteinSolver::new(9);
solver.execute(&original, &mut out_of_place, Sign::Forward);
let mut in_place = original;
solver.execute_inplace(&mut in_place, Sign::Forward);
for (a, b) in out_of_place.iter().zip(in_place.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_fft_bluestein_convenience() {
let input: Vec<Complex<f64>> = (0..7).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut output = vec![Complex::zero(); 7];
fft_bluestein(&input, &mut output);
let energy: f64 = output.iter().map(|c| c.norm_sqr()).sum();
assert!(energy > 0.0);
}
#[test]
fn test_ifft_bluestein_convenience() {
let input: Vec<Complex<f64>> = (0..7).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut forward = vec![Complex::zero(); 7];
let mut backward = vec![Complex::zero(); 7];
fft_bluestein(&input, &mut forward);
ifft_bluestein(&forward, &mut backward);
let n = 7.0_f64;
for (orig, recov) in input.iter().zip(backward.iter()) {
assert!(complex_approx_eq(*orig, *recov / n, 1e-9));
}
}
#[test]
fn test_fft_bluestein_inplace_convenience() {
let original: Vec<Complex<f64>> = (0..6).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut inplace = original.clone();
let mut ref_output = vec![Complex::zero(); 6];
fft_bluestein(&original, &mut ref_output);
fft_bluestein_inplace(&mut inplace);
for (a, b) in ref_output.iter().zip(inplace.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
fn roundtrip_f64(n: usize) {
let original: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new((i as f64).sin(), (i as f64 * 0.7).cos()))
.collect();
let mut transformed = vec![Complex::zero(); n];
let mut recovered = vec![Complex::zero(); n];
let solver = BluesteinSolver::new(n);
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let n_f = n as f64;
let mut max_rel = 0.0_f64;
for (orig, rec) in original.iter().zip(recovered.iter()) {
let rec_scaled = *rec / n_f;
let re_err = (orig.re - rec_scaled.re).abs();
let im_err = (orig.im - rec_scaled.im).abs();
let norm = (orig.re * orig.re + orig.im * orig.im).sqrt().max(1e-30);
max_rel = max_rel.max((re_err + im_err) / norm);
}
assert!(
max_rel < 1e-13,
"bluestein f64 round-trip n={n}: max_rel={max_rel} (must be < 1e-13)"
);
}
fn roundtrip_f32(n: usize) {
let original: Vec<Complex<f32>> = (0..n)
.map(|i| Complex::new((i as f32).sin(), (i as f32 * 0.7).cos()))
.collect();
let mut transformed = vec![Complex::new(0.0_f32, 0.0); n];
let mut recovered = vec![Complex::new(0.0_f32, 0.0); n];
let solver = BluesteinSolver::<f32>::new(n);
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let n_f = n as f32;
let mut max_rel = 0.0_f32;
for (orig, rec) in original.iter().zip(recovered.iter()) {
let rec_scaled = *rec / n_f;
let re_err = (orig.re - rec_scaled.re).abs();
let im_err = (orig.im - rec_scaled.im).abs();
let norm = (orig.re * orig.re + orig.im * orig.im)
.sqrt()
.max(1e-10_f32);
max_rel = max_rel.max((re_err + im_err) / norm);
}
assert!(
max_rel < 5e-4,
"bluestein f32 round-trip n={n}: max_rel={max_rel} (must be < 5e-4)"
);
}
#[test]
fn roundtrip_prime_17_f64() {
roundtrip_f64(17);
}
#[test]
fn roundtrip_prime_61_f64() {
roundtrip_f64(61);
}
#[test]
fn roundtrip_prime_127_f64() {
roundtrip_f64(127);
}
#[test]
fn roundtrip_prime_257_f64() {
roundtrip_f64(257);
}
#[test]
fn roundtrip_prime_509_f64() {
roundtrip_f64(509);
}
#[test]
fn roundtrip_prime_1009_f64() {
roundtrip_f64(1009);
}
#[test]
fn roundtrip_prime_17_f32() {
roundtrip_f32(17);
}
#[test]
fn roundtrip_prime_61_f32() {
roundtrip_f32(61);
}
#[test]
fn roundtrip_prime_127_f32() {
roundtrip_f32(127);
}
#[test]
fn roundtrip_prime_257_f32() {
roundtrip_f32(257);
}
#[test]
fn roundtrip_prime_509_f32() {
roundtrip_f32(509);
}
#[test]
fn roundtrip_prime_1009_f32() {
roundtrip_f32(1009);
}
#[cfg(feature = "threading")]
#[test]
fn parallel_shared_solver_correctness() {
use rayon::prelude::*;
let n = 61_usize;
let solver = std::sync::Arc::new(BluesteinSolver::new(n));
let input: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new((i as f64).sin(), (i as f64).cos()))
.collect();
let mut reference = vec![Complex::zero(); n];
solver.execute(&input, &mut reference, Sign::Forward);
let results: Vec<Vec<Complex<f64>>> = (0..16_usize)
.into_par_iter()
.map(|_| {
let mut out = vec![Complex::zero(); n];
solver.execute(&input, &mut out, Sign::Forward);
out
})
.collect();
for (thread_idx, result) in results.iter().enumerate() {
for (k, (r, rr)) in result.iter().zip(reference.iter()).enumerate() {
let err = ((r.re - rr.re).abs() + (r.im - rr.im).abs())
/ (rr.re * rr.re + rr.im * rr.im).sqrt().max(1e-30);
assert!(
err < 1e-12,
"parallel thread {thread_idx} output[{k}] diverged: err={err}"
);
}
}
}
}