gpu_fft/
lib.rs

1pub mod fft;
2pub mod ifft;
3pub mod psd;
4pub mod twiddles;
5pub mod utils;
6
7// The general advice for WebGPU is to choose a workgroup size of 64
8// Common sizes are 32, 64, 128, 256, or 512 threads per workgroup.
9// Apple Metal supports a maximum workgroup size of 1024 threads.
10pub(crate) const WORKGROUP_SIZE: u32 = 1024;
11
12#[cfg(feature = "wgpu")]
13type Runtime = cubecl::wgpu::WgpuRuntime;
14
15#[cfg(feature = "cuda")]
16type Runtime = cubecl::cuda::CudaRuntime;
17
18/// Computes the Fast Fourier Transform (FFT) of the given input vector.
19///
20/// This function takes a vector of real numbers as input and returns a tuple
21/// containing two vectors: the real and imaginary parts of the FFT result.
22///
23/// # Parameters
24///
25/// - `input`: A vector of `f32` representing the input signal in the time domain.
26///
27/// # Returns
28///
29/// A tuple containing two vectors:
30/// - A vector of `f32` representing the real part of the FFT output.
31/// - A vector of `f32` representing the imaginary part of the FFT output.
32///
33/// # Example
34///
35/// ```
36/// let input = vec![0.0, 1.0, 0.0, 0.0];
37/// let (real, imag) = fft(input);
38/// ```
39pub fn fft(input: Vec<f32>) -> (Vec<f32>, Vec<f32>) {
40    fft::fft::<Runtime>(&Default::default(), input)
41}
42
43/// Computes the Inverse Fast Fourier Transform (IFFT) of the given real and imaginary parts.
44///
45/// This function takes the real and imaginary parts of a frequency domain signal
46/// and returns the corresponding time domain signal as a vector of real numbers.
47///
48/// # Parameters
49///
50/// - `input_real`: A vector of `f32` representing the real part of the frequency domain signal.
51/// - `input_imag`: A vector of `f32` representing the imaginary part of the frequency domain signal.
52///
53/// # Returns
54///
55/// A vector of `f32` representing the reconstructed time domain signal.
56///
57/// # Example
58///
59/// ```
60/// let real = vec![0.0, 1.0, 0.0, 0.0];
61/// let imag = vec![0.0, 0.0, 0.0, 0.0];
62/// let time_domain = ifft(real, imag);
63/// ```
64pub fn ifft(input_real: Vec<f32>, input_imag: Vec<f32>) -> Vec<f32> {
65    ifft::ifft::<Runtime>(&Default::default(), input_real, input_imag)
66}