nice_plug_core/util/stft.rs
1//! Utilities for buffering audio, likely used as part of a short-term Fourier transform.
2
3use std::cmp;
4
5use crate::buffer::{Block, Buffer};
6
7/// Some buffer that can be used with the [`StftHelper`].
8pub trait StftInput {
9 /// The number of samples in this input.
10 fn num_samples(&self) -> usize;
11
12 /// The number of channels in this input.
13 fn num_channels(&self) -> usize;
14
15 /// Index the buffer without any bounds checks.
16 unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32;
17}
18
19/// The same as [`StftInput`], but with support for writing results back to the buffer
20pub trait StftInputMut: StftInput {
21 /// Get a mutable reference to a sample in the buffer without any bounds checks.
22 unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32;
23}
24
25/// Process the input buffer in equal sized blocks, running a callback on each block to transform
26/// the block and then writing back the results from the previous block to the buffer. This
27/// introduces latency equal to the size of the block.
28///
29/// Additional inputs can be processed by setting the `NUM_SIDECHAIN_INPUTS` constant. These buffers
30/// will not be written to, so they are purely used for analysis. These sidechain inputs will have
31/// the same number of channels as the main input.
32///
33/// TODO: Better name?
34/// TODO: We may need something like this purely for analysis, e.g. for showing spectrums in a GUI.
35/// Figure out the cleanest way to adapt this for the non-processing use case.
36pub struct StftHelper<const NUM_SIDECHAIN_INPUTS: usize = 0> {
37 // These ring buffers store the input samples and the already processed output produced by
38 // adding overlapping windows. Whenever we reach a new overlapping window, we'll write the
39 // already calculated outputs to the main buffer passed to the process function and then process
40 // a new block.
41 main_input_ring_buffers: Vec<Vec<f32>>,
42 main_output_ring_buffers: Vec<Vec<f32>>,
43 sidechain_ring_buffers: [Vec<Vec<f32>>; NUM_SIDECHAIN_INPUTS],
44
45 /// Results from the ring buffers are copied to this scratch buffer before being passed to the
46 /// plugin. Needed to handle overlap.
47 scratch_buffer: Vec<f32>,
48 /// If padding is used, then this will contain the previous iteration's values from the padding
49 /// values in `scratch_buffer` (`scratch_buffer[(scratch_buffer.len() - padding -
50 /// 1)..scratch_buffer.len()]`). This is then added to the ring buffer in the next iteration.
51 padding_buffers: Vec<Vec<f32>>,
52
53 /// The current position in our ring buffers. Whenever this wraps around to 0, we'll process
54 /// a block.
55 current_pos: usize,
56 /// If padding is used, then this much extra capacity has been added to the buffers.
57 padding: usize,
58}
59
60/// Marker struct for the version without sidechaining.
61struct NoSidechain;
62
63impl StftInput for Buffer<'_> {
64 #[inline]
65 fn num_samples(&self) -> usize {
66 self.samples()
67 }
68
69 #[inline]
70 fn num_channels(&self) -> usize {
71 self.channels()
72 }
73
74 #[inline]
75 unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
76 unsafe {
77 *self
78 .as_slice_immutable()
79 .get_unchecked(channel)
80 .get_unchecked(sample_idx)
81 }
82 }
83}
84
85impl StftInputMut for Buffer<'_> {
86 #[inline]
87 unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 {
88 unsafe {
89 self.as_slice()
90 .get_unchecked_mut(channel)
91 .get_unchecked_mut(sample_idx)
92 }
93 }
94}
95
96impl StftInput for Block<'_, '_> {
97 #[inline]
98 fn num_samples(&self) -> usize {
99 self.samples()
100 }
101
102 #[inline]
103 fn num_channels(&self) -> usize {
104 self.channels()
105 }
106
107 #[inline]
108 unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
109 unsafe { *self.get_unchecked(channel).get_unchecked(sample_idx) }
110 }
111}
112
113impl StftInputMut for Block<'_, '_> {
114 #[inline]
115 unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 {
116 unsafe {
117 self.get_unchecked_mut(channel)
118 .get_unchecked_mut(sample_idx)
119 }
120 }
121}
122
123impl StftInput for [&[f32]] {
124 #[inline]
125 fn num_samples(&self) -> usize {
126 if self.is_empty() { 0 } else { self[0].len() }
127 }
128
129 #[inline]
130 fn num_channels(&self) -> usize {
131 self.len()
132 }
133
134 #[inline]
135 unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
136 unsafe { *self.get_unchecked(channel).get_unchecked(sample_idx) }
137 }
138}
139
140impl StftInput for [&mut [f32]] {
141 #[inline]
142 fn num_samples(&self) -> usize {
143 if self.is_empty() { 0 } else { self[0].len() }
144 }
145
146 #[inline]
147 fn num_channels(&self) -> usize {
148 self.len()
149 }
150
151 #[inline]
152 unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
153 unsafe { *self.get_unchecked(channel).get_unchecked(sample_idx) }
154 }
155}
156
157impl StftInputMut for [&mut [f32]] {
158 #[inline]
159 unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 {
160 unsafe {
161 self.get_unchecked_mut(channel)
162 .get_unchecked_mut(sample_idx)
163 }
164 }
165}
166
167impl StftInput for NoSidechain {
168 fn num_samples(&self) -> usize {
169 0
170 }
171
172 fn num_channels(&self) -> usize {
173 0
174 }
175
176 unsafe fn get_sample_unchecked(&self, _channel: usize, _sample_idx: usize) -> f32 {
177 0.0
178 }
179}
180
181impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
182 /// Initialize the [`StftHelper`] for [`Buffer`]s with the specified number of channels and the
183 /// given maximum block size. When the option is set, then every yielded sample buffer will have
184 /// this many zero samples appended at the end of the block. Call
185 /// [`set_block_size()`][Self::set_block_size()] afterwards if you do not need the full capacity
186 /// upfront. If the padding option is non zero, then all yielded blocks will have that many
187 /// zeroes added to the end of it and the results stored in the padding area will be added to
188 /// the outputs in the next iteration(s). You may also change how much padding is added with
189 /// [`set_padding()`][Self::set_padding()].
190 ///
191 /// # Panics
192 ///
193 /// Panics if `num_channels == 0 || max_block_size == 0`.
194 pub fn new(num_channels: usize, max_block_size: usize, max_padding: usize) -> Self {
195 assert_ne!(num_channels, 0);
196 assert_ne!(max_block_size, 0);
197
198 Self {
199 main_input_ring_buffers: vec![vec![0.0; max_block_size]; num_channels],
200 main_output_ring_buffers: vec![vec![0.0; max_block_size]; num_channels],
201 // Kinda hacky way to initialize an array of non-copy types
202 sidechain_ring_buffers: [(); NUM_SIDECHAIN_INPUTS]
203 .map(|_| vec![vec![0.0; max_block_size]; num_channels]),
204
205 // When padding is used this scratch buffer will have a bunch of zeroes added to it
206 // after copying a block of audio to it
207 scratch_buffer: vec![0.0; max_block_size + max_padding],
208 padding_buffers: vec![vec![0.0; max_padding]; num_channels],
209
210 current_pos: 0,
211 padding: max_padding,
212 }
213 }
214
215 /// Change the current block size. This will clear the buffers, causing the next block to output
216 /// silence.
217 ///
218 /// # Panics
219 ///
220 /// Will panic if `block_size > max_block_size`.
221 pub fn set_block_size(&mut self, block_size: usize) {
222 assert!(block_size <= self.main_input_ring_buffers[0].capacity());
223
224 self.update_buffers(block_size);
225 }
226
227 /// Change the current padding amount. This will clear the buffers, causing the next block to
228 /// output silence.
229 ///
230 /// # Panics
231 ///
232 /// Will panic if `padding > max_padding`.
233 pub fn set_padding(&mut self, padding: usize) {
234 assert!(padding <= self.padding_buffers[0].capacity());
235
236 self.padding = padding;
237 self.update_buffers(self.main_input_ring_buffers[0].len());
238 }
239
240 /// The number of channels this `StftHelper` was configured for
241 pub fn num_channels(&self) -> usize {
242 self.main_input_ring_buffers.len()
243 }
244
245 /// The maximum block size supported by this instance.
246 pub fn max_block_size(&self) -> usize {
247 self.main_input_ring_buffers.capacity()
248 }
249
250 /// The maximum amount of padding supported by this instance.
251 pub fn max_padding(&self) -> usize {
252 self.padding_buffers[0].capacity()
253 }
254
255 /// The amount of latency introduced when processing audio through this [`StftHelper`].
256 pub fn latency_samples(&self) -> u32 {
257 self.main_input_ring_buffers[0].len() as u32
258 }
259
260 /// Process the audio in `main_buffer` in small overlapping blocks, adding up the results for
261 /// the main buffer so they can eventually be written back to the host one block later. This
262 /// means that this function will introduce one block of latency. This can be compensated by
263 /// calling [`InitContext::set_latency()`][`crate::context::init::InitContext::set_latency_samples()`]
264 /// in your plugin's initialization function.
265 ///
266 /// If a padding value was specified in [`new()`][Self::new()], then the yielded blocks will
267 /// have that many zeroes appended at the end of them. The padding values will be added to the
268 /// next block before `process_cb()` is called.
269 ///
270 /// Since there are a couple different ways to do it, any window functions needs to be applied
271 /// in the callbacks. Check the [`nice_plug_core::util::window`][crate::util::window] module for more
272 /// information.
273 ///
274 /// For efficiency's sake this function will reuse the same vector for all calls to
275 /// `process_cb`. This means you can only access a single channel's worth of windowed data at a
276 /// time. The arguments to that function are `process_cb(channel_idx, real_fft_buffer)`.
277 /// `real_fft_buffer` will be a slice of `block_size` real valued samples. This can be passed
278 /// directly to an FFT algorithm.
279 ///
280 /// # Panics
281 ///
282 /// Panics if `main_buffer` or the buffers in `sidechain_buffers` do not have the same number of
283 /// channels as this [`StftHelper`], or if the sidechain buffers do not contain the same number of
284 /// samples as the main buffer.
285 ///
286 /// TODO: Add more useful ways to do STFT and other buffered operations. I just went with this
287 /// approach because it's what I needed myself, but generic combinators like this could
288 /// also be useful for other operations.
289 pub fn process_overlap_add<M, F>(
290 &mut self,
291 main_buffer: &mut M,
292 overlap_times: usize,
293 mut process_cb: F,
294 ) where
295 M: StftInputMut,
296 F: FnMut(usize, &mut [f32]),
297 {
298 self.process_overlap_add_sidechain(
299 main_buffer,
300 [&NoSidechain; NUM_SIDECHAIN_INPUTS],
301 overlap_times,
302 |channel_idx, sidechain_idx, real_fft_scratch_buffer| {
303 if sidechain_idx.is_none() {
304 process_cb(channel_idx, real_fft_scratch_buffer);
305 }
306 },
307 );
308 }
309
310 /// The same as [`process_overlap_add()`][Self::process_overlap_add()], but with sidechain
311 /// inputs that can be analyzed before the main input gets processed.
312 ///
313 /// The extra argument in the process function is `sidechain_buffer_idx`, which will be `None`
314 /// for the main buffer.
315 pub fn process_overlap_add_sidechain<M, S, F>(
316 &mut self,
317 main_buffer: &mut M,
318 sidechain_buffers: [&S; NUM_SIDECHAIN_INPUTS],
319 overlap_times: usize,
320 mut process_cb: F,
321 ) where
322 M: StftInputMut,
323 S: StftInput,
324 F: FnMut(usize, Option<usize>, &mut [f32]),
325 {
326 assert_eq!(
327 main_buffer.num_channels(),
328 self.main_input_ring_buffers.len()
329 );
330 assert!(overlap_times > 0);
331
332 // We'll copy samples from `*_buffer` into `*_ring_buffers` while simultaneously copying
333 // already processed samples from `main_ring_buffers` in into `main_buffer`
334 let main_buffer_len = main_buffer.num_samples();
335 let num_channels = main_buffer.num_channels();
336 let block_size = self.main_input_ring_buffers[0].len();
337 let window_interval = (block_size / overlap_times) as i32;
338 let mut already_processed_samples = 0;
339 while already_processed_samples < main_buffer_len {
340 let remaining_samples = main_buffer_len - already_processed_samples;
341 let samples_until_next_window = ((window_interval - self.current_pos as i32 - 1)
342 .rem_euclid(window_interval)
343 + 1) as usize;
344 let samples_to_process = samples_until_next_window.min(remaining_samples);
345
346 // Copy the input from `main_buffer` to the ring buffer while copying last block's
347 // result from the buffer to `main_buffer`
348 // TODO: This might be able to be sped up a bit with SIMD
349
350 // For the main buffer
351 for sample_offset in 0..samples_to_process {
352 for channel_idx in 0..num_channels {
353 let sample = unsafe {
354 main_buffer.get_sample_unchecked_mut(
355 channel_idx,
356 already_processed_samples + sample_offset,
357 )
358 };
359 let input_ring_buffer_sample = unsafe {
360 self.main_input_ring_buffers
361 .get_unchecked_mut(channel_idx)
362 .get_unchecked_mut(self.current_pos + sample_offset)
363 };
364 let output_ring_buffer_sample = unsafe {
365 self.main_output_ring_buffers
366 .get_unchecked_mut(channel_idx)
367 .get_unchecked_mut(self.current_pos + sample_offset)
368 };
369 *input_ring_buffer_sample = *sample;
370 *sample = *output_ring_buffer_sample;
371 // Very important, or else we'll overlap-add ourselves into a feedback hell
372 *output_ring_buffer_sample = 0.0;
373 }
374 }
375
376 // And for the sidechain buffers we only need to copy the inputs
377 for (sidechain_buffer, sidechain_ring_buffers) in sidechain_buffers
378 .iter()
379 .zip(self.sidechain_ring_buffers.iter_mut())
380 {
381 for sample_offset in 0..samples_to_process {
382 for channel_idx in 0..num_channels {
383 let sample = unsafe {
384 sidechain_buffer.get_sample_unchecked(
385 channel_idx,
386 already_processed_samples + sample_offset,
387 )
388 };
389 let ring_buffer_sample = unsafe {
390 sidechain_ring_buffers
391 .get_unchecked_mut(channel_idx)
392 .get_unchecked_mut(self.current_pos + sample_offset)
393 };
394 *ring_buffer_sample = sample;
395 }
396 }
397 }
398
399 already_processed_samples += samples_to_process;
400 self.current_pos = (self.current_pos + samples_to_process) % block_size;
401
402 // At this point we either have `already_processed_samples == main_buffer_len`, or
403 // `self.current_pos % window_interval == 0`. If it's the latter, then we can process a
404 // new block.
405 if samples_to_process == samples_until_next_window {
406 // Because we're processing in smaller windows, the input ring buffers sadly does
407 // not always contain the full contiguous range we're interested in because they map
408 // wrap around. Because premade FFT algorithms typically can't handle this, we'll
409 // start with copying the wrapped ranges from our ring buffers to the scratch
410 // buffer. Then we apply the windowing function and this it along to
411 for (sidechain_idx, sidechain_ring_buffers) in
412 self.sidechain_ring_buffers.iter().enumerate()
413 {
414 for (channel_idx, sidechain_ring_buffer) in
415 sidechain_ring_buffers.iter().enumerate()
416 {
417 copy_ring_to_scratch_buffer(
418 &mut self.scratch_buffer,
419 self.current_pos,
420 sidechain_ring_buffer,
421 );
422 if self.padding > 0 {
423 self.scratch_buffer[block_size..].fill(0.0);
424 }
425
426 process_cb(channel_idx, Some(sidechain_idx), &mut self.scratch_buffer);
427 }
428 }
429
430 for (channel_idx, ((input_ring_buffer, output_ring_buffer), padding_buffer)) in self
431 .main_input_ring_buffers
432 .iter()
433 .zip(self.main_output_ring_buffers.iter_mut())
434 .zip(self.padding_buffers.iter_mut())
435 .enumerate()
436 {
437 copy_ring_to_scratch_buffer(
438 &mut self.scratch_buffer,
439 self.current_pos,
440 input_ring_buffer,
441 );
442 if self.padding > 0 {
443 self.scratch_buffer[block_size..].fill(0.0);
444 }
445
446 process_cb(channel_idx, None, &mut self.scratch_buffer);
447
448 // Add the padding from the last iteration (for this channel) to the scratch
449 // buffer before it is copied to the output ring buffer. In case the padding is
450 // longer than the block size, then this will cause everything else to be
451 // shifted to the left so it can be added in the iteration after this.
452 if self.padding > 0 {
453 let padding_to_copy = cmp::min(self.padding, block_size);
454 for (scratch_sample, padding_sample) in self.scratch_buffer
455 [..padding_to_copy]
456 .iter_mut()
457 .zip(&mut padding_buffer[..padding_to_copy])
458 {
459 *scratch_sample += *padding_sample;
460 }
461
462 // Any remaining padding tail should be moved towards the start of the
463 // buffer
464 padding_buffer.copy_within(padding_to_copy.., 0);
465
466 // And we obviously don't want this to feedback
467 padding_buffer[self.padding - padding_to_copy..].fill(0.0);
468 }
469
470 // The actual overlap-add part of the equation
471 add_scratch_to_ring_buffer(
472 &self.scratch_buffer,
473 self.current_pos,
474 output_ring_buffer,
475 );
476
477 // And the data from the padding area should be saved so it can be added to next
478 // iteration's scratch buffer. Like mentioned above, the padding can be larger
479 // than the block size so we also need to do overlap-add here.
480 if self.padding > 0 {
481 for (padding_sample, scratch_sample) in padding_buffer
482 .iter_mut()
483 .zip(&mut self.scratch_buffer[block_size..])
484 {
485 *padding_sample += *scratch_sample;
486 }
487 }
488 }
489 }
490 }
491 }
492
493 /// Similar to [`process_overlap_add()`][Self::process_overlap_add()], but without the inverse
494 /// STFT part. `buffer` will only ever be read from. This can be useful for providing FFT data
495 /// for a spectrum analyzer in a plugin GUI. These is still a delay to the analysis equal to the
496 /// block size.
497 pub fn process_analyze_only<B, F>(
498 &mut self,
499 buffer: &B,
500 overlap_times: usize,
501 mut analyze_cb: F,
502 ) where
503 B: StftInput,
504 F: FnMut(usize, &mut [f32]),
505 {
506 assert_eq!(buffer.num_channels(), self.main_input_ring_buffers.len());
507 assert!(overlap_times > 0);
508
509 // See `process_overlap_add_sidechain` for an annotated version
510 let main_buffer_len = buffer.num_samples();
511 let num_channels = buffer.num_channels();
512 let block_size = self.main_input_ring_buffers[0].len();
513 let window_interval = (block_size / overlap_times) as i32;
514 let mut already_processed_samples = 0;
515 while already_processed_samples < main_buffer_len {
516 let remaining_samples = main_buffer_len - already_processed_samples;
517 let samples_until_next_window = ((window_interval - self.current_pos as i32 - 1)
518 .rem_euclid(window_interval)
519 + 1) as usize;
520 let samples_to_process = samples_until_next_window.min(remaining_samples);
521
522 for sample_offset in 0..samples_to_process {
523 for channel_idx in 0..num_channels {
524 let sample = unsafe {
525 buffer.get_sample_unchecked(
526 channel_idx,
527 already_processed_samples + sample_offset,
528 )
529 };
530 let input_ring_buffer_sample = unsafe {
531 self.main_input_ring_buffers
532 .get_unchecked_mut(channel_idx)
533 .get_unchecked_mut(self.current_pos + sample_offset)
534 };
535 *input_ring_buffer_sample = sample;
536 }
537 }
538
539 already_processed_samples += samples_to_process;
540 self.current_pos = (self.current_pos + samples_to_process) % block_size;
541
542 if samples_to_process == samples_until_next_window {
543 for (channel_idx, input_ring_buffer) in
544 self.main_input_ring_buffers.iter().enumerate()
545 {
546 copy_ring_to_scratch_buffer(
547 &mut self.scratch_buffer,
548 self.current_pos,
549 input_ring_buffer,
550 );
551 if self.padding > 0 {
552 self.scratch_buffer[block_size..].fill(0.0);
553 }
554
555 analyze_cb(channel_idx, &mut self.scratch_buffer);
556 }
557 }
558 }
559 }
560
561 fn update_buffers(&mut self, block_size: usize) {
562 for main_ring_buffer in &mut self.main_input_ring_buffers {
563 main_ring_buffer.resize(block_size, 0.0);
564 main_ring_buffer.fill(0.0);
565 }
566 for main_ring_buffer in &mut self.main_output_ring_buffers {
567 main_ring_buffer.resize(block_size, 0.0);
568 main_ring_buffer.fill(0.0);
569 }
570 for sidechain_ring_buffers in &mut self.sidechain_ring_buffers {
571 for sidechain_ring_buffer in sidechain_ring_buffers {
572 sidechain_ring_buffer.resize(block_size, 0.0);
573 sidechain_ring_buffer.fill(0.0);
574 }
575 }
576 self.scratch_buffer.resize(block_size + self.padding, 0.0);
577 self.scratch_buffer.fill(0.0);
578
579 for padding_buffer in &mut self.padding_buffers {
580 // In case this changed since the last call, like in `set_padding()`
581 padding_buffer.resize(self.padding, 0.0);
582 padding_buffer.fill(0.0);
583 }
584
585 self.current_pos = 0;
586 }
587}
588
589/// Copy data from the the specified ring buffer (borrowed from `self`) to the scratch buffers at
590/// the current position. This is a free function because you cannot pass an immutable reference to
591/// a field from `&self` to a `&mut self` method.
592#[inline]
593fn copy_ring_to_scratch_buffer(
594 scratch_buffer: &mut [f32],
595 current_pos: usize,
596 ring_buffer: &[f32],
597) {
598 let block_size = ring_buffer.len();
599 let num_copy_before_wrap = block_size - current_pos;
600 scratch_buffer[0..num_copy_before_wrap].copy_from_slice(&ring_buffer[current_pos..block_size]);
601 scratch_buffer[num_copy_before_wrap..block_size].copy_from_slice(&ring_buffer[0..current_pos]);
602}
603
604/// Add data from the scratch buffer to the specified ring buffer. When writing samples from this
605/// ring buffer back to the host's outputs they must be cleared to prevent infinite feedback.
606#[inline]
607fn add_scratch_to_ring_buffer(scratch_buffer: &[f32], current_pos: usize, ring_buffer: &mut [f32]) {
608 // TODO: This could also use some SIMD
609 let block_size = ring_buffer.len();
610 let num_copy_before_wrap = block_size - current_pos;
611 for (scratch_sample, ring_sample) in scratch_buffer[0..num_copy_before_wrap]
612 .iter()
613 .zip(&mut ring_buffer[current_pos..block_size])
614 {
615 *ring_sample += *scratch_sample;
616 }
617 for (scratch_sample, ring_sample) in scratch_buffer[num_copy_before_wrap..block_size]
618 .iter()
619 .zip(&mut ring_buffer[0..current_pos])
620 {
621 *ring_sample += *scratch_sample;
622 }
623}