use scirs2_core::numeric::Complex64;
use super::pipeline::GpuFftPipeline;
use super::types::{FftDirection, GpuFftConfig, GpuFftError, NormalizationMode};
use crate::error::FFTError;
#[derive(Debug, Clone)]
pub struct AutoDispatchConfig {
pub gpu_threshold: usize,
pub inverse: bool,
}
impl Default for AutoDispatchConfig {
fn default() -> Self {
Self {
gpu_threshold: 4096,
inverse: false,
}
}
}
#[derive(Debug)]
pub struct DispatchFftOutput {
pub data: Vec<Complex64>,
pub used_gpu: bool,
pub n_stages: u32,
}
fn next_power_of_two(n: usize) -> usize {
if n.is_power_of_two() {
n
} else {
1usize << (usize::BITS - n.leading_zeros()) as usize
}
}
fn gpu_err_to_fft(e: GpuFftError) -> FFTError {
FFTError::BackendError(e.to_string())
}
fn build_pipeline() -> GpuFftPipeline {
GpuFftPipeline::new(GpuFftConfig {
normalization: NormalizationMode::None,
..GpuFftConfig::default()
})
}
pub fn fft_auto_dispatch(
input: &[Complex64],
config: &AutoDispatchConfig,
) -> Result<DispatchFftOutput, FFTError> {
let n_padded = next_power_of_two(input.len().max(2));
let n_stages = n_padded.trailing_zeros();
let direction = if config.inverse {
FftDirection::Inverse
} else {
FftDirection::Forward
};
let mut buf = Vec::with_capacity(n_padded);
buf.extend_from_slice(input);
buf.resize(n_padded, Complex64::new(0.0, 0.0));
#[cfg(feature = "wgpu_fft")]
{
if n_padded >= config.gpu_threshold {
match super::wgpu_backend::fft_wgpu(&buf, config.inverse) {
Ok(result) => {
return Ok(DispatchFftOutput {
data: result,
used_gpu: true,
n_stages,
});
}
Err(_) => {
}
}
}
}
let pipeline = build_pipeline();
pipeline
.execute(&mut buf, n_padded, direction)
.map_err(gpu_err_to_fft)?;
Ok(DispatchFftOutput {
data: buf,
used_gpu: false,
n_stages,
})
}
pub fn fft_batch_gpu(inputs: &[Vec<Complex64>]) -> Result<Vec<Vec<Complex64>>, FFTError> {
if inputs.is_empty() {
return Err(FFTError::ValueError(
"batch input must contain at least one signal".into(),
));
}
let max_len = inputs.iter().map(|v| v.len()).max().unwrap_or(0);
let n_padded = next_power_of_two(max_len.max(2));
let mut batch: Vec<Vec<Complex64>> = inputs
.iter()
.map(|signal| {
let mut buf = Vec::with_capacity(n_padded);
buf.extend_from_slice(signal);
buf.resize(n_padded, Complex64::new(0.0, 0.0));
buf
})
.collect();
let pipeline = build_pipeline();
let result = pipeline
.execute_batch(&mut batch, FftDirection::Forward)
.map_err(gpu_err_to_fft)?;
Ok(result.outputs)
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
const EPS: f64 = 1e-7;
#[test]
fn gpu_fft_config_default_threshold_4096() {
let cfg = AutoDispatchConfig::default();
assert_eq!(cfg.gpu_threshold, 4096);
assert!(!cfg.inverse);
}
#[test]
fn gpu_fft_auto_dispatch_cpu_path_correct() {
let input: Vec<Complex64> = {
let mut v = vec![Complex64::new(0.0, 0.0); 8];
v[0] = Complex64::new(1.0, 0.0);
v
};
let config = AutoDispatchConfig {
gpu_threshold: 4096, inverse: false,
};
let out = fft_auto_dispatch(&input, &config).expect("dispatch failed");
assert!(!out.used_gpu, "small input must use CPU");
assert_eq!(out.n_stages, 3);
for (k, c) in out.data.iter().enumerate() {
assert!(
(c.re - 1.0).abs() < EPS,
"bin {k} re = {} (expected 1.0)",
c.re
);
assert!(c.im.abs() < EPS, "bin {k} im = {} (expected 0.0)", c.im);
}
}
#[test]
fn fft_power_of_two_padding_correct() {
let input: Vec<Complex64> = (0..6).map(|i| Complex64::new(i as f64, 0.0)).collect();
let config = AutoDispatchConfig::default();
let out = fft_auto_dispatch(&input, &config).expect("dispatch failed");
assert_eq!(out.data.len(), 8, "padded length must be 8");
}
#[test]
fn gpu_fft_auto_dispatch_roundtrip() {
let n = 16;
let original: Vec<Complex64> = (0..n)
.map(|i| Complex64::new((i as f64 * PI / 8.0).sin(), 0.0))
.collect();
let config_fwd = AutoDispatchConfig {
gpu_threshold: 4096,
inverse: false,
};
let config_inv = AutoDispatchConfig {
gpu_threshold: 4096,
inverse: true,
};
let forward = fft_auto_dispatch(&original, &config_fwd).expect("forward");
let recovered = fft_auto_dispatch(&forward.data, &config_inv).expect("inverse");
for (i, (orig, rec)) in original.iter().zip(recovered.data.iter()).enumerate() {
assert!(
(orig.re - rec.re).abs() < 1e-6,
"index {i}: {:.6} vs {:.6}",
orig.re,
rec.re
);
}
}
#[test]
fn gpu_fft_batch_results_match_individual() {
let n = 16;
let signals: Vec<Vec<Complex64>> = (0..8_u64)
.map(|k| {
(0..n)
.map(|i| Complex64::new(i as f64 + k as f64, 0.0))
.collect()
})
.collect();
let config = AutoDispatchConfig::default();
let individual: Vec<Vec<Complex64>> = signals
.iter()
.map(|s| fft_auto_dispatch(s, &config).expect("individual").data)
.collect();
let batch = fft_batch_gpu(&signals).expect("batch");
assert_eq!(batch.len(), signals.len());
for (sig_idx, (ind, bat)) in individual.iter().zip(batch.iter()).enumerate() {
assert_eq!(ind.len(), bat.len(), "signal {sig_idx} length mismatch");
for (bin, (a, b)) in ind.iter().zip(bat.iter()).enumerate() {
assert!(
(a.re - b.re).abs() < 1e-6,
"signal {sig_idx} bin {bin} re: {:.8} vs {:.8}",
a.re,
b.re
);
assert!(
(a.im - b.im).abs() < 1e-6,
"signal {sig_idx} bin {bin} im: {:.8} vs {:.8}",
a.im,
b.im
);
}
}
}
#[test]
fn gpu_fft_batch_rejects_empty() {
let result = fft_batch_gpu(&[]);
assert!(result.is_err());
}
}