gpu_fft/ifft.rs
1use cubecl::prelude::*;
2use std::f32::consts::PI;
3
4use crate::WORKGROUP_SIZE;
5
6/// Performs an Inverse Fast Fourier Transform (IFFT) on the input data.
7///
8/// This kernel computes the IFFT of a given input array of complex numbers represented as
9/// separate real and imaginary parts. The IFFT is computed using the inverse of the Cooley-Tukey
10/// algorithm, which is efficient for large datasets.
11///
12/// # Parameters
13///
14/// - `input_real`: An array of real parts of complex numbers represented as lines of type
15/// `Line<F>`, where `F` is a floating-point type. The input array should contain `n` real
16/// components.
17/// - `input_imag`: An array of imaginary parts of complex numbers represented as lines of type
18/// `Line<F>`. The input array should contain `n` imaginary components.
19/// - `output`: A mutable reference to an array of lines where the output of the IFFT will be
20/// stored. The output will contain interleaved real and imaginary parts.
21/// - `n`: The number of complex samples in the input arrays. This value is provided at compile-time.
22///
23/// # Safety
24///
25/// This function is marked as `unsafe` because it performs raw pointer operations and assumes
26/// that the input and output arrays are correctly sized and aligned. The caller must ensure
27/// that the input data is valid and that the output array has sufficient space to store
28/// the results.
29///
30/// # Example
31///
32/// ```rust
33/// let input_real = vec![1.0, 0.0, 0.0, 0.0]; // Example real input
34/// let input_imag = vec![0.0, 0.0, 0.0, 0.0]; // Example imaginary input
35/// let output = ifft::<YourRuntimeType>(device, input_real, input_imag);
36/// ```
37///
38/// # Returns
39///
40/// This function does not return a value directly. Instead, it populates the `output` array
41/// with the interleaved real and imaginary parts of the IFFT result.
42#[cube(launch_unchecked)]
43fn ifft_kernel<F: Float>(
44 input_real: &Array<Line<F>>,
45 input_imag: &Array<Line<F>>,
46 output: &mut Array<Line<F>>,
47 #[comptime] n: u32,
48) {
49 let idx = ABSOLUTE_POS;
50 if idx < n {
51 let mut sum_real = Line::<F>::new(F::new(0.0));
52 let mut sum_imag = Line::<F>::new(F::new(0.0));
53
54 for k in 0..n {
55 let angle =
56 F::new(2.0) * F::new(PI) * F::cast_from(k) * F::cast_from(idx) / F::cast_from(n);
57 let (cos_angle, sin_angle) = (F::cos(angle), F::sin(angle));
58 sum_real += input_real[k] * Line::new(cos_angle) - input_imag[k] * Line::new(sin_angle);
59 sum_imag += input_real[k] * Line::new(sin_angle) + input_imag[k] * Line::new(cos_angle);
60 }
61
62 let n_line = Line::<F>::new(F::cast_from(n));
63
64 // Scale the output by 1/n
65 output[idx] = sum_real / n_line;
66 output[idx + n] = sum_imag / n_line;
67 }
68}
69
70/// Computes the Inverse Fast Fourier Transform (IFFT) of a vector of real and imaginary input data.
71///
72/// This function initializes the IFFT computation on the provided input vectors, launching
73/// the IFFT kernel to perform the transformation. The input data is expected to be in the
74/// form of separate real and imaginary components of complex numbers.
75///
76/// # Parameters
77///
78/// - `device`: A reference to the device on which the IFFT computation will be performed.
79/// - `input_real`: A vector of `f32` values representing the real parts of the input data.
80/// - `input_imag`: A vector of `f32` values representing the imaginary parts of the input data.
81///
82/// # Returns
83///
84/// A vector of `f32` values representing the interleaved output of the IFFT, which contains
85/// both the real and imaginary parts of the result.
86///
87/// # Example
88///
89/// ```rust
90/// let input_real = vec![1.0, 0.0, 0.0, 0.0]; // Example real input
91/// let input_imag = vec![0.0, 0.0, 0.0, 0.0]; // Example imaginary input
92/// let output = ifft::<YourRuntimeType>(device, input_real, input_imag);
93/// ```
94///
95/// # Safety
96///
97/// This function uses unsafe operations to interact with the underlying runtime and device.
98/// The caller must ensure that the input data is valid and that the device is properly set up
99/// for computation.
100pub fn ifft<R: Runtime>(
101 device: &R::Device,
102 input_real: Vec<f32>,
103 input_imag: Vec<f32>,
104) -> Vec<f32> {
105 let client = R::client(device);
106 let n = input_real.len();
107
108 let real_handle = client.create(f32::as_bytes(&input_real));
109 let imag_handle = client.create(f32::as_bytes(&input_imag));
110 let output_handle = client.empty(n * 2 * core::mem::size_of::<f32>()); // Assuming output is interleaved
111
112 let num_workgroups = (n as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
113
114 unsafe {
115 ifft_kernel::launch_unchecked::<f32, R>(
116 &client,
117 CubeCount::Static(num_workgroups, 1, 1),
118 CubeDim::new(WORKGROUP_SIZE, 1, 1),
119 ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
120 ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
121 ArrayArg::from_raw_parts::<f32>(&output_handle, n * 2, 1),
122 n as u32,
123 )
124 };
125
126 let output_bytes = client.read_one(output_handle.binding());
127 let output = f32::from_bytes(&output_bytes);
128
129 output.into()
130}