Skip to main content

wgsl_fft/
fft.rs

1//! GPU-accelerated FFT implementation using wgpu compute shaders.
2//!
3//! Implements the **Stockham autosort** Radix-4/2 FFT — a two-buffer ping-pong formulation
4//! where each stage reads from one buffer and writes to the other. This eliminates the separate
5//! bit-reversal pass and removes all inter-stage memory hazards.
6//!
7//! Also implements **Bluestein's algorithm** for arbitrary FFT sizes (not just powers of 2).
8
9use std::any::Any;
10use std::cell::RefCell;
11use std::num::NonZeroU64;
12
13use num_complex::Complex;
14
15use crate::error::{FftError, Result};
16use crate::shaders;
17
18/// Number of components in a complex number (real and imaginary)
19const COMPLEX_COMPONENT_COUNT: usize = 2;
20
21/// Byte size of an f32
22const F32_BYTE_SIZE: usize = std::mem::size_of::<f32>();
23
24/// Trait for FFT implementations that can be benchmarked.
25pub trait FftExecutor {
26    fn name(&self) -> &str;
27    fn fft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>>;
28    fn ifft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>>;
29
30    /// Get a reference to the underlying type for downcasting.
31    fn as_any(&self) -> &dyn Any;
32}
33
34/// Trait for GPU FFT implementations that support GPU-only benchmarking.
35pub trait GpuFftTrait {
36    /// Benchmark only the GPU compute pass and DMA operations (isolated from CPU overhead).
37    /// Returns duration in seconds for the GPU operations only.
38    fn benchmark_gpu_only(
39        &self,
40        sc: &SizeCache,
41        batch_size: u32,
42        n: usize,
43        warmup_iters: usize,
44        bench_iters: usize,
45    ) -> Result<f64>;
46
47    /// Get or build size-specific GPU resources.
48    fn get_or_build_size_cache(&self, n: usize, log_n: u32) -> SizeCache;
49
50    /// Prepare input data for GPU processing, applying conjugation for IFFT if needed.
51    fn prepare_input_data(&self, input: &[Complex<f32>], inverse: bool) -> Vec<f32>;
52
53    /// Get the queue for GPU operations.
54    fn queue(&self) -> &wgpu::Queue;
55}
56
57/// Pre-allocated GPU resources for a specific FFT size.
58#[derive(Clone, Debug)]
59pub struct SizeCache {
60    pub buf_a: wgpu::Buffer,
61    pub buf_b: wgpu::Buffer,
62    pub staging_buf: wgpu::Buffer,
63    pub twiddle_buf: wgpu::Buffer,
64    pub data_bytes: u64,
65    /// R4 stages (R4 mode) or R2 stages (legacy with_shader mode).
66    pub stage_bgs: Vec<wgpu::BindGroup>,
67    /// Final R2 stage when log₂N is odd (R4 mode only).
68    pub stage_bg_r2: Option<wgpu::BindGroup>,
69    pub result_in_b: bool,
70    /// Workgroup count for the main-stage dispatch (N/4 in R4 mode, N/2 in legacy mode).
71    pub wg_n2: u32,
72    /// Workgroup count for R4 dispatch (N/4). 0 in legacy mode.
73    pub wg_r4: u32,
74}
75
76/// Uniforms passed to the compute shader (16-byte aligned).
77#[repr(C)]
78#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
79pub struct FftUniforms {
80    pub n: u32,
81    pub stage: u32,
82    pub log_n: u32,
83    pub _pad: u32,
84}
85
86/// GPU-accelerated FFT engine backed by wgpu compute shaders.
87///
88/// Implements the Stockham autosort Radix-4 algorithm with an optional Radix-2
89/// final stage for odd log₂N sizes. Use [`GpuFft::new`] for the default R4
90/// pipeline or [`GpuFft::with_shader`] to supply a custom WGSL kernel.
91///
92/// For arbitrary FFT sizes (not powers of 2), Bluestein's algorithm is used automatically.
93#[derive(Debug)]
94pub struct GpuFft {
95    pub device: wgpu::Device,
96    pub queue: wgpu::Queue,
97    pub pipeline: wgpu::ComputePipeline,
98    /// Present only when created via `new()` (R4 mode). `None` in legacy `with_shader` mode.
99    pub pipeline_r2: Option<wgpu::ComputePipeline>,
100    pub cache: RefCell<std::collections::HashMap<usize, SizeCache>>,
101    /// Bluestein algorithm pipelines for GPU-accelerated arbitrary size FFT
102    pub pipeline_bluestein_chirp: wgpu::ComputePipeline,
103    pub pipeline_bluestein_inv_chirp: wgpu::ComputePipeline,
104    pub pipeline_bluestein_zero_pad: wgpu::ComputePipeline,
105    /// Cache for precomputed Bluestein chirp FFTs: (n, is_inverse) -> B_fft
106    pub bluestein_cache: RefCell<std::collections::HashMap<(usize, bool), Vec<Complex<f32>>>>,
107}
108
109impl FftExecutor for GpuFft {
110    fn name(&self) -> &str {
111        "Baseline (Stockham Radix-4/2)"
112    }
113
114    fn fft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>> {
115        self.transform_batch_internal(inputs, false)
116    }
117
118    fn ifft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>> {
119        self.transform_batch_internal(inputs, true)
120    }
121
122    fn as_any(&self) -> &dyn Any {
123        self
124    }
125}
126
127impl GpuFftTrait for GpuFft {
128    fn benchmark_gpu_only(
129        &self,
130        sc: &SizeCache,
131        batch_size: u32,
132        n: usize,
133        warmup_iters: usize,
134        bench_iters: usize,
135    ) -> Result<f64> {
136        use std::time::Instant;
137
138        // Warmup
139        for _ in 0..warmup_iters {
140            self.execute_compute_pass(sc, batch_size, n);
141            self.device.poll(wgpu::PollType::Wait {
142                submission_index: None,
143                timeout: None,
144            })?;
145        }
146
147        // Benchmark
148        let start = Instant::now();
149        for _ in 0..bench_iters {
150            self.execute_compute_pass(sc, batch_size, n);
151        }
152
153        self.device.poll(wgpu::PollType::Wait {
154            submission_index: None,
155            timeout: None,
156        })?;
157
158        let duration = start.elapsed();
159        Ok(duration.as_secs_f64() / bench_iters as f64)
160    }
161
162    fn get_or_build_size_cache(&self, n: usize, log_n: u32) -> SizeCache {
163        self.get_or_build_size_cache(n, log_n)
164    }
165
166    fn prepare_input_data(&self, input: &[Complex<f32>], inverse: bool) -> Vec<f32> {
167        self.prepare_input_data(input, inverse)
168    }
169
170    fn queue(&self) -> &wgpu::Queue {
171        &self.queue
172    }
173}
174
175impl GpuFft {
176    /// Access the underlying wgpu device.
177    pub fn device(&self) -> &wgpu::Device {
178        &self.device
179    }
180
181    /// Access the compiled compute pipeline.
182    pub fn compute_pipeline(&self) -> &wgpu::ComputePipeline {
183        &self.pipeline
184    }
185
186    /// Create a new [`GpuFft`] using the Radix-4/2 Stockham baseline.
187    ///
188    /// Dispatches ⌊log₄N⌋ Radix-4 passes (+ one Radix-2 pass when log₂N is odd),
189    /// halving the pass count vs the old Radix-2 baseline.
190    ///
191    /// For arbitrary FFT sizes (not powers of 2), Bluestein's algorithm is used automatically.
192    ///
193    /// # Examples
194    ///
195    /// ```no_run
196    /// use wgsl_fft::GpuFft;
197    ///
198    /// let fft = GpuFft::new().expect("GPU required");
199    /// // Now use fft.fft() and fft.ifft()
200    /// ```
201    pub fn new() -> Result<Self> {
202        let instance = wgpu::Instance::default();
203        let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
204            power_preference: wgpu::PowerPreference::HighPerformance,
205            compatible_surface: None,
206            force_fallback_adapter: false,
207        }))
208        .or_else(|_| {
209            pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
210                power_preference: wgpu::PowerPreference::HighPerformance,
211                compatible_surface: None,
212                force_fallback_adapter: true,
213            }))
214        })?;
215
216        let (device, queue) =
217            pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
218                ..Default::default()
219            }))?;
220        Self::from_device_queue(device, queue)
221    }
222
223    /// Create a new [`GpuFft`] using the Radix-4/2 Stockham baseline with an existing device and queue.
224    ///
225    /// This constructor allows you to provide your own wgpu device and queue, which is useful
226    /// when you want to share a single GPU context across multiple resources.
227    ///
228    /// # Arguments
229    ///
230    /// * `device` - A wgpu device to use for creating resources.
231    /// * `queue` - A wgpu queue to use for submitting commands.
232    pub fn from_device_queue(device: wgpu::Device, queue: wgpu::Queue) -> Result<Self> {
233        let compile = |src: &str, label: &str| {
234            let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
235                label: Some(label),
236                source: wgpu::ShaderSource::Wgsl(src.into()),
237            });
238            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
239                label: Some(&format!("{label}_pipeline")),
240                layout: None,
241                module: &shader,
242                entry_point: Some("main"),
243                compilation_options: Default::default(),
244                cache: None,
245            })
246        };
247
248        let pipeline = compile(shaders::R4_WGSL, "stockham_r4");
249        let pipeline_r2 = Some(compile(shaders::R2_WGSL, "stockham_r2"));
250
251        // Bluestein algorithm pipelines for arbitrary size FFT (fully GPU-accelerated)
252        let pipeline_bluestein_chirp = compile(shaders::BLUESTEIN_CHIRP_WGSL, "bluestein_chirp");
253        let pipeline_bluestein_inv_chirp =
254            compile(shaders::BLUESTEIN_INV_CHIRP_WGSL, "bluestein_inv_chirp");
255        let pipeline_bluestein_zero_pad =
256            compile(shaders::BLUESTEIN_ZERO_PAD_WGSL, "bluestein_zero_pad");
257
258        Ok(Self {
259            device,
260            queue,
261            pipeline,
262            pipeline_r2,
263            cache: RefCell::new(std::collections::HashMap::new()),
264            pipeline_bluestein_chirp,
265            pipeline_bluestein_inv_chirp,
266            pipeline_bluestein_zero_pad,
267            bluestein_cache: RefCell::new(std::collections::HashMap::new()),
268        })
269    }
270
271    /// Create a new [`GpuFft`] with a custom WGSL shader.
272    /// This allows AI rivals to swap kernels easily.
273    pub fn with_shader(wgsl_source: String, label: &str) -> Result<Self> {
274        let instance = wgpu::Instance::default();
275        let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
276            power_preference: wgpu::PowerPreference::HighPerformance,
277            compatible_surface: None,
278            force_fallback_adapter: false,
279        }))
280        .or_else(|_| {
281            pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
282                power_preference: wgpu::PowerPreference::HighPerformance,
283                compatible_surface: None,
284                force_fallback_adapter: true,
285            }))
286        })?;
287
288        let (device, queue) =
289            pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
290                ..Default::default()
291            }))?;
292        Self::with_shader_and_device(device, queue, wgsl_source, label)
293    }
294
295    /// Create a new [`GpuFft`] with a custom WGSL shader using an existing device and queue.
296    ///
297    /// This allows AI rivals to swap kernels easily while sharing a GPU context.
298    ///
299    /// # Arguments
300    ///
301    /// * `device` - A wgpu device to use for creating resources.
302    /// * `queue` - A wgpu queue to use for submitting commands.
303    /// * `wgsl_source` - The WGSL shader source code.
304    /// * `label` - A label for the shader and pipeline.
305    pub fn with_shader_and_device(
306        device: wgpu::Device,
307        queue: wgpu::Queue,
308        wgsl_source: String,
309        label: &str,
310    ) -> Result<Self> {
311        let shader_mod = device.create_shader_module(wgpu::ShaderModuleDescriptor {
312            label: Some(label),
313            source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
314        });
315
316        let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
317            label: Some(&format!("{}_pipeline", label)),
318            layout: None,
319            module: &shader_mod,
320            entry_point: Some("main"),
321            compilation_options: Default::default(),
322            cache: None,
323        });
324
325        // Bluestein algorithm pipelines for arbitrary size FFT (fully GPU-accelerated)
326        let compile_bluestein = |src: &str, label: &str| {
327            let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
328                label: Some(label),
329                source: wgpu::ShaderSource::Wgsl(src.into()),
330            });
331            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
332                label: Some(&format!("{label}_pipeline")),
333                layout: None,
334                module: &shader,
335                entry_point: Some("main"),
336                compilation_options: Default::default(),
337                cache: None,
338            })
339        };
340        let pipeline_bluestein_chirp =
341            compile_bluestein(shaders::BLUESTEIN_CHIRP_WGSL, "bluestein_chirp");
342        let pipeline_bluestein_inv_chirp =
343            compile_bluestein(shaders::BLUESTEIN_INV_CHIRP_WGSL, "bluestein_inv_chirp");
344        let pipeline_bluestein_zero_pad =
345            compile_bluestein(shaders::BLUESTEIN_ZERO_PAD_WGSL, "bluestein_zero_pad");
346
347        Ok(Self {
348            device,
349            queue,
350            pipeline,
351            pipeline_r2: None, // legacy single-pipeline mode
352            cache: RefCell::new(std::collections::HashMap::new()),
353            pipeline_bluestein_chirp,
354            pipeline_bluestein_inv_chirp,
355            pipeline_bluestein_zero_pad,
356            bluestein_cache: RefCell::new(std::collections::HashMap::new()),
357        })
358    }
359
360    /// Check if a GPU is available without creating an instance.
361    pub fn is_gpu_available() -> bool {
362        let instance = wgpu::Instance::default();
363        pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
364            power_preference: wgpu::PowerPreference::HighPerformance,
365            compatible_surface: None,
366            force_fallback_adapter: false,
367        }))
368        .is_ok()
369    }
370
371    /// Compute the forward FFT for a batch of input vectors.
372    ///
373    /// Processes multiple FFTs efficiently. For single vector processing,
374    /// pass a vector containing one input vector.
375    /// All input vectors must have the same length.
376    ///
377    /// For power-of-two sizes, uses the fast Stockham Radix-4/2 algorithm.
378    /// For arbitrary sizes, uses Bluestein's algorithm.
379    ///
380    /// # Arguments
381    ///
382    /// * `inputs` - A vector of input vectors, each containing complex samples.
383    ///
384    /// # Returns
385    ///
386    /// A vector of FFT results, one for each input vector.
387    ///
388    /// # Panics
389    ///
390    /// Panics if any input vector is empty or has a different length than others.
391    ///
392    /// # Errors
393    ///
394    /// Returns an error if a GPU operation fails (buffer mapping, device lost, etc.).
395    ///
396    /// # Examples
397    ///
398    /// ```no_run
399    /// use wgsl_fft::GpuFft;
400    /// use num_complex::Complex;
401    ///
402    /// let fft = GpuFft::new().expect("GPU or CPU fallback required");
403    ///
404    /// // Single FFT (pass vector with one element)
405    /// let single_input = vec![vec![Complex::new(1.0, 0.0); 1024]];
406    /// let single_spectrum = fft.fft(&single_input).expect("FFT failed");
407    ///
408    /// // Batch FFT
409    /// let batch_inputs = vec![
410    ///     vec![Complex::new(1.0, 0.0); 1024],
411    ///     vec![Complex::new(0.5, 0.0); 1024],
412    /// ];
413    /// let batch_spectra = fft.fft(&batch_inputs).expect("Batch FFT failed");
414    ///
415    /// // Arbitrary size FFT (not power of two)
416    /// let arbitrary_input = vec![vec![Complex::new(1.0, 0.0); 150]];
417    /// let arbitrary_spectrum = fft.fft(&arbitrary_input).expect("Arbitrary size FFT failed");
418    /// ```
419    pub fn fft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>> {
420        self.transform_batch_internal(inputs, false)
421    }
422
423    /// Compute the inverse FFT for a batch of input vectors.
424    ///
425    /// Processes multiple IFFTs efficiently. For single vector processing,
426    /// pass a vector containing one input vector.
427    /// All input vectors must have the same length.
428    /// The output is automatically scaled by `1/N` to maintain the unitary transform property.
429    ///
430    /// For power-of-two sizes, uses the fast Stockham Radix-4/2 algorithm.
431    /// For arbitrary sizes, uses Bluestein's algorithm.
432    ///
433    /// # Arguments
434    ///
435    /// * `inputs` - A vector of input vectors, each containing complex samples.
436    ///
437    /// # Returns
438    ///
439    /// A vector of IFFT results, one for each input vector.
440    ///
441    /// # Panics
442    ///
443    /// Panics if any input vector is empty or has a different length than others.
444    ///
445    /// # Errors
446    ///
447    /// Returns an error if a GPU operation fails (buffer mapping, device lost, etc.).
448    ///
449    /// # Examples
450    ///
451    /// ```no_run
452    /// use wgsl_fft::GpuFft;
453    /// use num_complex::Complex;
454    ///
455    /// let fft = GpuFft::new().expect("GPU or CPU fallback required");
456    ///
457    /// // Single IFFT (pass vector with one element)
458    /// let single_spectrum = vec![vec![Complex::new(1.0, 0.0); 1024]];
459    /// let single_reconstructed = fft.ifft(&single_spectrum).expect("IFFT failed");
460    ///
461    /// // Batch IFFT
462    /// let batch_spectra = vec![
463    ///     vec![Complex::new(1.0, 0.0); 1024],
464    ///     vec![Complex::new(0.5, 0.0); 1024],
465    /// ];
466    /// let batch_reconstructed = fft.ifft(&batch_spectra).expect("Batch IFFT failed");
467    ///
468    /// // Arbitrary size IFFT (not power of two)
469    /// let arbitrary_spectrum = vec![vec![Complex::new(1.0, 0.0); 150]];
470    /// let arbitrary_reconstructed = fft.ifft(&arbitrary_spectrum).expect("Arbitrary size IFFT failed");
471    /// ```
472    pub fn ifft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>> {
473        self.transform_batch_internal(inputs, true)
474    }
475
476    /// Validate that the input size is non-zero.
477    /// Arbitrary sizes are now supported via Bluestein's algorithm.
478    pub fn validate_input_size(&self, n: usize) -> Result<()> {
479        if n == 0 {
480            return Err(FftError::ValidationError(
481                "Transform length must be non-zero".to_string(),
482            ));
483        }
484        Ok(())
485    }
486
487    /// Check if a size is a power of two.
488    pub fn is_power_of_two(n: usize) -> bool {
489        n > 0 && (n & (n - 1)) == 0
490    }
491
492    /// Internal batch transform implementation that handles both FFT and IFFT for multiple inputs.
493    ///
494    /// When `inverse` is true, computes IFFT (with conjugation and 1/N scaling).
495    /// When `inverse` is false, computes standard FFT.
496    ///
497    /// For power-of-two sizes, uses the Stockham Radix-4/2 algorithm.
498    /// For arbitrary sizes, uses Bluestein's algorithm.
499    pub fn transform_batch_internal(
500        &self,
501        inputs: &[Vec<Complex<f32>>],
502        inverse: bool,
503    ) -> Result<Vec<Vec<Complex<f32>>>> {
504        if inputs.is_empty() {
505            return Ok(Vec::new());
506        }
507
508        self.validate_batch_inputs(inputs)?;
509
510        let n = inputs[0].len();
511        let batch_size = inputs.len() as u32;
512
513        if Self::is_power_of_two(n) {
514            return self.transform_power_of_two(inputs, inverse, n, batch_size);
515        }
516
517        self.transform_batch_bluestein(inputs, inverse)
518    }
519
520    /// Validate that all inputs in a batch have the same size.
521    fn validate_batch_inputs(&self, inputs: &[Vec<Complex<f32>>]) -> Result<()> {
522        let n = inputs[0].len();
523
524        for input in inputs {
525            if input.len() != n {
526                return Err(FftError::BatchError(
527                    "All input vectors in a batch must have the same length".to_string(),
528                ));
529            }
530            self.validate_input_size(input.len())?;
531        }
532
533        Ok(())
534    }
535
536    /// Transform batch for power-of-two sizes using Stockham Radix-4/2.
537    fn transform_power_of_two(
538        &self,
539        inputs: &[Vec<Complex<f32>>],
540        inverse: bool,
541        n: usize,
542        batch_size: u32,
543    ) -> Result<Vec<Vec<Complex<f32>>>> {
544        let log_n = n.trailing_zeros();
545        let sc = self.get_or_build_size_cache(n, log_n);
546
547        let all_raw_data = self.prepare_batch_input_data(inputs, inverse);
548
549        self.upload_batch_data(&sc, &all_raw_data);
550        self.execute_compute_pass(&sc, batch_size, n);
551
552        let mut output = self.readback_results(&sc, batch_size, n)?;
553
554        if inverse {
555            self.apply_inverse_postprocessing(&mut output, n);
556        }
557
558        Ok(self.split_results(output, n))
559    }
560
561    /// Prepare input data for all inputs in a batch.
562    fn prepare_batch_input_data(&self, inputs: &[Vec<Complex<f32>>], inverse: bool) -> Vec<f32> {
563        let batch_size = inputs.len();
564        let n = inputs[0].len();
565
566        let mut all_raw_data = Vec::with_capacity(n * COMPLEX_COMPONENT_COUNT * batch_size);
567
568        for input in inputs {
569            let raw = self.prepare_input_data(input, inverse);
570            all_raw_data.extend_from_slice(&raw);
571        }
572
573        all_raw_data
574    }
575
576    /// Upload batch data to GPU buffer.
577    fn upload_batch_data(&self, sc: &SizeCache, data: &[f32]) {
578        self.queue
579            .write_buffer(&sc.buf_a, 0, bytemuck::cast_slice(data));
580    }
581
582    /// Apply inverse transform postprocessing to all chunks.
583    fn apply_inverse_postprocessing(&self, output: &mut [Complex<f32>], n: usize) {
584        for chunk in output.chunks_mut(n) {
585            self.apply_inverse_transform_postprocessing(chunk, n);
586        }
587    }
588
589    /// Split output into individual results.
590    fn split_results(&self, output: Vec<Complex<f32>>, n: usize) -> Vec<Vec<Complex<f32>>> {
591        output.chunks(n).map(|chunk| chunk.to_vec()).collect()
592    }
593
594    /// Get the result buffer based on whether result is in buffer B.
595    fn get_result_buffer<'a>(&self, sc: &'a SizeCache) -> &'a wgpu::Buffer {
596        if sc.result_in_b {
597            return &sc.buf_b;
598        }
599        &sc.buf_a
600    }
601
602    /// Calculate number of R4 stages.
603    fn calculate_num_r4_stages(&self, is_r4_mode: bool, log_n: u32) -> usize {
604        if is_r4_mode {
605            return (log_n / 2) as usize;
606        }
607        0
608    }
609
610    /// Calculate total number of stages.
611    fn calculate_total_stages(
612        &self,
613        is_r4_mode: bool,
614        num_r4: usize,
615        has_r2: bool,
616        log_n: u32,
617    ) -> usize {
618        if is_r4_mode {
619            return num_r4 + has_r2 as usize;
620        }
621        log_n as usize
622    }
623
624    /// Calculate twiddle table count.
625    fn calculate_twiddle_count(&self, is_r4_mode: bool, n: usize) -> usize {
626        if is_r4_mode {
627            return n;
628        }
629        n / 2
630    }
631
632    /// Transform using Bluestein's algorithm for arbitrary FFT sizes.
633    ///
634    /// Bluestein's algorithm converts a non-power-of-two FFT of size N into a
635    /// convolution of size M >= 2N-1, where M is a power of two.
636    ///
637    /// Formula: X[k] = exp(-πi*k²/N) * Σ_m (x[m]*exp(-πi*m²/N)) * exp(πi*(k-m)²/N)
638    fn transform_batch_bluestein(
639        &self,
640        inputs: &[Vec<Complex<f32>>],
641        inverse: bool,
642    ) -> Result<Vec<Vec<Complex<f32>>>> {
643        if inputs.is_empty() {
644            return Ok(Vec::new());
645        }
646
647        let n = inputs[0].len();
648        let batch_size = inputs.len();
649        let m = self.next_power_of_two(2 * n - 1);
650
651        // a_angle = +/- pi * i^2 / n
652        // b_angle = -a_angle
653        // post_angle = a_angle
654        let a_angle_sign = if inverse { 1.0 } else { -1.0 };
655        let b_angle_sign = -a_angle_sign;
656
657        // 1. Get or compute B_fft = FFT(b_pad)
658        let b_fft = {
659            let mut cache = self.bluestein_cache.borrow_mut();
660            if let Some(cached) = cache.get(&(n, inverse)) {
661                cached.clone()
662            } else {
663                let mut b = vec![Complex::new(0.0, 0.0); m];
664                for i in 0..n {
665                    let angle =
666                        b_angle_sign * std::f64::consts::PI * (i as f64 * i as f64) / n as f64;
667                    let chirp = Complex::new(angle.cos() as f32, angle.sin() as f32);
668                    b[i] = chirp;
669                    if i > 0 {
670                        b[m - i] = chirp;
671                    }
672                }
673                let b_fft_res = self.transform_power_of_two(&[b], false, m, 1)?[0].clone();
674                cache.insert((n, inverse), b_fft_res.clone());
675                b_fft_res
676            }
677        };
678
679        // 2. Prepare batch of a_pad: a[i] = input[i] * exp(a_angle_sign * pi * i^2 / n)
680        let mut a_batch = Vec::with_capacity(batch_size);
681        for input in inputs {
682            let mut a = vec![Complex::new(0.0, 0.0); m];
683            for i in 0..n {
684                let angle = a_angle_sign * std::f64::consts::PI * (i as f64 * i as f64) / n as f64;
685                let chirp = Complex::new(angle.cos() as f32, angle.sin() as f32);
686                a[i] = input[i] * chirp;
687            }
688            a_batch.push(a);
689        }
690
691        // 3. A_fft = FFT(a_batch)
692        let a_fft_batch = self.transform_power_of_two(&a_batch, false, m, batch_size as u32)?;
693
694        // 4. Multiply by B_fft and prepare for IFFT
695        let mut c_fft_batch = Vec::with_capacity(batch_size);
696        for a_fft in a_fft_batch {
697            let mut c_fft = vec![Complex::new(0.0, 0.0); m];
698            for i in 0..m {
699                c_fft[i] = a_fft[i] * b_fft[i];
700            }
701            c_fft_batch.push(c_fft);
702        }
703
704        // 5. c = IFFT(c_fft_batch)
705        let c_batch = self.transform_power_of_two(&c_fft_batch, true, m, batch_size as u32)?;
706
707        // 6. Post-process: result[k] = c[k] * exp(a_angle_sign * pi * k^2 / n)
708        let mut results = Vec::with_capacity(batch_size);
709        let scale = if inverse { 1.0 / n as f32 } else { 1.0 };
710        for c in c_batch {
711            let mut result = vec![Complex::new(0.0, 0.0); n];
712            for i in 0..n {
713                let angle = a_angle_sign * std::f64::consts::PI * (i as f64 * i as f64) / n as f64;
714                let chirp = Complex::new(angle.cos() as f32, angle.sin() as f32);
715                result[i] = c[i] * chirp * scale;
716            }
717            results.push(result);
718        }
719
720        Ok(results)
721    }
722
723    /// Find the next power of two >= n.
724    fn next_power_of_two(&self, n: usize) -> usize {
725        if n <= 1 {
726            return 1;
727        }
728        let mut p = 1usize;
729        while p < n {
730            p *= 2;
731        }
732        p
733    }
734
735    /// Prepare input data for GPU processing, applying conjugation for IFFT if needed.
736    pub fn prepare_input_data(&self, input: &[Complex<f32>], inverse: bool) -> Vec<f32> {
737        if inverse {
738            return input.iter().flat_map(|c| [c.re, -c.im]).collect();
739        }
740        input.iter().flat_map(|c| [c.re, c.im]).collect()
741    }
742
743    /// Execute the compute shader pass.
744    pub fn execute_compute_pass(&self, sc: &SizeCache, batch_size: u32, n: usize) {
745        let mut enc = self
746            .device
747            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
748                label: Some("FFT Pass"),
749            });
750
751        self.run_compute_pass(&mut enc, sc, batch_size);
752
753        let result_buf = self.get_result_buffer(sc);
754        let single_fft_bytes = (n * COMPLEX_COMPONENT_COUNT * F32_BYTE_SIZE) as u64;
755
756        enc.copy_buffer_to_buffer(
757            result_buf,
758            0,
759            &sc.staging_buf,
760            0,
761            single_fft_bytes * batch_size as u64,
762        );
763
764        self.queue.submit(std::iter::once(enc.finish()));
765    }
766
767    /// Run compute pass on encoder.
768    fn run_compute_pass(&self, enc: &mut wgpu::CommandEncoder, sc: &SizeCache, batch_size: u32) {
769        let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
770            label: Some("FFT Compute"),
771            timestamp_writes: None,
772        });
773
774        if sc.wg_r4 > 0 {
775            self.dispatch_r4_mode_pass(&mut pass, sc, batch_size);
776            return;
777        }
778
779        self.dispatch_legacy_mode_pass(&mut pass, sc, batch_size);
780    }
781
782    /// Dispatch R4 mode compute pass.
783    fn dispatch_r4_mode_pass(&self, pass: &mut wgpu::ComputePass, sc: &SizeCache, batch_size: u32) {
784        pass.set_pipeline(&self.pipeline);
785
786        for bg in &sc.stage_bgs {
787            pass.set_bind_group(0, bg, &[]);
788            pass.dispatch_workgroups(sc.wg_r4, batch_size, 1);
789        }
790
791        if let Some(r2_bg) = &sc.stage_bg_r2 {
792            self.dispatch_r2_stage_pass(pass, r2_bg, sc, batch_size);
793        }
794    }
795
796    /// Dispatch R2 stage pass.
797    fn dispatch_r2_stage_pass(
798        &self,
799        pass: &mut wgpu::ComputePass,
800        r2_bg: &wgpu::BindGroup,
801        sc: &SizeCache,
802        batch_size: u32,
803    ) {
804        pass.set_pipeline(self.pipeline_r2.as_ref().unwrap());
805        pass.set_bind_group(0, r2_bg, &[]);
806        pass.dispatch_workgroups(sc.wg_n2, batch_size, 1);
807    }
808
809    /// Dispatch legacy mode compute pass.
810    fn dispatch_legacy_mode_pass(
811        &self,
812        pass: &mut wgpu::ComputePass,
813        sc: &SizeCache,
814        batch_size: u32,
815    ) {
816        pass.set_pipeline(&self.pipeline);
817
818        for bg in &sc.stage_bgs {
819            pass.set_bind_group(0, bg, &[]);
820            pass.dispatch_workgroups(sc.wg_n2, batch_size, 1);
821        }
822    }
823
824    /// Read back results from GPU and convert to complex numbers.
825    pub fn readback_results(
826        &self,
827        sc: &SizeCache,
828        batch_size: u32,
829        n: usize,
830    ) -> Result<Vec<Complex<f32>>> {
831        // Readback
832        let single_fft_bytes = (n * COMPLEX_COMPONENT_COUNT * F32_BYTE_SIZE) as u64;
833        let total_bytes = single_fft_bytes * batch_size as u64;
834        let slice = sc.staging_buf.slice(0..total_bytes);
835        slice.map_async(wgpu::MapMode::Read, |_| {});
836        self.device.poll(wgpu::PollType::Wait {
837            submission_index: None,
838            timeout: None,
839        })?;
840
841        let mapped = slice.get_mapped_range();
842        let floats: &[f32] = bytemuck::cast_slice(&mapped);
843        let output: Vec<Complex<f32>> = floats
844            .chunks_exact(2)
845            .map(|p| Complex { re: p[0], im: p[1] })
846            .collect();
847
848        drop(mapped);
849        sc.staging_buf.unmap();
850
851        Ok(output)
852    }
853
854    /// Apply postprocessing for inverse transform (conjugation and 1/N scaling).
855    pub fn apply_inverse_transform_postprocessing(&self, output: &mut [Complex<f32>], n: usize) {
856        let scale = 1.0 / n as f32;
857        for c in output {
858            *c = Complex {
859                re: c.re * scale,
860                im: -c.im * scale,
861            };
862        }
863    }
864
865    /// Get or build size-specific GPU resources.
866    pub fn get_or_build_size_cache(&self, n: usize, log_n: u32) -> SizeCache {
867        let mut cache = self.cache.borrow_mut();
868        if let Some(sc) = cache.get(&n) {
869            return sc.clone();
870        }
871
872        let sc = self.build_size_cache(n, log_n);
873        cache.insert(n, sc.clone());
874        sc
875    }
876
877    /// Build GPU buffers and bind groups for a specific FFT size.
878    pub fn build_size_cache(&self, n: usize, log_n: u32) -> SizeCache {
879        let is_r4_mode = self.pipeline_r2.is_some();
880
881        let num_r4 = self.calculate_num_r4_stages(is_r4_mode, log_n);
882        let has_r2 = is_r4_mode && log_n % 2 == 1;
883        let total_stages = self.calculate_total_stages(is_r4_mode, num_r4, has_r2, log_n);
884
885        let single_fft_bytes = n as u64 * 2 * std::mem::size_of::<f32>() as u64;
886        // Cap at 1024 to avoid excessive pre-allocation; hardware limits are often much larger.
887        let max_batch_size = (self.device.limits().max_storage_buffer_binding_size
888            / single_fft_bytes)
889            .min(1024) as u32;
890        let data_bytes = single_fft_bytes * max_batch_size as u64;
891
892        let make_buf = |label| {
893            self.device.create_buffer(&wgpu::BufferDescriptor {
894                label: Some(label),
895                size: data_bytes,
896                usage: wgpu::BufferUsages::STORAGE
897                    | wgpu::BufferUsages::COPY_SRC
898                    | wgpu::BufferUsages::COPY_DST,
899                mapped_at_creation: false,
900            })
901        };
902
903        let buf_a = make_buf("fft_buf_a");
904        let buf_b = make_buf("fft_buf_b");
905        let staging_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
906            label: Some("fft_staging"),
907            size: data_bytes,
908            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
909            mapped_at_creation: false,
910        });
911
912        // Twiddle table: N entries for R4 mode (max accessed index = 3N/2−5 < 2N),
913        // N/2 entries for legacy R2 mode (max accessed index = N−2 < N).
914        let twiddle_count = self.calculate_twiddle_count(is_r4_mode, n);
915        let twiddles: Vec<f32> = (0..twiddle_count)
916            .flat_map(|j| {
917                let angle = -std::f64::consts::TAU * (j as f64) / (n as f64);
918                [angle.cos() as f32, angle.sin() as f32]
919            })
920            .collect();
921        let twiddle_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
922            label: Some("fft_twiddles"),
923            size: (twiddles.len() * std::mem::size_of::<f32>()) as u64,
924            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
925            mapped_at_creation: false,
926        });
927        self.queue
928            .write_buffer(&twiddle_buf, 0, bytemuck::cast_slice(&twiddles));
929
930        let alignment = self.device.limits().min_uniform_buffer_offset_alignment as u64;
931        let entry_bytes = std::mem::size_of::<FftUniforms>() as u64;
932        let stride = entry_bytes.div_ceil(alignment) * alignment;
933
934        let uniform_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
935            label: Some("fft_uniforms"),
936            size: stride * total_stages.max(1) as u64,
937            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
938            mapped_at_creation: false,
939        });
940
941        let uniform_size = NonZeroU64::new(entry_bytes);
942        let layout_r4 = self.pipeline.get_bind_group_layout(0);
943        let layout_r2_opt = self
944            .pipeline_r2
945            .as_ref()
946            .map(|p| p.get_bind_group_layout(0));
947
948        let make_bg_with_layout = |layout: &wgpu::BindGroupLayout,
949                                   src: &wgpu::Buffer,
950                                   dst: &wgpu::Buffer,
951                                   uniform_offset: u64| {
952            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
953                label: None,
954                layout,
955                entries: &[
956                    wgpu::BindGroupEntry {
957                        binding: 0,
958                        resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
959                            buffer: &uniform_buf,
960                            offset: uniform_offset,
961                            size: uniform_size,
962                        }),
963                    },
964                    wgpu::BindGroupEntry {
965                        binding: 1,
966                        resource: src.as_entire_binding(),
967                    },
968                    wgpu::BindGroupEntry {
969                        binding: 2,
970                        resource: dst.as_entire_binding(),
971                    },
972                    wgpu::BindGroupEntry {
973                        binding: 3,
974                        resource: twiddle_buf.as_entire_binding(),
975                    },
976                ],
977            })
978        };
979
980        let make_bg = |src: &wgpu::Buffer, dst: &wgpu::Buffer, uniform_offset: u64| {
981            make_bg_with_layout(&layout_r4, src, dst, uniform_offset)
982        };
983
984        if is_r4_mode {
985            // R4 mode: ⌊log₄N⌋ Radix-4 stages + optional Radix-2
986            for s in 0..num_r4 {
987                let p = 1u32 << (s as u32 * 2);
988                self.queue.write_buffer(
989                    &uniform_buf,
990                    stride * s as u64,
991                    bytemuck::bytes_of(&FftUniforms {
992                        n: n as u32,
993                        stage: p,
994                        log_n,
995                        _pad: 0,
996                    }),
997                );
998            }
999            if has_r2 {
1000                let p = 1u32 << (num_r4 as u32 * 2);
1001                self.queue.write_buffer(
1002                    &uniform_buf,
1003                    stride * num_r4 as u64,
1004                    bytemuck::bytes_of(&FftUniforms {
1005                        n: n as u32,
1006                        stage: p,
1007                        log_n,
1008                        _pad: 0,
1009                    }),
1010                );
1011            }
1012
1013            let stage_bgs: Vec<wgpu::BindGroup> = (0..num_r4)
1014                .map(|s| {
1015                    let (src, dst) = if s % 2 == 0 {
1016                        (&buf_a, &buf_b)
1017                    } else {
1018                        (&buf_b, &buf_a)
1019                    };
1020                    make_bg(src, dst, stride * s as u64)
1021                })
1022                .collect();
1023
1024            let stage_bg_r2 = if has_r2 {
1025                let (src, dst) = if num_r4 % 2 == 0 {
1026                    (&buf_a, &buf_b)
1027                } else {
1028                    (&buf_b, &buf_a)
1029                };
1030                let layout_r2 = layout_r2_opt.as_ref().unwrap();
1031                Some(make_bg_with_layout(
1032                    layout_r2,
1033                    src,
1034                    dst,
1035                    stride * num_r4 as u64,
1036                ))
1037            } else {
1038                None
1039            };
1040
1041            SizeCache {
1042                buf_a,
1043                buf_b,
1044                staging_buf,
1045                twiddle_buf,
1046                data_bytes,
1047                stage_bgs,
1048                stage_bg_r2,
1049                result_in_b: total_stages % 2 == 1,
1050                wg_n2: (n as u32 / 2).div_ceil(256),
1051                wg_r4: (n as u32 / 4).div_ceil(256),
1052            }
1053        } else {
1054            // Legacy mode (with_shader): log₂N Radix-2 stages, stage-index uniforms
1055            for stage in 0..log_n {
1056                self.queue.write_buffer(
1057                    &uniform_buf,
1058                    stride * stage as u64,
1059                    bytemuck::bytes_of(&FftUniforms {
1060                        n: n as u32,
1061                        stage,
1062                        log_n,
1063                        _pad: 0,
1064                    }),
1065                );
1066            }
1067
1068            let stage_bgs = (0..log_n as usize)
1069                .map(|s| {
1070                    let (src, dst) = if s % 2 == 0 {
1071                        (&buf_a, &buf_b)
1072                    } else {
1073                        (&buf_b, &buf_a)
1074                    };
1075                    make_bg(src, dst, stride * s as u64)
1076                })
1077                .collect();
1078
1079            SizeCache {
1080                buf_a,
1081                buf_b,
1082                staging_buf,
1083                twiddle_buf,
1084                data_bytes,
1085                stage_bgs,
1086                stage_bg_r2: None,
1087                result_in_b: log_n % 2 == 1,
1088                wg_n2: (n as u32 / 2).div_ceil(256),
1089                wg_r4: 0,
1090            }
1091        }
1092    }
1093}
1094
1095impl Default for GpuFft {
1096    fn default() -> Self {
1097        Self::new().expect("No GPU available for default GpuFft instance")
1098    }
1099}
1100
1101#[cfg(test)]
1102mod tests {
1103    use super::*;
1104    use num_complex::Complex;
1105
1106    #[test]
1107    fn test_prepare_input_data_fft() {
1108        let fft = GpuFft::new().expect("Failed to create FFT instance");
1109        let input = vec![Complex::new(1.0, 2.0), Complex::new(3.0, 4.0)];
1110        let result = fft.prepare_input_data(&input, false);
1111        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0]);
1112    }
1113
1114    #[test]
1115    fn test_prepare_input_data_ifft() {
1116        let fft = GpuFft::new().expect("Failed to create FFT instance");
1117        let input = vec![Complex::new(1.0, 2.0), Complex::new(3.0, 4.0)];
1118        let result = fft.prepare_input_data(&input, true);
1119        assert_eq!(result, vec![1.0, -2.0, 3.0, -4.0]);
1120    }
1121
1122    #[test]
1123    fn test_apply_inverse_transform_postprocessing() {
1124        let fft = GpuFft::new().expect("Failed to create FFT instance");
1125        let mut output = vec![Complex::new(2.0, 4.0), Complex::new(6.0, 8.0)];
1126        fft.apply_inverse_transform_postprocessing(&mut output, 2);
1127        assert_eq!(output[0].re, 1.0);
1128        assert_eq!(output[0].im, -2.0);
1129        assert_eq!(output[1].re, 3.0);
1130        assert_eq!(output[1].im, -4.0);
1131    }
1132}