convolution_dsp/
conv.rs

1use std::sync::Arc;
2
3use num_complex::Complex;
4use rustfft::Fft;
5
6use crate::{ConvMode, ConvNum};
7
8pub struct Conv1d<T: ConvNum> {
9    kernel: Vec<Complex<T>>,
10    kernel_len: usize,
11    fft: Arc<dyn Fft<T>>,
12    ifft: Arc<dyn Fft<T>>,
13    mode: ConvMode,
14    fft_length: T,
15}
16
17impl<T: ConvNum> Conv1d<T> {
18    pub fn new(
19        kernel: Vec<Complex<T>>,
20        kernel_len: usize,
21        fft: Arc<dyn Fft<T>>,
22        ifft: Arc<dyn Fft<T>>,
23        mode: ConvMode,
24        fft_length: T,
25    ) -> Self {
26        Self {
27            kernel,
28            kernel_len,
29            fft,
30            ifft,
31            mode,
32            fft_length,
33        }
34    }
35
36    pub fn process(&mut self, input: Vec<Complex<T>>) -> Vec<Complex<T>> {
37        let segment_len = self.fft.len() - self.kernel_len - 1;
38        let segments = ((input.len() as f32) / (segment_len as f32)).ceil() as usize;
39
40        let mut output = vec![Complex::<T>::ZERO; input.len() + self.kernel_len - 1];
41
42        let mut segment = Vec::with_capacity(self.fft.len());
43        for i in 0..segments {
44            let offset = i * segment_len;
45            let end = offset + segment_len;
46            if end > input.len() {
47                segment.extend_from_slice(&input[offset..input.len()]);
48                segment.extend(std::iter::repeat(Complex::<T>::ZERO).take(end - input.len()));
49            } else {
50                segment.extend_from_slice(&input[offset..(offset + segment_len)]);
51            }
52            segment
53                .extend(std::iter::repeat(Complex::<T>::ZERO).take(self.fft.len() - segment_len));
54            assert_eq!(segment.len(), self.fft.len());
55
56            // FFT the segment
57            self.fft.process(&mut segment);
58
59            // Piecewise multiply with kernel.
60            for (j, value) in segment.iter_mut().enumerate() {
61                *value = *value * self.kernel[j];
62            }
63
64            // IFFT back to time domain
65            self.ifft.process(&mut segment);
66
67            // Normalize and accumulate to output
68            for j in 0..segment.len() {
69                if offset + j < output.len() {
70                    output[offset + j] = output[offset + j] + (segment[j] / self.fft_length);
71                } else {
72                    break;
73                }
74            }
75
76            segment.clear();
77        }
78
79        match self.mode {
80            ConvMode::Full => output,
81            ConvMode::Same => {
82                let target_len = input.len().max(self.kernel_len);
83                let left = (output.len() - target_len) / 2;
84                let right = left + target_len;
85
86                output[left..right].to_vec()
87            }
88        }
89    }
90}