gpu_fft/
twiddles.rs

1use cubecl::prelude::*;
2use std::f32::consts::PI;
3
4use crate::WORKGROUP_SIZE;
5
6#[cube(launch_unchecked)]
7fn precompute_twiddles<F: Float>(twiddles: &mut Array<Line<F>>, #[comptime] n: u32) {
8    let idx = ABSOLUTE_POS;
9    if idx < n {
10        for k in 0..n {
11            let angle =
12                F::new(-2.0) * F::new(PI) * F::cast_from(k) * F::cast_from(idx) / F::cast_from(n);
13            let (cos_angle, sin_angle) = (F::cos(angle), F::sin(angle));
14
15            // Store the twiddle factors in the array
16            twiddles[(k * n + idx) * 2] = Line::new(cos_angle);
17            twiddles[(k * n + idx) * 2 + 1] = Line::new(sin_angle);
18        }
19    }
20}
21
22#[cube(launch_unchecked)]
23fn fft_kernel<F: Float>(
24    input: &Array<Line<F>>,
25    twiddles: &Array<Line<F>>,
26    real_output: &mut Array<Line<F>>,
27    imag_output: &mut Array<Line<F>>,
28    #[comptime] n: u32,
29) {
30    let idx = ABSOLUTE_POS;
31    if idx < n {
32        let mut real = Line::<F>::new(F::new(0.0));
33        let mut imag = Line::<F>::new(F::new(0.0));
34
35        for k in 0..n {
36            let cos_angle = twiddles[(k * n + idx) * 2]; // Get the cosine from the even index
37            let sin_angle = twiddles[(k * n + idx) * 2 + 1]; // Get the sine from the odd index
38
39            real += input[k] * Line::cast_from(cos_angle);
40            imag += input[k] * Line::cast_from(sin_angle);
41        }
42
43        real_output[idx] = real;
44        imag_output[idx] = imag;
45    }
46}
47
48pub fn fft<R: Runtime>(device: &R::Device, input: Vec<f32>) -> (Vec<f32>, Vec<f32>) {
49    let client = R::client(device);
50    let n = input.len();
51
52    let input_handle = client.create(f32::as_bytes(&input));
53    let real_handle = client.empty(n * core::mem::size_of::<f32>());
54    let imag_handle = client.empty(n * core::mem::size_of::<f32>());
55    let twiddles_size = 2 * n * n;
56    let twiddles_handle = client.empty(twiddles_size * core::mem::size_of::<f32>());
57
58    let num_workgroups = (n as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
59
60    unsafe {
61        precompute_twiddles::launch_unchecked::<f32, R>(
62            &client,
63            CubeCount::Static(num_workgroups, 1, 1),
64            CubeDim::new(WORKGROUP_SIZE, 1, 1),
65            ArrayArg::from_raw_parts::<f32>(&twiddles_handle, twiddles_size, 1),
66            n as u32,
67        )
68    };
69
70    let twiddles_bytes = client.read_one(twiddles_handle.clone().binding());
71    let twiddles = f32::from_bytes(&twiddles_bytes);
72    println!("twiddles[{}] - {:#?}", twiddles.len(), &twiddles[0..10]);
73    // println!("twiddles[{}] - {:?}", twiddles.len(), twiddles);
74    // (vec![], vec![])
75
76    unsafe {
77        fft_kernel::launch_unchecked::<f32, R>(
78            &client,
79            CubeCount::Static(num_workgroups, 1, 1),
80            CubeDim::new(WORKGROUP_SIZE, 1, 1),
81            ArrayArg::from_raw_parts::<f32>(&input_handle, n, 1),
82            ArrayArg::from_raw_parts::<f32>(&twiddles_handle, twiddles_size, 1),
83            ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
84            ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
85            n as u32,
86        )
87    };
88
89    let real_bytes = client.read_one(real_handle.binding());
90    let real = f32::from_bytes(&real_bytes);
91
92    let imag_bytes = client.read_one(imag_handle.binding());
93    let imag = f32::from_bytes(&imag_bytes);
94
95    println!(
96        "real {:?}..{:?}",
97        &real[0..10],
98        &real[real.len() - 10..real.len() - 1]
99    );
100    println!(
101        "imag {:?}..{:?}",
102        &imag[0..10],
103        &imag[imag.len() - 10..imag.len() - 1]
104    );
105
106    (real.into(), imag.into())
107}