gpu_fft/
fft.rs

1use cubecl::prelude::*;
2use std::f32::consts::PI;
3
4use crate::WORKGROUP_SIZE;
5
6/// Performs a Fast Fourier Transform (FFT) on the input data.
7///
8/// This kernel computes the FFT of a given input array of complex numbers represented as
9/// separate real and imaginary parts. The FFT is computed using the Cooley-Tukey algorithm,
10/// which is efficient for large datasets.
11///
12/// # Parameters
13///
14/// - `input`: An array of complex numbers represented as lines of type `Line<F>`, where `F`
15///   is a floating-point type. The input array should contain `n` complex numbers.
16/// - `real_output`: A mutable reference to an array of lines where the real parts of the
17///   FFT output will be stored.
18/// - `imag_output`: A mutable reference to an array of lines where the imaginary parts of
19///   the FFT output will be stored.
20/// - `n`: The number of complex samples in the input array. This value is provided at compile-time.
21///
22/// # Safety
23///
24/// This function is marked as `unsafe` because it performs raw pointer operations and assumes
25/// that the input and output arrays are correctly sized and aligned. The caller must ensure
26/// that the input data is valid and that the output arrays have sufficient space to store
27/// the results.
28///
29/// # Example
30///
31/// ```rust
32/// let input = vec![1.0, 0.0, 0.0, 0.0]; // Example input
33/// let (real, imag) = fft::<YourRuntimeType>(device, input);
34/// ```
35///
36/// # Returns
37///
38/// This function does not return a value directly. Instead, it populates the `output` array
39/// with the real and imaginary parts of the FFT result interleaved.
40#[cube(launch)]
41fn fft_kernel<F: Float>(input: &Array<Line<F>>, output: &mut Array<Line<F>>, #[comptime] n: u32) {
42    let idx = ABSOLUTE_POS;
43    if idx < n {
44        let mut real = Line::<F>::new(F::new(0.0));
45        let mut imag = Line::<F>::new(F::new(0.0));
46
47        // Precompute the angle increment
48        let angle_increment = -2.0 * PI / n as f32;
49
50        // #[unroll(true)]
51        for k in 0..n {
52            let angle = F::cast_from(angle_increment) * F::cast_from(k) * F::cast_from(idx);
53            let (cos_angle, sin_angle) = (F::cos(angle), F::sin(angle));
54
55            // Combine the multiplication and addition
56            real += input[k] * Line::new(cos_angle);
57            imag += input[k] * Line::new(sin_angle);
58        }
59
60        // Store the real and imaginary parts in an interleaved manner
61        output[idx * 2] = Line::new(F::cast_from(real)); // Real part
62        output[idx * 2 + 1] = Line::new(F::cast_from(imag)); // Imaginary part
63    }
64}
65
66/// Computes the Fast Fourier Transform (FFT) of a vector of f32 input data.
67///
68/// This function initializes the FFT computation on the provided input vector, launching
69/// the FFT kernel to perform the transformation. The input data is expected to be in the
70/// form of real numbers, which are treated as complex numbers with zero imaginary parts.
71///
72/// # Parameters
73///
74/// - `device`: A reference to the device on which the FFT computation will be performed.
75/// - `input`: A vector of `f32` values representing the real parts of the input data.
76///
77/// # Returns
78///
79/// A tuple containing two vectors:
80/// - A vector of `f32` values representing the real parts of the FFT output.
81/// - A vector of `f32` values representing the imaginary parts of the FFT output.
82///
83/// # Example
84///
85/// ```rust
86/// let input = vec![1.0, 0.0, 0.0, 0.0]; // Example input
87/// let (real, imag) = fft::<YourRuntimeType>(device, input);
88/// ```
89///
90/// # Safety
91///
92/// This function uses unsafe operations to interact with the underlying runtime and device.
93/// The caller must ensure that the input data is valid and that the device is properly set up
94/// for computation.
95pub fn fft<R: Runtime>(device: &R::Device, input: Vec<f32>) -> (Vec<f32>, Vec<f32>) {
96    let client = R::client(device);
97    let n = input.len();
98
99    let input_handle = client.create(f32::as_bytes(&input));
100    let output_handle = client.empty(n * 2 * core::mem::size_of::<f32>()); // Adjust for interleaved output
101
102    let num_workgroups = (n as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
103
104    unsafe {
105        fft_kernel::launch::<f32, R>(
106            &client,
107            CubeCount::Static(num_workgroups, 1, 1),
108            CubeDim::new(WORKGROUP_SIZE, 1, 1),
109            ArrayArg::from_raw_parts::<f32>(&input_handle, n, 1),
110            ArrayArg::from_raw_parts::<f32>(&output_handle, n * 2, 1), // Adjust for interleaved output
111            n as u32,
112        )
113    };
114
115    let output_bytes = client.read_one(output_handle.binding());
116    let output = f32::from_bytes(&output_bytes);
117
118    // Split the interleaved output into real and imaginary parts
119    let real: Vec<f32> = output.iter().step_by(2).cloned().collect();
120    let imag: Vec<f32> = output.iter().skip(1).step_by(2).cloned().collect();
121
122    // println!(
123    //     "real {:?}..{:?}",
124    //     &real[0..10],
125    //     &real[real.len() - 10..real.len() - 1]
126    // );
127    // println!(
128    //     "imag {:?}..{:?}",
129    //     &imag[0..10],
130    //     &imag[imag.len() - 10..imag.len() - 1]
131    // );
132
133    (real, imag)
134}