use scirs2_core::numeric::Complex64;
use super::pipeline::GpuFftPipeline;
use super::types::{FftDirection, GpuFftConfig, GpuFftError, NormalizationMode};
use crate::error::FFTError;
fn next_pow2(n: usize) -> usize {
if n.is_power_of_two() {
n
} else {
1usize << (usize::BITS - n.leading_zeros()) as usize
}
}
fn gpu_err(e: GpuFftError) -> FFTError {
FFTError::BackendError(e.to_string())
}
fn make_pipeline() -> GpuFftPipeline {
GpuFftPipeline::new(GpuFftConfig {
normalization: NormalizationMode::None,
..GpuFftConfig::default()
})
}
pub fn overlap_save_gpu(
signal: &[f32],
kernel: &[f32],
block_size: usize,
) -> Result<Vec<f32>, FFTError> {
if signal.is_empty() {
return Err(FFTError::ValueError("signal must not be empty".into()));
}
if kernel.is_empty() {
return Err(FFTError::ValueError("kernel must not be empty".into()));
}
let m = kernel.len(); let b = if block_size == 0 {
(4 * m).max(8)
} else {
block_size.max(1)
};
let n_fft = next_pow2(b + m - 1);
let pipeline = make_pipeline();
let mut kernel_buf: Vec<Complex64> = kernel
.iter()
.map(|&x| Complex64::new(x as f64, 0.0))
.collect();
kernel_buf.resize(n_fft, Complex64::new(0.0, 0.0));
pipeline
.execute(&mut kernel_buf, n_fft, FftDirection::Forward)
.map_err(gpu_err)?;
let h_freq = kernel_buf;
let sig_len = signal.len();
let mut output = vec![0.0_f32; sig_len];
let effective_b = n_fft - (m - 1);
let mut out_pos = 0usize;
let mut block_idx = 0usize;
while out_pos < sig_len {
let sig_start = block_idx * effective_b;
let mut block: Vec<Complex64> = Vec::with_capacity(n_fft);
for k in 0..n_fft {
let idx = sig_start + k;
let val = if idx < (m - 1) {
0.0_f64
} else {
let real_idx = idx - (m - 1);
if real_idx < sig_len {
signal[real_idx] as f64
} else {
0.0_f64
}
};
block.push(Complex64::new(val, 0.0));
}
pipeline
.execute(&mut block, n_fft, FftDirection::Forward)
.map_err(gpu_err)?;
for (y, h) in block.iter_mut().zip(h_freq.iter()) {
*y = *y * *h;
}
pipeline
.execute(&mut block, n_fft, FftDirection::Inverse)
.map_err(gpu_err)?;
for k in (m - 1)..n_fft {
if out_pos >= sig_len {
break;
}
output[out_pos] = block[k].re as f32;
out_pos += 1;
}
block_idx += 1;
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
fn direct_convolve(signal: &[f32], kernel: &[f32]) -> Vec<f32> {
let n = signal.len();
let m = kernel.len();
let mut out = vec![0.0_f32; n];
for i in 0..n {
let mut acc = 0.0_f32;
for (j, &k) in kernel.iter().enumerate() {
if i + j >= m - 1 && i + j - (m - 1) < n {
acc += signal[i + j - (m - 1)] * k;
}
}
out[i] = acc;
}
out
}
#[test]
fn overlap_save_convolution_matches_direct() {
let kernel: Vec<f32> = vec![0.2, 0.2, 0.2, 0.2, 0.2];
let signal: Vec<f32> = (0..64).map(|i| i as f32).collect();
let ols = overlap_save_gpu(&signal, &kernel, 32).expect("OLS failed");
let direct = direct_convolve(&signal, &kernel);
assert_eq!(ols.len(), signal.len());
for (i, (&o, &d)) in ols.iter().zip(direct.iter()).enumerate() {
assert!((o - d).abs() < 1e-3, "index {i}: OLS={o:.6} direct={d:.6}");
}
}
#[test]
fn overlap_save_impulse_kernel_identity() {
let kernel = vec![1.0_f32];
let signal: Vec<f32> = (0..32).map(|i| i as f32).collect();
let out = overlap_save_gpu(&signal, &kernel, 0).expect("OLS impulse");
assert_eq!(out.len(), signal.len());
for (i, (&o, &s)) in out.iter().zip(signal.iter()).enumerate() {
assert!((o - s).abs() < 1e-4, "index {i}: {o} vs {s}");
}
}
#[test]
fn overlap_save_lowpass_reduces_high_freq() {
use std::f32::consts::PI;
let n = 128;
let signal: Vec<f32> = (0..n).map(|i| (2.0 * PI * i as f32 / 16.0).sin()).collect();
let kernel: Vec<f32> = vec![1.0 / 7.0; 7];
let out = overlap_save_gpu(&signal, &kernel, 64).expect("OLS lowpass");
assert_eq!(out.len(), n);
let mid = 16;
assert!(
out[mid].abs() < 2.0,
"output magnitude out of expected range at {mid}: {}",
out[mid]
);
}
#[test]
fn overlap_save_empty_signal_error() {
let err = overlap_save_gpu(&[], &[1.0], 32);
assert!(err.is_err());
}
#[test]
fn overlap_save_empty_kernel_error() {
let err = overlap_save_gpu(&[1.0, 2.0], &[], 32);
assert!(err.is_err());
}
}