use rustfft::{num_complex::Complex, FftNum, FftPlanner};
pub fn fft_conv<T: FftNum>(signal: &[T], kernel: &[T]) -> Vec<T> {
let n = signal.len();
assert_eq!(kernel.len(), n, "signal and kernel must be the same length");
if n == 0 {
return Vec::new();
}
let conv_len = 2 * n - 1;
let fft_len = conv_len.next_power_of_two();
let mut planner = FftPlanner::<T>::new();
let fft = planner.plan_fft_forward(fft_len);
let ifft = planner.plan_fft_inverse(fft_len);
let mut signal_buf: Vec<Complex<T>> = signal
.iter()
.map(|&v| Complex::new(v, T::zero()))
.chain(std::iter::repeat(Complex::new(T::zero(), T::zero())))
.take(fft_len)
.collect();
let mut kernel_buf: Vec<Complex<T>> = kernel
.iter()
.map(|&v| Complex::new(v, T::zero()))
.chain(std::iter::repeat(Complex::new(T::zero(), T::zero())))
.take(fft_len)
.collect();
fft.process(&mut signal_buf);
fft.process(&mut kernel_buf);
for (s, k) in signal_buf.iter_mut().zip(kernel_buf.iter()) {
*s = *s * *k;
}
ifft.process(&mut signal_buf);
let scale = T::from_usize(fft_len).expect("fft_len fits in T");
signal_buf.iter().take(n).map(|c| c.re / scale).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::structured::toeplitz_matvec;
#[test]
fn matches_toeplitz_matvec_reference() {
let n = 17; let signal: Vec<f64> = (0..n).map(|i| (i as f64 * 0.23).sin()).collect();
let kernel: Vec<f64> = (0..n).map(|i| (i as f64 * 0.41).cos() * 0.8_f64.powi(i as i32)).collect();
let expected = toeplitz_matvec(&kernel, &signal);
let actual = fft_conv(&signal, &kernel);
for (e, got) in expected.iter().zip(actual.iter()) {
assert!((e - got).abs() < 1e-9, "expected {e}, got {got}");
}
}
#[test]
fn impulse_kernel_passes_signal_through() {
let n = 5;
let mut kernel = vec![0.0; n];
kernel[0] = 1.0;
let signal: Vec<f64> = (0..n).map(|i| i as f64 + 1.0).collect();
let result = fft_conv(&signal, &kernel);
for (r, s) in result.iter().zip(signal.iter()) {
assert!((r - s).abs() < 1e-9);
}
}
#[test]
fn empty_input() {
let result: Vec<f64> = fft_conv(&[], &[]);
assert!(result.is_empty());
}
}