use num_complex::Complex;
use wgsl_fft::GpuFft;
fn main() {
println!("🚀 Testing Batch Processing Implementation");
println!("=========================================\n");
let fft = GpuFft::new().expect("GPU required");
println!("✅ Created GPU FFT instance\n");
println!("Test 1: Single FFT (backward compatibility)");
let single_input = vec![vec![Complex::new(1.0, 0.0); 16]];
let single_result_batch = fft.fft(&single_input).expect("Single FFT failed");
let single_result = &single_result_batch[0];
println!(" Input length: {}", single_input[0].len());
println!(" Output length: {}", single_result.len());
println!(" ✅ Single FFT works\n");
println!("Test 2: Batch FFT Processing");
let batch_inputs = vec![
vec![Complex::new(1.0, 0.0); 16],
vec![Complex::new(0.5, 0.0); 16],
vec![Complex::new(0.25, 0.0); 16],
];
let batch_results = fft.fft(&batch_inputs).expect("Batch FFT failed");
println!(" Batch size: {}", batch_results.len());
println!(" Each FFT size: {}", batch_results[0].len());
println!(" ✅ Batch FFT works\n");
println!("Test 3: Batch IFFT Processing");
let batch_ifft_results = fft.ifft(&batch_results).expect("Batch IFFT failed");
println!(" Batch size: {}", batch_ifft_results.len());
println!(" Each IFFT size: {}", batch_ifft_results[0].len());
println!(" ✅ Batch IFFT works\n");
println!("Test 4: Roundtrip Consistency Check");
let mut max_total_error: f32 = 0.0;
for (i, (original, reconstructed)) in batch_inputs
.iter()
.zip(batch_ifft_results.iter())
.enumerate()
{
let max_error: f32 = original
.iter()
.zip(reconstructed.iter())
.map(|(orig, recon)| {
((orig.re - recon.re).powi(2) + (orig.im - recon.im).powi(2)).sqrt()
})
.fold(0.0, f32::max);
println!(" Signal {} max error: {:.2e}", i, max_error);
max_total_error = max_total_error.max(max_error);
}
println!(" Overall max error: {:.2e}", max_total_error);
assert!(
max_total_error < 1e-4,
"Roundtrip error too large: {}",
max_total_error
);
println!(" ✅ Roundtrip consistency verified\n");
println!("Test 5: Performance Comparison");
use std::time::Instant;
let large_batch_size = 8;
let large_fft_size = 1024;
let large_batch: Vec<_> = (0..large_batch_size)
.map(|i| vec![Complex::new(i as f32 * 0.1, 0.0); large_fft_size])
.collect();
let start = Instant::now();
let individual_results: Vec<_> = large_batch
.iter()
.map(|input| {
let batch = vec![input.clone()];
fft.fft(&batch).expect("FFT failed")[0].clone()
})
.collect();
let individual_time = start.elapsed();
let start = Instant::now();
let batch_results = fft.fft(&large_batch).expect("Batch FFT failed");
let batch_time = start.elapsed();
println!(" Individual processing time: {:?}", individual_time);
println!(" Batch processing time: {:?}", batch_time);
if batch_time < individual_time {
let speedup = individual_time.as_secs_f32() / batch_time.as_secs_f32();
println!(" Speedup: {:.2}x", speedup);
} else {
println!(" Note: Batch processing may not be faster for small batches");
}
println!(" ✅ Performance test completed\n");
println!("🎉 All batch processing tests passed!");
println!("\nSummary:");
println!(" ✅ Single FFT (backward compatibility)");
println!(" ✅ Batch FFT processing");
println!(" ✅ Batch IFFT processing");
println!(" ✅ Roundtrip consistency");
println!(" ✅ Performance comparison");
println!("\nThe batch processing implementation is working correctly!");
}
#[cfg(test)]
mod tests {
use num_complex::Complex;
use wgsl_fft::GpuFft;
fn make_fft() -> GpuFft {
GpuFft::new().expect("GPU required")
}
#[test]
fn single_fft_output_length() {
let fft = make_fft();
let input = vec![vec![Complex::new(1.0f32, 0.0); 256]];
let out = fft.fft(&input).expect("fft failed");
assert_eq!(out.len(), 1);
assert_eq!(out[0].len(), 256);
}
#[test]
fn batch_fft_output_count() {
let fft = make_fft();
let n = 64;
let batch: Vec<Vec<Complex<f32>>> = (0..8)
.map(|i| vec![Complex::new(i as f32 * 0.1, 0.0); n])
.collect();
let out = fft.fft(&batch).expect("batch fft failed");
assert_eq!(out.len(), 8, "output batch size should match input");
for o in &out {
assert_eq!(o.len(), n, "each output should have length n");
}
}
#[test]
fn roundtrip_constant_signal() {
let fft = make_fft();
let n = 64;
let input = vec![vec![Complex::new(0.5f32, 0.0); n]];
let freq = fft.fft(&input).expect("fft failed");
let recon = fft.ifft(&freq).expect("ifft failed");
let max_err = input[0]
.iter()
.zip(recon[0].iter())
.map(|(a, b)| (a - b).norm())
.fold(0.0f32, f32::max);
assert!(max_err < 1e-4, "roundtrip error {max_err:.2e}");
}
#[test]
fn roundtrip_batch_error_below_threshold() {
let fft = make_fft();
let n = 128;
let batch: Vec<Vec<Complex<f32>>> = (0..4)
.map(|i| {
(0..n)
.map(|j| Complex::new((i * n + j) as f32 * 0.001, 0.0))
.collect()
})
.collect();
let freq = fft.fft(&batch).expect("fft failed");
let recon = fft.ifft(&freq).expect("ifft failed");
for (orig, rec) in batch.iter().zip(recon.iter()) {
let max_err = orig
.iter()
.zip(rec.iter())
.map(|(a, b)| (a - b).norm())
.fold(0.0f32, f32::max);
assert!(max_err < 1e-4, "batch roundtrip error {max_err:.2e}");
}
}
#[test]
fn impulse_has_flat_spectrum() {
let fft = make_fft();
let n = 256;
let mut impulse = vec![Complex::<f32>::new(0.0, 0.0); n];
impulse[0] = Complex::new(1.0, 0.0);
let out = fft.fft(&[impulse]).expect("fft failed");
for (k, c) in out[0].iter().enumerate() {
let mag = c.norm();
assert!(
(mag - 1.0).abs() < 1e-3,
"bin {k} magnitude {mag:.4} should be ~1 for impulse input"
);
}
}
}