gpu_fft/fft.rs
1use cubecl::prelude::*;
2use std::f32::consts::PI;
3
4use crate::WORKGROUP_SIZE;
5
6/// Performs a Fast Fourier Transform (FFT) on the input data.
7///
8/// This kernel computes the FFT of a given input array of complex numbers represented as
9/// separate real and imaginary parts. The FFT is computed using the Cooley-Tukey algorithm,
10/// which is efficient for large datasets.
11///
12/// # Parameters
13///
14/// - `input`: An array of complex numbers represented as lines of type `Line<F>`, where `F`
15/// is a floating-point type. The input array should contain `n` complex numbers.
16/// - `real_output`: A mutable reference to an array of lines where the real parts of the
17/// FFT output will be stored.
18/// - `imag_output`: A mutable reference to an array of lines where the imaginary parts of
19/// the FFT output will be stored.
20/// - `n`: The number of complex samples in the input array. This value is provided at compile-time.
21///
22/// # Safety
23///
24/// This function is marked as `unsafe` because it performs raw pointer operations and assumes
25/// that the input and output arrays are correctly sized and aligned. The caller must ensure
26/// that the input data is valid and that the output arrays have sufficient space to store
27/// the results.
28///
29/// # Example
30///
31/// ```rust
32/// let input = vec![1.0, 0.0, 0.0, 0.0]; // Example input
33/// let (real, imag) = fft::<YourRuntimeType>(device, input);
34/// ```
35///
36/// # Returns
37///
38/// This function does not return a value directly. Instead, it populates the `output` array
39/// with the real and imaginary parts of the FFT result interleaved.
40#[cube(launch)]
41fn fft_kernel<F: Float>(input: &Array<Line<F>>, output: &mut Array<Line<F>>, #[comptime] n: u32) {
42 let idx = ABSOLUTE_POS;
43 if idx < n {
44 let mut real = Line::<F>::new(F::new(0.0));
45 let mut imag = Line::<F>::new(F::new(0.0));
46
47 // Precompute the angle increment
48 let angle_increment = -2.0 * PI / n as f32;
49
50 // #[unroll(true)]
51 for k in 0..n {
52 let angle = F::cast_from(angle_increment) * F::cast_from(k) * F::cast_from(idx);
53 let (cos_angle, sin_angle) = (F::cos(angle), F::sin(angle));
54
55 // Combine the multiplication and addition
56 real += input[k] * Line::new(cos_angle);
57 imag += input[k] * Line::new(sin_angle);
58 }
59
60 // Store the real and imaginary parts in an interleaved manner
61 output[idx * 2] = Line::new(F::cast_from(real)); // Real part
62 output[idx * 2 + 1] = Line::new(F::cast_from(imag)); // Imaginary part
63 }
64}
65
66/// Computes the Fast Fourier Transform (FFT) of a vector of f32 input data.
67///
68/// This function initializes the FFT computation on the provided input vector, launching
69/// the FFT kernel to perform the transformation. The input data is expected to be in the
70/// form of real numbers, which are treated as complex numbers with zero imaginary parts.
71///
72/// # Parameters
73///
74/// - `device`: A reference to the device on which the FFT computation will be performed.
75/// - `input`: A vector of `f32` values representing the real parts of the input data.
76///
77/// # Returns
78///
79/// A tuple containing two vectors:
80/// - A vector of `f32` values representing the real parts of the FFT output.
81/// - A vector of `f32` values representing the imaginary parts of the FFT output.
82///
83/// # Example
84///
85/// ```rust
86/// let input = vec![1.0, 0.0, 0.0, 0.0]; // Example input
87/// let (real, imag) = fft::<YourRuntimeType>(device, input);
88/// ```
89///
90/// # Safety
91///
92/// This function uses unsafe operations to interact with the underlying runtime and device.
93/// The caller must ensure that the input data is valid and that the device is properly set up
94/// for computation.
95pub fn fft<R: Runtime>(device: &R::Device, input: Vec<f32>) -> (Vec<f32>, Vec<f32>) {
96 let client = R::client(device);
97 let n = input.len();
98
99 let input_handle = client.create(f32::as_bytes(&input));
100 let output_handle = client.empty(n * 2 * core::mem::size_of::<f32>()); // Adjust for interleaved output
101
102 let num_workgroups = (n as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
103
104 unsafe {
105 fft_kernel::launch::<f32, R>(
106 &client,
107 CubeCount::Static(num_workgroups, 1, 1),
108 CubeDim::new(WORKGROUP_SIZE, 1, 1),
109 ArrayArg::from_raw_parts::<f32>(&input_handle, n, 1),
110 ArrayArg::from_raw_parts::<f32>(&output_handle, n * 2, 1), // Adjust for interleaved output
111 n as u32,
112 )
113 };
114
115 let output_bytes = client.read_one(output_handle.binding());
116 let output = f32::from_bytes(&output_bytes);
117
118 // Split the interleaved output into real and imaginary parts
119 let real: Vec<f32> = output.iter().step_by(2).cloned().collect();
120 let imag: Vec<f32> = output.iter().skip(1).step_by(2).cloned().collect();
121
122 // println!(
123 // "real {:?}..{:?}",
124 // &real[0..10],
125 // &real[real.len() - 10..real.len() - 1]
126 // );
127 // println!(
128 // "imag {:?}..{:?}",
129 // &imag[0..10],
130 // &imag[imag.len() - 10..imag.len() - 1]
131 // );
132
133 (real, imag)
134}