gpu_fft/
ifft.rs

1use cubecl::prelude::*;
2use std::f32::consts::PI;
3
4use crate::WORKGROUP_SIZE;
5
6/// Performs an Inverse Fast Fourier Transform (IFFT) on the input data.
7///
8/// This kernel computes the IFFT of a given input array of complex numbers represented as
9/// separate real and imaginary parts. The IFFT is computed using the inverse of the Cooley-Tukey
10/// algorithm, which is efficient for large datasets.
11///
12/// # Parameters
13///
14/// - `input_real`: An array of real parts of complex numbers represented as lines of type
15///   `Line<F>`, where `F` is a floating-point type. The input array should contain `n` real
16///   components.
17/// - `input_imag`: An array of imaginary parts of complex numbers represented as lines of type
18///   `Line<F>`. The input array should contain `n` imaginary components.
19/// - `output`: A mutable reference to an array of lines where the output of the IFFT will be
20///   stored. The output will contain interleaved real and imaginary parts.
21/// - `n`: The number of complex samples in the input arrays. This value is provided at compile-time.
22///
23/// # Safety
24///
25/// This function is marked as `unsafe` because it performs raw pointer operations and assumes
26/// that the input and output arrays are correctly sized and aligned. The caller must ensure
27/// that the input data is valid and that the output array has sufficient space to store
28/// the results.
29///
30/// # Example
31///
32/// ```rust
33/// let input_real = vec![1.0, 0.0, 0.0, 0.0]; // Example real input
34/// let input_imag = vec![0.0, 0.0, 0.0, 0.0]; // Example imaginary input
35/// let output = ifft::<YourRuntimeType>(device, input_real, input_imag);
36/// ```
37///
38/// # Returns
39///
40/// This function does not return a value directly. Instead, it populates the `output` array
41/// with the interleaved real and imaginary parts of the IFFT result.
42#[cube(launch_unchecked)]
43fn ifft_kernel<F: Float>(
44    input_real: &Array<Line<F>>,
45    input_imag: &Array<Line<F>>,
46    output: &mut Array<Line<F>>,
47    #[comptime] n: u32,
48) {
49    let idx = ABSOLUTE_POS;
50    if idx < n {
51        let mut sum_real = Line::<F>::new(F::new(0.0));
52        let mut sum_imag = Line::<F>::new(F::new(0.0));
53
54        for k in 0..n {
55            let angle =
56                F::new(2.0) * F::new(PI) * F::cast_from(k) * F::cast_from(idx) / F::cast_from(n);
57            let (cos_angle, sin_angle) = (F::cos(angle), F::sin(angle));
58            sum_real += input_real[k] * Line::new(cos_angle) - input_imag[k] * Line::new(sin_angle);
59            sum_imag += input_real[k] * Line::new(sin_angle) + input_imag[k] * Line::new(cos_angle);
60        }
61
62        let n_line = Line::<F>::new(F::cast_from(n));
63
64        // Scale the output by 1/n
65        output[idx] = sum_real / n_line;
66        output[idx + n] = sum_imag / n_line;
67    }
68}
69
70/// Computes the Inverse Fast Fourier Transform (IFFT) of a vector of real and imaginary input data.
71///
72/// This function initializes the IFFT computation on the provided input vectors, launching
73/// the IFFT kernel to perform the transformation. The input data is expected to be in the
74/// form of separate real and imaginary components of complex numbers.
75///
76/// # Parameters
77///
78/// - `device`: A reference to the device on which the IFFT computation will be performed.
79/// - `input_real`: A vector of `f32` values representing the real parts of the input data.
80/// - `input_imag`: A vector of `f32` values representing the imaginary parts of the input data.
81///
82/// # Returns
83///
84/// A vector of `f32` values representing the interleaved output of the IFFT, which contains
85/// both the real and imaginary parts of the result.
86///
87/// # Example
88///
89/// ```rust
90/// let input_real = vec![1.0, 0.0, 0.0, 0.0]; // Example real input
91/// let input_imag = vec![0.0, 0.0, 0.0, 0.0]; // Example imaginary input
92/// let output = ifft::<YourRuntimeType>(device, input_real, input_imag);
93/// ```
94///
95/// # Safety
96///
97/// This function uses unsafe operations to interact with the underlying runtime and device.
98/// The caller must ensure that the input data is valid and that the device is properly set up
99/// for computation.
100pub fn ifft<R: Runtime>(
101    device: &R::Device,
102    input_real: Vec<f32>,
103    input_imag: Vec<f32>,
104) -> Vec<f32> {
105    let client = R::client(device);
106    let n = input_real.len();
107
108    let real_handle = client.create(f32::as_bytes(&input_real));
109    let imag_handle = client.create(f32::as_bytes(&input_imag));
110    let output_handle = client.empty(n * 2 * core::mem::size_of::<f32>()); // Assuming output is interleaved
111
112    let num_workgroups = (n as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
113
114    unsafe {
115        ifft_kernel::launch_unchecked::<f32, R>(
116            &client,
117            CubeCount::Static(num_workgroups, 1, 1),
118            CubeDim::new(WORKGROUP_SIZE, 1, 1),
119            ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
120            ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
121            ArrayArg::from_raw_parts::<f32>(&output_handle, n * 2, 1),
122            n as u32,
123        )
124    };
125
126    let output_bytes = client.read_one(output_handle.binding());
127    let output = f32::from_bytes(&output_bytes);
128
129    output.into()
130}