Skip to main content

proteus_lib/effects/
convolution.rs

1use std::collections::VecDeque;
2use std::sync::Arc;
3use rustfft::{Fft, FftPlanner, num_complex::Complex};
4
5// Taken from https://github.com/BordenJardine/reverb_vst
6
7/*
8Setup IR
9  - segment len is 1/2 fft_size
10  - segment IR buffer (pad with 0s to be fft_size)
11  - FFT and hold onto each IR segment
12Setup frame history Queue
13  - queue for previous input frame buffers
14  - len is same as # of IR segments
15  - start with 0.s
16Process Input
17  - segment len is 1/2 fft_size
18  - segment Input buffer (pad with 0s to be fft_size)
19  - keep track of # of input segments (N)
20  - FFT each input segment
21  - convolve each input segment with the IR and the History (frequency domain)
22    - ???
23    - push/pop history queue
24  - take N output segments
25  - IFFT output segments
26  - Concat into time domain output vec
27  - overlap add output with output from previous frame
28  - hold on to output for overlap with next frame
29
30  - return output vec
31Convolution
32*/
33
34#[derive(Clone)]
35pub struct Convolver {
36  pub fft_size: usize,
37  ir_segments: Vec<Vec<Complex<f32>>>, // freq domain impulse response segments
38  previous_frame_q: VecDeque<Vec<Complex<f32>>>, // previous freq domain input signals
39  pub previous_tail: Vec<f32>, // previous output frame (time domain) for overlap add
40  fft_processor: Arc<dyn Fft<f32>>,
41  ifft_processor: Arc<dyn Fft<f32>>, //inverse ff
42}
43
44impl Convolver {
45  // set up saved segmented IR
46  pub fn new(ir_signal: &[f32], fft_size: usize) -> Self {
47    let mut planner = FftPlanner::<f32>::new();
48    let fft_processor = planner.plan_fft_forward(fft_size);
49    let ifft_processor = planner.plan_fft_inverse(fft_size);
50
51    let ir_segments = segment_buffer(ir_signal, fft_size, &fft_processor);
52    let segment_count = ir_segments.len();
53    Self {
54      fft_size,
55      ir_segments,
56      fft_processor,
57      ifft_processor,
58      previous_frame_q: init_previous_frame_q(segment_count, fft_size),
59      previous_tail: init_previous_tail(fft_size/2),
60    }
61  }
62
63  pub fn process(&mut self, input_buffer: &[f32]) -> Vec<f32> {
64    let io_len = input_buffer.len();
65    // segment and convert to freq domain
66    let input_segments = segment_buffer(input_buffer, self.fft_size, &self.fft_processor);
67
68
69    let mut output_segments: Vec<Vec<Complex<f32>>> = Vec::new();
70    // push front/ pop back
71    for segment in input_segments {
72      self.previous_frame_q.push_front(segment);
73      self.previous_frame_q.pop_back();
74      // multiply
75      output_segments.push(self.convolve_frame());
76    }
77
78    // go back to time domain
79    let mut time_domain: Vec<f32> = Vec::new();
80    for mut segment in output_segments {
81      self.ifft_processor.process(&mut segment);
82      for sample in segment {
83        time_domain.push(sample.re);
84      }
85    }
86
87    // overlap add
88    for (i, sample) in self.previous_tail.iter().enumerate() {
89      match time_domain.get_mut(i) {
90        Some(out_sample) => *out_sample += sample,
91        None => break
92      }
93    }
94
95    // everything outside of the buffer length is the tail for the next run
96    self.previous_tail = time_domain[io_len..time_domain.len()].to_vec();
97
98    // return a buffers worth of signal
99    return time_domain[0..io_len].to_vec();
100  }
101 
102  // in freq domain
103  // 𝑌𝑛(𝑧)=𝑋𝑛(𝑧)⋅𝐻0(𝑧)+𝑋𝑛−1(𝑧)⋅𝐻1(𝑧)+...+𝑋𝑛−15(𝑧)⋅𝐻15(𝑧)
104  fn convolve_frame(&mut self) -> Vec<Complex<f32>> {
105    //init output to accumulate onto
106    let mut convolved: Vec<Complex<f32>> = Vec::new();
107    for _ in 0..self.fft_size {
108      convolved.push(Complex {re: 0. , im: 0. });
109    }
110
111    for i in 0..self.ir_segments.len() {
112      add_frames(&mut convolved, mult_frames(
113        &self.previous_frame_q[i],
114        &self.ir_segments[i]
115      ));
116    }
117    convolved
118  }
119}
120
121// mutates the first frame!
122pub fn add_frames(f1: &mut [Complex<f32>], f2: Vec<Complex<f32>>) {
123  for (mut sample1, sample2) in f1.iter_mut().zip(f2) {
124    sample1.re = sample1.re + sample2.re;
125    sample1.im = sample1.im + sample2.im;
126  }
127}
128
129//freq domain multiplication
130//ReY[f] = ReX[f]ReH[f]-ImX[f]ImH[f]
131//ImY[f] = ImX[f]ReH[f] + ReX[f]ImH[f]
132//
133// returns new vec
134pub fn mult_frames(f1: &[Complex<f32>], f2: &[Complex<f32>]) -> Vec<Complex<f32>> {
135  let mut out: Vec<Complex<f32>> = Vec::new();
136  for (sample1, sample2) in f1.iter().zip(f2) {
137    out.push(Complex {
138     re: (sample1.re * sample2.re) - (sample1.im * sample2.im),
139     im: (sample1.im * sample2.re) - (sample1.re * sample2.im)
140    });
141  }
142  out
143}
144
145pub fn init_previous_tail(size: usize) -> Vec<f32> {
146  let mut tail = Vec::new();
147  for _ in 0..size {
148    tail.push(0.);
149  }
150  tail
151}
152
153// - segment buffer (pad with 0s to be fft_size)
154// - FFT and hold onto each segment
155pub fn segment_buffer(buffer: &[f32], fft_size: usize, fft_processor: &Arc<dyn Fft<f32>>) -> Vec<Vec<Complex<f32>>> {
156  let mut segments = Vec::new();
157  let segment_size = fft_size / 2;
158
159  let mut index = 0;
160  while index < buffer.len() {
161    let mut new_segment: Vec<Complex<f32>> = Vec::new();
162    for i in index..index+segment_size {
163      match buffer.get(i) {
164        Some(sample) => new_segment.push(Complex { re: *sample, im: 0. }),
165        None => continue
166      }
167    }
168    while new_segment.len() < fft_size {
169      new_segment.push(Complex { re: 0., im: 0. });
170    }
171    fft_processor.process(&mut new_segment);
172    segments.push(new_segment);
173    index += segment_size;
174  }
175
176  segments
177}
178
179// queue of previous input segments in the frequency domain (polar notation)
180// init to 0s
181pub fn init_previous_frame_q(segment_count: usize, fft_size: usize) -> VecDeque<Vec<Complex<f32>>> {
182  let mut q = VecDeque::new();
183  for _ in 0..segment_count {
184    let mut empty = Vec::new();
185    for _ in 0..fft_size {
186      empty.push(Complex{ re: 0., im: 0. });
187    }
188    q.push_back(empty);
189  }
190  q
191}