use std::f32::consts::PI;
use gpu_fft::{fft, fft_batch, ifft, ifft_batch};
mod common;
use common::assert_slice_approx;
#[test]
fn test_ifft_batch_roundtrip_arbitrary() {
let signals: Vec<Vec<f32>> = vec![
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
vec![-3.0, 1.5, 0.0, -2.0, 4.0, -1.0, 0.5, 2.5],
vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
];
let n = signals[0].len();
let spectra = fft_batch(&signals);
let recovered = ifft_batch(&spectra);
assert_eq!(recovered.len(), signals.len());
for (b, (signal, output)) in signals.iter().zip(recovered.iter()).enumerate() {
assert_slice_approx(&output[..n], signal, &format!("batch[{b}] real (round-trip)"));
assert_slice_approx(&output[n..], &vec![0.0f32; n], &format!("batch[{b}] imag (~0)"));
}
}
#[test]
fn test_ifft_batch_matches_single() {
let signals: Vec<Vec<f32>> = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![1.0, 1.0, 1.0, 1.0],
vec![1.0, -1.0, 1.0, -1.0],
vec![0.5, 0.25, 0.125, 0.0625],
];
let spectra = fft_batch(&signals);
let batch_out = ifft_batch(&spectra);
assert_eq!(batch_out.len(), signals.len());
for (b, (re, im)) in spectra.iter().enumerate() {
let scalar_out = ifft(re, im);
assert_slice_approx(&batch_out[b], &scalar_out, &format!("batch[{b}]"));
}
}
#[test]
fn test_ifft_batch_all_ones_spectrum() {
let n = 8usize;
let count = 3;
let spectra: Vec<(Vec<f32>, Vec<f32>)> = (0..count)
.map(|_| (vec![1.0f32; n], vec![0.0f32; n]))
.collect();
let results = ifft_batch(&spectra);
assert_eq!(results.len(), count);
let mut expected = vec![0.0f32; n];
expected[0] = 1.0;
for (b, output) in results.iter().enumerate() {
assert_slice_approx(&output[..n], &expected, &format!("batch[{b}] real"));
assert_slice_approx(&output[n..], &vec![0.0f32; n], &format!("batch[{b}] imag"));
}
}
#[test]
fn test_ifft_batch_roundtrip_sine() {
let n = 8usize;
let signal: Vec<f32> = (0..n)
.map(|i| (2.0 * PI * i as f32 / n as f32).sin())
.collect();
let batch = vec![signal.clone(); 4];
let spectra = fft_batch(&batch);
let recovered = ifft_batch(&spectra);
for (b, output) in recovered.iter().enumerate() {
assert_slice_approx(&output[..n], &signal, &format!("batch[{b}] real (sine round-trip)"));
assert_slice_approx(&output[n..], &vec![0.0f32; n], &format!("batch[{b}] imag (~0)"));
}
}
#[test]
fn test_ifft_batch_empty() {
let results = ifft_batch(&[]);
assert!(results.is_empty());
}
#[test]
fn test_ifft_batch_single_element() {
let signal = vec![1.0f32, 2.0, 3.0, 4.0];
let (re, im) = fft(&signal);
let scalar_out = ifft(&re, &im);
let batch_out = ifft_batch(&[(re, im)]);
assert_eq!(batch_out.len(), 1);
assert_slice_approx(&batch_out[0], &scalar_out, "single-element batch");
}
#[test]
fn test_ifft_batch_large_roundtrip() {
let n = 64usize;
let count = 64;
let signal: Vec<f32> = (0..n).map(|i| i as f32).collect();
let batch = vec![signal.clone(); count];
let spectra = fft_batch(&batch);
let recovered = ifft_batch(&spectra);
assert_eq!(recovered.len(), count);
for (b, output) in recovered.iter().enumerate() {
assert_slice_approx(&output[..n], &signal, &format!("batch[{b}] real (large round-trip)"));
assert_slice_approx(&output[n..], &vec![0.0f32; n], &format!("batch[{b}] imag (~0)"));
}
}
#[test]
fn test_ifft_batch_signal_independence() {
let signal_a = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let signal_b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let batch = vec![
signal_a.clone(),
signal_b.clone(),
signal_a.clone(),
signal_b.clone(),
];
let spectra = fft_batch(&batch);
let recovered = ifft_batch(&spectra);
let n = signal_a.len();
for (b, output) in recovered.iter().enumerate() {
let expected = if b % 2 == 0 { &signal_a } else { &signal_b };
assert_slice_approx(&output[..n], expected, &format!("batch[{b}] real"));
assert_slice_approx(&output[n..], &vec![0.0f32; n], &format!("batch[{b}] imag"));
}
}
#[test]
fn test_ifft_batch_scaling() {
for &litude in &[1.0f32, 2.0, 0.5, 10.0] {
let signal = vec![amplitude; 8];
let spectra = fft_batch(&[signal.clone()]);
let result = ifft_batch(&spectra);
let n = signal.len();
assert_slice_approx(
&result[0][..n],
&signal,
&format!("DC amplitude={amplitude} real"),
);
}
}