gpu_fft/fft.rs
1use cubecl::prelude::*;
2use std::f32::consts::PI;
3
4use crate::WORKGROUP_SIZE;
5
6/// Performs a Fast Fourier Transform (FFT) on the input data.
7///
8/// This kernel computes the FFT of a given input array of complex numbers represented as
9/// separate real and imaginary parts. The FFT is computed using the Cooley-Tukey algorithm,
10/// which is efficient for large datasets.
11///
12/// # Parameters
13///
14/// - `input`: An array of complex numbers represented as lines of type `Line<F>`, where `F`
15/// is a floating-point type. The input array should contain `n` complex numbers.
16/// - `real_output`: A mutable reference to an array of lines where the real parts of the
17/// FFT output will be stored.
18/// - `imag_output`: A mutable reference to an array of lines where the imaginary parts of
19/// the FFT output will be stored.
20/// - `n`: The number of complex samples in the input array. This value is provided at compile-time.
21///
22/// # Safety
23///
24/// This function is marked as `unsafe` because it performs raw pointer operations and assumes
25/// that the input and output arrays are correctly sized and aligned. The caller must ensure
26/// that the input data is valid and that the output arrays have sufficient space to store
27/// the results.
28///
29/// # Example
30///
31/// ```rust
32/// let input = vec![1.0, 0.0, 0.0, 0.0]; // Example input
33/// let (real, imag) = fft::<YourRuntimeType>(device, input);
34/// ```
35///
36/// # Returns
37///
38/// This function does not return a value directly. Instead, it populates the `real_output`
39/// and `imag_output` arrays with the real and imaginary parts of the FFT result, respectively.
40#[cube(launch_unchecked)]
41fn fft_kernel<F: Float>(
42 input: &Array<Line<F>>,
43 real_output: &mut Array<Line<F>>,
44 imag_output: &mut Array<Line<F>>,
45 #[comptime] n: u32,
46) {
47 let idx = ABSOLUTE_POS;
48 if idx < n {
49 let mut real = Line::<F>::new(F::new(0.0));
50 let mut imag = Line::<F>::new(F::new(0.0));
51
52 for k in 0..n {
53 let angle =
54 F::new(-2.0) * F::new(PI) * F::cast_from(k) * F::cast_from(idx) / F::cast_from(n);
55
56 let (cos_angle, sin_angle) = (F::cos(angle), F::sin(angle));
57 real += input[k] * Line::new(cos_angle);
58 imag += input[k] * Line::new(sin_angle);
59 }
60
61 real_output[idx] = real;
62 imag_output[idx] = imag;
63 }
64}
65
66/// Computes the Fast Fourier Transform (FFT) of a vector of f32 input data.
67///
68/// This function initializes the FFT computation on the provided input vector, launching
69/// the FFT kernel to perform the transformation. The input data is expected to be in the
70/// form of real numbers, which are treated as complex numbers with zero imaginary parts.
71///
72/// # Parameters
73///
74/// - `device`: A reference to the device on which the FFT computation will be performed.
75/// - `input`: A vector of `f32` values representing the real parts of the input data.
76///
77/// # Returns
78///
79/// A tuple containing two vectors:
80/// - A vector of `f32` values representing the real parts of the FFT output.
81/// - A vector of `f32` values representing the imaginary parts of the FFT output.
82///
83/// # Example
84///
85/// ```rust
86/// let input = vec![1.0, 0.0, 0.0, 0.0]; // Example input
87/// let (real, imag) = fft::<YourRuntimeType>(device, input);
88/// ```
89///
90/// # Safety
91///
92/// This function uses unsafe operations to interact with the underlying runtime and device.
93/// The caller must ensure that the input data is valid and that the device is properly set up
94/// for computation.
95pub fn fft<R: Runtime>(device: &R::Device, input: Vec<f32>) -> (Vec<f32>, Vec<f32>) {
96 let client = R::client(device);
97 let n = input.len();
98
99 let input_handle = client.create(f32::as_bytes(&input));
100
101 let real_handle = client.empty(n * core::mem::size_of::<f32>());
102 let imag_handle = client.empty(n * core::mem::size_of::<f32>());
103
104 let num_workgroups = (n as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
105
106 unsafe {
107 fft_kernel::launch_unchecked::<f32, R>(
108 &client,
109 CubeCount::Static(num_workgroups, 1, 1),
110 CubeDim::new(WORKGROUP_SIZE, 1, 1),
111 ArrayArg::from_raw_parts::<f32>(&input_handle, n, 1),
112 ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
113 ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
114 n as u32,
115 )
116 };
117
118 let real_bytes = client.read_one(real_handle.binding());
119 let real = f32::from_bytes(&real_bytes);
120
121 let imag_bytes = client.read_one(imag_handle.binding());
122 let imag = f32::from_bytes(&imag_bytes);
123
124 (real.into(), imag.into())
125}