1use cubecl::prelude::*;
2use std::f32::consts::PI;
3
4use crate::WORKGROUP_SIZE;
5
6#[cube(launch_unchecked)]
7fn precompute_twiddles<F: Float>(twiddles: &mut Array<Line<F>>, #[comptime] n: u32) {
8 let idx = ABSOLUTE_POS;
9 if idx < n {
10 for k in 0..n {
11 let angle =
12 F::new(-2.0) * F::new(PI) * F::cast_from(k) * F::cast_from(idx) / F::cast_from(n);
13 let (cos_angle, sin_angle) = (F::cos(angle), F::sin(angle));
14
15 twiddles[(k * n + idx) * 2] = Line::new(cos_angle);
17 twiddles[(k * n + idx) * 2 + 1] = Line::new(sin_angle);
18 }
19 }
20}
21
22#[cube(launch_unchecked)]
23fn fft_kernel<F: Float>(
24 input: &Array<Line<F>>,
25 twiddles: &Array<Line<F>>,
26 real_output: &mut Array<Line<F>>,
27 imag_output: &mut Array<Line<F>>,
28 #[comptime] n: u32,
29) {
30 let idx = ABSOLUTE_POS;
31 if idx < n {
32 let mut real = Line::<F>::new(F::new(0.0));
33 let mut imag = Line::<F>::new(F::new(0.0));
34
35 for k in 0..n {
36 let cos_angle = twiddles[(k * n + idx) * 2]; let sin_angle = twiddles[(k * n + idx) * 2 + 1]; real += input[k] * Line::cast_from(cos_angle);
40 imag += input[k] * Line::cast_from(sin_angle);
41 }
42
43 real_output[idx] = real;
44 imag_output[idx] = imag;
45 }
46}
47
48pub fn fft<R: Runtime>(device: &R::Device, input: Vec<f32>) -> (Vec<f32>, Vec<f32>) {
49 let client = R::client(device);
50 let n = input.len();
51
52 let input_handle = client.create(f32::as_bytes(&input));
53 let real_handle = client.empty(n * core::mem::size_of::<f32>());
54 let imag_handle = client.empty(n * core::mem::size_of::<f32>());
55 let twiddles_size = 2 * n * n;
56 let twiddles_handle = client.empty(twiddles_size * core::mem::size_of::<f32>());
57
58 let num_workgroups = (n as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
59
60 unsafe {
61 precompute_twiddles::launch_unchecked::<f32, R>(
62 &client,
63 CubeCount::Static(num_workgroups, 1, 1),
64 CubeDim::new(WORKGROUP_SIZE, 1, 1),
65 ArrayArg::from_raw_parts::<f32>(&twiddles_handle, twiddles_size, 1),
66 n as u32,
67 )
68 };
69
70 let twiddles_bytes = client.read_one(twiddles_handle.clone().binding());
71 let twiddles = f32::from_bytes(&twiddles_bytes);
72 println!("twiddles[{}] - {:#?}", twiddles.len(), &twiddles[0..10]);
73 unsafe {
77 fft_kernel::launch_unchecked::<f32, R>(
78 &client,
79 CubeCount::Static(num_workgroups, 1, 1),
80 CubeDim::new(WORKGROUP_SIZE, 1, 1),
81 ArrayArg::from_raw_parts::<f32>(&input_handle, n, 1),
82 ArrayArg::from_raw_parts::<f32>(&twiddles_handle, twiddles_size, 1),
83 ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
84 ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
85 n as u32,
86 )
87 };
88
89 let real_bytes = client.read_one(real_handle.binding());
90 let real = f32::from_bytes(&real_bytes);
91
92 let imag_bytes = client.read_one(imag_handle.binding());
93 let imag = f32::from_bytes(&imag_bytes);
94
95 println!(
96 "real {:?}..{:?}",
97 &real[0..10],
98 &real[real.len() - 10..real.len() - 1]
99 );
100 println!(
101 "imag {:?}..{:?}",
102 &imag[0..10],
103 &imag[imag.len() - 10..imag.len() - 1]
104 );
105
106 (real.into(), imag.into())
107}