convolution_dsp/
planner.rs1use num_complex::Complex;
2use rustfft::FftPlanner;
3
4use crate::conv::Conv1d;
5use crate::{ConvMode, ConvNum};
6
7pub struct Conv1dPlanner;
8
9impl Conv1dPlanner {
10 pub fn new() -> Self {
11 Self
12 }
13
14 pub fn plan_conv1d<T: ConvNum>(&self, kernel: &[T], mode: ConvMode) -> Conv1d<T> {
15 let kernel_len = kernel.len();
16 assert!(kernel_len > 1);
17
18 let fft_size = if kernel_len & (kernel_len - 1) != 0 {
20 usize::pow(2, kernel_len.ilog2() + 2)
21 } else {
22 kernel_len * 2
23 };
24
25 let mut fft_planner = FftPlanner::new();
26 let fft = fft_planner.plan_fft_forward(fft_size);
27 let ifft = fft_planner.plan_fft_inverse(fft_size);
28
29 let mut kernel: Vec<_> = kernel
30 .iter()
31 .map(|re| Complex::<T>::new(*re, T::ZERO))
32 .collect();
33 kernel.extend(vec![Complex::<T>::ZERO; fft_size - kernel.len()]);
34 fft.process(&mut kernel);
35
36 let fft_length = match T::from(fft.len()) {
37 Some(len) => len,
38 None => panic!("Failed to convert usize to T."),
39 };
40
41 Conv1d::new(kernel, kernel_len, fft, ifft, mode, fft_length)
42 }
43}