use anyhow::Result;
use rustfft::num_complex::Complex;
use rustfft::{Fft, FftPlanner};
use std::f32::consts::PI;
use std::sync::Arc;
use tunes::gpu::{GpuDevice, GpuFft};
fn main() -> Result<()> {
println!("GPU FFT Validation");
println!("==================\n");
let device = Arc::new(GpuDevice::new()?);
println!();
let sizes = vec![256, 512, 1024, 2048];
for size in sizes {
println!("Testing FFT size: {}", size);
println!("{}", "-".repeat(40));
test_fft_size(&device, size)?;
println!();
}
println!("✓ All FFT validations passed!");
println!("\nGPU FFT is producing correct results! 🎉");
Ok(())
}
fn test_fft_size(device: &Arc<GpuDevice>, size: usize) -> Result<()> {
let mut gpu_fft = GpuFft::new(device.clone(), size)?;
let mut planner = FftPlanner::new();
let cpu_fft_forward = planner.plan_fft_forward(size);
let cpu_fft_inverse = planner.plan_fft_inverse(size);
println!(" Test 1: Impulse response");
test_impulse(&mut gpu_fft, &cpu_fft_forward, size)?;
println!(" Test 2: Sine wave (10 Hz)");
test_sine_wave(&mut gpu_fft, &cpu_fft_forward, size, 10.0)?;
println!(" Test 3: Complex signal (multiple frequencies)");
test_complex_signal(&mut gpu_fft, &cpu_fft_forward, size)?;
println!(" Test 4: Round-trip (FFT → IFFT)");
test_round_trip(&mut gpu_fft, &cpu_fft_inverse, size)?;
println!(" Test 5: Random noise");
test_random_noise(&mut gpu_fft, &cpu_fft_forward, size)?;
Ok(())
}
fn test_impulse(gpu_fft: &mut GpuFft, cpu_fft: &Arc<dyn Fft<f32>>, size: usize) -> Result<()> {
let mut gpu_data = vec![Complex::new(0.0, 0.0); size];
gpu_data[0] = Complex::new(1.0, 0.0);
let mut cpu_data = gpu_data.clone();
gpu_fft.forward(&mut gpu_data)?;
cpu_fft.process(&mut cpu_data);
let max_error = compare_results(&gpu_data, &cpu_data);
println!(" Max error: {:.6e}", max_error);
if max_error > 0.001 {
anyhow::bail!("FFT error too large: {}", max_error);
}
Ok(())
}
fn test_sine_wave(
gpu_fft: &mut GpuFft,
cpu_fft: &Arc<dyn Fft<f32>>,
size: usize,
freq: f32,
) -> Result<()> {
let mut gpu_data = vec![Complex::new(0.0, 0.0); size];
for i in 0..size {
let t = i as f32 / size as f32;
gpu_data[i] = Complex::new((2.0 * PI * freq * t).sin(), 0.0);
}
let mut cpu_data = gpu_data.clone();
gpu_fft.forward(&mut gpu_data)?;
cpu_fft.process(&mut cpu_data);
let max_error = compare_results(&gpu_data, &cpu_data);
println!(" Max error: {:.6e}", max_error);
if max_error > 0.001 {
anyhow::bail!("FFT error too large: {}", max_error);
}
Ok(())
}
fn test_complex_signal(
gpu_fft: &mut GpuFft,
cpu_fft: &Arc<dyn Fft<f32>>,
size: usize,
) -> Result<()> {
let mut gpu_data = vec![Complex::new(0.0, 0.0); size];
for i in 0..size {
let t = i as f32 / size as f32;
let signal = (2.0 * PI * 5.0 * t).sin() * 0.5
+ (2.0 * PI * 12.0 * t).sin() * 0.3
+ (2.0 * PI * 25.0 * t).sin() * 0.2;
gpu_data[i] = Complex::new(signal, 0.0);
}
let mut cpu_data = gpu_data.clone();
gpu_fft.forward(&mut gpu_data)?;
cpu_fft.process(&mut cpu_data);
let max_error = compare_results(&gpu_data, &cpu_data);
println!(" Max error: {:.6e}", max_error);
if max_error > 0.001 {
anyhow::bail!("FFT error too large: {}", max_error);
}
Ok(())
}
fn test_round_trip(
gpu_fft: &mut GpuFft,
_cpu_ifft: &Arc<dyn Fft<f32>>,
size: usize,
) -> Result<()> {
let original = (0..size)
.map(|i| {
let t = i as f32 / size as f32;
Complex::new((2.0 * PI * 7.0 * t).sin(), 0.0)
})
.collect::<Vec<_>>();
let mut gpu_data = original.clone();
gpu_fft.forward(&mut gpu_data)?;
gpu_fft.inverse(&mut gpu_data)?;
let max_error = compare_results(&gpu_data, &original);
println!(" Max error: {:.6e}", max_error);
if max_error > 0.01 {
anyhow::bail!("Round-trip error too large: {}", max_error);
}
Ok(())
}
fn test_random_noise(
gpu_fft: &mut GpuFft,
cpu_fft: &Arc<dyn Fft<f32>>,
size: usize,
) -> Result<()> {
use rand::Rng;
let mut rng = rand::rng();
let mut gpu_data = (0..size)
.map(|_| {
Complex::new(
rng.random_range(-1.0..1.0),
rng.random_range(-1.0..1.0),
)
})
.collect::<Vec<_>>();
let mut cpu_data = gpu_data.clone();
gpu_fft.forward(&mut gpu_data)?;
cpu_fft.process(&mut cpu_data);
let max_error = compare_results(&gpu_data, &cpu_data);
println!(" Max error: {:.6e}", max_error);
if max_error > 0.001 {
anyhow::bail!("FFT error too large: {}", max_error);
}
Ok(())
}
fn compare_results(a: &[Complex<f32>], b: &[Complex<f32>]) -> f32 {
let mut max_error = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
let error_re = (x.re - y.re).abs();
let error_im = (x.im - y.im).abs();
max_error = max_error.max(error_re).max(error_im);
}
max_error
}