1use std::any::Any;
10use std::cell::RefCell;
11use std::num::NonZeroU64;
12
13use num_complex::Complex;
14
15use crate::error::{FftError, Result};
16use crate::shaders;
17
18const COMPLEX_COMPONENT_COUNT: usize = 2;
20
21const F32_BYTE_SIZE: usize = std::mem::size_of::<f32>();
23
24pub trait FftExecutor {
26 fn name(&self) -> &str;
27 fn fft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>>;
28 fn ifft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>>;
29
30 fn as_any(&self) -> &dyn Any;
32}
33
34pub trait GpuFftTrait {
36 fn benchmark_gpu_only(
39 &self,
40 sc: &SizeCache,
41 batch_size: u32,
42 n: usize,
43 warmup_iters: usize,
44 bench_iters: usize,
45 ) -> Result<f64>;
46
47 fn get_or_build_size_cache(&self, n: usize, log_n: u32) -> SizeCache;
49
50 fn prepare_input_data(&self, input: &[Complex<f32>], inverse: bool) -> Vec<f32>;
52
53 fn queue(&self) -> &wgpu::Queue;
55}
56
57#[derive(Clone, Debug)]
59pub struct SizeCache {
60 pub buf_a: wgpu::Buffer,
61 pub buf_b: wgpu::Buffer,
62 pub staging_buf: wgpu::Buffer,
63 pub twiddle_buf: wgpu::Buffer,
64 pub data_bytes: u64,
65 pub stage_bgs: Vec<wgpu::BindGroup>,
67 pub stage_bg_r2: Option<wgpu::BindGroup>,
69 pub result_in_b: bool,
70 pub wg_n2: u32,
72 pub wg_r4: u32,
74}
75
76#[repr(C)]
78#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
79pub struct FftUniforms {
80 pub n: u32,
81 pub stage: u32,
82 pub log_n: u32,
83 pub _pad: u32,
84}
85
86#[derive(Debug)]
94pub struct GpuFft {
95 pub device: wgpu::Device,
96 pub queue: wgpu::Queue,
97 pub pipeline: wgpu::ComputePipeline,
98 pub pipeline_r2: Option<wgpu::ComputePipeline>,
100 pub cache: RefCell<std::collections::HashMap<usize, SizeCache>>,
101 pub pipeline_bluestein_chirp: wgpu::ComputePipeline,
103 pub pipeline_bluestein_inv_chirp: wgpu::ComputePipeline,
104 pub pipeline_bluestein_zero_pad: wgpu::ComputePipeline,
105 pub bluestein_cache: RefCell<std::collections::HashMap<(usize, bool), Vec<Complex<f32>>>>,
107}
108
109impl FftExecutor for GpuFft {
110 fn name(&self) -> &str {
111 "Baseline (Stockham Radix-4/2)"
112 }
113
114 fn fft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>> {
115 self.transform_batch_internal(inputs, false)
116 }
117
118 fn ifft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>> {
119 self.transform_batch_internal(inputs, true)
120 }
121
122 fn as_any(&self) -> &dyn Any {
123 self
124 }
125}
126
127impl GpuFftTrait for GpuFft {
128 fn benchmark_gpu_only(
129 &self,
130 sc: &SizeCache,
131 batch_size: u32,
132 n: usize,
133 warmup_iters: usize,
134 bench_iters: usize,
135 ) -> Result<f64> {
136 use std::time::Instant;
137
138 for _ in 0..warmup_iters {
140 self.execute_compute_pass(sc, batch_size, n);
141 self.device.poll(wgpu::PollType::Wait {
142 submission_index: None,
143 timeout: None,
144 })?;
145 }
146
147 let start = Instant::now();
149 for _ in 0..bench_iters {
150 self.execute_compute_pass(sc, batch_size, n);
151 }
152
153 self.device.poll(wgpu::PollType::Wait {
154 submission_index: None,
155 timeout: None,
156 })?;
157
158 let duration = start.elapsed();
159 Ok(duration.as_secs_f64() / bench_iters as f64)
160 }
161
162 fn get_or_build_size_cache(&self, n: usize, log_n: u32) -> SizeCache {
163 self.get_or_build_size_cache(n, log_n)
164 }
165
166 fn prepare_input_data(&self, input: &[Complex<f32>], inverse: bool) -> Vec<f32> {
167 self.prepare_input_data(input, inverse)
168 }
169
170 fn queue(&self) -> &wgpu::Queue {
171 &self.queue
172 }
173}
174
175impl GpuFft {
176 pub fn device(&self) -> &wgpu::Device {
178 &self.device
179 }
180
181 pub fn compute_pipeline(&self) -> &wgpu::ComputePipeline {
183 &self.pipeline
184 }
185
186 pub fn new() -> Result<Self> {
202 let instance = wgpu::Instance::default();
203 let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
204 power_preference: wgpu::PowerPreference::HighPerformance,
205 compatible_surface: None,
206 force_fallback_adapter: false,
207 }))
208 .or_else(|_| {
209 pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
210 power_preference: wgpu::PowerPreference::HighPerformance,
211 compatible_surface: None,
212 force_fallback_adapter: true,
213 }))
214 })?;
215
216 let (device, queue) =
217 pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
218 ..Default::default()
219 }))?;
220 Self::from_device_queue(device, queue)
221 }
222
223 pub fn from_device_queue(device: wgpu::Device, queue: wgpu::Queue) -> Result<Self> {
233 let compile = |src: &str, label: &str| {
234 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
235 label: Some(label),
236 source: wgpu::ShaderSource::Wgsl(src.into()),
237 });
238 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
239 label: Some(&format!("{label}_pipeline")),
240 layout: None,
241 module: &shader,
242 entry_point: Some("main"),
243 compilation_options: Default::default(),
244 cache: None,
245 })
246 };
247
248 let pipeline = compile(shaders::R4_WGSL, "stockham_r4");
249 let pipeline_r2 = Some(compile(shaders::R2_WGSL, "stockham_r2"));
250
251 let pipeline_bluestein_chirp = compile(shaders::BLUESTEIN_CHIRP_WGSL, "bluestein_chirp");
253 let pipeline_bluestein_inv_chirp =
254 compile(shaders::BLUESTEIN_INV_CHIRP_WGSL, "bluestein_inv_chirp");
255 let pipeline_bluestein_zero_pad =
256 compile(shaders::BLUESTEIN_ZERO_PAD_WGSL, "bluestein_zero_pad");
257
258 Ok(Self {
259 device,
260 queue,
261 pipeline,
262 pipeline_r2,
263 cache: RefCell::new(std::collections::HashMap::new()),
264 pipeline_bluestein_chirp,
265 pipeline_bluestein_inv_chirp,
266 pipeline_bluestein_zero_pad,
267 bluestein_cache: RefCell::new(std::collections::HashMap::new()),
268 })
269 }
270
271 pub fn with_shader(wgsl_source: String, label: &str) -> Result<Self> {
274 let instance = wgpu::Instance::default();
275 let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
276 power_preference: wgpu::PowerPreference::HighPerformance,
277 compatible_surface: None,
278 force_fallback_adapter: false,
279 }))
280 .or_else(|_| {
281 pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
282 power_preference: wgpu::PowerPreference::HighPerformance,
283 compatible_surface: None,
284 force_fallback_adapter: true,
285 }))
286 })?;
287
288 let (device, queue) =
289 pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
290 ..Default::default()
291 }))?;
292 Self::with_shader_and_device(device, queue, wgsl_source, label)
293 }
294
295 pub fn with_shader_and_device(
306 device: wgpu::Device,
307 queue: wgpu::Queue,
308 wgsl_source: String,
309 label: &str,
310 ) -> Result<Self> {
311 let shader_mod = device.create_shader_module(wgpu::ShaderModuleDescriptor {
312 label: Some(label),
313 source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
314 });
315
316 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
317 label: Some(&format!("{}_pipeline", label)),
318 layout: None,
319 module: &shader_mod,
320 entry_point: Some("main"),
321 compilation_options: Default::default(),
322 cache: None,
323 });
324
325 let compile_bluestein = |src: &str, label: &str| {
327 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
328 label: Some(label),
329 source: wgpu::ShaderSource::Wgsl(src.into()),
330 });
331 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
332 label: Some(&format!("{label}_pipeline")),
333 layout: None,
334 module: &shader,
335 entry_point: Some("main"),
336 compilation_options: Default::default(),
337 cache: None,
338 })
339 };
340 let pipeline_bluestein_chirp =
341 compile_bluestein(shaders::BLUESTEIN_CHIRP_WGSL, "bluestein_chirp");
342 let pipeline_bluestein_inv_chirp =
343 compile_bluestein(shaders::BLUESTEIN_INV_CHIRP_WGSL, "bluestein_inv_chirp");
344 let pipeline_bluestein_zero_pad =
345 compile_bluestein(shaders::BLUESTEIN_ZERO_PAD_WGSL, "bluestein_zero_pad");
346
347 Ok(Self {
348 device,
349 queue,
350 pipeline,
351 pipeline_r2: None, cache: RefCell::new(std::collections::HashMap::new()),
353 pipeline_bluestein_chirp,
354 pipeline_bluestein_inv_chirp,
355 pipeline_bluestein_zero_pad,
356 bluestein_cache: RefCell::new(std::collections::HashMap::new()),
357 })
358 }
359
360 pub fn is_gpu_available() -> bool {
362 let instance = wgpu::Instance::default();
363 pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
364 power_preference: wgpu::PowerPreference::HighPerformance,
365 compatible_surface: None,
366 force_fallback_adapter: false,
367 }))
368 .is_ok()
369 }
370
371 pub fn fft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>> {
420 self.transform_batch_internal(inputs, false)
421 }
422
423 pub fn ifft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>> {
473 self.transform_batch_internal(inputs, true)
474 }
475
476 pub fn validate_input_size(&self, n: usize) -> Result<()> {
479 if n == 0 {
480 return Err(FftError::ValidationError(
481 "Transform length must be non-zero".to_string(),
482 ));
483 }
484 Ok(())
485 }
486
487 pub fn is_power_of_two(n: usize) -> bool {
489 n > 0 && (n & (n - 1)) == 0
490 }
491
492 pub fn transform_batch_internal(
500 &self,
501 inputs: &[Vec<Complex<f32>>],
502 inverse: bool,
503 ) -> Result<Vec<Vec<Complex<f32>>>> {
504 if inputs.is_empty() {
505 return Ok(Vec::new());
506 }
507
508 self.validate_batch_inputs(inputs)?;
509
510 let n = inputs[0].len();
511 let batch_size = inputs.len() as u32;
512
513 if Self::is_power_of_two(n) {
514 return self.transform_power_of_two(inputs, inverse, n, batch_size);
515 }
516
517 self.transform_batch_bluestein(inputs, inverse)
518 }
519
520 fn validate_batch_inputs(&self, inputs: &[Vec<Complex<f32>>]) -> Result<()> {
522 let n = inputs[0].len();
523
524 for input in inputs {
525 if input.len() != n {
526 return Err(FftError::BatchError(
527 "All input vectors in a batch must have the same length".to_string(),
528 ));
529 }
530 self.validate_input_size(input.len())?;
531 }
532
533 Ok(())
534 }
535
536 fn transform_power_of_two(
538 &self,
539 inputs: &[Vec<Complex<f32>>],
540 inverse: bool,
541 n: usize,
542 batch_size: u32,
543 ) -> Result<Vec<Vec<Complex<f32>>>> {
544 let log_n = n.trailing_zeros();
545 let sc = self.get_or_build_size_cache(n, log_n);
546
547 let all_raw_data = self.prepare_batch_input_data(inputs, inverse);
548
549 self.upload_batch_data(&sc, &all_raw_data);
550 self.execute_compute_pass(&sc, batch_size, n);
551
552 let mut output = self.readback_results(&sc, batch_size, n)?;
553
554 if inverse {
555 self.apply_inverse_postprocessing(&mut output, n);
556 }
557
558 Ok(self.split_results(output, n))
559 }
560
561 fn prepare_batch_input_data(&self, inputs: &[Vec<Complex<f32>>], inverse: bool) -> Vec<f32> {
563 let batch_size = inputs.len();
564 let n = inputs[0].len();
565
566 let mut all_raw_data = Vec::with_capacity(n * COMPLEX_COMPONENT_COUNT * batch_size);
567
568 for input in inputs {
569 let raw = self.prepare_input_data(input, inverse);
570 all_raw_data.extend_from_slice(&raw);
571 }
572
573 all_raw_data
574 }
575
576 fn upload_batch_data(&self, sc: &SizeCache, data: &[f32]) {
578 self.queue
579 .write_buffer(&sc.buf_a, 0, bytemuck::cast_slice(data));
580 }
581
582 fn apply_inverse_postprocessing(&self, output: &mut [Complex<f32>], n: usize) {
584 for chunk in output.chunks_mut(n) {
585 self.apply_inverse_transform_postprocessing(chunk, n);
586 }
587 }
588
589 fn split_results(&self, output: Vec<Complex<f32>>, n: usize) -> Vec<Vec<Complex<f32>>> {
591 output.chunks(n).map(|chunk| chunk.to_vec()).collect()
592 }
593
594 fn get_result_buffer<'a>(&self, sc: &'a SizeCache) -> &'a wgpu::Buffer {
596 if sc.result_in_b {
597 return &sc.buf_b;
598 }
599 &sc.buf_a
600 }
601
602 fn calculate_num_r4_stages(&self, is_r4_mode: bool, log_n: u32) -> usize {
604 if is_r4_mode {
605 return (log_n / 2) as usize;
606 }
607 0
608 }
609
610 fn calculate_total_stages(
612 &self,
613 is_r4_mode: bool,
614 num_r4: usize,
615 has_r2: bool,
616 log_n: u32,
617 ) -> usize {
618 if is_r4_mode {
619 return num_r4 + has_r2 as usize;
620 }
621 log_n as usize
622 }
623
624 fn calculate_twiddle_count(&self, is_r4_mode: bool, n: usize) -> usize {
626 if is_r4_mode {
627 return n;
628 }
629 n / 2
630 }
631
632 fn transform_batch_bluestein(
639 &self,
640 inputs: &[Vec<Complex<f32>>],
641 inverse: bool,
642 ) -> Result<Vec<Vec<Complex<f32>>>> {
643 if inputs.is_empty() {
644 return Ok(Vec::new());
645 }
646
647 let n = inputs[0].len();
648 let batch_size = inputs.len();
649 let m = self.next_power_of_two(2 * n - 1);
650
651 let a_angle_sign = if inverse { 1.0 } else { -1.0 };
655 let b_angle_sign = -a_angle_sign;
656
657 let b_fft = {
659 let mut cache = self.bluestein_cache.borrow_mut();
660 if let Some(cached) = cache.get(&(n, inverse)) {
661 cached.clone()
662 } else {
663 let mut b = vec![Complex::new(0.0, 0.0); m];
664 for i in 0..n {
665 let angle =
666 b_angle_sign * std::f64::consts::PI * (i as f64 * i as f64) / n as f64;
667 let chirp = Complex::new(angle.cos() as f32, angle.sin() as f32);
668 b[i] = chirp;
669 if i > 0 {
670 b[m - i] = chirp;
671 }
672 }
673 let b_fft_res = self.transform_power_of_two(&[b], false, m, 1)?[0].clone();
674 cache.insert((n, inverse), b_fft_res.clone());
675 b_fft_res
676 }
677 };
678
679 let mut a_batch = Vec::with_capacity(batch_size);
681 for input in inputs {
682 let mut a = vec![Complex::new(0.0, 0.0); m];
683 for i in 0..n {
684 let angle = a_angle_sign * std::f64::consts::PI * (i as f64 * i as f64) / n as f64;
685 let chirp = Complex::new(angle.cos() as f32, angle.sin() as f32);
686 a[i] = input[i] * chirp;
687 }
688 a_batch.push(a);
689 }
690
691 let a_fft_batch = self.transform_power_of_two(&a_batch, false, m, batch_size as u32)?;
693
694 let mut c_fft_batch = Vec::with_capacity(batch_size);
696 for a_fft in a_fft_batch {
697 let mut c_fft = vec![Complex::new(0.0, 0.0); m];
698 for i in 0..m {
699 c_fft[i] = a_fft[i] * b_fft[i];
700 }
701 c_fft_batch.push(c_fft);
702 }
703
704 let c_batch = self.transform_power_of_two(&c_fft_batch, true, m, batch_size as u32)?;
706
707 let mut results = Vec::with_capacity(batch_size);
709 let scale = if inverse { 1.0 / n as f32 } else { 1.0 };
710 for c in c_batch {
711 let mut result = vec![Complex::new(0.0, 0.0); n];
712 for i in 0..n {
713 let angle = a_angle_sign * std::f64::consts::PI * (i as f64 * i as f64) / n as f64;
714 let chirp = Complex::new(angle.cos() as f32, angle.sin() as f32);
715 result[i] = c[i] * chirp * scale;
716 }
717 results.push(result);
718 }
719
720 Ok(results)
721 }
722
723 fn next_power_of_two(&self, n: usize) -> usize {
725 if n <= 1 {
726 return 1;
727 }
728 let mut p = 1usize;
729 while p < n {
730 p *= 2;
731 }
732 p
733 }
734
735 pub fn prepare_input_data(&self, input: &[Complex<f32>], inverse: bool) -> Vec<f32> {
737 if inverse {
738 return input.iter().flat_map(|c| [c.re, -c.im]).collect();
739 }
740 input.iter().flat_map(|c| [c.re, c.im]).collect()
741 }
742
743 pub fn execute_compute_pass(&self, sc: &SizeCache, batch_size: u32, n: usize) {
745 let mut enc = self
746 .device
747 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
748 label: Some("FFT Pass"),
749 });
750
751 self.run_compute_pass(&mut enc, sc, batch_size);
752
753 let result_buf = self.get_result_buffer(sc);
754 let single_fft_bytes = (n * COMPLEX_COMPONENT_COUNT * F32_BYTE_SIZE) as u64;
755
756 enc.copy_buffer_to_buffer(
757 result_buf,
758 0,
759 &sc.staging_buf,
760 0,
761 single_fft_bytes * batch_size as u64,
762 );
763
764 self.queue.submit(std::iter::once(enc.finish()));
765 }
766
767 fn run_compute_pass(&self, enc: &mut wgpu::CommandEncoder, sc: &SizeCache, batch_size: u32) {
769 let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
770 label: Some("FFT Compute"),
771 timestamp_writes: None,
772 });
773
774 if sc.wg_r4 > 0 {
775 self.dispatch_r4_mode_pass(&mut pass, sc, batch_size);
776 return;
777 }
778
779 self.dispatch_legacy_mode_pass(&mut pass, sc, batch_size);
780 }
781
782 fn dispatch_r4_mode_pass(&self, pass: &mut wgpu::ComputePass, sc: &SizeCache, batch_size: u32) {
784 pass.set_pipeline(&self.pipeline);
785
786 for bg in &sc.stage_bgs {
787 pass.set_bind_group(0, bg, &[]);
788 pass.dispatch_workgroups(sc.wg_r4, batch_size, 1);
789 }
790
791 if let Some(r2_bg) = &sc.stage_bg_r2 {
792 self.dispatch_r2_stage_pass(pass, r2_bg, sc, batch_size);
793 }
794 }
795
796 fn dispatch_r2_stage_pass(
798 &self,
799 pass: &mut wgpu::ComputePass,
800 r2_bg: &wgpu::BindGroup,
801 sc: &SizeCache,
802 batch_size: u32,
803 ) {
804 pass.set_pipeline(self.pipeline_r2.as_ref().unwrap());
805 pass.set_bind_group(0, r2_bg, &[]);
806 pass.dispatch_workgroups(sc.wg_n2, batch_size, 1);
807 }
808
809 fn dispatch_legacy_mode_pass(
811 &self,
812 pass: &mut wgpu::ComputePass,
813 sc: &SizeCache,
814 batch_size: u32,
815 ) {
816 pass.set_pipeline(&self.pipeline);
817
818 for bg in &sc.stage_bgs {
819 pass.set_bind_group(0, bg, &[]);
820 pass.dispatch_workgroups(sc.wg_n2, batch_size, 1);
821 }
822 }
823
824 pub fn readback_results(
826 &self,
827 sc: &SizeCache,
828 batch_size: u32,
829 n: usize,
830 ) -> Result<Vec<Complex<f32>>> {
831 let single_fft_bytes = (n * COMPLEX_COMPONENT_COUNT * F32_BYTE_SIZE) as u64;
833 let total_bytes = single_fft_bytes * batch_size as u64;
834 let slice = sc.staging_buf.slice(0..total_bytes);
835 slice.map_async(wgpu::MapMode::Read, |_| {});
836 self.device.poll(wgpu::PollType::Wait {
837 submission_index: None,
838 timeout: None,
839 })?;
840
841 let mapped = slice.get_mapped_range();
842 let floats: &[f32] = bytemuck::cast_slice(&mapped);
843 let output: Vec<Complex<f32>> = floats
844 .chunks_exact(2)
845 .map(|p| Complex { re: p[0], im: p[1] })
846 .collect();
847
848 drop(mapped);
849 sc.staging_buf.unmap();
850
851 Ok(output)
852 }
853
854 pub fn apply_inverse_transform_postprocessing(&self, output: &mut [Complex<f32>], n: usize) {
856 let scale = 1.0 / n as f32;
857 for c in output {
858 *c = Complex {
859 re: c.re * scale,
860 im: -c.im * scale,
861 };
862 }
863 }
864
865 pub fn get_or_build_size_cache(&self, n: usize, log_n: u32) -> SizeCache {
867 let mut cache = self.cache.borrow_mut();
868 if let Some(sc) = cache.get(&n) {
869 return sc.clone();
870 }
871
872 let sc = self.build_size_cache(n, log_n);
873 cache.insert(n, sc.clone());
874 sc
875 }
876
877 pub fn build_size_cache(&self, n: usize, log_n: u32) -> SizeCache {
879 let is_r4_mode = self.pipeline_r2.is_some();
880
881 let num_r4 = self.calculate_num_r4_stages(is_r4_mode, log_n);
882 let has_r2 = is_r4_mode && log_n % 2 == 1;
883 let total_stages = self.calculate_total_stages(is_r4_mode, num_r4, has_r2, log_n);
884
885 let single_fft_bytes = n as u64 * 2 * std::mem::size_of::<f32>() as u64;
886 let max_batch_size = (self.device.limits().max_storage_buffer_binding_size
888 / single_fft_bytes)
889 .min(1024) as u32;
890 let data_bytes = single_fft_bytes * max_batch_size as u64;
891
892 let make_buf = |label| {
893 self.device.create_buffer(&wgpu::BufferDescriptor {
894 label: Some(label),
895 size: data_bytes,
896 usage: wgpu::BufferUsages::STORAGE
897 | wgpu::BufferUsages::COPY_SRC
898 | wgpu::BufferUsages::COPY_DST,
899 mapped_at_creation: false,
900 })
901 };
902
903 let buf_a = make_buf("fft_buf_a");
904 let buf_b = make_buf("fft_buf_b");
905 let staging_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
906 label: Some("fft_staging"),
907 size: data_bytes,
908 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
909 mapped_at_creation: false,
910 });
911
912 let twiddle_count = self.calculate_twiddle_count(is_r4_mode, n);
915 let twiddles: Vec<f32> = (0..twiddle_count)
916 .flat_map(|j| {
917 let angle = -std::f64::consts::TAU * (j as f64) / (n as f64);
918 [angle.cos() as f32, angle.sin() as f32]
919 })
920 .collect();
921 let twiddle_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
922 label: Some("fft_twiddles"),
923 size: (twiddles.len() * std::mem::size_of::<f32>()) as u64,
924 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
925 mapped_at_creation: false,
926 });
927 self.queue
928 .write_buffer(&twiddle_buf, 0, bytemuck::cast_slice(&twiddles));
929
930 let alignment = self.device.limits().min_uniform_buffer_offset_alignment as u64;
931 let entry_bytes = std::mem::size_of::<FftUniforms>() as u64;
932 let stride = entry_bytes.div_ceil(alignment) * alignment;
933
934 let uniform_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
935 label: Some("fft_uniforms"),
936 size: stride * total_stages.max(1) as u64,
937 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
938 mapped_at_creation: false,
939 });
940
941 let uniform_size = NonZeroU64::new(entry_bytes);
942 let layout_r4 = self.pipeline.get_bind_group_layout(0);
943 let layout_r2_opt = self
944 .pipeline_r2
945 .as_ref()
946 .map(|p| p.get_bind_group_layout(0));
947
948 let make_bg_with_layout = |layout: &wgpu::BindGroupLayout,
949 src: &wgpu::Buffer,
950 dst: &wgpu::Buffer,
951 uniform_offset: u64| {
952 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
953 label: None,
954 layout,
955 entries: &[
956 wgpu::BindGroupEntry {
957 binding: 0,
958 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
959 buffer: &uniform_buf,
960 offset: uniform_offset,
961 size: uniform_size,
962 }),
963 },
964 wgpu::BindGroupEntry {
965 binding: 1,
966 resource: src.as_entire_binding(),
967 },
968 wgpu::BindGroupEntry {
969 binding: 2,
970 resource: dst.as_entire_binding(),
971 },
972 wgpu::BindGroupEntry {
973 binding: 3,
974 resource: twiddle_buf.as_entire_binding(),
975 },
976 ],
977 })
978 };
979
980 let make_bg = |src: &wgpu::Buffer, dst: &wgpu::Buffer, uniform_offset: u64| {
981 make_bg_with_layout(&layout_r4, src, dst, uniform_offset)
982 };
983
984 if is_r4_mode {
985 for s in 0..num_r4 {
987 let p = 1u32 << (s as u32 * 2);
988 self.queue.write_buffer(
989 &uniform_buf,
990 stride * s as u64,
991 bytemuck::bytes_of(&FftUniforms {
992 n: n as u32,
993 stage: p,
994 log_n,
995 _pad: 0,
996 }),
997 );
998 }
999 if has_r2 {
1000 let p = 1u32 << (num_r4 as u32 * 2);
1001 self.queue.write_buffer(
1002 &uniform_buf,
1003 stride * num_r4 as u64,
1004 bytemuck::bytes_of(&FftUniforms {
1005 n: n as u32,
1006 stage: p,
1007 log_n,
1008 _pad: 0,
1009 }),
1010 );
1011 }
1012
1013 let stage_bgs: Vec<wgpu::BindGroup> = (0..num_r4)
1014 .map(|s| {
1015 let (src, dst) = if s % 2 == 0 {
1016 (&buf_a, &buf_b)
1017 } else {
1018 (&buf_b, &buf_a)
1019 };
1020 make_bg(src, dst, stride * s as u64)
1021 })
1022 .collect();
1023
1024 let stage_bg_r2 = if has_r2 {
1025 let (src, dst) = if num_r4 % 2 == 0 {
1026 (&buf_a, &buf_b)
1027 } else {
1028 (&buf_b, &buf_a)
1029 };
1030 let layout_r2 = layout_r2_opt.as_ref().unwrap();
1031 Some(make_bg_with_layout(
1032 layout_r2,
1033 src,
1034 dst,
1035 stride * num_r4 as u64,
1036 ))
1037 } else {
1038 None
1039 };
1040
1041 SizeCache {
1042 buf_a,
1043 buf_b,
1044 staging_buf,
1045 twiddle_buf,
1046 data_bytes,
1047 stage_bgs,
1048 stage_bg_r2,
1049 result_in_b: total_stages % 2 == 1,
1050 wg_n2: (n as u32 / 2).div_ceil(256),
1051 wg_r4: (n as u32 / 4).div_ceil(256),
1052 }
1053 } else {
1054 for stage in 0..log_n {
1056 self.queue.write_buffer(
1057 &uniform_buf,
1058 stride * stage as u64,
1059 bytemuck::bytes_of(&FftUniforms {
1060 n: n as u32,
1061 stage,
1062 log_n,
1063 _pad: 0,
1064 }),
1065 );
1066 }
1067
1068 let stage_bgs = (0..log_n as usize)
1069 .map(|s| {
1070 let (src, dst) = if s % 2 == 0 {
1071 (&buf_a, &buf_b)
1072 } else {
1073 (&buf_b, &buf_a)
1074 };
1075 make_bg(src, dst, stride * s as u64)
1076 })
1077 .collect();
1078
1079 SizeCache {
1080 buf_a,
1081 buf_b,
1082 staging_buf,
1083 twiddle_buf,
1084 data_bytes,
1085 stage_bgs,
1086 stage_bg_r2: None,
1087 result_in_b: log_n % 2 == 1,
1088 wg_n2: (n as u32 / 2).div_ceil(256),
1089 wg_r4: 0,
1090 }
1091 }
1092 }
1093}
1094
1095impl Default for GpuFft {
1096 fn default() -> Self {
1097 Self::new().expect("No GPU available for default GpuFft instance")
1098 }
1099}
1100
1101#[cfg(test)]
1102mod tests {
1103 use super::*;
1104 use num_complex::Complex;
1105
1106 #[test]
1107 fn test_prepare_input_data_fft() {
1108 let fft = GpuFft::new().expect("Failed to create FFT instance");
1109 let input = vec![Complex::new(1.0, 2.0), Complex::new(3.0, 4.0)];
1110 let result = fft.prepare_input_data(&input, false);
1111 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0]);
1112 }
1113
1114 #[test]
1115 fn test_prepare_input_data_ifft() {
1116 let fft = GpuFft::new().expect("Failed to create FFT instance");
1117 let input = vec![Complex::new(1.0, 2.0), Complex::new(3.0, 4.0)];
1118 let result = fft.prepare_input_data(&input, true);
1119 assert_eq!(result, vec![1.0, -2.0, 3.0, -4.0]);
1120 }
1121
1122 #[test]
1123 fn test_apply_inverse_transform_postprocessing() {
1124 let fft = GpuFft::new().expect("Failed to create FFT instance");
1125 let mut output = vec![Complex::new(2.0, 4.0), Complex::new(6.0, 8.0)];
1126 fft.apply_inverse_transform_postprocessing(&mut output, 2);
1127 assert_eq!(output[0].re, 1.0);
1128 assert_eq!(output[0].im, -2.0);
1129 assert_eq!(output[1].re, 3.0);
1130 assert_eq!(output[1].im, -4.0);
1131 }
1132}