convolution_dsp/
planner.rs

1use num_complex::Complex;
2use rustfft::FftPlanner;
3
4use crate::conv::Conv1d;
5use crate::{ConvMode, ConvNum};
6
7pub struct Conv1dPlanner;
8
9impl Conv1dPlanner {
10    pub fn new() -> Self {
11        Self
12    }
13
14    pub fn plan_conv1d<T: ConvNum>(&self, kernel: &[T], mode: ConvMode) -> Conv1d<T> {
15        let kernel_len = kernel.len();
16        assert!(kernel_len > 1);
17
18        // FFT size must be reasonably large to avoid circular convolution
19        let fft_size = if kernel_len & (kernel_len - 1) != 0 {
20            usize::pow(2, kernel_len.ilog2() + 2)
21        } else {
22            kernel_len * 2
23        };
24
25        let mut fft_planner = FftPlanner::new();
26        let fft = fft_planner.plan_fft_forward(fft_size);
27        let ifft = fft_planner.plan_fft_inverse(fft_size);
28
29        let mut kernel: Vec<_> = kernel
30            .iter()
31            .map(|re| Complex::<T>::new(*re, T::ZERO))
32            .collect();
33        kernel.extend(vec![Complex::<T>::ZERO; fft_size - kernel.len()]);
34        fft.process(&mut kernel);
35
36        let fft_length = match T::from(fft.len()) {
37            Some(len) => len,
38            None => panic!("Failed to convert usize to T."),
39        };
40
41        Conv1d::new(kernel, kernel_len, fft, ifft, mode, fft_length)
42    }
43}