1use std::sync::Arc;
2
3use num_complex::Complex;
4use rustfft::Fft;
5
6use crate::{ConvMode, ConvNum};
7
8pub struct Conv1d<T: ConvNum> {
9 kernel: Vec<Complex<T>>,
10 kernel_len: usize,
11 fft: Arc<dyn Fft<T>>,
12 ifft: Arc<dyn Fft<T>>,
13 mode: ConvMode,
14 fft_length: T,
15}
16
17impl<T: ConvNum> Conv1d<T> {
18 pub fn new(
19 kernel: Vec<Complex<T>>,
20 kernel_len: usize,
21 fft: Arc<dyn Fft<T>>,
22 ifft: Arc<dyn Fft<T>>,
23 mode: ConvMode,
24 fft_length: T,
25 ) -> Self {
26 Self {
27 kernel,
28 kernel_len,
29 fft,
30 ifft,
31 mode,
32 fft_length,
33 }
34 }
35
36 pub fn process(&mut self, input: Vec<Complex<T>>) -> Vec<Complex<T>> {
37 let segment_len = self.fft.len() - self.kernel_len - 1;
38 let segments = ((input.len() as f32) / (segment_len as f32)).ceil() as usize;
39
40 let mut output = vec![Complex::<T>::ZERO; input.len() + self.kernel_len - 1];
41
42 let mut segment = Vec::with_capacity(self.fft.len());
43 for i in 0..segments {
44 let offset = i * segment_len;
45 let end = offset + segment_len;
46 if end > input.len() {
47 segment.extend_from_slice(&input[offset..input.len()]);
48 segment.extend(std::iter::repeat(Complex::<T>::ZERO).take(end - input.len()));
49 } else {
50 segment.extend_from_slice(&input[offset..(offset + segment_len)]);
51 }
52 segment
53 .extend(std::iter::repeat(Complex::<T>::ZERO).take(self.fft.len() - segment_len));
54 assert_eq!(segment.len(), self.fft.len());
55
56 self.fft.process(&mut segment);
58
59 for (j, value) in segment.iter_mut().enumerate() {
61 *value = *value * self.kernel[j];
62 }
63
64 self.ifft.process(&mut segment);
66
67 for j in 0..segment.len() {
69 if offset + j < output.len() {
70 output[offset + j] = output[offset + j] + (segment[j] / self.fft_length);
71 } else {
72 break;
73 }
74 }
75
76 segment.clear();
77 }
78
79 match self.mode {
80 ConvMode::Full => output,
81 ConvMode::Same => {
82 let target_len = input.len().max(self.kernel_len);
83 let left = (output.len() - target_len) / 2;
84 let right = left + target_len;
85
86 output[left..right].to_vec()
87 }
88 }
89 }
90}