use cubecl::prelude::*;
#[cube(launch_unchecked)]
fn stft_power_kernel<F: Float>(
samples: &Array<F>,
output: &mut Array<F>,
n_fft: u32,
hop_length: u32,
n_freqs: u32,
#[comptime] n_fft_smem: usize,
) {
let freq_k = UNIT_POS;
let frame_idx = CUBE_POS_X;
let frame_start = frame_idx * hop_length;
let mut smem = SharedMemory::<F>::new(n_fft_smem);
let two_pi = F::new(std::f32::consts::TAU);
let n_fft_f = F::cast_from(n_fft);
let mut n = UNIT_POS;
while n < n_fft {
let sample = samples[(frame_start + n) as usize];
let angle = two_pi * F::cast_from(n) / n_fft_f;
let window = (F::new(1.0_f32) - angle.cos()) * F::new(0.5_f32);
smem[n as usize] = sample * window;
n += n_freqs;
}
sync_cube();
let two_pi_k = two_pi * F::cast_from(freq_k) / n_fft_f;
let mut re = F::new(0.0_f32);
let mut im = F::new(0.0_f32);
let mut m = 0u32;
while m < n_fft {
let phase = two_pi_k * F::cast_from(m);
re += smem[m as usize] * phase.cos();
im -= smem[m as usize] * phase.sin();
m += 1u32;
}
output[(frame_idx * n_freqs + freq_k) as usize] = re * re + im * im;
}
pub fn compute_stft_power_gpu<R: Runtime>(
client: &ComputeClient<R>,
padded_samples: &[f32],
n_fft: usize,
hop_length: usize,
n_frames: usize,
) -> Vec<f32> {
let n_freqs = n_fft / 2 + 1;
let samples_bytes: Vec<u8> = padded_samples
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let samples_handle = client.create_from_slice(&samples_bytes);
let n_output = n_frames * n_freqs;
let output_handle = client.empty(n_output * std::mem::size_of::<f32>());
let cube_dim = CubeDim::new_1d(n_freqs as u32);
let cube_count = CubeCount::Static(n_frames as u32, 1, 1);
unsafe {
stft_power_kernel::launch_unchecked::<f32, R>(
client,
cube_count,
cube_dim,
ArrayArg::from_raw_parts(samples_handle, padded_samples.len()),
ArrayArg::from_raw_parts(output_handle.clone(), n_output),
n_fft as u32,
hop_length as u32,
n_freqs as u32,
n_fft, );
}
let raw = client.read_one_unchecked(output_handle);
raw.chunks(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use burn_wgpu::WgpuRuntime;
use cubecl::prelude::Runtime;
use rustfft::{FftPlanner, num_complex::Complex};
use std::f32::consts::PI;
fn cpu_stft_power(
samples: &[f32],
n_fft: usize,
hop_length: usize,
n_frames: usize,
) -> Vec<f32> {
let n_freqs = n_fft / 2 + 1;
let window: Vec<f32> = (0..n_fft)
.map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n_fft as f32).cos()))
.collect();
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(n_fft);
let mut out = vec![0.0f32; n_frames * n_freqs];
let mut buf: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n_fft];
for frame in 0..n_frames {
let start = frame * hop_length;
for i in 0..n_fft {
buf[i] = Complex::new(samples[start + i] * window[i], 0.0);
}
fft.process(&mut buf);
for k in 0..n_freqs {
out[frame * n_freqs + k] = buf[k].norm_sqr();
}
}
out
}
#[test]
fn test_gpu_stft_matches_cpu() {
let device = burn_wgpu::WgpuDevice::default();
let client = WgpuRuntime::client(&device);
let n_samples = 16000 + 400; let samples: Vec<f32> = (0..n_samples)
.map(|i| (2.0 * PI * 440.0 * i as f32 / 16000.0).sin())
.collect();
let n_fft = 400;
let hop = 160;
let n_freqs = n_fft / 2 + 1;
let n_frames = (samples.len() - n_fft) / hop + 1;
let gpu = compute_stft_power_gpu(&client, &samples, n_fft, hop, n_frames);
let cpu = cpu_stft_power(&samples, n_fft, hop, n_frames);
assert_eq!(gpu.len(), cpu.len(), "output length mismatch");
let mut max_abs_err = 0.0f32;
for (g, c) in gpu.iter().zip(cpu.iter()) {
let err = (g - c).abs();
if err > max_abs_err {
max_abs_err = err;
}
}
assert!(
max_abs_err < 1e-3,
"GPU/CPU STFT max abs error {max_abs_err} exceeds 1e-3 — outputs disagree"
);
let bin_440 = (440.0 * n_fft as f32 / 16000.0).round() as usize;
let frame_mid = n_frames / 2;
let power_at_440 = gpu[frame_mid * n_freqs + bin_440];
assert!(
power_at_440 > 1.0,
"Expected significant power at 440 Hz bin {bin_440}, got {power_at_440}"
);
}
}